import argparse from scenario_env import MultiAgentScenarioEnv from simple_idm_policy import ConstantVelocityPolicy from replay_policy import ReplayPolicy from metadrive.engine.asset_loader import AssetLoader WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted" def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False, scenario_id=None, use_scenario_duration=False, spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True): """ 回放模式:严格按照专家轨迹回放 Args: data_dir: 数据目录 num_episodes: 回合数(如果指定scenario_id,则忽略) horizon: 最大步数(如果use_scenario_duration=True,则自动设置) render: 是否渲染 debug: 是否调试模式 scenario_id: 指定场景ID(可选) use_scenario_duration: 是否使用场景原始时长 spawn_vehicles: 是否生成车辆(默认True) spawn_pedestrians: 是否生成行人(默认True) spawn_cyclists: 是否生成自行车(默认True) """ print("=" * 50) print("运行模式: 专家轨迹回放 (Replay Mode)") if scenario_id is not None: print(f"指定场景ID: {scenario_id}") if use_scenario_duration: print("使用场景原始时长") print("=" * 50) # 如果指定了场景ID,只运行1个回合 if scenario_id is not None: num_episodes = 1 # ✅ 环境创建移到循环外面,避免重复创建 env = MultiAgentScenarioEnv( config={ "data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False), "is_multi_agent": True, "horizon": horizon, "use_render": render, "sequential_seed": True, "reactive_traffic": False, # 回放模式下不需要反应式交通 "manual_control": False, "filter_offroad_vehicles": True, # 启用车道过滤 "lane_tolerance": 3.0, "replay_mode": True, # 标记为回放模式 "debug": debug, "specific_scenario_id": scenario_id, # 指定场景ID "use_scenario_duration": use_scenario_duration, # 使用场景时长 # 对象类型过滤 "spawn_vehicles": spawn_vehicles, "spawn_pedestrians": spawn_pedestrians, "spawn_cyclists": spawn_cyclists, }, agent2policy=None # 回放模式不需要统一策略 ) try: for episode in range(num_episodes): print(f"\n{'='*50}") print(f"回合 {episode + 1}/{num_episodes}") if scenario_id is not None: print(f"场景ID: {scenario_id}") print(f"{'='*50}") # ✅ 如果不是指定场景,使用seed来遍历不同场景 seed = scenario_id if scenario_id is not None else episode obs = env.reset(seed=seed) # 为每个车辆分配 ReplayPolicy replay_policies = {} for agent_id, vehicle in env.controlled_agents.items(): vehicle_id = vehicle.expert_vehicle_id if vehicle_id in env.expert_trajectories: replay_policy = ReplayPolicy( env.expert_trajectories[vehicle_id], vehicle_id ) vehicle.set_policy(replay_policy) replay_policies[agent_id] = replay_policy # 输出场景信息 actual_horizon = env.config["horizon"] print(f"初始化完成:") print(f" 可控车辆数: {len(env.controlled_agents)}") print(f" 专家轨迹数: {len(env.expert_trajectories)}") print(f" 场景时长: {env.scenario_max_duration} 步") print(f" 实际Horizon: {actual_horizon} 步") step_count = 0 active_vehicles_count = [] while True: # 在回放模式下,直接使用专家轨迹设置车辆状态 for agent_id, vehicle in list(env.controlled_agents.items()): vehicle_id = vehicle.expert_vehicle_id if vehicle_id in env.expert_trajectories and agent_id in replay_policies: target_state = replay_policies[agent_id].get_target_state(env.round) if target_state is not None: # 直接设置车辆状态(绕过物理引擎) # 只使用xy坐标,保持车辆在地面上 position_2d = target_state['position'][:2] vehicle.set_position(position_2d) vehicle.set_heading_theta(target_state['heading']) vehicle.set_velocity(target_state['velocity'][:2] if len(target_state['velocity']) > 2 else target_state['velocity']) # 使用空动作进行步进 actions = {aid: [0.0, 0.0] for aid in env.controlled_agents} obs, rewards, dones, infos = env.step(actions) if render: env.render(mode="topdown") step_count += 1 active_vehicles_count.append(len(env.controlled_agents)) # 每50步打印一次状态 if step_count % 50 == 0: print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆") # 调试模式下打印车辆高度信息 if debug and len(env.controlled_agents) > 0: sample_vehicle = list(env.controlled_agents.values())[0] z_pos = sample_vehicle.position[2] if len(sample_vehicle.position) > 2 else 0 print(f" [DEBUG] 示例车辆高度: z={z_pos:.3f}m") if dones["__all__"]: print(f"\n回合结束统计:") print(f" 总步数: {step_count}") print(f" 最大同时车辆数: {max(active_vehicles_count) if active_vehicles_count else 0}") print(f" 平均车辆数: {sum(active_vehicles_count) / len(active_vehicles_count) if active_vehicles_count else 0:.1f}") if use_scenario_duration: print(f" 场景完整回放: {'是' if step_count >= env.scenario_max_duration else '否'}") break finally: # ✅ 确保环境被正确关闭 env.close() print("\n" + "=" * 50) print("回放完成!") print("=" * 50) def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False, scenario_id=None, use_scenario_duration=False, spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True): """ 仿真模式:使用自定义策略控制车辆 车辆根据专家数据的初始位姿生成,然后由策略控制 Args: data_dir: 数据目录 num_episodes: 回合数 horizon: 最大步数 render: 是否渲染 debug: 是否调试模式 scenario_id: 指定场景ID(可选) use_scenario_duration: 是否使用场景原始时长 spawn_vehicles: 是否生成车辆(默认True) spawn_pedestrians: 是否生成行人(默认True) spawn_cyclists: 是否生成自行车(默认True) """ print("=" * 50) print("运行模式: 策略仿真 (Simulation Mode)") if scenario_id is not None: print(f"指定场景ID: {scenario_id}") if use_scenario_duration: print("使用场景原始时长") print("=" * 50) # 如果指定了场景ID,只运行1个回合 if scenario_id is not None: num_episodes = 1 env = MultiAgentScenarioEnv( config={ "data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False), "is_multi_agent": True, "num_controlled_agents": 3, "horizon": horizon, "use_render": render, "sequential_seed": True, "reactive_traffic": True, "manual_control": False, "filter_offroad_vehicles": True, # 启用车道过滤 "lane_tolerance": 3.0, "replay_mode": False, # 仿真模式 "debug": debug, "specific_scenario_id": scenario_id, # 指定场景ID "use_scenario_duration": use_scenario_duration, # 使用场景时长 # 对象类型过滤 "spawn_vehicles": spawn_vehicles, "spawn_pedestrians": spawn_pedestrians, "spawn_cyclists": spawn_cyclists, }, agent2policy=ConstantVelocityPolicy(target_speed=50) ) try: for episode in range(num_episodes): print(f"\n{'='*50}") print(f"回合 {episode + 1}/{num_episodes}") if scenario_id is not None: print(f"场景ID: {scenario_id}") print(f"{'='*50}") seed = scenario_id if scenario_id is not None else episode obs = env.reset(seed=seed) actual_horizon = env.config["horizon"] print(f"初始化完成:") print(f" 可控车辆数: {len(env.controlled_agents)}") print(f" 场景时长: {env.scenario_max_duration} 步") print(f" 实际Horizon: {actual_horizon} 步") step_count = 0 total_reward = 0.0 while True: # 使用策略生成动作 actions = { aid: env.controlled_agents[aid].policy.act() for aid in env.controlled_agents } obs, rewards, dones, infos = env.step(actions) if render: env.render(mode="topdown") step_count += 1 total_reward += sum(rewards.values()) # 每50步打印一次状态 if step_count % 50 == 0: print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆") if dones["__all__"]: print(f"\n回合结束统计:") print(f" 总步数: {step_count}") print(f" 总奖励: {total_reward:.2f}") break finally: env.close() print("\n" + "=" * 50) print("仿真完成!") print("=" * 50) def main(): parser = argparse.ArgumentParser(description="MetaDrive 多智能体环境运行脚本") parser.add_argument( "--mode", type=str, choices=["replay", "simulation"], default="simulation", help="运行模式: replay=专家轨迹回放, simulation=策略仿真 (默认: simulation)" ) parser.add_argument( "--data_dir", type=str, default=WAYMO_DATA_DIR, help=f"数据目录路径 (默认: {WAYMO_DATA_DIR})" ) parser.add_argument( "--episodes", type=int, default=1, help="运行回合数 (默认: 1)" ) parser.add_argument( "--horizon", type=int, default=300, help="每回合最大步数 (默认: 300,如果启用 --use_scenario_duration 则自动设置)" ) parser.add_argument( "--no_render", action="store_true", help="禁用渲染(加速运行)" ) parser.add_argument( "--debug", action="store_true", help="启用调试模式(显示详细日志)" ) parser.add_argument( "--scenario_id", type=int, default=None, help="指定场景ID(可选,如指定则只运行该场景)" ) parser.add_argument( "--use_scenario_duration", action="store_true", help="使用场景原始时长作为horizon(自动停止)" ) parser.add_argument( "--no_vehicles", action="store_true", help="禁止生成车辆" ) parser.add_argument( "--no_pedestrians", action="store_true", help="禁止生成行人" ) parser.add_argument( "--no_cyclists", action="store_true", help="禁止生成自行车" ) args = parser.parse_args() if args.mode == "replay": run_replay_mode( data_dir=args.data_dir, num_episodes=args.episodes, horizon=args.horizon, render=not args.no_render, debug=args.debug, scenario_id=args.scenario_id, use_scenario_duration=args.use_scenario_duration, spawn_vehicles=not args.no_vehicles, spawn_pedestrians=not args.no_pedestrians, spawn_cyclists=not args.no_cyclists ) else: run_simulation_mode( data_dir=args.data_dir, num_episodes=args.episodes, horizon=args.horizon, render=not args.no_render, debug=args.debug, scenario_id=args.scenario_id, use_scenario_duration=args.use_scenario_duration, spawn_vehicles=not args.no_vehicles, spawn_pedestrians=not args.no_pedestrians, spawn_cyclists=not args.no_cyclists ) if __name__ == "__main__": main()