From 113e86bda23586c708b5c7ef22eb04c0fc72d36c Mon Sep 17 00:00:00 2001 From: huangfu <3045324663@qq.com> Date: Mon, 27 Oct 2025 00:27:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E5=9B=9E=E6=94=BE=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E4=B8=8E=E4=BB=BF=E7=9C=9F=E6=A8=A1=E5=BC=8F,?= =?UTF-8?q?=E8=BF=87=E6=BB=A4=E9=9D=9E=E8=BD=A6=E9=81=93=E7=94=9F=E6=88=90?= =?UTF-8?q?=E8=BD=A6=E8=BE=86,=E5=A2=9E=E5=8A=A0=E5=AF=B9=E4=BA=8E?= =?UTF-8?q?=E8=A1=8C=E4=BA=BA=E8=87=AA=E8=A1=8C=E8=BD=A6=E7=9A=84=E8=BF=87?= =?UTF-8?q?=E6=BB=A4=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Env/__pycache__/replay_policy.cpython-310.pyc | Bin 0 -> 1990 bytes Env/__pycache__/scenario_env.cpython-310.pyc | Bin 0 -> 11180 bytes Env/__pycache__/scenario_env.cpython-313.pyc | Bin 0 -> 20798 bytes .../simple_idm_policy.cpython-310.pyc | Bin 0 -> 770 bytes Env/replay_policy.py | 62 +++ Env/run_multiagent_env.py | 362 +++++++++++++++- Env/scenario_env.py | 313 ++++++++++++-- README.md | 407 +++++++++++++++++- 8 files changed, 1065 insertions(+), 79 deletions(-) create mode 100644 Env/__pycache__/replay_policy.cpython-310.pyc create mode 100644 Env/__pycache__/scenario_env.cpython-310.pyc create mode 100644 Env/__pycache__/scenario_env.cpython-313.pyc create mode 100644 Env/__pycache__/simple_idm_policy.cpython-310.pyc create mode 100644 Env/replay_policy.py diff --git a/Env/__pycache__/replay_policy.cpython-310.pyc b/Env/__pycache__/replay_policy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fed7d1be07dedd8819c7dc27fa593d9d7de09902 GIT binary patch literal 1990 zcma)7U2Gdg5WYX#m$QpPC{==ln1fSjY+}jv^>J$z~%chQvuG8%146f^Q57U^ihWhHy#El|R;1 zC119Nz>JiHakRhbZq%LCb$8*=t1fwqnL024z?T5p16w-fpb!y zIleXT?EetRHx9NMhhKVk_uN~b`^NIr&DNka2lIWH#wA!PY=x1T3I{8TD9s@bT;r{u z;n3Ho4h@omF%a%32Y*(=#FR~vyNUCrPRaw|b=x(ky6jBtgKN&rlr#TRvZNb|spy*V zMzV-=NYM&FjbT;K0~=F1rz+^w4P&^d5X+C{7ox zoK_gNMn=*f4PG6*_Tig@mZ@JM${0?6ppB)8kD^@iVeBuJZBdpLO)+Iz1S1CI9utj5 zS&1EgwG;6EmQ*HBF=-7-hZs=<2#_gIR5(k8Z5VRn*#cnh+}%DpSc3#{*0zHbaJRp8 z>T^f?Uj+i*&eK39RCH?V#}B509Pt)@@n(0O?_cer1C-s9SLX@UE|@^V%DEvI1H&V``DRVa%XNkwH0@H?|Au< zTix_lCcV{~yK%c&IHC0U%%|xykPJ1YP;>~J7SYbs2~zUrK=Mo#ghde6KsgB>w2Nlz z3wM3ZUEXjPr=95s0gTCvuHo=9JK|SN#_kGB%NARuMf1K@T9D~tOT>;fC2X-NP)>0+ zuVRhR#<0_bSD+!D#kISDq1z5k2w4P93OVIQ*KX1KF`;q+3|Ty--56q;fa(Cfr%14A02SzIyXVdyYM5u-F5`7Nd=F0j4)c<-s#?+{F}nUqLd zla?FBvTRGTJgOrLTC8j*WcgDor#4SwvW)b+ER!yJZJaLjUqO25(ssqCXwef|>okK5 u8idFQtSCs57)xC8Pap*&nl%c2k<&^(=XpkK@$JB!cW5U!Xjc%K9_~Mr0euGm literal 0 HcmV?d00001 diff --git a/Env/__pycache__/scenario_env.cpython-310.pyc b/Env/__pycache__/scenario_env.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31f5e12183cb60ccfe1b7d1400b46f4ad13eda37 GIT binary patch literal 11180 zcma)C+m9R9d7m4H!{KGQT&}b$tsPpH6w$F)iS5|2BFU1iU0bQVP9)i`2QdZ|&+P7S z$sza5>{?3<8zpPkKoSQ^(P#N%}V`jQ%tf&fxKXfP_g*Zc3Jn zx6)Kvs-?$9(re)%tYGzwGD<|6YX1-Oh3atri zqE)ntqAuN>Y?Z81Ys#7uWusYcOi0AGl2k3+!fK_ZsvJ!(wVAWz1h2OF+EQ)PU8x0b+e1V4Qr+=t+-)y-n^h$$oO70I z8_nQFXVtAYooIS9&e-*4&G+q~Z3kN&hx)M?kh%4ssz#|7Jd77<=TMCFv)BZW6~w*9eJsmicsQ$D!t#4ql}$k-+o8> z&7;Rx+b!q#>PF35S>9M%JO13+XU?9#@aWl%pe=BD90z&aPY%^~yv_N}7B6BK1=@~* zq$KP3n{s#G7!KwK3$7~w8vH0v&(nB*O*|6H*g6n!N#ZnORpyiQQg|6Jk>|6N6UhPO zB9jh2#sR_nnuf#dGvl-G0^N0fevk%EW-TF~%nDL){JpbI(10Tt-%ow-LDIeF_K8Q} z{(rMijn2O5vc0oU)9kk$F`*+~!~M6f;Wp@{`hW${O`<$KbP#?8NhmdN3>z}fQmzmB z+yT!cuk_WO;qhqVNFudX53C##aH;KgYIVob{aUNjbgHU2IV;<~;yZj(Q0JYFk<*-n zKVL^;$~oCox|8GE!1%YJZ@A{C(2Z6K@&O4o6uecuHM~=J>v*R_$W-j;0Ln%`LpzcI z-#6pjF41Kpqto_X+KrWmWFq>Ci;)50(yqI~R?;vmru>fMFv}F+A|It!of>n!6{{#J zlakF`ZcT}r`UdBY7uYf4BE8n>I3BaIHXUzVnb08G1akKhLiUQY+~$t$2TtdH!XA;Y zArJmJxgd9E$FVrlV@xiQjG{+S012oeKE$NpM8Sf4xG(McHfwv%gXrKB>uE}NdVInJ zBIYch*0#M`%Yo!FZM)TG8%@gRYe*s43t zrfrYlMR1~Eg*@Qs@ku@zQcgB>Lp4-WHS}VJQ%jUi^h_Ovu^&+({rs;X8K?(G-Bl3V zvZaXq0DqgXDKJWtjtQwu__^zk%xjP3#}?cM`u|Ms={P>R=_*UCY+p;cXX%V zOf-sdPwEXNGx?fkO^O*xYzi}!gnF5RdYO)gXC8s<%7xWL)kz zBz?JC^DtRlRo|*NT|e-v8Yc{4GqJHj8(4N%B7LKS?KzPqtPdqp>rE)kxRDF;8fqW@ zX$<+=EMT-$+aQGu~uBt6@Vs=hz$AQLTeDd}aEbpzdUFJsyq`6Hvj8eS4 z;jw5cQ7-|nH3UhxjvuKo&%^}L^zgV^wX62v$OQHUuP0i9KZ(^uIYB%RcZ6B97vIFn zgb1G$G#MI$e*w9o+`aeCI35@uOJoE&Y)p+zW277&f#g2{;|r7lxeZbm63+$7wm>*) z)V{`mR`CYs>?r&|kZS0jf;7FO^pwB|m2Ie%AQOV8wiSLA&@joNcCLy#J*msaWwfPH zb1BFHF5A!VNZax`>6OWz8j|`l8->Qiro=A}^8I3{5`V6rO?pfwJ*Gad^wgf#H3_m` z3bn@6icoMbN$Zvj$u3JTBPV7l@1TU-bU(@U_r*2+{R5>a)-3Ta2Q$prmcK+QmG++5 ziB;&2cFVm~ND5YE=1CfppkmArjNG4P*)a88>1`!ewpiPNWUhnBD$K@Q|DIv~-o{57 zmBxJ_-9sxvfQ6%MXY6VR$`@D>AMn} zdOJSZUyQ86OL~@*vAy3rb}JcsdTi`{V`ImAc`F%ve=_!4zjy5KCS%WxjXgUYTVOWE z>UD*G9Z(0byMw@XAuRNd@8DGUUj`rVeA5JHNm4qYsvZc zO1mrQe-c`F3RF8qdwgAC_bn^nov8%k(5yrdN$LKv5^zav?C@JKY#WEelHj@fL2spS zBAnbEA)LUT(VW5oGmVHy-)4NasEgz32K9e1 z%a19dU#sN-^ik#>!%(JKl=xIs6P%Sd7Rc2=?mvCU9ODGG%)t! z9tfWVWZI=RYfYEk`Vc>U9|qx%>=`KuoK(68F}>FgDs8VqrfQ`YR0hsqQv=dKl$66Kihf#z3<)p@lFL58)TQ>{JX!n`S*WZIhGtn^+e^y zwfAm(^QWga$k=_bg7#Zq{Kn0nzj5Qe@1NSph`L|?{O$L@`JGc6> zCpW(R*SEg;b3l-ZTeFh zt)Ageqi3YY0YOgF3WJg_P<5KBL84LW+`>~YJQHc>mlmJrS<$pw^J_uC`5#c1GgPGs z<5pX9wtRkpD)QuiZ#o0VM(wmaQOXURmLKI2$*jS(5oGT3%LJYhb^I*yR;F3&Zi#4} zRkFE5zWqwGeWlifoA0p5e6iNta27ak^GI{O;h)=I7LT4yy%5%K!bR2?h5;__e>3tDaq8=jxTD)}NU zbprPLGWlDH0I|x0fd>u5($<8}##}N^=ChSfD9$o>NykbjB*G8VB%eWoV3D^KDPG6Y z5rKq15UEYai&V;$Z3GJjEMy~=M5~Updb_h_O$O}_LJe^75Hn#8M-Vhckq4iXuO6kA zo9$X)O*kFbhwlg=1UTWUsIE6sicaTOwj^rEos=H4AMP2cN+Dxy_}Aqbr9ghOYN{d_2q;6ILms{5 zIke2tcz4Lt*fa9HT!!CWQcGG%);=ie)Y}9!`G|_FeW-n?XgRDDCAEN+6y)4*btCrO z3-TfPsBHYkG?iVL=KN6(-28BJD@j9$vSlwD&vm^J^Zi7bbXUJcArPkBnGn)M@`Ir1Nvh-3s z7WGGcBS)TJJW4Mm#-Q5#NUa{~VQL0`RFG~I`V(OqbC4I6S(A8s{WPpRnS=ul4a>=M_n47^q$bSnJf*h=L z{WMmbx`exxU~=mpWcVDi)ZIUZsdp1hWy8Y2l6(nXF$3%$6{v5VdsTi_s^*sXqc|9T zoRU>aNO&!FAGkn*RG0(_8@XMFJy`MA+)jr~AYmocSDpIWWlsFD^5%d1RsVnf@u#On zz@=NF{uR+)aY*Q&s8mb4dWwI3INu*4XQd&Rac1N8wc_3+GEl zaN!!lMKXgEO9aWJw8aOIiLuV6kGE(uaVW{=f7#pw%?r3zs!7{}1qJ-EVts*V~DjSl^DK&C^u+gwq zp>R{uQpCKt=!Af|2Dr(7iRD5CE>IpN8Sg^8R_sd%o^^;OI4Z>@{sK`;c0{QPlZd~F zYGKIYCbE#IvLxafo8V+e~<&*1$*0s5<` zRDDdawng+sgR5L3z@}zK> z1|E5mrFxoIL`g?!ks*xfWl&0^^c*u#QfOuwA&Yeq=tjRlfyl-L(I-X3F;}rM$#UQz z-1CWcC?>QcBsXja{i0oP2>KV3c7$t_{&BkzOpWhmTNX74<7~_OCA5`hr4`EQ#N~J+ z94X!iNs6~dt>~jP%G3i2kCgQ^6di?a4Jv_8rjq#keZap^6V}w_i)R-X&p&NnIRDJE zm+faRE-WnZB{W;*@s^htE?jv2C2MlL_Ts`hJ`ZSE`Yx;STc1V#7F|lfD>dnM2KJxB zef|X+F?R)pfjt@JJF8p1TlZ~zzR_e>-gjX?Gba#%{98mNw+!l4iX}XP4Z~L9_geYE z)n>2?ufMvFKLv2%NbwqV%#bxL9Pe0@UZrMID!fYxxm~zUbyrq{q2nD z76m8ZLy2XG0@>K$=18Z(H{fZDiG;&s8L)!u4tzHLx)`5r;c^kKO}xn*^%#`+E3^!{ zQ4oAyHDjG9IGLE36Edy**mC0ae~B7QfiQt~Yho8iq9Rh^LF{GwfO-+zUqnWiOHfmq zTu^3&i3yVusw#doMU$2PF6B}Fut*9G!k`F~n%>7W2#OB}58ysD6zzk8rYqe;cM%FB zPEC}JFO+eIx93)!A_An{gjATrDnpo5_!dBZ9A9eK&{ln1B0|$mLN?&aLGX`)D~dT# zHY5cCl(@POv`&UB3{w%vlF_3;l7yfD{qPW4r}lN2wEc7v@-h$}#Wf34Un5zhoCD|e zQZRN$%eZM3yqrN;lP)~63^mYIRH9zFwGK9~^dKncFuE+sL4#>R14__+ z7>hwVt))6W7_2(WorK04n;GUX%u+oxmNoEkA?DFBxk^zX4G&p=zb2}pK;SDWVHrr@ zP}UXn!|h%26U6b35-IZ^hiiQMw;tw6aYXG_=k&L58shtwl5{u0PqBgdlL#jvu6BBU zx9eRO$iIqpROQ8(M&CqXFSTAr-M29_|2i^@B1UsrsJ-sYMR8eMS%xL-uR<|H8`AYd=HDQA$KGhp=p!wPm_{ zrj>EhxSaS;TuhGG-HTKwG`d55a)USm^{QrKf$$ZA5mMm@<-Uw0N+*E^-l1+2$@TRB zBP6tz+qVe-TGlnJ-zSL;PZ@qKdCTyIk%Qk0ztd2Q3i-ul_?Zw6@Hb}^*Z@lRz@3?! z);We{6+;sh(%UBoCIA&O8II&bpmLc%6c8`UeSFO$t{mwCQRsTq3a$m09uj#Sx-1=E zCax2Q$bE3AxIoOrxxhr|I!krJ22_^BF;sK0Gjj%W#}ZbYkdXFS5=F5c=JP8AaG8=i zC1)ui=a`~s@C*i*!@J^m&mH`?0Q&>NW)Ybp?Eg6SKwf#_uKu_fI{qf5%#YEORZ1Q} z5@p5bh=^dXwwbu$k3UEuTZwZ@guzJ8$Ot`=isU3b=fND6k5E6ru4FPKjlRpV&oX!#f?nCgs`f9Y{w?j_+m9l+$)3Y-p6XYJpeC%)F0AC2Z_ zY1qOph?tOKkFhOJ*cm~sKr*2=GETjreVe{o~3Av$hM=c`(YJ)_cguB7% z?o{ne3a2J(+?y^#W@?OQcguJ-snD~t6H+y^>rS>}=g&mfqWzQ?r{b;cs$FmGR^hg@ zNoK2d_niCmBq5&N?wPp-_ubEP&%5{BbI(2JJUufQvKa`fp8q@Z$6XBbUoj&EO+4}J zZy@n8<7AwQL55SnQ#q&{P;n|2%T$Bv0S%|2Y4u>nKqi+tpyjkQuNchYvS`0-E*qYj zLEV6!)6=?)LBl`}mos4Gi~}alM9VShP04{sl7k1AEr>A|KU_ZJ*RXkIF*ZOQ#h5y%sX06?Nk*roW`k!IK!!d zIMbQ!%z!!tr*&pRKFg_vINO=U>0Hcl1*f+$cAa2&Eja6+T#Ucv69bGgsfdq>4_tgB z{L}GIC_ZK;u#|>T82>Fb154sk>}Ooc9>rk=E|-g$m9Q3b@r7c<9P|*wG;>EWV%G?o z(84^-iq`G+2mE2TTgaHFIN{8O!u|kVAR*WH9<&O(rDgDY?P@_C^36^OY7aLZ5;88m zOP^Wp_!7B8I26mW-xS(fW`c9RmYD@lV0vod(xsN+?*8txLv7s);UJy&7FhI_P@DkV zzQE`-+#3L>?=woKgGz(8LOfy^_r0fbVA0JAt5FRug z!d@)PmB2&DMq-LjSmAD5lQsAzp-hOr$;9@Rd`HKXLHesBvJZO25fN0dIJ*!r;&7=e zR=O9PlaU1;UsXn?Dy`%DIam#LpVa!Ud;A1Do(+#!Ul^c;+LRh zsKzv1yl+9`W1xEnJk!HPEAxUHM=Gp%H<#j2eD~R~S7NUup*nCG)54*6-WDerH6;^4 zBw4FEXVlyh(YHJ{7w~2qG23`^12H#5%}o(~ z)81uHOymCmkT04>U@8f=BtJ3GA~B)R0!b?^Knv+C!^zdBRYMk9B&OG;h%X&(&9vY! zrxidpxpXW-$4#3SQd>y7TPJAegCRdIkzjz61^041tA)cMIOz{B#w!w8O=#Zd^9tFt zxtPwJ@p-)dz_ehZh4Gx*>*oYBEty>4IA0*_7O6qV@XXKq0$xGq#vLLSO1mgS%2is#G^RK4dK=N(czpxWH*EB6>kr1XdR|*dfW7YfOv{SolOP2c3e&jo&v3JR z%)}@54l+*Um?5l>^U6$r$n6aV;4t6)2JvEK=OLt!Z!(XydODIllS8J*A(!EfU~~YZ z!x-TXpqobI(*``j0-iY^oM72*_gv7sFpFuu+x_(g&uqNL=yp%}xlnl4AMgc&ZZ|z- zI2J+KFsj1{_Z@c}qi%?Vd^gNH?9oZLCmiPd7Z<|55CAsB54uU;?5x{;N5KK75nuPD zxP9^5<}kujJ%m7i|1IXXO0}lu3%x>9`bCC9Q;t#f^9;47`3rrfrs8*bN{wNMfr!37 z?JZ=+C7elskO&#^)Uz!}qzfh&)Jf4qVoscfmVuZ8;!1`y$EmrH=~Q3PwkdJS+my5x z1s3f)Y7knopaqC6kjkcliw@;JP7jiw;X=-Z%nQZ~CP!rgn?yo_Cc<>;0D_v!qp(3P zgptio181QEP(BDiMj?L~ggk%uG;EF(0q_Yf?k+(H=a7=wMM$IxLV){QFrf6N0fkH9 zKm^jlM1r@VEbLS|RZg{2gYvL5+a_t{T4U1<@;3SrxMs#%y&lYgw$Y6ojx?W$n}4tb*+4 zOjed|$HJKQEsJzMiFil>`3UBp{S`#G(qn8=Dg!|`+ zK@&h9NKa`2DdJEZTIxweAtlbU36u(HpaQsk_sKKO*wrH(0^!a>)MwX;vOw-$0+0=XZ@kDpy7fGa738nqEVRhyhAC@?+XbkxX#c4g!}|G%z55(OB6xrf$}!H zS`=L00&GASLE1bpPvv`&0&s>v;E0m|3aUbxJPAmmcMr4);mg6|$M#i%m|anYXGmGh zR)f!qSRqQ^hFJbScx1<_K=|I)vN3DL&5pH>^|?(~)Y^$fg|TYdp^^&lyH1!9yju9Y zyF_G^7bsnP01^Sx1zi^?6a$?32}+R4fQJD8ok*;^#352?2>MC9G}(=DWTj-{Nhk1` zS8fZuMWQ7MjS9Ym-Fz_vE2?}g)1J8Jfpwj{1XfMv6YUdS|k?0rD zjJ+l|nlrl=Q2Pp)$nhAUW`H2W2^|Ne_f9^YI+rGgIhevhs;oZuaHI zc`XS`N_C9ExdqfT4A&P|mQ#;M=>(i2VsMN~spJ!TP~HjWi7V4-XaY$qkO_Tru&>jI zaUk23?EpTV9;7#rP7@EDG!p<%430OYRPqVR&unQwo>MqY0o^M?hWD93nTUoY48Rvo z=r11=ZVimnbU=k~BaM$eDGqv<&S|3a0KCDa7t?eW{~lnw&zEMS09UX?@TInC0ji4}Mk#$g;D34`R*Wsu%mYU_no zE^U0pwXO`uWU{rCa%rSq>_3&$0kZ_mzCtBuAKO}(kEe@ ze9Ep-j!WkJJHX*kj~{!^io$Q&~}+i z3vv0_d&rZ3%2Q`be$U9kVqtz{I21sxIG7^Ceg&$il%a}0rtc>+j#JU`T{wLfUPO1! zDx}GPQ65WqtejQ<3^Eo-*#KRR^T^Gj&1HlctCq(SYHALq&YseR zywAjCspY5%bb)@Dt(4mfKwOmSlTYcTS)8LQ$#N+r)!Uc~xSy3X{TaibFwAjPIWwiO z*NQqGAqz~(q8VAxczqWaru&2pUtk)P%iP-lNH9=^i+j!!fCRTxe2NQR@&(w0y_xj~ zK(!3Yzh(}_A<+7TJz!)_=yd!6ukXF35?2Up3#_COWTk5MY)=O(sL>GqM1kEksHgS` zn#-P9zxT=igTUc!hZ}`xbSbZoDk2e7+=|Nz+Bqj|@xRfOeG)aQ`S!Ty^i6@%U?Zu1BU*>OV&=&866v=a*aff*kMrsaf@PqeGp zGXy2*G_yi8!FSzZ{~REG(Fb#v970fmkz}bLZXIXw?LvUtXfW)V?O>PmjjZ@eJ}=}9 zjjY2r=K&*jU>dRo&bXSC1;rREDaN281K{}PLmbMLoD-s@{4=7emKF7-UKSL!bL^6u z4F}mJ3tN|%ZwZRseu6{cM#vCd3OG2~7!Jir&WFXBSd8W%O;7JBSHGY>JJNTK)6=RM zPYBHj9Bws{NYqr_r@rJ{3~{*afNH>Ic-bdul2!A;c|qe3`{qJ|0em4mpi8D}72=R? z1r06ZK>A{UAU#Wq)Q~Uaxj7$Nm!@Zf7d^9JXhNISdC%;Eub1P3oS^mxrh=el#72Ta zY#Z_gLO~9c*I`fS((p7iq82CJoM&DlVnHX_vY=4VQ> zV0zKdg=YYvsUX#Aa$bDxDMZx>8qFzGaL0FtV37bz6_hTGA@437IA_4lndf`YoIO3% z3r5z!qM!`SgGr5>hLsjnv%Y|!!j#1gh9!wY-C%=7v^Zo=+$_RSPX^}~h1_s(9;{OD z24K2iw{hSKp+O^6h|9!AnyJ~KCoC9!^ZpR%<)IO{Ie?y4MeJK>;x})n0E_lEl zo`Lt9;=tr3nb1UD%ZoeuiU!wp!UQgNYIAy^GkpczX}A_eAY|4{iOc z+L*Npf$_B`N$p9#c7)W9d}bYqm9RGl*9Q5LL!{&oUviq1oaRgVNlE{G?ZcASpK6%u zGm7WgOnvi)mv8PN%{_ec0BIhGHV-|t4--rE%9$OdD!(*VSaMUhrsE4+Nnz_|`;R(4 z>DX+1Sl9&s>PHkidL}oYH`NeR4R2~Brp9ekbIfGnO?G0k^QIPJYS}imf&ukf&ea@L zvZMXBoM_9}LFHbZGRt`-kbtFUL0*emuIxZH<0%CE9fQLCNXJ*qbX^kIkj)c~NuK`r_uPfBoSvEhRwN z&8?djzIBka4&EP+jKBHwx1voGE9z}i{f?H&FJ7&^KC<3HEX^BBTdf=K6U*_KrD$*A zjz(q6i3~V7s9AsaR@r7Ysp#Y@ zPLhg~(TYQ9^)&!No9Am@^lQKwXc0` zqo0(v@uf!zkbY_B(=1JU-in$x)e%$OMiDU`+%eJJg?Vu=#fqz96%Db{rdYXc{f)H` zHvFXg$W}2a?}-hJ{G#FK4Zmplc}wKYP-J0=AGksWuB>Y)MoWb4*y?z|_U$mtDP^BB zR#LHEyEd}XK}rs9F5a!!{01rMx$C>%|7jrNa+AKdBc696vvU#89O(;0tih*-%+|c; zOlGbb45hj8jlBu`>e$mc1yePlc*ZH1k~W}J*txJJd3`O>*Yf&fM1PFepCtN|cdZZg z!%tPv2uLEiy;s}UdwE?w(bXqjg!%T@N&D-O!;Xk!EK)kYt$#C(wHyG_&~*@9$5xo{8Yf-j zkqd7{yWWbGSKYj_c4gxzDL)j^m&B^=e03M8?pp4_cU#RO`pQ?~Hb){y&qj_6+_&8i z^XL5JoIf)1PUM{rA{8G3!nxf_;C{TWg5X~h^>wKPz`VYW=%V^}THcE0$`Z^RjFz|14CP&AL|evdZA5F^*4EIH3ZkvxwY5ZByREf$Ac^7lJ)Dn)+juL?{4A%rsG&vI+tKCFCvSo!ub z%`VG0UUyaS|Ln4wdTJ>r!JY-T7br>(a?Wk&^C0Q>z=DX-XaHMb!tO2U@;o4A0c07y z#Uwb!e1$Q=!$KM}0ozI;Xr=L!PY24;X~CJIyO_?LKirUmU2HyV5-cNpv!X#`dCB2n69|6rj5;A&eriC=f8Dl=wK?PEAm$Sl_r-6a2 zD^r3M(1Kn$t0OONi}-g81DY+I)vU2Y6nb_F{_I{zmCV>3i3$TcXSPchQXW@!z*WEl zROr%TirOq?ss?Mgr^S}PX z`tN?Wd4j{-QXckaX+7)1`)LPj&r4pZxqpE}(eXZnl#mIxFudq<>sStId$_52IL@;u zWPtKx0dAja98fUK2jQY7UH+(ogbQ;YVEYX40rbpH3+h0Sn*)F9U|7iNeQ(k?Pd`X7 zikCA~p9ongK3dY%$bG=e{T)_<8z^{jQN4|zrD=5R6jPUdlY)+xi5En428G5@?8ZSs zO9OL+S`@?)8)wgRLDa%fcht$*UtBJ9C!}7!8 zJH^D_`SLTl3WVZ_?X_*w>)@RH`Tm2OXa2q6t|Gd>`@0u!%zQMn`rd|eBX{*1+xr^- zGk5R3pS%}&Z7kY9zO36(7iAS89@%Tze4(8b+WEqU6*YP|AH9B*H`fz${YEEmZYSpU zUs}sQFD<`0_Pw!n?`{7r|Lx$dV6^n`Cby-Clypi#6-a31G_||aL zdSLUwW_QGTcm;29#JYb)`>2e)>0k5T46Fq<#-e41S9FifW$Tuxxf*2c$2HBQ=2*m3 z6)UMFC2cENk6^s>-#fp4Hd=amCHt}2vYP#wxpn;#Z*C>#)|jn1a^iK;`Z}?_zB2r% zyyE8K+TzU*KybW7n$Jed2UgC+3d$cAH2$)@>ZzS6YTRjKO!>cmu47C^&lnZF_)Bk9 zbWUrIY~^mLBDQ1OrsI#yh5u(K6Z-u_2x;{v9bNs^%uhQ`HN)d)jV%!0V-*-zDlx9k zK66BMuchnELDl_(8c6SQvIK@D3K;(r_T-CP!Re6(+AnmobShAa0A_`II4vX$!Ei8x z>y*MtjitzF0?Pg1q657^j7bl2{`C_HM&?oh95o$N;nd(0TsPHlxTcq8HY>rrqmUIS z5^f(Ck@+i}S_)6&)Kb%cTFx;GK6^5Ps$3&}V*^JyZosqv?;|60NCyDRKreGz7s$f+ zSvdZ#cKNdHb{U&lX6&S#a^0nj8#rVReK`lJRQEA}D7PYV2Obey`vgJT%{O#EmxIZ++x^CqJvq!FZBthYGw}AY#fUb5=@E%``eiH zU_`Z@s8r^@g%MRz;{i%hvuDslPEdiOmis=IMlt#UM4;HE%46}Pt`JJS;&~&w3jZZG zX@0>m^8)Rs22PofUp#8?%an%C%ni$@W5&FdPGT%yZ@b-btK&0cQ>?UZ!@Mzir)WbL zv9=uyd0aHIWt`)WtjTtmzUVl4+YJMMhAGPGVyERBy%1vd;I z8HlNAUH`z;1ozh5Lf%wKOqGCI=dI3dQ%lTLz?*7`sdhuP(R(L*+jNkYG(n$@$xYiG z|F-GKIM~H`V;a2wD`un55c=P9P-XDsLO-8)ak(nUz&POJ{9;xAYOX!P! z1!HO4`>_TxQrKe5voGiklKM75Rt)kRNY0Z!<#wYy(--$N%5y2@n77n6paLH#^iW7J zb7@?8_8|2GTm3yaY&0&79AlPc#40EKrUC9Ys^ncS!D4_(P;Vw#cI|*WMtX4I?n(C0UGeoW`$&k!>%Q6r;?qfRRa}lNVs5{b#{QJAu*x4KCSTva3EfRWQMO zF!ag8vK9r+Wr*S*O@oIc*3dL~IwsO;v26-mr9;{Z6b~_y-X8l0vS8j)>ZGKjK}sc` zt^{qd;Pp5Rcjq{5CYvNu@KujF{%r<5tbYtOf_c=@-Pd>aw0r1m|Cv#DzoWNzg!?+y zSa#Qp_YMu6dqc?GUFzuV;n0``Hc!9T!##No(oax00_%7-ei#JSH6It^{vO-tFG5DL z$idZgW-;WS47uS`_*pM}mF@@gqSqIuA9*}MQTz$AN+D0YK!Q&}qL$j<1%vdNVR#0N z`1XC=MQF~wjnOi8%tB)$^#~9z&VP$Fe}~b3#ApuN7-s#`GvR~@HW#eZZpr*gtHAJ# zJ_uke>xB~9mD68K5o(n)@$y+Q25?O`0@^nyy8phv4@o7 z{s_@PmPy4Fd$xGxrh>>1u#}?^cA+EUkGEF%@){?MPP<@C&XHK40t0IZ=P-Fo79|(Z z9Wg@X3nZ5m>udu`D;MFc00X7?JYYQCbKRJ8LI$4SL;zM_lhOP9~YELQPkj#_G$ zd%?`KJp4HI?)q@=-SrTdCK`@yIkr5JhLgOxJ8JHZ6-^; zyK0D8>mrTaQEPXsv1x5MR$3ov?unN6By-)-QnBuFW7F#JFD*q+i|{g9!hppMY_7be zo>=NP+M||(yyY0N9J{ZKxZa3N&BcaZ`^DhT2l*i%8S?Q%Au<%&Q8D?)!IqQ1pjhsG zTvU3~x@P5zT1ZjLrZHM{k}o<-iq2y1i<7Z2cWnGEe%wpOz5F;w#yNicLo)v1j)o~7 zP{4{9ssX$NbkP_8#Vu0%}_Dc&X za#dSga;q9+CXRJxVJnc|p%4C;saYW-9x{W1d@rwTOVjGWvk9kHH&Md#%eEf4TL>|RR{aA6o@ z?`vbJjC9%o$Bbu;N|lJ)tB#x@GcrspK#-AN9da@qN?S5mf&u{Xkiitx0V&?!&_+(F zA^QfR%BBNOs^bc2rg5jRU`v2+9;EMhUTIDDGbwAD1>XksjV-pRZ9pw7{0B5svMIYx zlv1XpH}0eUY@_sEzf|a;UlOxZpo$O8fabvyrcQ6HplG2kZ$hS;D_u(6jv* zrk-O%l?Q#D;x1L(zr`YB{3B=yb(BA*g7RlMMX*@@EXSc=$20tIA%Zg;+)J{*oxS3@ zYWxnmj8ru8Y%5_~AF>C*_zkDHv4FR>5^L*c##VZ1NB-fH~=WJQ2@L)ZVLunz#5!WPIAhlITg#Q$F&W7 z?ePb-$D?^CAYNCksy;eGYL72xM)FR?5|t*ga>e&iF;;3KrV{|k$}St>Cqa0f75>)M zw{^F4w~e=qk^TMm^Lh3hVb6i-`P$OerPWcsw4Ic;M~&^V%9<5p>aT~0dI!{xNO}hw zvDhNDU3b;@vOmq<`o{g%`=k8uG#Q?b49!HU{o9sz9;2g3Be6Das(+;WM7MR8>_=}A z-a3B2Wm&WAhlzS@tK)4Q4{RM#Qzyi$&XtQF6%bno;@8!xzO&?y*oFki}0+_5Fv`KRv4N?@-^X()V|%cj>KhqtPRB6Be8PH~>7w0LE=a zx|Y%NHdY_o^%vQ*BM~=1Cn|K4lr%{crRCdBo1I3#aFJk z%&`V)n1q=)Io__%IfoL*oy;7(?1%#fCwttyyw6195TKvTHzkiuFn@yXbcAvqyIwT% zy$wi|Vq%YnnulWI1g1-Ti_%SN+ z8lVCKd69=Ag|Y|@LLWoN-@{KF=<-(DB1Mh7t_d7RHf{WYGw}DQ+P-1q>rayUlYD(2 zsqc%{pLtk)mY6C)X=#qTj~yaqhoWWepPAbeyQzs-nl|43;nJO@joF8mPT&bOAo^y2 zsL1Pg`=j8mu~ZX%wfJi++xn)sL&Bu50dM1rA_wK}V00TJFGdNQsdzoi z$FlPnO=3hjH|{XFKAy&c9*lZ10=*vU!{GuO0w?jApqCrRVq8lONs-&Yi0YvK0#nJ; zAq%tF82u4Oe-4q58G!3C*k|df!r}CD{aAShBe=Hj`sEa$Plm&&4J(!)f?rOFX-vtzd+L*K(FDrgEag;ja1DczY?rT z81SQKq2`3QFXT8iwx~b$o(?D|$FAIWt&J&*DzKKb&XTXn=P=j*Z zp*aH@qv4z`VPd6(g5)^j2Mo=ruI+k+m6L2j2SnZ@S|jLsoy07PP;`e0+GFBR@SNaG zu-Nd&ghn6&KR<=yS4`fo8TYT5gTGYmG4ZD|3XgJZv!+9>n)DuJoV}-q++M@Adn>RC{}40`d64uho(g@*Ow*2w?LH(jEaw zB56Su^pr>@gBv1)YqnsLek2j^;GE1;#D4tke+X3ty^=u3$oJDlAh{gekOh}Rm=U&z zCzZA?(e9Hfsx)_-4^@$;o9+oTr>r*@*O(`e78OW(NdyptF>G!?3E9#%v`*DRH)X_H zZdI{tCoVCWa>CY1$t^JSYwUR9-vS52a`WwTz#?`Nm8x9^_L8dtqG`MHYU==0d- zWnCzd%Tm;y)p5OPLm_gVJ0Um{K}KVG@#x$%x0zOO+kbyyaahOHQAviq=d@Tyv0X#B12h+8$ X^g8_@FCcue`|C^e%A)#^+-JW5WlO5n literal 0 HcmV?d00001 diff --git a/Env/replay_policy.py b/Env/replay_policy.py new file mode 100644 index 0000000..f52d957 --- /dev/null +++ b/Env/replay_policy.py @@ -0,0 +1,62 @@ +import numpy as np + +class ReplayPolicy: + """ + 严格回放策略:根据专家轨迹数据,逐帧回放车辆状态 + """ + + def __init__(self, expert_trajectory, vehicle_id): + """ + Args: + expert_trajectory: 专家轨迹字典,包含 positions, headings, velocities, valid + vehicle_id: 车辆ID(用于调试) + """ + self.trajectory = expert_trajectory + self.vehicle_id = vehicle_id + self.current_step = 0 + + def act(self, observation=None): + """ + 返回动作:在回放模式下返回空动作 + 实际状态由环境直接设置 + """ + return [0.0, 0.0] + + def get_target_state(self, step): + """ + 获取指定时间步的目标状态 + + Args: + step: 时间步 + + Returns: + dict: 包含 position, heading, velocity 的字典,如果无效则返回 None + """ + if step >= len(self.trajectory['valid']): + return None + + if not self.trajectory['valid'][step]: + return None + + return { + 'position': self.trajectory['positions'][step], + 'heading': self.trajectory['headings'][step], + 'velocity': self.trajectory['velocities'][step] + } + + def is_finished(self, step): + """ + 判断轨迹是否已经播放完毕 + + Args: + step: 当前时间步 + + Returns: + bool: 如果轨迹已播放完或当前步无效,返回 True + """ + # 超出轨迹长度 + if step >= len(self.trajectory['valid']): + return True + + # 当前步及之后都无效 + return not any(self.trajectory['valid'][step:]) \ No newline at end of file diff --git a/Env/run_multiagent_env.py b/Env/run_multiagent_env.py index a6a21bc..7f8e70d 100644 --- a/Env/run_multiagent_env.py +++ b/Env/run_multiagent_env.py @@ -1,40 +1,362 @@ +import argparse from scenario_env import MultiAgentScenarioEnv -from Env.simple_idm_policy import ConstantVelocityPolicy +from simple_idm_policy import ConstantVelocityPolicy +from replay_policy import ReplayPolicy from metadrive.engine.asset_loader import AssetLoader -WAYMO_DATA_DIR = r"/home/zhy/桌面/MAGAIL_TR/Env" +WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted" -def main(): + +def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False, + scenario_id=None, use_scenario_duration=False, + spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True): + """ + 回放模式:严格按照专家轨迹回放 + + Args: + data_dir: 数据目录 + num_episodes: 回合数(如果指定scenario_id,则忽略) + horizon: 最大步数(如果use_scenario_duration=True,则自动设置) + render: 是否渲染 + debug: 是否调试模式 + scenario_id: 指定场景ID(可选) + use_scenario_duration: 是否使用场景原始时长 + spawn_vehicles: 是否生成车辆(默认True) + spawn_pedestrians: 是否生成行人(默认True) + spawn_cyclists: 是否生成自行车(默认True) + """ + print("=" * 50) + print("运行模式: 专家轨迹回放 (Replay Mode)") + if scenario_id is not None: + print(f"指定场景ID: {scenario_id}") + if use_scenario_duration: + print("使用场景原始时长") + print("=" * 50) + + # 如果指定了场景ID,只运行1个回合 + if scenario_id is not None: + num_episodes = 1 + + # ✅ 环境创建移到循环外面,避免重复创建 env = MultiAgentScenarioEnv( config={ - # "data_directory": AssetLoader.file_path(AssetLoader.asset_path, "waymo", unix_style=False), - "data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False), + "data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False), + "is_multi_agent": True, + "horizon": horizon, + "use_render": render, + "sequential_seed": True, + "reactive_traffic": False, # 回放模式下不需要反应式交通 + "manual_control": False, + "filter_offroad_vehicles": True, # 启用车道过滤 + "lane_tolerance": 3.0, + "replay_mode": True, # 标记为回放模式 + "debug": debug, + "specific_scenario_id": scenario_id, # 指定场景ID + "use_scenario_duration": use_scenario_duration, # 使用场景时长 + # 对象类型过滤 + "spawn_vehicles": spawn_vehicles, + "spawn_pedestrians": spawn_pedestrians, + "spawn_cyclists": spawn_cyclists, + }, + agent2policy=None # 回放模式不需要统一策略 + ) + + try: + for episode in range(num_episodes): + print(f"\n{'='*50}") + print(f"回合 {episode + 1}/{num_episodes}") + if scenario_id is not None: + print(f"场景ID: {scenario_id}") + print(f"{'='*50}") + + # ✅ 如果不是指定场景,使用seed来遍历不同场景 + seed = scenario_id if scenario_id is not None else episode + obs = env.reset(seed=seed) + + # 为每个车辆分配 ReplayPolicy + replay_policies = {} + for agent_id, vehicle in env.controlled_agents.items(): + vehicle_id = vehicle.expert_vehicle_id + if vehicle_id in env.expert_trajectories: + replay_policy = ReplayPolicy( + env.expert_trajectories[vehicle_id], + vehicle_id + ) + vehicle.set_policy(replay_policy) + replay_policies[agent_id] = replay_policy + + # 输出场景信息 + actual_horizon = env.config["horizon"] + print(f"初始化完成:") + print(f" 可控车辆数: {len(env.controlled_agents)}") + print(f" 专家轨迹数: {len(env.expert_trajectories)}") + print(f" 场景时长: {env.scenario_max_duration} 步") + print(f" 实际Horizon: {actual_horizon} 步") + + step_count = 0 + active_vehicles_count = [] + + while True: + # 在回放模式下,直接使用专家轨迹设置车辆状态 + for agent_id, vehicle in list(env.controlled_agents.items()): + vehicle_id = vehicle.expert_vehicle_id + if vehicle_id in env.expert_trajectories and agent_id in replay_policies: + target_state = replay_policies[agent_id].get_target_state(env.round) + if target_state is not None: + # 直接设置车辆状态(绕过物理引擎) + # 只使用xy坐标,保持车辆在地面上 + position_2d = target_state['position'][:2] + vehicle.set_position(position_2d) + vehicle.set_heading_theta(target_state['heading']) + vehicle.set_velocity(target_state['velocity'][:2] if len(target_state['velocity']) > 2 else target_state['velocity']) + + # 使用空动作进行步进 + actions = {aid: [0.0, 0.0] for aid in env.controlled_agents} + obs, rewards, dones, infos = env.step(actions) + + if render: + env.render(mode="topdown") + + step_count += 1 + active_vehicles_count.append(len(env.controlled_agents)) + + # 每50步打印一次状态 + if step_count % 50 == 0: + print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆") + + # 调试模式下打印车辆高度信息 + if debug and len(env.controlled_agents) > 0: + sample_vehicle = list(env.controlled_agents.values())[0] + z_pos = sample_vehicle.position[2] if len(sample_vehicle.position) > 2 else 0 + print(f" [DEBUG] 示例车辆高度: z={z_pos:.3f}m") + + if dones["__all__"]: + print(f"\n回合结束统计:") + print(f" 总步数: {step_count}") + print(f" 最大同时车辆数: {max(active_vehicles_count) if active_vehicles_count else 0}") + print(f" 平均车辆数: {sum(active_vehicles_count) / len(active_vehicles_count) if active_vehicles_count else 0:.1f}") + if use_scenario_duration: + print(f" 场景完整回放: {'是' if step_count >= env.scenario_max_duration else '否'}") + break + finally: + # ✅ 确保环境被正确关闭 + env.close() + + print("\n" + "=" * 50) + print("回放完成!") + print("=" * 50) + + +def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False, + scenario_id=None, use_scenario_duration=False, + spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True): + """ + 仿真模式:使用自定义策略控制车辆 + 车辆根据专家数据的初始位姿生成,然后由策略控制 + + Args: + data_dir: 数据目录 + num_episodes: 回合数 + horizon: 最大步数 + render: 是否渲染 + debug: 是否调试模式 + scenario_id: 指定场景ID(可选) + use_scenario_duration: 是否使用场景原始时长 + spawn_vehicles: 是否生成车辆(默认True) + spawn_pedestrians: 是否生成行人(默认True) + spawn_cyclists: 是否生成自行车(默认True) + """ + print("=" * 50) + print("运行模式: 策略仿真 (Simulation Mode)") + if scenario_id is not None: + print(f"指定场景ID: {scenario_id}") + if use_scenario_duration: + print("使用场景原始时长") + print("=" * 50) + + # 如果指定了场景ID,只运行1个回合 + if scenario_id is not None: + num_episodes = 1 + + env = MultiAgentScenarioEnv( + config={ + "data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False), "is_multi_agent": True, "num_controlled_agents": 3, - "horizon": 300, - "use_render": True, + "horizon": horizon, + "use_render": render, "sequential_seed": True, "reactive_traffic": True, - "manual_control": True, + "manual_control": False, + "filter_offroad_vehicles": True, # 启用车道过滤 + "lane_tolerance": 3.0, + "replay_mode": False, # 仿真模式 + "debug": debug, + "specific_scenario_id": scenario_id, # 指定场景ID + "use_scenario_duration": use_scenario_duration, # 使用场景时长 + # 对象类型过滤 + "spawn_vehicles": spawn_vehicles, + "spawn_pedestrians": spawn_pedestrians, + "spawn_cyclists": spawn_cyclists, }, agent2policy=ConstantVelocityPolicy(target_speed=50) ) - obs = env.reset(0 - ) - for step in range(10000): - actions = { - aid: env.controlled_agents[aid].policy.act() - for aid in env.controlled_agents - } + try: + for episode in range(num_episodes): + print(f"\n{'='*50}") + print(f"回合 {episode + 1}/{num_episodes}") + if scenario_id is not None: + print(f"场景ID: {scenario_id}") + print(f"{'='*50}") - obs, rewards, dones, infos = env.step(actions) - env.render(mode="topdown") + seed = scenario_id if scenario_id is not None else episode + obs = env.reset(seed=seed) - if dones["__all__"]: - break + actual_horizon = env.config["horizon"] + print(f"初始化完成:") + print(f" 可控车辆数: {len(env.controlled_agents)}") + print(f" 场景时长: {env.scenario_max_duration} 步") + print(f" 实际Horizon: {actual_horizon} 步") - env.close() + step_count = 0 + total_reward = 0.0 + + while True: + # 使用策略生成动作 + actions = { + aid: env.controlled_agents[aid].policy.act() + for aid in env.controlled_agents + } + + obs, rewards, dones, infos = env.step(actions) + + if render: + env.render(mode="topdown") + + step_count += 1 + total_reward += sum(rewards.values()) + + # 每50步打印一次状态 + if step_count % 50 == 0: + print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆") + + if dones["__all__"]: + print(f"\n回合结束统计:") + print(f" 总步数: {step_count}") + print(f" 总奖励: {total_reward:.2f}") + break + finally: + env.close() + + print("\n" + "=" * 50) + print("仿真完成!") + print("=" * 50) + + +def main(): + parser = argparse.ArgumentParser(description="MetaDrive 多智能体环境运行脚本") + + parser.add_argument( + "--mode", + type=str, + choices=["replay", "simulation"], + default="simulation", + help="运行模式: replay=专家轨迹回放, simulation=策略仿真 (默认: simulation)" + ) + + parser.add_argument( + "--data_dir", + type=str, + default=WAYMO_DATA_DIR, + help=f"数据目录路径 (默认: {WAYMO_DATA_DIR})" + ) + + parser.add_argument( + "--episodes", + type=int, + default=1, + help="运行回合数 (默认: 1)" + ) + + parser.add_argument( + "--horizon", + type=int, + default=300, + help="每回合最大步数 (默认: 300,如果启用 --use_scenario_duration 则自动设置)" + ) + + parser.add_argument( + "--no_render", + action="store_true", + help="禁用渲染(加速运行)" + ) + + parser.add_argument( + "--debug", + action="store_true", + help="启用调试模式(显示详细日志)" + ) + + parser.add_argument( + "--scenario_id", + type=int, + default=None, + help="指定场景ID(可选,如指定则只运行该场景)" + ) + + parser.add_argument( + "--use_scenario_duration", + action="store_true", + help="使用场景原始时长作为horizon(自动停止)" + ) + + parser.add_argument( + "--no_vehicles", + action="store_true", + help="禁止生成车辆" + ) + + parser.add_argument( + "--no_pedestrians", + action="store_true", + help="禁止生成行人" + ) + + parser.add_argument( + "--no_cyclists", + action="store_true", + help="禁止生成自行车" + ) + + args = parser.parse_args() + + if args.mode == "replay": + run_replay_mode( + data_dir=args.data_dir, + num_episodes=args.episodes, + horizon=args.horizon, + render=not args.no_render, + debug=args.debug, + scenario_id=args.scenario_id, + use_scenario_duration=args.use_scenario_duration, + spawn_vehicles=not args.no_vehicles, + spawn_pedestrians=not args.no_pedestrians, + spawn_cyclists=not args.no_cyclists + ) + else: + run_simulation_mode( + data_dir=args.data_dir, + num_episodes=args.episodes, + horizon=args.horizon, + render=not args.no_render, + debug=args.debug, + scenario_id=args.scenario_id, + use_scenario_duration=args.use_scenario_duration, + spawn_vehicles=not args.no_vehicles, + spawn_pedestrians=not args.no_pedestrians, + spawn_cyclists=not args.no_cyclists + ) if __name__ == "__main__": diff --git a/Env/scenario_env.py b/Env/scenario_env.py index ec1e3b9..85a616d 100644 --- a/Env/scenario_env.py +++ b/Env/scenario_env.py @@ -15,6 +15,7 @@ class PolicyVehicle(DefaultVehicle): super().__init__(*args, **kwargs) self.policy = None self.destination = None + self.expert_vehicle_id = None # 关联专家车辆ID def set_policy(self, policy): self.policy = policy @@ -22,6 +23,9 @@ class PolicyVehicle(DefaultVehicle): def set_destination(self, des): self.destination = des + def set_expert_vehicle_id(self, vid): + self.expert_vehicle_id = vid + def act(self, observation, policy=None): if self.policy is not None: return self.policy.act(observation) @@ -53,6 +57,15 @@ class MultiAgentScenarioEnv(ScenarioEnv): data_directory=None, num_controlled_agents=3, horizon=1000, + filter_offroad_vehicles=True, # 车道过滤开关 + lane_tolerance=3.0, # 车道检测容差(米) + replay_mode=False, # 回放模式开关 + specific_scenario_id=None, # 新增:指定场景ID(仅回放模式) + use_scenario_duration=False, # 新增:使用场景原始时长作为horizon + # 对象类型过滤选项 + spawn_vehicles=True, # 是否生成车辆 + spawn_pedestrians=True, # 是否生成行人 + spawn_cyclists=True, # 是否生成自行车 )) return config @@ -62,50 +75,180 @@ class MultiAgentScenarioEnv(ScenarioEnv): self.controlled_agent_ids = [] self.obs_list = [] self.round = 0 + self.expert_trajectories = {} # 存储完整专家轨迹 + self.replay_mode = config.get("replay_mode", False) + self.scenario_max_duration = 0 # 场景实际最大时长 super().__init__(config) def reset(self, seed: Union[None, int] = None): self.round = 0 + if self.logger is None: self.logger = get_logger() - log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO) - set_log_level(log_level) + log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO) + set_log_level(log_level) + # ✅ 关键修复:在每次 reset 前清理所有自定义生成的对象 + if hasattr(self, 'engine') and self.engine is not None: + if hasattr(self, 'controlled_agents') and self.controlled_agents: + # 先从 agent_manager 中移除 + if hasattr(self.engine, 'agent_manager'): + for agent_id in list(self.controlled_agents.keys()): + if agent_id in self.engine.agent_manager.active_agents: + self.engine.agent_manager.active_agents.pop(agent_id) + + # 然后清理对象 + for agent_id, vehicle in list(self.controlled_agents.items()): + try: + self.engine.clear_objects([vehicle.id]) + except: + pass + + self.controlled_agents.clear() + self.controlled_agent_ids.clear() + self.lazy_init() self._reset_global_seed(seed) + if self.engine is None: raise ValueError("Broken MetaDrive instance.") - # 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成 - _obj_to_clean_this_frame = [] - self.car_birth_info_list = [] - for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items(): - if scenario_id == self.engine.traffic_manager.sdc_scenario_id: - continue - else: - if track["type"] == MetaDriveType.VEHICLE: - _obj_to_clean_this_frame.append(scenario_id) - valid = track['state']['valid'] - first_show = np.argmax(valid) if valid.any() else -1 - last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1 - # id,出现时间,出生点坐标,出生朝向,目的地 - self.car_birth_info_list.append({ - 'id': track['metadata']['object_id'], - 'show_time': first_show, - 'begin': (track['state']['position'][first_show, 0], track['state']['position'][first_show, 1]), - 'heading': track['state']['heading'][first_show], - 'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1]) - }) - - for scenario_id in _obj_to_clean_this_frame: - self.engine.traffic_manager.current_traffic_data.pop(scenario_id) + # 如果指定了场景ID,修改start_scenario_index + if self.config.get("specific_scenario_id") is not None: + scenario_id = self.config.get("specific_scenario_id") + self.config["start_scenario_index"] = scenario_id + if self.config.get("debug", False): + self.logger.info(f"Using specific scenario ID: {scenario_id}") + # ✅ 先初始化引擎和 lanes self.engine.reset() self.reset_sensors() self.engine.taskMgr.step() - self.lanes = self.engine.map_manager.current_map.road_network.graph + # 记录专家数据(现在 self.lanes 已经初始化) + _obj_to_clean_this_frame = [] + self.car_birth_info_list = [] + self.expert_trajectories.clear() + total_vehicles = 0 + total_pedestrians = 0 + total_cyclists = 0 + filtered_vehicles = 0 + filtered_by_type = 0 + self.scenario_max_duration = 0 # 重置场景时长 + + for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items(): + if scenario_id == self.engine.traffic_manager.sdc_scenario_id: + continue + + # 对象类型过滤 + obj_type = track["type"] + + # 统计对象类型 + if obj_type == MetaDriveType.VEHICLE: + total_vehicles += 1 + elif obj_type == MetaDriveType.PEDESTRIAN: + total_pedestrians += 1 + elif obj_type == MetaDriveType.CYCLIST: + total_cyclists += 1 + + # 根据配置过滤对象类型 + if obj_type == MetaDriveType.VEHICLE and not self.config.get("spawn_vehicles", True): + _obj_to_clean_this_frame.append(scenario_id) + filtered_by_type += 1 + if self.config.get("debug", False): + self.logger.debug(f"Filtering VEHICLE {track['metadata']['object_id']} - spawn_vehicles=False") + continue + + if obj_type == MetaDriveType.PEDESTRIAN and not self.config.get("spawn_pedestrians", True): + _obj_to_clean_this_frame.append(scenario_id) + filtered_by_type += 1 + if self.config.get("debug", False): + self.logger.debug(f"Filtering PEDESTRIAN {track['metadata']['object_id']} - spawn_pedestrians=False") + continue + + if obj_type == MetaDriveType.CYCLIST and not self.config.get("spawn_cyclists", True): + _obj_to_clean_this_frame.append(scenario_id) + filtered_by_type += 1 + if self.config.get("debug", False): + self.logger.debug(f"Filtering CYCLIST {track['metadata']['object_id']} - spawn_cyclists=False") + continue + + # 只处理车辆类型(行人和自行车暂时只做过滤) + if track["type"] == MetaDriveType.VEHICLE: + valid = track['state']['valid'] + first_show = np.argmax(valid) if valid.any() else -1 + last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1 + + if first_show == -1 or last_show == -1: + continue + + # 更新场景最大时长 + self.scenario_max_duration = max(self.scenario_max_duration, last_show + 1) + + # 获取车辆初始位置 + initial_position = ( + track['state']['position'][first_show, 0], + track['state']['position'][first_show, 1] + ) + + # 车道过滤 + if self.config.get("filter_offroad_vehicles", True): + if not self._is_position_on_lane(initial_position): + filtered_vehicles += 1 + _obj_to_clean_this_frame.append(scenario_id) + if self.config.get("debug", False): + self.logger.debug( + f"Filtering vehicle {track['metadata']['object_id']} - " + f"not on lane at position {initial_position}" + ) + continue + + # 存储完整专家轨迹(只使用2D位置,避免高度问题) + object_id = track['metadata']['object_id'] + positions_2d = track['state']['position'].copy() + positions_2d[:, 2] = 0 # 将z坐标设为0,让MetaDrive自动处理高度 + + self.expert_trajectories[object_id] = { + 'positions': positions_2d, + 'headings': track['state']['heading'].copy(), + 'velocities': track['state']['velocity'].copy(), + 'valid': track['state']['valid'].copy(), + } + + # 保存车辆生成信息 + self.car_birth_info_list.append({ + 'id': object_id, + 'show_time': first_show, + 'begin': initial_position, + 'heading': track['state']['heading'][first_show], + 'velocity': track['state']['velocity'][first_show] if self.config.get("inherit_expert_velocity", False) else None, + 'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1]) + }) + + # 在回放和仿真模式下都清除原始专家车辆 + _obj_to_clean_this_frame.append(scenario_id) + + # 清除专家车辆和过滤的对象 + for scenario_id in _obj_to_clean_this_frame: + self.engine.traffic_manager.current_traffic_data.pop(scenario_id) + + # 输出统计信息 + if self.config.get("debug", False): + self.logger.info(f"=== 对象统计 ===") + self.logger.info(f"车辆 (VEHICLE): 总数={total_vehicles}, 车道过滤={filtered_vehicles}, 保留={total_vehicles - filtered_vehicles}") + self.logger.info(f"行人 (PEDESTRIAN): 总数={total_pedestrians}") + self.logger.info(f"自行车 (CYCLIST): 总数={total_cyclists}") + self.logger.info(f"类型过滤: {filtered_by_type} 个对象") + self.logger.info(f"场景时长: {self.scenario_max_duration} 步") + + # 如果启用场景时长控制,更新horizon + if self.config.get("use_scenario_duration", False) and self.scenario_max_duration > 0: + original_horizon = self.config["horizon"] + self.config["horizon"] = self.scenario_max_duration + if self.config.get("debug", False): + self.logger.info(f"Horizon updated from {original_horizon} to {self.scenario_max_duration} (scenario duration)") + if self.top_down_renderer is not None: self.top_down_renderer.clear() self.engine.top_down_renderer = None @@ -113,7 +256,6 @@ class MultiAgentScenarioEnv(ScenarioEnv): self.dones = {} self.episode_rewards = defaultdict(float) self.episode_lengths = defaultdict(int) - self.controlled_agents.clear() self.controlled_agent_ids.clear() @@ -122,37 +264,92 @@ class MultiAgentScenarioEnv(ScenarioEnv): return self._get_all_obs() + def _is_position_on_lane(self, position, tolerance=None): + if tolerance is None: + tolerance = self.config.get("lane_tolerance", 3.0) + + # 确保 self.lanes 已初始化 + if not hasattr(self, 'lanes') or self.lanes is None: + if self.config.get("debug", False): + self.logger.warning("Lanes not initialized, skipping lane check") + return True + + position_2d = np.array(position[:2]) if len(position) > 2 else np.array(position) + + try: + for lane in self.lanes.values(): + if lane.lane.point_on_lane(position_2d): + return True + + lane_start = np.array(lane.lane.start)[:2] + lane_end = np.array(lane.lane.end)[:2] + lane_vec = lane_end - lane_start + lane_length = np.linalg.norm(lane_vec) + + if lane_length < 1e-6: + continue + + lane_vec_normalized = lane_vec / lane_length + point_vec = position_2d - lane_start + projection = np.dot(point_vec, lane_vec_normalized) + + if 0 <= projection <= lane_length: + closest_point = lane_start + projection * lane_vec_normalized + distance = np.linalg.norm(position_2d - closest_point) + if distance <= tolerance: + return True + except Exception as e: + if self.config.get("debug", False): + self.logger.warning(f"Lane check error: {e}") + return False + + return False + def _spawn_controlled_agents(self): - # ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent") - # ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0]) for car in self.car_birth_info_list: if car['show_time'] == self.round: agent_id = f"controlled_{car['id']}" - + vehicle_config = {} vehicle = self.engine.spawn_object( PolicyVehicle, - vehicle_config={}, + vehicle_config=vehicle_config, position=car['begin'], heading=car['heading'] ) - vehicle.reset(position=car['begin'], heading=car['heading']) + # 重置车辆状态 + reset_kwargs = { + 'position': car['begin'], + 'heading': car['heading'] + } + + # 如果启用速度继承,设置初始速度 + if car.get('velocity') is not None: + reset_kwargs['velocity'] = car['velocity'] + + vehicle.reset(**reset_kwargs) + + # 设置策略和目的地 vehicle.set_policy(self.policy) vehicle.set_destination(car['end']) + vehicle.set_expert_vehicle_id(car['id']) self.controlled_agents[agent_id] = vehicle self.controlled_agent_ids.append(agent_id) - # ✅ 关键:注册到引擎的 active_agents,才能参与物理更新 + # 注册到引擎的 active_agents self.engine.agent_manager.active_agents[agent_id] = vehicle + if self.config.get("debug", False): + self.logger.debug(f"Spawned vehicle {agent_id} at round {self.round}, position {car['begin']}") + def _get_all_obs(self): - # position, velocity, heading, lidar, navigation, TODO: trafficlight -> list self.obs_list = [] + for agent_id, vehicle in self.controlled_agents.items(): state = vehicle.get_state() - traffic_light = 0 + for lane in self.lanes.values(): if lane.lane.point_on_lane(state['position'][:2]): if self.engine.light_manager.has_traffic_light(lane.lane.index): @@ -168,37 +365,69 @@ class MultiAgentScenarioEnv(ScenarioEnv): break lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle, - physics_world=self.engine.physics_world.dynamic_world) + physics_world=self.engine.physics_world.dynamic_world) side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8, - base_vehicle=vehicle, - physics_world=self.engine.physics_world.static_world) + base_vehicle=vehicle, + physics_world=self.engine.physics_world.static_world) lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3, - base_vehicle=vehicle, - physics_world=self.engine.physics_world.static_world) + base_vehicle=vehicle, + physics_world=self.engine.physics_world.static_world) - obs = (state['position'][:2] + list(state['velocity']) + [state['heading_theta']] + obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']] + lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light] + list(vehicle.destination)) + self.obs_list.append(obs) + return self.obs_list def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]): self.round += 1 + # 应用动作 for agent_id, action in action_dict.items(): if agent_id in self.controlled_agents: self.controlled_agents[agent_id].before_step(action) + # 物理引擎步进 self.engine.step() + # 后处理 for agent_id in action_dict: if agent_id in self.controlled_agents: self.controlled_agents[agent_id].after_step() + # 生成新车辆 self._spawn_controlled_agents() + + # 获取观测 obs = self._get_all_obs() + rewards = {aid: 0.0 for aid in self.controlled_agents} dones = {aid: False for aid in self.controlled_agents} - dones["__all__"] = self.episode_step >= self.config["horizon"] + + # ✅ 修复:添加回放模式的完成检查 + replay_finished = False + if self.replay_mode and self.config.get("use_scenario_duration", False): + # 检查是否所有专家轨迹都已播放完毕 + if self.round >= self.scenario_max_duration: + replay_finished = True + if self.config.get("debug", False): + self.logger.info(f"Replay finished at step {self.round}/{self.scenario_max_duration}") + + dones["__all__"] = self.episode_step >= self.config["horizon"] or replay_finished + infos = {aid: {} for aid in self.controlled_agents} + return obs, rewards, dones, infos + + def close(self): + # ✅ 清理所有生成的车辆 + if hasattr(self, 'controlled_agents') and self.controlled_agents: + for agent_id, vehicle in list(self.controlled_agents.items()): + if vehicle in self.engine.get_objects(): + self.engine.clear_objects([vehicle.id]) + self.controlled_agents.clear() + self.controlled_agent_ids.clear() + + super().close() \ No newline at end of file diff --git a/README.md b/README.md index d0b11e2..35a5e65 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,401 @@ -# MAGAIL4AutoDrive -### 1.1 环境搭建 -环境核心代码封装于`Env`文件夹,通过运行`run_multiagent_env.py`即可启动多智能体交互环境,该脚本的核心功能为读取各智能体(车辆)的动作指令,并将其传入`env.step()`方法中完成仿真执行。 +# MAGAIL4AutoDrive - 多智能体自动驾驶环境 -当前已初步实现`Env.senario_env.MultiAgentScenarioEnv.reset()`车辆生成函数,具体逻辑如下:首先读取专家数据集中各车辆的初始位姿信息;随后对原始数据进行清洗,剔除车辆 Agent 实例信息,记录核心参数(车辆 ID、初始生成位置、朝向角、生成时间戳、目标终点坐标);最后调用`_spawn_controlled_agents()`函数,依据清洗后的参数在指定时间、指定位置生成搭载自动驾驶算法的可控车辆。 +基于 MetaDrive 的多智能体自动驾驶仿真与回放环境,支持 Waymo Open Dataset 的专家轨迹回放和自定义策略仿真。 -需解决的关键问题:部分车辆存在生成位置偏差(如生成于草坪区域),推测成因可能为专家数据记录误差或场景中模拟停车场区域的特殊标注。后续计划引入车道区域检测机制,通过判断车辆初始生成位置是否位于有效车道范围内,对非车道区域生成的车辆进行过滤,确保环境初始化的合理性。 +## 📋 目录 + +- [项目简介](#项目简介) +- [功能特性](#功能特性) +- [环境要求](#环境要求) +- [安装步骤](#安装步骤) +- [快速开始](#快速开始) +- [使用指南](#使用指南) +- [项目结构](#项目结构) +- [配置说明](#配置说明) +- [常见问题](#常见问题) + +## 项目简介 + +MAGAIL4AutoDrive 是一个基于 MetaDrive 0.4.3 的多智能体自动驾驶环境,专为模仿学习(Imitation Learning)和强化学习(Reinforcement Learning)研究设计。项目支持从真实世界数据集(如 Waymo Open Dataset)中加载场景,并提供两种核心运行模式: + +- **回放模式(Replay Mode)**:严格按照专家轨迹回放,用于数据可视化和验证 +- **仿真模式(Simulation Mode)**:使用自定义策略控制车辆,用于算法训练和测试 + +## 功能特性 + +### 核心功能 +- ✅ **多智能体支持**:同时控制多辆车辆进行协同仿真 +- ✅ **专家轨迹回放**:精确回放 Waymo 数据集中的专家驾驶行为 +- ✅ **自定义策略接口**:灵活接入各种控制策略(IDM、RL 等) +- ✅ **智能车道过滤**:自动过滤不在车道上的异常车辆 +- ✅ **场景时长控制**:支持使用数据集原始场景时长或自定义 horizon +- ✅ **丰富的传感器**:LiDAR、侧向检测器、车道线检测器、相机、仪表盘 + +### 高级特性 +- 🎯 指定场景 ID 运行 +- 🔄 自动场景切换(修复版) +- 📊 详细的调试日志输出 +- 🚗 车辆动态生成与管理 +- 🎮 支持可视化渲染和无头运行 + +## 环境要求 + +### 系统要求 +- **操作系统**:Ubuntu 18.04+ / macOS 10.14+ / Windows 10+ +- **Python 版本**:3.8 - 3.10 +- **GPU**:可选,但推荐使用(用于加速渲染) + +### 依赖库 +``` + +metadrive-simulator==0.4.3 +numpy>=1.19.0 +pygame>=2.0.0 + +``` + +## 安装步骤 + +### 1. 创建 Conda 环境 +``` + +conda create -n metadrive python=3.10 +conda activate metadrive + +``` + +### 2. 安装 MetaDrive +``` + +pip install metadrive-simulator==0.4.3 + +``` + +### 3. 克隆项目 +``` + +git clone https://github.com/your-username/MAGAIL4AutoDrive.git +cd MAGAIL4AutoDrive/Env + +``` + +### 4. 准备数据集 +将 Waymo 数据集转换为 MetaDrive 格式并放置在项目目录下: +``` + +MAGAIL4AutoDrive/Env/ +├── exp_converted/ +│ ├── scenario_0/ +│ ├── scenario_1/ +│ └── ... + +``` + +## 快速开始 + +### 回放模式(推荐先尝试) +``` -### 1.2 观测获取 -观测信息采集功能通过`Env.senario_env.MultiAgentScenarioEnv._get_all_obs()`函数实现,该函数支持遍历所有可控车辆并采集多维度观测数据,当前已实现的观测维度包括:车辆实时位置坐标、朝向角、行驶速度、雷达扫描点云(含障碍物与车道线特征)、导航信息(因场景复杂度较低,暂采用目标终点坐标直接作为导航输入)。 +# 使用场景原始时长回放第一个场景 -红绿灯信息采集机制需改进:当前方案通过 “车辆所属车道序号匹配对应红绿灯实例” 的逻辑获取信号灯状态,但存在两类问题:一是部分红绿灯实例的状态值为`None`;二是当单条车道存在分段设计时,部分区域的车辆会无法获取红绿灯状态。 +python run_multiagent_env.py --mode replay --episodes 1 --use_scenario_duration + +# 回放指定场景 + +python run_multiagent_env.py --mode replay --scenario_id 0 --use_scenario_duration + +# 回放多个场景 + +python run_multiagent_env.py --mode replay --episodes 3 --use_scenario_duration + +``` + +### 仿真模式 +``` -### 1.3 算法模块 -本方案的核心创新点在于对 GAIL 算法的判别器进行改进,使其适配多智能体场景下 “输入长度动态变化”(车辆数量不固定)的特性,实现对整体交互场景的分类判断,进而满足多智能体自动驾驶环境的训练需求。算法核心代码封装于`Algorithm.bert.Bert`类,具体实现逻辑如下: +# 使用默认策略运行仿真 -1. 输入层处理:输入数据为维度`(N, input_dim)`的矩阵(其中`N`为当前场景车辆数量,`input_dim`为单车辆固定观测维度),初始化`Bert`类时需设置`input_dim`,确保输入维度匹配; -2. 嵌入层与位置编码:通过`projection`线性投影层将单车辆观测维度映射至预设的嵌入维度(`embed_dim`),随后叠加可学习的位置编码(`pos_embed`),以捕捉观测序列的时序与空间关联信息; -3. Transformer 特征提取:嵌入后的特征向量输入至多层`Transformer`网络(层数由`num_layers`参数控制),完成高阶特征交互与抽象; -4. 分类头设计:提供两种特征聚合与分类方案:若开启`CLS`模式,在嵌入层前拼接 1 个可学习的`CLS`标记,最终取`CLS`标记对应的特征向量输入全连接层完成分类;若关闭`CLS`模式,则对`Transformer`输出的所有车辆特征向量进行序列维度均值池化,再将池化后的全局特征输入全连接层。分类器支持可选的`Tanh`激活函数,以适配不同场景下的输出分布需求。 +python run_multiagent_env.py --mode simulation --episodes 1 +# 无渲染运行(加速训练) -### 1.4 动作执行 -在当前环境测试阶段,暂沿用腾达的动作执行框架:为每辆可控车辆分配独立的`policy`模型,将单车辆观测数据输入对应`policy`得到动作指令后,传入`env.step()`完成仿真;同时在`before_step`阶段调用`_set_action()`函数,将动作指令绑定至车辆实例,最终由 MetaDrive 仿真系统完成物理动力学计算与场景渲染。 +python run_multiagent_env.py --mode simulation --episodes 5 --no_render -后续优化方向为构建 “参数共享式统一模型框架”,具体设计如下:所有车辆共用 1 个`policy`模型,通过参数共享机制实现模型的全局统一维护。该框架具备三重优势:一是避免多车辆独立模型带来的训练偏差(如不同模型训练程度不一致);二是解决车辆数量动态变化时的模型管理问题(车辆新增无需额外初始化模型,车辆减少不丢失模型训练信息);三是支持动作指令的并行计算,可显著提升每一步决策的迭代效率,适配大规模多智能体交互场景的训练需求。 +``` + +## 使用指南 + +### 命令行参数 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `--mode` | str | simulation | 运行模式:`replay` 或 `simulation` | +| `--data_dir` | str | 当前目录 | Waymo 数据目录路径 | +| `--episodes` | int | 1 | 运行回合数 | +| `--horizon` | int | 300 | 每回合最大步数 | +| `--no_render` | flag | False | 禁用渲染(加速运行) | +| `--debug` | flag | False | 启用调试模式 | +| `--scenario_id` | int | None | 指定场景 ID | +| `--use_scenario_duration` | flag | False | 使用场景原始时长 | +| `--no_vehicles` | flag | False | 禁止生成车辆 | +| `--no_pedestrians` | flag | False | 禁止生成行人 | +| `--no_cyclists` | flag | False | 禁止生成自行车 | + +### 回放模式详解 + +回放模式严格按照专家轨迹回放车辆状态,不涉及物理引擎控制。主要用途: +- 数据集可视化 +- 验证数据质量 +- 生成演示视频 + +```bash +# 完整参数示例 +python run_multiagent_env.py \ + --mode replay \ + --episodes 1 \ + --use_scenario_duration \ + --debug + +# 仅回放车辆,禁止行人和自行车 +python run_multiagent_env.py \ + --mode replay \ + --use_scenario_duration \ + --no_pedestrians \ + --no_cyclists +``` + +**重要提示**:回放模式建议始终启用 `--use_scenario_duration`,否则会出现场景播放完后继续运行的问题。 + +### 仿真模式详解 + +仿真模式使用自定义策略控制车辆,适合算法开发和测试: + +```bash +# 基础仿真 +python run_multiagent_env.py --mode simulation + +# 长时间训练(无渲染) +python run_multiagent_env.py \ + --mode simulation \ + --episodes 100 \ + --horizon 500 \ + --no_render + +# 仅车辆仿真(用于专注车车交互场景) +python run_multiagent_env.py \ + --mode simulation \ + --no_pedestrians \ + --no_cyclists +``` + +### 自定义策略 + +修改 `simple_idm_policy.py` 或创建新的策略类: + +```python +class CustomPolicy: +def __init__(self, **kwargs): +# 初始化策略参数 +pass + + def act(self, observation=None): + # 返回动作 [steering, acceleration] + # steering: [-1, 1] + # acceleration: [-1, 1] + return [0.0, 0.5] +``` + +在 `run_multiagent_env.py` 中使用: +``` + +from custom_policy import CustomPolicy + +env = MultiAgentScenarioEnv( +config={...}, +agent2policy=CustomPolicy() +) + +``` + +## 项目结构 + +``` + +MAGAIL4AutoDrive/Env/ +├── run_multiagent_env.py \# 主运行脚本 +├── scenario_env.py \# 多智能体场景环境 +├── replay_policy.py \# 专家轨迹回放策略 +├── simple_idm_policy.py \# IDM 策略实现 +├── utils.py \# 工具函数 +├── ENHANCED_USAGE_GUIDE.md \# 详细使用指南 +├── README.md \# 本文档 +└── exp_converted/ \# Waymo 数据集(需自行准备) +├── scenario_0/ +├── scenario_1/ +└── ... + +``` + +### 核心文件说明 + +**run_multiagent_env.py** +- 主入口脚本 +- 处理命令行参数 +- 管理回放和仿真两种模式的运行逻辑 + +**scenario_env.py** +- 自定义多智能体环境类 +- 车辆生成与管理 +- 车道过滤逻辑 +- 观测空间定义 + +**replay_policy.py** +- 专家轨迹回放策略 +- 逐帧状态查询 +- 轨迹完成判断 + +**simple_idm_policy.py** +- 简单的恒速策略示例 +- 可作为自定义策略的模板 + +## 配置说明 + +### 环境配置参数 + +在 `scenario_env.py` 的 `default_config()` 中可修改: + +```python +config.update(dict( + data_directory=None, # 数据目录 + num_controlled_agents=3, # 可控车辆数量(仅仿真模式) + horizon=1000, # 最大步数 + filter_offroad_vehicles=True, # 是否过滤车道外车辆 + lane_tolerance=3.0, # 车道容差(米) + replay_mode=False, # 是否为回放模式 + specific_scenario_id=None, # 指定场景 ID + use_scenario_duration=False, # 使用场景原始时长 + # 对象类型过滤选项 + spawn_vehicles=True, # 是否生成车辆 + spawn_pedestrians=True, # 是否生成行人 + spawn_cyclists=True, # 是否生成自行车 +)) +``` + +### 传感器配置 + +默认启用的传感器(可在环境初始化时修改): +- **LiDAR**:80 条激光,探测距离 30 米 +- **侧向检测器**:10 条激光,探测距离 8 米 +- **车道线检测器**:10 条激光,探测距离 3 米 +- **主相机**:分辨率 1200x900 +- **仪表盘**:车辆状态信息 + +## 常见问题 + +### Q1: 回放模式为什么超出数据集的最大帧数还在继续? +**A**: 需要添加 `--use_scenario_duration` 参数。修复版本已在 `scenario_env.py` 中添加了自动检测机制。 + +### Q2: 如何切换不同的场景? +**A**: +- 方法一:使用 `--scenario_id` 指定场景 +- 方法二:使用 `--episodes N` 自动遍历 N 个场景 + +### Q3: 为什么有些车辆没有出现? +**A**: 启用了车道过滤功能(`filter_offroad_vehicles=True`),不在车道上的车辆会被过滤。可以通过设置 `lane_tolerance` 调整容差或关闭此功能。 + +### Q4: 如何提高运行速度? +**A**: +- 使用 `--no_render` 禁用可视化 +- 减少 `num_controlled_agents` 数量 +- 使用 GPU 加速 + +### Q5: 如何控制场景中的对象类型? +**A**: 使用对象过滤参数: +```bash +# 仅车辆,无行人和自行车 +python run_multiagent_env.py --mode replay --no_pedestrians --no_cyclists + +# 仅行人和自行车,无车辆(特殊场景) +python run_multiagent_env.py --mode replay --no_vehicles + +# 调试模式查看过滤统计 +python run_multiagent_env.py --mode replay --debug --no_pedestrians +``` + +### Q6: 为什么有些车辆生成在空中? +**A**: 已在 v1.2.0 中修复。现在所有车辆位置都只使用 2D 坐标(x, y),z 坐标设为 0,让 MetaDrive 自动处理高度,确保车辆贴在地面上。 + +### Q7: 如何导出观测数据? +**A**: 在 `run_multiagent_env.py` 中添加数据保存逻辑: +```python +import pickle + +obs_data = [] +while True: + obs, rewards, dones, infos = env.step(actions) + obs_data.append(obs) + if dones["__all__"]: + break + +with open('observations.pkl', 'wb') as f: + pickle.dump(obs_data, f) +``` + +## 更新日志 + +### v1.2.0 (2025-10-26) +- ✅ 修复车辆生成高度问题(车辆悬空) +- ✅ 添加对象类型过滤功能(车辆/行人/自行车) +- ✅ 新增命令行参数:`--no_vehicles`、`--no_pedestrians`、`--no_cyclists` +- ✅ 改进调试信息输出,显示各类型对象统计 +- ✅ 优化位置处理逻辑,只使用 2D 坐标避免高度问题 + +### v1.1.0 (2025-10-26) +- ✅ 修复回放模式超出场景时长问题 +- ✅ 添加场景自动切换功能 +- ✅ 改进 `replay_policy.py`,新增 `is_finished()` 方法 +- ✅ 优化 `scenario_env.py` 的 done 判断逻辑 +- ✅ 修复多回合运行时的对象清理问题 + +### v1.0.0 (初始版本) +- 基础多智能体环境实现 +- 回放和仿真两种模式 +- 车道过滤功能 +- Waymo 数据集支持 + +## 贡献指南 + +欢迎提交 Issue 和 Pull Request! + +### 提交 Issue +- 请详细描述问题和复现步骤 +- 附上运行日志和错误信息 +- 说明运行环境(OS、Python 版本等) + +### 提交 PR +- Fork 本项目 +- 创建特性分支:`git checkout -b feature/your-feature` +- 提交更改:`git commit -m 'Add some feature'` +- 推送分支:`git push origin feature/your-feature` +- 提交 Pull Request + +## 许可证 + +本项目基于 MIT 许可证开源。 + +## 致谢 + +- [MetaDrive](https://github.com/metadriverse/metadrive) - 优秀的驾驶仿真平台 +- [Waymo Open Dataset](https://waymo.com/open/) - 高质量的自动驾驶数据集 + +## 联系方式 + +如有问题或建议,请通过以下方式联系: +- GitHub Issues: [项目 Issues 页面] +- Email: huangfukk@xxx.com + +--- + +**Happy Driving! 🚗💨**