Files
MAGAIL4AutoDrive/Env/run_multiagent_env.py

363 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
from scenario_env import MultiAgentScenarioEnv
from simple_idm_policy import ConstantVelocityPolicy
from replay_policy import ReplayPolicy
from metadrive.engine.asset_loader import AssetLoader
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted"
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
scenario_id=None, use_scenario_duration=False,
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
"""
回放模式:严格按照专家轨迹回放
Args:
data_dir: 数据目录
num_episodes: 回合数如果指定scenario_id则忽略
horizon: 最大步数如果use_scenario_duration=True则自动设置
render: 是否渲染
debug: 是否调试模式
scenario_id: 指定场景ID可选
use_scenario_duration: 是否使用场景原始时长
spawn_vehicles: 是否生成车辆默认True
spawn_pedestrians: 是否生成行人默认True
spawn_cyclists: 是否生成自行车默认True
"""
print("=" * 50)
print("运行模式: 专家轨迹回放 (Replay Mode)")
if scenario_id is not None:
print(f"指定场景ID: {scenario_id}")
if use_scenario_duration:
print("使用场景原始时长")
print("=" * 50)
# 如果指定了场景ID只运行1个回合
if scenario_id is not None:
num_episodes = 1
# ✅ 环境创建移到循环外面,避免重复创建
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
"is_multi_agent": True,
"horizon": horizon,
"use_render": render,
"sequential_seed": True,
"reactive_traffic": False, # 回放模式下不需要反应式交通
"manual_control": False,
"filter_offroad_vehicles": True, # 启用车道过滤
"lane_tolerance": 3.0,
"replay_mode": True, # 标记为回放模式
"debug": debug,
"specific_scenario_id": scenario_id, # 指定场景ID
"use_scenario_duration": use_scenario_duration, # 使用场景时长
# 对象类型过滤
"spawn_vehicles": spawn_vehicles,
"spawn_pedestrians": spawn_pedestrians,
"spawn_cyclists": spawn_cyclists,
},
agent2policy=None # 回放模式不需要统一策略
)
try:
for episode in range(num_episodes):
print(f"\n{'='*50}")
print(f"回合 {episode + 1}/{num_episodes}")
if scenario_id is not None:
print(f"场景ID: {scenario_id}")
print(f"{'='*50}")
# ✅ 如果不是指定场景使用seed来遍历不同场景
seed = scenario_id if scenario_id is not None else episode
obs = env.reset(seed=seed)
# 为每个车辆分配 ReplayPolicy
replay_policies = {}
for agent_id, vehicle in env.controlled_agents.items():
vehicle_id = vehicle.expert_vehicle_id
if vehicle_id in env.expert_trajectories:
replay_policy = ReplayPolicy(
env.expert_trajectories[vehicle_id],
vehicle_id
)
vehicle.set_policy(replay_policy)
replay_policies[agent_id] = replay_policy
# 输出场景信息
actual_horizon = env.config["horizon"]
print(f"初始化完成:")
print(f" 可控车辆数: {len(env.controlled_agents)}")
print(f" 专家轨迹数: {len(env.expert_trajectories)}")
print(f" 场景时长: {env.scenario_max_duration}")
print(f" 实际Horizon: {actual_horizon}")
step_count = 0
active_vehicles_count = []
while True:
# 在回放模式下,直接使用专家轨迹设置车辆状态
for agent_id, vehicle in list(env.controlled_agents.items()):
vehicle_id = vehicle.expert_vehicle_id
if vehicle_id in env.expert_trajectories and agent_id in replay_policies:
target_state = replay_policies[agent_id].get_target_state(env.round)
if target_state is not None:
# 直接设置车辆状态(绕过物理引擎)
# 只使用xy坐标保持车辆在地面上
position_2d = target_state['position'][:2]
vehicle.set_position(position_2d)
vehicle.set_heading_theta(target_state['heading'])
vehicle.set_velocity(target_state['velocity'][:2] if len(target_state['velocity']) > 2 else target_state['velocity'])
# 使用空动作进行步进
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
obs, rewards, dones, infos = env.step(actions)
if render:
env.render(mode="topdown")
step_count += 1
active_vehicles_count.append(len(env.controlled_agents))
# 每50步打印一次状态
if step_count % 50 == 0:
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
# 调试模式下打印车辆高度信息
if debug and len(env.controlled_agents) > 0:
sample_vehicle = list(env.controlled_agents.values())[0]
z_pos = sample_vehicle.position[2] if len(sample_vehicle.position) > 2 else 0
print(f" [DEBUG] 示例车辆高度: z={z_pos:.3f}m")
if dones["__all__"]:
print(f"\n回合结束统计:")
print(f" 总步数: {step_count}")
print(f" 最大同时车辆数: {max(active_vehicles_count) if active_vehicles_count else 0}")
print(f" 平均车辆数: {sum(active_vehicles_count) / len(active_vehicles_count) if active_vehicles_count else 0:.1f}")
if use_scenario_duration:
print(f" 场景完整回放: {'' if step_count >= env.scenario_max_duration else ''}")
break
finally:
# ✅ 确保环境被正确关闭
env.close()
print("\n" + "=" * 50)
print("回放完成!")
print("=" * 50)
def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
scenario_id=None, use_scenario_duration=False,
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
"""
仿真模式:使用自定义策略控制车辆
车辆根据专家数据的初始位姿生成,然后由策略控制
Args:
data_dir: 数据目录
num_episodes: 回合数
horizon: 最大步数
render: 是否渲染
debug: 是否调试模式
scenario_id: 指定场景ID可选
use_scenario_duration: 是否使用场景原始时长
spawn_vehicles: 是否生成车辆默认True
spawn_pedestrians: 是否生成行人默认True
spawn_cyclists: 是否生成自行车默认True
"""
print("=" * 50)
print("运行模式: 策略仿真 (Simulation Mode)")
if scenario_id is not None:
print(f"指定场景ID: {scenario_id}")
if use_scenario_duration:
print("使用场景原始时长")
print("=" * 50)
# 如果指定了场景ID只运行1个回合
if scenario_id is not None:
num_episodes = 1
env = MultiAgentScenarioEnv(
config={
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
"is_multi_agent": True,
"num_controlled_agents": 3,
"horizon": horizon,
"use_render": render,
"sequential_seed": True,
"reactive_traffic": True,
"manual_control": False,
"filter_offroad_vehicles": True, # 启用车道过滤
"lane_tolerance": 3.0,
"replay_mode": False, # 仿真模式
"debug": debug,
"specific_scenario_id": scenario_id, # 指定场景ID
"use_scenario_duration": use_scenario_duration, # 使用场景时长
# 对象类型过滤
"spawn_vehicles": spawn_vehicles,
"spawn_pedestrians": spawn_pedestrians,
"spawn_cyclists": spawn_cyclists,
},
agent2policy=ConstantVelocityPolicy(target_speed=50)
)
try:
for episode in range(num_episodes):
print(f"\n{'='*50}")
print(f"回合 {episode + 1}/{num_episodes}")
if scenario_id is not None:
print(f"场景ID: {scenario_id}")
print(f"{'='*50}")
seed = scenario_id if scenario_id is not None else episode
obs = env.reset(seed=seed)
actual_horizon = env.config["horizon"]
print(f"初始化完成:")
print(f" 可控车辆数: {len(env.controlled_agents)}")
print(f" 场景时长: {env.scenario_max_duration}")
print(f" 实际Horizon: {actual_horizon}")
step_count = 0
total_reward = 0.0
while True:
# 使用策略生成动作
actions = {
aid: env.controlled_agents[aid].policy.act()
for aid in env.controlled_agents
}
obs, rewards, dones, infos = env.step(actions)
if render:
env.render(mode="topdown")
step_count += 1
total_reward += sum(rewards.values())
# 每50步打印一次状态
if step_count % 50 == 0:
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
if dones["__all__"]:
print(f"\n回合结束统计:")
print(f" 总步数: {step_count}")
print(f" 总奖励: {total_reward:.2f}")
break
finally:
env.close()
print("\n" + "=" * 50)
print("仿真完成!")
print("=" * 50)
def main():
parser = argparse.ArgumentParser(description="MetaDrive 多智能体环境运行脚本")
parser.add_argument(
"--mode",
type=str,
choices=["replay", "simulation"],
default="simulation",
help="运行模式: replay=专家轨迹回放, simulation=策略仿真 (默认: simulation)"
)
parser.add_argument(
"--data_dir",
type=str,
default=WAYMO_DATA_DIR,
help=f"数据目录路径 (默认: {WAYMO_DATA_DIR})"
)
parser.add_argument(
"--episodes",
type=int,
default=1,
help="运行回合数 (默认: 1)"
)
parser.add_argument(
"--horizon",
type=int,
default=300,
help="每回合最大步数 (默认: 300如果启用 --use_scenario_duration 则自动设置)"
)
parser.add_argument(
"--no_render",
action="store_true",
help="禁用渲染(加速运行)"
)
parser.add_argument(
"--debug",
action="store_true",
help="启用调试模式(显示详细日志)"
)
parser.add_argument(
"--scenario_id",
type=int,
default=None,
help="指定场景ID可选如指定则只运行该场景"
)
parser.add_argument(
"--use_scenario_duration",
action="store_true",
help="使用场景原始时长作为horizon自动停止"
)
parser.add_argument(
"--no_vehicles",
action="store_true",
help="禁止生成车辆"
)
parser.add_argument(
"--no_pedestrians",
action="store_true",
help="禁止生成行人"
)
parser.add_argument(
"--no_cyclists",
action="store_true",
help="禁止生成自行车"
)
args = parser.parse_args()
if args.mode == "replay":
run_replay_mode(
data_dir=args.data_dir,
num_episodes=args.episodes,
horizon=args.horizon,
render=not args.no_render,
debug=args.debug,
scenario_id=args.scenario_id,
use_scenario_duration=args.use_scenario_duration,
spawn_vehicles=not args.no_vehicles,
spawn_pedestrians=not args.no_pedestrians,
spawn_cyclists=not args.no_cyclists
)
else:
run_simulation_mode(
data_dir=args.data_dir,
num_episodes=args.episodes,
horizon=args.horizon,
render=not args.no_render,
debug=args.debug,
scenario_id=args.scenario_id,
use_scenario_duration=args.use_scenario_duration,
spawn_vehicles=not args.no_vehicles,
spawn_pedestrians=not args.no_pedestrians,
spawn_cyclists=not args.no_cyclists
)
if __name__ == "__main__":
main()