363 lines
13 KiB
Python
363 lines
13 KiB
Python
import argparse
|
||
from scenario_env import MultiAgentScenarioEnv
|
||
from simple_idm_policy import ConstantVelocityPolicy
|
||
from replay_policy import ReplayPolicy
|
||
from metadrive.engine.asset_loader import AssetLoader
|
||
|
||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted"
|
||
|
||
|
||
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||
scenario_id=None, use_scenario_duration=False,
|
||
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
||
"""
|
||
回放模式:严格按照专家轨迹回放
|
||
|
||
Args:
|
||
data_dir: 数据目录
|
||
num_episodes: 回合数(如果指定scenario_id,则忽略)
|
||
horizon: 最大步数(如果use_scenario_duration=True,则自动设置)
|
||
render: 是否渲染
|
||
debug: 是否调试模式
|
||
scenario_id: 指定场景ID(可选)
|
||
use_scenario_duration: 是否使用场景原始时长
|
||
spawn_vehicles: 是否生成车辆(默认True)
|
||
spawn_pedestrians: 是否生成行人(默认True)
|
||
spawn_cyclists: 是否生成自行车(默认True)
|
||
"""
|
||
print("=" * 50)
|
||
print("运行模式: 专家轨迹回放 (Replay Mode)")
|
||
if scenario_id is not None:
|
||
print(f"指定场景ID: {scenario_id}")
|
||
if use_scenario_duration:
|
||
print("使用场景原始时长")
|
||
print("=" * 50)
|
||
|
||
# 如果指定了场景ID,只运行1个回合
|
||
if scenario_id is not None:
|
||
num_episodes = 1
|
||
|
||
# ✅ 环境创建移到循环外面,避免重复创建
|
||
env = MultiAgentScenarioEnv(
|
||
config={
|
||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
||
"is_multi_agent": True,
|
||
"horizon": horizon,
|
||
"use_render": render,
|
||
"sequential_seed": True,
|
||
"reactive_traffic": False, # 回放模式下不需要反应式交通
|
||
"manual_control": False,
|
||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||
"lane_tolerance": 3.0,
|
||
"replay_mode": True, # 标记为回放模式
|
||
"debug": debug,
|
||
"specific_scenario_id": scenario_id, # 指定场景ID
|
||
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
||
# 对象类型过滤
|
||
"spawn_vehicles": spawn_vehicles,
|
||
"spawn_pedestrians": spawn_pedestrians,
|
||
"spawn_cyclists": spawn_cyclists,
|
||
},
|
||
agent2policy=None # 回放模式不需要统一策略
|
||
)
|
||
|
||
try:
|
||
for episode in range(num_episodes):
|
||
print(f"\n{'='*50}")
|
||
print(f"回合 {episode + 1}/{num_episodes}")
|
||
if scenario_id is not None:
|
||
print(f"场景ID: {scenario_id}")
|
||
print(f"{'='*50}")
|
||
|
||
# ✅ 如果不是指定场景,使用seed来遍历不同场景
|
||
seed = scenario_id if scenario_id is not None else episode
|
||
obs = env.reset(seed=seed)
|
||
|
||
# 为每个车辆分配 ReplayPolicy
|
||
replay_policies = {}
|
||
for agent_id, vehicle in env.controlled_agents.items():
|
||
vehicle_id = vehicle.expert_vehicle_id
|
||
if vehicle_id in env.expert_trajectories:
|
||
replay_policy = ReplayPolicy(
|
||
env.expert_trajectories[vehicle_id],
|
||
vehicle_id
|
||
)
|
||
vehicle.set_policy(replay_policy)
|
||
replay_policies[agent_id] = replay_policy
|
||
|
||
# 输出场景信息
|
||
actual_horizon = env.config["horizon"]
|
||
print(f"初始化完成:")
|
||
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
||
print(f" 专家轨迹数: {len(env.expert_trajectories)}")
|
||
print(f" 场景时长: {env.scenario_max_duration} 步")
|
||
print(f" 实际Horizon: {actual_horizon} 步")
|
||
|
||
step_count = 0
|
||
active_vehicles_count = []
|
||
|
||
while True:
|
||
# 在回放模式下,直接使用专家轨迹设置车辆状态
|
||
for agent_id, vehicle in list(env.controlled_agents.items()):
|
||
vehicle_id = vehicle.expert_vehicle_id
|
||
if vehicle_id in env.expert_trajectories and agent_id in replay_policies:
|
||
target_state = replay_policies[agent_id].get_target_state(env.round)
|
||
if target_state is not None:
|
||
# 直接设置车辆状态(绕过物理引擎)
|
||
# 只使用xy坐标,保持车辆在地面上
|
||
position_2d = target_state['position'][:2]
|
||
vehicle.set_position(position_2d)
|
||
vehicle.set_heading_theta(target_state['heading'])
|
||
vehicle.set_velocity(target_state['velocity'][:2] if len(target_state['velocity']) > 2 else target_state['velocity'])
|
||
|
||
# 使用空动作进行步进
|
||
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
|
||
obs, rewards, dones, infos = env.step(actions)
|
||
|
||
if render:
|
||
env.render(mode="topdown")
|
||
|
||
step_count += 1
|
||
active_vehicles_count.append(len(env.controlled_agents))
|
||
|
||
# 每50步打印一次状态
|
||
if step_count % 50 == 0:
|
||
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
||
|
||
# 调试模式下打印车辆高度信息
|
||
if debug and len(env.controlled_agents) > 0:
|
||
sample_vehicle = list(env.controlled_agents.values())[0]
|
||
z_pos = sample_vehicle.position[2] if len(sample_vehicle.position) > 2 else 0
|
||
print(f" [DEBUG] 示例车辆高度: z={z_pos:.3f}m")
|
||
|
||
if dones["__all__"]:
|
||
print(f"\n回合结束统计:")
|
||
print(f" 总步数: {step_count}")
|
||
print(f" 最大同时车辆数: {max(active_vehicles_count) if active_vehicles_count else 0}")
|
||
print(f" 平均车辆数: {sum(active_vehicles_count) / len(active_vehicles_count) if active_vehicles_count else 0:.1f}")
|
||
if use_scenario_duration:
|
||
print(f" 场景完整回放: {'是' if step_count >= env.scenario_max_duration else '否'}")
|
||
break
|
||
finally:
|
||
# ✅ 确保环境被正确关闭
|
||
env.close()
|
||
|
||
print("\n" + "=" * 50)
|
||
print("回放完成!")
|
||
print("=" * 50)
|
||
|
||
|
||
def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||
scenario_id=None, use_scenario_duration=False,
|
||
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
||
"""
|
||
仿真模式:使用自定义策略控制车辆
|
||
车辆根据专家数据的初始位姿生成,然后由策略控制
|
||
|
||
Args:
|
||
data_dir: 数据目录
|
||
num_episodes: 回合数
|
||
horizon: 最大步数
|
||
render: 是否渲染
|
||
debug: 是否调试模式
|
||
scenario_id: 指定场景ID(可选)
|
||
use_scenario_duration: 是否使用场景原始时长
|
||
spawn_vehicles: 是否生成车辆(默认True)
|
||
spawn_pedestrians: 是否生成行人(默认True)
|
||
spawn_cyclists: 是否生成自行车(默认True)
|
||
"""
|
||
print("=" * 50)
|
||
print("运行模式: 策略仿真 (Simulation Mode)")
|
||
if scenario_id is not None:
|
||
print(f"指定场景ID: {scenario_id}")
|
||
if use_scenario_duration:
|
||
print("使用场景原始时长")
|
||
print("=" * 50)
|
||
|
||
# 如果指定了场景ID,只运行1个回合
|
||
if scenario_id is not None:
|
||
num_episodes = 1
|
||
|
||
env = MultiAgentScenarioEnv(
|
||
config={
|
||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
||
"is_multi_agent": True,
|
||
"num_controlled_agents": 3,
|
||
"horizon": horizon,
|
||
"use_render": render,
|
||
"sequential_seed": True,
|
||
"reactive_traffic": True,
|
||
"manual_control": False,
|
||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||
"lane_tolerance": 3.0,
|
||
"replay_mode": False, # 仿真模式
|
||
"debug": debug,
|
||
"specific_scenario_id": scenario_id, # 指定场景ID
|
||
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
||
# 对象类型过滤
|
||
"spawn_vehicles": spawn_vehicles,
|
||
"spawn_pedestrians": spawn_pedestrians,
|
||
"spawn_cyclists": spawn_cyclists,
|
||
},
|
||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||
)
|
||
|
||
try:
|
||
for episode in range(num_episodes):
|
||
print(f"\n{'='*50}")
|
||
print(f"回合 {episode + 1}/{num_episodes}")
|
||
if scenario_id is not None:
|
||
print(f"场景ID: {scenario_id}")
|
||
print(f"{'='*50}")
|
||
|
||
seed = scenario_id if scenario_id is not None else episode
|
||
obs = env.reset(seed=seed)
|
||
|
||
actual_horizon = env.config["horizon"]
|
||
print(f"初始化完成:")
|
||
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
||
print(f" 场景时长: {env.scenario_max_duration} 步")
|
||
print(f" 实际Horizon: {actual_horizon} 步")
|
||
|
||
step_count = 0
|
||
total_reward = 0.0
|
||
|
||
while True:
|
||
# 使用策略生成动作
|
||
actions = {
|
||
aid: env.controlled_agents[aid].policy.act()
|
||
for aid in env.controlled_agents
|
||
}
|
||
|
||
obs, rewards, dones, infos = env.step(actions)
|
||
|
||
if render:
|
||
env.render(mode="topdown")
|
||
|
||
step_count += 1
|
||
total_reward += sum(rewards.values())
|
||
|
||
# 每50步打印一次状态
|
||
if step_count % 50 == 0:
|
||
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
||
|
||
if dones["__all__"]:
|
||
print(f"\n回合结束统计:")
|
||
print(f" 总步数: {step_count}")
|
||
print(f" 总奖励: {total_reward:.2f}")
|
||
break
|
||
finally:
|
||
env.close()
|
||
|
||
print("\n" + "=" * 50)
|
||
print("仿真完成!")
|
||
print("=" * 50)
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="MetaDrive 多智能体环境运行脚本")
|
||
|
||
parser.add_argument(
|
||
"--mode",
|
||
type=str,
|
||
choices=["replay", "simulation"],
|
||
default="simulation",
|
||
help="运行模式: replay=专家轨迹回放, simulation=策略仿真 (默认: simulation)"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--data_dir",
|
||
type=str,
|
||
default=WAYMO_DATA_DIR,
|
||
help=f"数据目录路径 (默认: {WAYMO_DATA_DIR})"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--episodes",
|
||
type=int,
|
||
default=1,
|
||
help="运行回合数 (默认: 1)"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--horizon",
|
||
type=int,
|
||
default=300,
|
||
help="每回合最大步数 (默认: 300,如果启用 --use_scenario_duration 则自动设置)"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--no_render",
|
||
action="store_true",
|
||
help="禁用渲染(加速运行)"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--debug",
|
||
action="store_true",
|
||
help="启用调试模式(显示详细日志)"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--scenario_id",
|
||
type=int,
|
||
default=None,
|
||
help="指定场景ID(可选,如指定则只运行该场景)"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--use_scenario_duration",
|
||
action="store_true",
|
||
help="使用场景原始时长作为horizon(自动停止)"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--no_vehicles",
|
||
action="store_true",
|
||
help="禁止生成车辆"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--no_pedestrians",
|
||
action="store_true",
|
||
help="禁止生成行人"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--no_cyclists",
|
||
action="store_true",
|
||
help="禁止生成自行车"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.mode == "replay":
|
||
run_replay_mode(
|
||
data_dir=args.data_dir,
|
||
num_episodes=args.episodes,
|
||
horizon=args.horizon,
|
||
render=not args.no_render,
|
||
debug=args.debug,
|
||
scenario_id=args.scenario_id,
|
||
use_scenario_duration=args.use_scenario_duration,
|
||
spawn_vehicles=not args.no_vehicles,
|
||
spawn_pedestrians=not args.no_pedestrians,
|
||
spawn_cyclists=not args.no_cyclists
|
||
)
|
||
else:
|
||
run_simulation_mode(
|
||
data_dir=args.data_dir,
|
||
num_episodes=args.episodes,
|
||
horizon=args.horizon,
|
||
render=not args.no_render,
|
||
debug=args.debug,
|
||
scenario_id=args.scenario_id,
|
||
use_scenario_duration=args.use_scenario_duration,
|
||
spawn_vehicles=not args.no_vehicles,
|
||
spawn_pedestrians=not args.no_pedestrians,
|
||
spawn_cyclists=not args.no_cyclists
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |