Files
MAGAIL4AutoDrive/Env/run_multiagent_env.py

363 lines
13 KiB
Python
Raw Normal View History

import argparse
2025-09-28 18:57:04 +08:00
from scenario_env import MultiAgentScenarioEnv
from simple_idm_policy import ConstantVelocityPolicy
from replay_policy import ReplayPolicy
2025-09-28 18:57:04 +08:00
from metadrive.engine.asset_loader import AssetLoader
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted"
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
scenario_id=None, use_scenario_duration=False,
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
"""
回放模式严格按照专家轨迹回放
Args:
data_dir: 数据目录
num_episodes: 回合数如果指定scenario_id则忽略
horizon: 最大步数如果use_scenario_duration=True则自动设置
render: 是否渲染
debug: 是否调试模式
scenario_id: 指定场景ID可选
use_scenario_duration: 是否使用场景原始时长
spawn_vehicles: 是否生成车辆默认True
spawn_pedestrians: 是否生成行人默认True
spawn_cyclists: 是否生成自行车默认True
"""
print("=" * 50)
print("运行模式: 专家轨迹回放 (Replay Mode)")
if scenario_id is not None:
print(f"指定场景ID: {scenario_id}")
if use_scenario_duration:
print("使用场景原始时长")
print("=" * 50)
# 如果指定了场景ID只运行1个回合
if scenario_id is not None:
num_episodes = 1
# ✅ 环境创建移到循环外面,避免重复创建
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
"is_multi_agent": True,
"horizon": horizon,
"use_render": render,
"sequential_seed": True,
"reactive_traffic": False, # 回放模式下不需要反应式交通
"manual_control": False,
"filter_offroad_vehicles": True, # 启用车道过滤
"lane_tolerance": 3.0,
"replay_mode": True, # 标记为回放模式
"debug": debug,
"specific_scenario_id": scenario_id, # 指定场景ID
"use_scenario_duration": use_scenario_duration, # 使用场景时长
# 对象类型过滤
"spawn_vehicles": spawn_vehicles,
"spawn_pedestrians": spawn_pedestrians,
"spawn_cyclists": spawn_cyclists,
},
agent2policy=None # 回放模式不需要统一策略
)
try:
for episode in range(num_episodes):
print(f"\n{'='*50}")
print(f"回合 {episode + 1}/{num_episodes}")
if scenario_id is not None:
print(f"场景ID: {scenario_id}")
print(f"{'='*50}")
# ✅ 如果不是指定场景使用seed来遍历不同场景
seed = scenario_id if scenario_id is not None else episode
obs = env.reset(seed=seed)
# 为每个车辆分配 ReplayPolicy
replay_policies = {}
for agent_id, vehicle in env.controlled_agents.items():
vehicle_id = vehicle.expert_vehicle_id
if vehicle_id in env.expert_trajectories:
replay_policy = ReplayPolicy(
env.expert_trajectories[vehicle_id],
vehicle_id
)
vehicle.set_policy(replay_policy)
replay_policies[agent_id] = replay_policy
# 输出场景信息
actual_horizon = env.config["horizon"]
print(f"初始化完成:")
print(f" 可控车辆数: {len(env.controlled_agents)}")
print(f" 专家轨迹数: {len(env.expert_trajectories)}")
print(f" 场景时长: {env.scenario_max_duration}")
print(f" 实际Horizon: {actual_horizon}")
step_count = 0
active_vehicles_count = []
while True:
# 在回放模式下,直接使用专家轨迹设置车辆状态
for agent_id, vehicle in list(env.controlled_agents.items()):
vehicle_id = vehicle.expert_vehicle_id
if vehicle_id in env.expert_trajectories and agent_id in replay_policies:
target_state = replay_policies[agent_id].get_target_state(env.round)
if target_state is not None:
# 直接设置车辆状态(绕过物理引擎)
# 只使用xy坐标保持车辆在地面上
position_2d = target_state['position'][:2]
vehicle.set_position(position_2d)
vehicle.set_heading_theta(target_state['heading'])
vehicle.set_velocity(target_state['velocity'][:2] if len(target_state['velocity']) > 2 else target_state['velocity'])
# 使用空动作进行步进
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
obs, rewards, dones, infos = env.step(actions)
if render:
env.render(mode="topdown")
step_count += 1
active_vehicles_count.append(len(env.controlled_agents))
# 每50步打印一次状态
if step_count % 50 == 0:
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
# 调试模式下打印车辆高度信息
if debug and len(env.controlled_agents) > 0:
sample_vehicle = list(env.controlled_agents.values())[0]
z_pos = sample_vehicle.position[2] if len(sample_vehicle.position) > 2 else 0
print(f" [DEBUG] 示例车辆高度: z={z_pos:.3f}m")
if dones["__all__"]:
print(f"\n回合结束统计:")
print(f" 总步数: {step_count}")
print(f" 最大同时车辆数: {max(active_vehicles_count) if active_vehicles_count else 0}")
print(f" 平均车辆数: {sum(active_vehicles_count) / len(active_vehicles_count) if active_vehicles_count else 0:.1f}")
if use_scenario_duration:
print(f" 场景完整回放: {'' if step_count >= env.scenario_max_duration else ''}")
break
finally:
# ✅ 确保环境被正确关闭
env.close()
print("\n" + "=" * 50)
print("回放完成!")
print("=" * 50)
def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
scenario_id=None, use_scenario_duration=False,
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
"""
仿真模式使用自定义策略控制车辆
车辆根据专家数据的初始位姿生成然后由策略控制
Args:
data_dir: 数据目录
num_episodes: 回合数
horizon: 最大步数
render: 是否渲染
debug: 是否调试模式
scenario_id: 指定场景ID可选
use_scenario_duration: 是否使用场景原始时长
spawn_vehicles: 是否生成车辆默认True
spawn_pedestrians: 是否生成行人默认True
spawn_cyclists: 是否生成自行车默认True
"""
print("=" * 50)
print("运行模式: 策略仿真 (Simulation Mode)")
if scenario_id is not None:
print(f"指定场景ID: {scenario_id}")
if use_scenario_duration:
print("使用场景原始时长")
print("=" * 50)
# 如果指定了场景ID只运行1个回合
if scenario_id is not None:
num_episodes = 1
2025-09-28 18:57:04 +08:00
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
2025-09-28 18:57:04 +08:00
"is_multi_agent": True,
"num_controlled_agents": 3,
"horizon": horizon,
"use_render": render,
2025-09-28 18:57:04 +08:00
"sequential_seed": True,
"reactive_traffic": True,
"manual_control": False,
"filter_offroad_vehicles": True, # 启用车道过滤
"lane_tolerance": 3.0,
"replay_mode": False, # 仿真模式
"debug": debug,
"specific_scenario_id": scenario_id, # 指定场景ID
"use_scenario_duration": use_scenario_duration, # 使用场景时长
# 对象类型过滤
"spawn_vehicles": spawn_vehicles,
"spawn_pedestrians": spawn_pedestrians,
"spawn_cyclists": spawn_cyclists,
2025-09-28 18:57:04 +08:00
},
agent2policy=ConstantVelocityPolicy(target_speed=50)
)
try:
for episode in range(num_episodes):
print(f"\n{'='*50}")
print(f"回合 {episode + 1}/{num_episodes}")
if scenario_id is not None:
print(f"场景ID: {scenario_id}")
print(f"{'='*50}")
seed = scenario_id if scenario_id is not None else episode
obs = env.reset(seed=seed)
actual_horizon = env.config["horizon"]
print(f"初始化完成:")
print(f" 可控车辆数: {len(env.controlled_agents)}")
print(f" 场景时长: {env.scenario_max_duration}")
print(f" 实际Horizon: {actual_horizon}")
2025-09-28 18:57:04 +08:00
step_count = 0
total_reward = 0.0
while True:
# 使用策略生成动作
actions = {
aid: env.controlled_agents[aid].policy.act()
for aid in env.controlled_agents
}
obs, rewards, dones, infos = env.step(actions)
if render:
env.render(mode="topdown")
step_count += 1
total_reward += sum(rewards.values())
# 每50步打印一次状态
if step_count % 50 == 0:
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
if dones["__all__"]:
print(f"\n回合结束统计:")
print(f" 总步数: {step_count}")
print(f" 总奖励: {total_reward:.2f}")
break
finally:
env.close()
print("\n" + "=" * 50)
print("仿真完成!")
print("=" * 50)
def main():
parser = argparse.ArgumentParser(description="MetaDrive 多智能体环境运行脚本")
parser.add_argument(
"--mode",
type=str,
choices=["replay", "simulation"],
default="simulation",
help="运行模式: replay=专家轨迹回放, simulation=策略仿真 (默认: simulation)"
)
parser.add_argument(
"--data_dir",
type=str,
default=WAYMO_DATA_DIR,
help=f"数据目录路径 (默认: {WAYMO_DATA_DIR})"
)
parser.add_argument(
"--episodes",
type=int,
default=1,
help="运行回合数 (默认: 1)"
)
parser.add_argument(
"--horizon",
type=int,
default=300,
help="每回合最大步数 (默认: 300如果启用 --use_scenario_duration 则自动设置)"
)
parser.add_argument(
"--no_render",
action="store_true",
help="禁用渲染(加速运行)"
)
parser.add_argument(
"--debug",
action="store_true",
help="启用调试模式(显示详细日志)"
)
parser.add_argument(
"--scenario_id",
type=int,
default=None,
help="指定场景ID可选如指定则只运行该场景"
)
parser.add_argument(
"--use_scenario_duration",
action="store_true",
help="使用场景原始时长作为horizon自动停止"
)
parser.add_argument(
"--no_vehicles",
action="store_true",
help="禁止生成车辆"
)
parser.add_argument(
"--no_pedestrians",
action="store_true",
help="禁止生成行人"
)
parser.add_argument(
"--no_cyclists",
action="store_true",
help="禁止生成自行车"
)
2025-09-28 18:57:04 +08:00
args = parser.parse_args()
2025-09-28 18:57:04 +08:00
if args.mode == "replay":
run_replay_mode(
data_dir=args.data_dir,
num_episodes=args.episodes,
horizon=args.horizon,
render=not args.no_render,
debug=args.debug,
scenario_id=args.scenario_id,
use_scenario_duration=args.use_scenario_duration,
spawn_vehicles=not args.no_vehicles,
spawn_pedestrians=not args.no_pedestrians,
spawn_cyclists=not args.no_cyclists
)
else:
run_simulation_mode(
data_dir=args.data_dir,
num_episodes=args.episodes,
horizon=args.horizon,
render=not args.no_render,
debug=args.debug,
scenario_id=args.scenario_id,
use_scenario_duration=args.use_scenario_duration,
spawn_vehicles=not args.no_vehicles,
spawn_pedestrians=not args.no_pedestrians,
spawn_cyclists=not args.no_cyclists
)
2025-09-28 18:57:04 +08:00
if __name__ == "__main__":
main()