2025-10-27 00:27:47 +08:00
|
|
|
|
import argparse
|
2025-09-28 18:57:04 +08:00
|
|
|
|
from scenario_env import MultiAgentScenarioEnv
|
2025-10-27 00:27:47 +08:00
|
|
|
|
from simple_idm_policy import ConstantVelocityPolicy
|
|
|
|
|
|
from replay_policy import ReplayPolicy
|
2025-09-28 18:57:04 +08:00
|
|
|
|
from metadrive.engine.asset_loader import AssetLoader
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
|
|
|
|
|
scenario_id=None, use_scenario_duration=False,
|
|
|
|
|
|
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
|
|
|
|
|
"""
|
|
|
|
|
|
回放模式:严格按照专家轨迹回放
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
data_dir: 数据目录
|
|
|
|
|
|
num_episodes: 回合数(如果指定scenario_id,则忽略)
|
|
|
|
|
|
horizon: 最大步数(如果use_scenario_duration=True,则自动设置)
|
|
|
|
|
|
render: 是否渲染
|
|
|
|
|
|
debug: 是否调试模式
|
|
|
|
|
|
scenario_id: 指定场景ID(可选)
|
|
|
|
|
|
use_scenario_duration: 是否使用场景原始时长
|
|
|
|
|
|
spawn_vehicles: 是否生成车辆(默认True)
|
|
|
|
|
|
spawn_pedestrians: 是否生成行人(默认True)
|
|
|
|
|
|
spawn_cyclists: 是否生成自行车(默认True)
|
|
|
|
|
|
"""
|
|
|
|
|
|
print("=" * 50)
|
|
|
|
|
|
print("运行模式: 专家轨迹回放 (Replay Mode)")
|
|
|
|
|
|
if scenario_id is not None:
|
|
|
|
|
|
print(f"指定场景ID: {scenario_id}")
|
|
|
|
|
|
if use_scenario_duration:
|
|
|
|
|
|
print("使用场景原始时长")
|
|
|
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果指定了场景ID,只运行1个回合
|
|
|
|
|
|
if scenario_id is not None:
|
|
|
|
|
|
num_episodes = 1
|
|
|
|
|
|
|
|
|
|
|
|
# ✅ 环境创建移到循环外面,避免重复创建
|
|
|
|
|
|
env = MultiAgentScenarioEnv(
|
|
|
|
|
|
config={
|
|
|
|
|
|
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
|
|
|
|
|
"is_multi_agent": True,
|
|
|
|
|
|
"horizon": horizon,
|
|
|
|
|
|
"use_render": render,
|
|
|
|
|
|
"sequential_seed": True,
|
|
|
|
|
|
"reactive_traffic": False, # 回放模式下不需要反应式交通
|
|
|
|
|
|
"manual_control": False,
|
|
|
|
|
|
"filter_offroad_vehicles": True, # 启用车道过滤
|
|
|
|
|
|
"lane_tolerance": 3.0,
|
|
|
|
|
|
"replay_mode": True, # 标记为回放模式
|
|
|
|
|
|
"debug": debug,
|
|
|
|
|
|
"specific_scenario_id": scenario_id, # 指定场景ID
|
|
|
|
|
|
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
|
|
|
|
|
# 对象类型过滤
|
|
|
|
|
|
"spawn_vehicles": spawn_vehicles,
|
|
|
|
|
|
"spawn_pedestrians": spawn_pedestrians,
|
|
|
|
|
|
"spawn_cyclists": spawn_cyclists,
|
|
|
|
|
|
},
|
|
|
|
|
|
agent2policy=None # 回放模式不需要统一策略
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
for episode in range(num_episodes):
|
|
|
|
|
|
print(f"\n{'='*50}")
|
|
|
|
|
|
print(f"回合 {episode + 1}/{num_episodes}")
|
|
|
|
|
|
if scenario_id is not None:
|
|
|
|
|
|
print(f"场景ID: {scenario_id}")
|
|
|
|
|
|
print(f"{'='*50}")
|
|
|
|
|
|
|
|
|
|
|
|
# ✅ 如果不是指定场景,使用seed来遍历不同场景
|
|
|
|
|
|
seed = scenario_id if scenario_id is not None else episode
|
|
|
|
|
|
obs = env.reset(seed=seed)
|
|
|
|
|
|
|
|
|
|
|
|
# 为每个车辆分配 ReplayPolicy
|
|
|
|
|
|
replay_policies = {}
|
|
|
|
|
|
for agent_id, vehicle in env.controlled_agents.items():
|
|
|
|
|
|
vehicle_id = vehicle.expert_vehicle_id
|
|
|
|
|
|
if vehicle_id in env.expert_trajectories:
|
|
|
|
|
|
replay_policy = ReplayPolicy(
|
|
|
|
|
|
env.expert_trajectories[vehicle_id],
|
|
|
|
|
|
vehicle_id
|
|
|
|
|
|
)
|
|
|
|
|
|
vehicle.set_policy(replay_policy)
|
|
|
|
|
|
replay_policies[agent_id] = replay_policy
|
|
|
|
|
|
|
|
|
|
|
|
# 输出场景信息
|
|
|
|
|
|
actual_horizon = env.config["horizon"]
|
|
|
|
|
|
print(f"初始化完成:")
|
|
|
|
|
|
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
|
|
|
|
|
print(f" 专家轨迹数: {len(env.expert_trajectories)}")
|
|
|
|
|
|
print(f" 场景时长: {env.scenario_max_duration} 步")
|
|
|
|
|
|
print(f" 实际Horizon: {actual_horizon} 步")
|
|
|
|
|
|
|
|
|
|
|
|
step_count = 0
|
|
|
|
|
|
active_vehicles_count = []
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
# 在回放模式下,直接使用专家轨迹设置车辆状态
|
|
|
|
|
|
for agent_id, vehicle in list(env.controlled_agents.items()):
|
|
|
|
|
|
vehicle_id = vehicle.expert_vehicle_id
|
|
|
|
|
|
if vehicle_id in env.expert_trajectories and agent_id in replay_policies:
|
|
|
|
|
|
target_state = replay_policies[agent_id].get_target_state(env.round)
|
|
|
|
|
|
if target_state is not None:
|
|
|
|
|
|
# 直接设置车辆状态(绕过物理引擎)
|
|
|
|
|
|
# 只使用xy坐标,保持车辆在地面上
|
|
|
|
|
|
position_2d = target_state['position'][:2]
|
|
|
|
|
|
vehicle.set_position(position_2d)
|
|
|
|
|
|
vehicle.set_heading_theta(target_state['heading'])
|
|
|
|
|
|
vehicle.set_velocity(target_state['velocity'][:2] if len(target_state['velocity']) > 2 else target_state['velocity'])
|
|
|
|
|
|
|
|
|
|
|
|
# 使用空动作进行步进
|
|
|
|
|
|
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
|
|
|
|
|
|
obs, rewards, dones, infos = env.step(actions)
|
|
|
|
|
|
|
|
|
|
|
|
if render:
|
|
|
|
|
|
env.render(mode="topdown")
|
|
|
|
|
|
|
|
|
|
|
|
step_count += 1
|
|
|
|
|
|
active_vehicles_count.append(len(env.controlled_agents))
|
|
|
|
|
|
|
|
|
|
|
|
# 每50步打印一次状态
|
|
|
|
|
|
if step_count % 50 == 0:
|
|
|
|
|
|
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
|
|
|
|
|
|
|
|
|
|
|
# 调试模式下打印车辆高度信息
|
|
|
|
|
|
if debug and len(env.controlled_agents) > 0:
|
|
|
|
|
|
sample_vehicle = list(env.controlled_agents.values())[0]
|
|
|
|
|
|
z_pos = sample_vehicle.position[2] if len(sample_vehicle.position) > 2 else 0
|
|
|
|
|
|
print(f" [DEBUG] 示例车辆高度: z={z_pos:.3f}m")
|
|
|
|
|
|
|
|
|
|
|
|
if dones["__all__"]:
|
|
|
|
|
|
print(f"\n回合结束统计:")
|
|
|
|
|
|
print(f" 总步数: {step_count}")
|
|
|
|
|
|
print(f" 最大同时车辆数: {max(active_vehicles_count) if active_vehicles_count else 0}")
|
|
|
|
|
|
print(f" 平均车辆数: {sum(active_vehicles_count) / len(active_vehicles_count) if active_vehicles_count else 0:.1f}")
|
|
|
|
|
|
if use_scenario_duration:
|
|
|
|
|
|
print(f" 场景完整回放: {'是' if step_count >= env.scenario_max_duration else '否'}")
|
|
|
|
|
|
break
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# ✅ 确保环境被正确关闭
|
|
|
|
|
|
env.close()
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 50)
|
|
|
|
|
|
print("回放完成!")
|
|
|
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
|
|
|
|
|
scenario_id=None, use_scenario_duration=False,
|
|
|
|
|
|
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
|
|
|
|
|
"""
|
|
|
|
|
|
仿真模式:使用自定义策略控制车辆
|
|
|
|
|
|
车辆根据专家数据的初始位姿生成,然后由策略控制
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
data_dir: 数据目录
|
|
|
|
|
|
num_episodes: 回合数
|
|
|
|
|
|
horizon: 最大步数
|
|
|
|
|
|
render: 是否渲染
|
|
|
|
|
|
debug: 是否调试模式
|
|
|
|
|
|
scenario_id: 指定场景ID(可选)
|
|
|
|
|
|
use_scenario_duration: 是否使用场景原始时长
|
|
|
|
|
|
spawn_vehicles: 是否生成车辆(默认True)
|
|
|
|
|
|
spawn_pedestrians: 是否生成行人(默认True)
|
|
|
|
|
|
spawn_cyclists: 是否生成自行车(默认True)
|
|
|
|
|
|
"""
|
|
|
|
|
|
print("=" * 50)
|
|
|
|
|
|
print("运行模式: 策略仿真 (Simulation Mode)")
|
|
|
|
|
|
if scenario_id is not None:
|
|
|
|
|
|
print(f"指定场景ID: {scenario_id}")
|
|
|
|
|
|
if use_scenario_duration:
|
|
|
|
|
|
print("使用场景原始时长")
|
|
|
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果指定了场景ID,只运行1个回合
|
|
|
|
|
|
if scenario_id is not None:
|
|
|
|
|
|
num_episodes = 1
|
2025-09-28 18:57:04 +08:00
|
|
|
|
|
|
|
|
|
|
env = MultiAgentScenarioEnv(
|
|
|
|
|
|
config={
|
2025-10-27 00:27:47 +08:00
|
|
|
|
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
2025-09-28 18:57:04 +08:00
|
|
|
|
"is_multi_agent": True,
|
|
|
|
|
|
"num_controlled_agents": 3,
|
2025-10-27 00:27:47 +08:00
|
|
|
|
"horizon": horizon,
|
|
|
|
|
|
"use_render": render,
|
2025-09-28 18:57:04 +08:00
|
|
|
|
"sequential_seed": True,
|
|
|
|
|
|
"reactive_traffic": True,
|
2025-10-27 00:27:47 +08:00
|
|
|
|
"manual_control": False,
|
|
|
|
|
|
"filter_offroad_vehicles": True, # 启用车道过滤
|
|
|
|
|
|
"lane_tolerance": 3.0,
|
|
|
|
|
|
"replay_mode": False, # 仿真模式
|
|
|
|
|
|
"debug": debug,
|
|
|
|
|
|
"specific_scenario_id": scenario_id, # 指定场景ID
|
|
|
|
|
|
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
|
|
|
|
|
# 对象类型过滤
|
|
|
|
|
|
"spawn_vehicles": spawn_vehicles,
|
|
|
|
|
|
"spawn_pedestrians": spawn_pedestrians,
|
|
|
|
|
|
"spawn_cyclists": spawn_cyclists,
|
2025-09-28 18:57:04 +08:00
|
|
|
|
},
|
|
|
|
|
|
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
try:
|
|
|
|
|
|
for episode in range(num_episodes):
|
|
|
|
|
|
print(f"\n{'='*50}")
|
|
|
|
|
|
print(f"回合 {episode + 1}/{num_episodes}")
|
|
|
|
|
|
if scenario_id is not None:
|
|
|
|
|
|
print(f"场景ID: {scenario_id}")
|
|
|
|
|
|
print(f"{'='*50}")
|
|
|
|
|
|
|
|
|
|
|
|
seed = scenario_id if scenario_id is not None else episode
|
|
|
|
|
|
obs = env.reset(seed=seed)
|
|
|
|
|
|
|
|
|
|
|
|
actual_horizon = env.config["horizon"]
|
|
|
|
|
|
print(f"初始化完成:")
|
|
|
|
|
|
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
|
|
|
|
|
print(f" 场景时长: {env.scenario_max_duration} 步")
|
|
|
|
|
|
print(f" 实际Horizon: {actual_horizon} 步")
|
2025-09-28 18:57:04 +08:00
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
step_count = 0
|
|
|
|
|
|
total_reward = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
# 使用策略生成动作
|
|
|
|
|
|
actions = {
|
|
|
|
|
|
aid: env.controlled_agents[aid].policy.act()
|
|
|
|
|
|
for aid in env.controlled_agents
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
obs, rewards, dones, infos = env.step(actions)
|
|
|
|
|
|
|
|
|
|
|
|
if render:
|
|
|
|
|
|
env.render(mode="topdown")
|
|
|
|
|
|
|
|
|
|
|
|
step_count += 1
|
|
|
|
|
|
total_reward += sum(rewards.values())
|
|
|
|
|
|
|
|
|
|
|
|
# 每50步打印一次状态
|
|
|
|
|
|
if step_count % 50 == 0:
|
|
|
|
|
|
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
|
|
|
|
|
|
|
|
|
|
|
if dones["__all__"]:
|
|
|
|
|
|
print(f"\n回合结束统计:")
|
|
|
|
|
|
print(f" 总步数: {step_count}")
|
|
|
|
|
|
print(f" 总奖励: {total_reward:.2f}")
|
|
|
|
|
|
break
|
|
|
|
|
|
finally:
|
|
|
|
|
|
env.close()
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 50)
|
|
|
|
|
|
print("仿真完成!")
|
|
|
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="MetaDrive 多智能体环境运行脚本")
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--mode",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
choices=["replay", "simulation"],
|
|
|
|
|
|
default="simulation",
|
|
|
|
|
|
help="运行模式: replay=专家轨迹回放, simulation=策略仿真 (默认: simulation)"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--data_dir",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default=WAYMO_DATA_DIR,
|
|
|
|
|
|
help=f"数据目录路径 (默认: {WAYMO_DATA_DIR})"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--episodes",
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
default=1,
|
|
|
|
|
|
help="运行回合数 (默认: 1)"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--horizon",
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
default=300,
|
|
|
|
|
|
help="每回合最大步数 (默认: 300,如果启用 --use_scenario_duration 则自动设置)"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--no_render",
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
help="禁用渲染(加速运行)"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--debug",
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
help="启用调试模式(显示详细日志)"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--scenario_id",
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
default=None,
|
|
|
|
|
|
help="指定场景ID(可选,如指定则只运行该场景)"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--use_scenario_duration",
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
help="使用场景原始时长作为horizon(自动停止)"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--no_vehicles",
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
help="禁止生成车辆"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--no_pedestrians",
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
help="禁止生成行人"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--no_cyclists",
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
help="禁止生成自行车"
|
|
|
|
|
|
)
|
2025-09-28 18:57:04 +08:00
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
args = parser.parse_args()
|
2025-09-28 18:57:04 +08:00
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
if args.mode == "replay":
|
|
|
|
|
|
run_replay_mode(
|
|
|
|
|
|
data_dir=args.data_dir,
|
|
|
|
|
|
num_episodes=args.episodes,
|
|
|
|
|
|
horizon=args.horizon,
|
|
|
|
|
|
render=not args.no_render,
|
|
|
|
|
|
debug=args.debug,
|
|
|
|
|
|
scenario_id=args.scenario_id,
|
|
|
|
|
|
use_scenario_duration=args.use_scenario_duration,
|
|
|
|
|
|
spawn_vehicles=not args.no_vehicles,
|
|
|
|
|
|
spawn_pedestrians=not args.no_pedestrians,
|
|
|
|
|
|
spawn_cyclists=not args.no_cyclists
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
run_simulation_mode(
|
|
|
|
|
|
data_dir=args.data_dir,
|
|
|
|
|
|
num_episodes=args.episodes,
|
|
|
|
|
|
horizon=args.horizon,
|
|
|
|
|
|
render=not args.no_render,
|
|
|
|
|
|
debug=args.debug,
|
|
|
|
|
|
scenario_id=args.scenario_id,
|
|
|
|
|
|
use_scenario_duration=args.use_scenario_duration,
|
|
|
|
|
|
spawn_vehicles=not args.no_vehicles,
|
|
|
|
|
|
spawn_pedestrians=not args.no_pedestrians,
|
|
|
|
|
|
spawn_cyclists=not args.no_cyclists
|
|
|
|
|
|
)
|
2025-09-28 18:57:04 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|