修改了算法代码,并建立了一个简单的训练脚本.修改bert处理二维输入,移除PPO的permute参数

This commit is contained in:
2025-10-22 16:56:12 +08:00
parent b626702cbb
commit 3f7e183c4b
101 changed files with 3837 additions and 39 deletions

131
test_vehicle_movement.py Normal file
View File

@@ -0,0 +1,131 @@
"""
测试车辆是否能正常运动
使用固定的前进动作,观察车辆运动
"""
import sys
import time
sys.path.append('Env')
from Env.scenario_env import MultiAgentScenarioEnv
from Env.simple_idm_policy import ConstantVelocityPolicy
from metadrive.engine.asset_loader import AssetLoader
class FixedForwardPolicy:
"""固定前进策略 - 确保车辆运动"""
def act(self):
# 大油门直行
return [0.0, 1.0] # [转向, 油门]
def main():
print("=" * 60)
print("🚗 测试车辆运动")
print("=" * 60)
# 创建环境
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(
"/home/huangfukk/MAGAIL4AutoDrive/Env",
"exp_converted",
unix_style=False
),
"is_multi_agent": True,
"num_controlled_agents": 3,
"horizon": 500,
"use_render": True,
"sequential_seed": True,
"reactive_traffic": True,
"manual_control": False,
"filter_offroad_vehicles": True,
"lane_tolerance": 3.0,
"max_controlled_vehicles": 3,
"debug_lane_filter": False,
"debug_traffic_light": False,
},
agent2policy=FixedForwardPolicy()
)
# 重置环境
obs = env.reset(0)
print(f"\n✅ 环境初始化完成")
print(f" 可控车辆数: {len(env.controlled_agents)}")
if len(env.controlled_agents) == 0:
print("❌ 没有可控车辆!")
return
# 获取第一辆车
first_vehicle = list(env.controlled_agents.values())[0]
initial_pos = [first_vehicle.position[0], first_vehicle.position[1]]
print(f"\n🚗 第一辆车初始状态:")
print(f" 位置: {initial_pos}")
print(f" 速度: {first_vehicle.speed:.2f} m/s")
print(f"\n🎬 开始运行... (固定动作: 直行+满油门)")
print(f" 按Ctrl+C停止\n")
# 固定动作:直行 + 满油门
fixed_action = [0.0, 1.0] # [转向, 油门]
for step in range(500):
# 所有车辆使用相同的固定动作
actions = {aid: fixed_action for aid in env.controlled_agents}
# 步进
obs, rewards, dones, infos = env.step(actions)
# 渲染
env.render(mode="topdown")
time.sleep(0.05) # 50ms延迟看得更清楚
# 每50步打印状态
if step % 50 == 0 and step > 0:
current_pos = [first_vehicle.position[0], first_vehicle.position[1]]
distance = ((current_pos[0] - initial_pos[0])**2 +
(current_pos[1] - initial_pos[1])**2) ** 0.5
print(f"步数 {step:3d}:")
print(f" 当前位置: ({current_pos[0]:.2f}, {current_pos[1]:.2f})")
print(f" 当前速度: {first_vehicle.speed:.2f} m/s")
print(f" 移动距离: {distance:.2f} m")
print()
if dones.get("__all__", False):
print(f"✅ Episode完成于步数 {step}")
break
# 最终统计
final_pos = [first_vehicle.position[0], first_vehicle.position[1]]
total_distance = ((final_pos[0] - initial_pos[0])**2 +
(final_pos[1] - initial_pos[1])**2) ** 0.5
print(f"\n" + "=" * 60)
print(f"📊 运动统计:")
print(f" 初始位置: ({initial_pos[0]:.2f}, {initial_pos[1]:.2f})")
print(f" 最终位置: ({final_pos[0]:.2f}, {final_pos[1]:.2f})")
print(f" 总移动距离: {total_distance:.2f} m")
print(f" 平均速度: {first_vehicle.speed:.2f} m/s")
if total_distance < 1.0:
print(f"\n❌ 警告: 车辆几乎没有移动!")
else:
print(f"\n✅ 车辆正常运动")
print("=" * 60)
env.close()
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\n⏹️ 用户中断")
except Exception as e:
print(f"\n❌ 错误: {e}")
import traceback
traceback.print_exc()