完成回放模式与仿真模式,过滤非车道生成车辆,增加对于行人自行车的过滤功能
This commit is contained in:
BIN
Env/__pycache__/replay_policy.cpython-310.pyc
Normal file
BIN
Env/__pycache__/replay_policy.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Env/__pycache__/scenario_env.cpython-310.pyc
Normal file
BIN
Env/__pycache__/scenario_env.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Env/__pycache__/scenario_env.cpython-313.pyc
Normal file
BIN
Env/__pycache__/scenario_env.cpython-313.pyc
Normal file
Binary file not shown.
BIN
Env/__pycache__/simple_idm_policy.cpython-310.pyc
Normal file
BIN
Env/__pycache__/simple_idm_policy.cpython-310.pyc
Normal file
Binary file not shown.
62
Env/replay_policy.py
Normal file
62
Env/replay_policy.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
|
||||
class ReplayPolicy:
|
||||
"""
|
||||
严格回放策略:根据专家轨迹数据,逐帧回放车辆状态
|
||||
"""
|
||||
|
||||
def __init__(self, expert_trajectory, vehicle_id):
|
||||
"""
|
||||
Args:
|
||||
expert_trajectory: 专家轨迹字典,包含 positions, headings, velocities, valid
|
||||
vehicle_id: 车辆ID(用于调试)
|
||||
"""
|
||||
self.trajectory = expert_trajectory
|
||||
self.vehicle_id = vehicle_id
|
||||
self.current_step = 0
|
||||
|
||||
def act(self, observation=None):
|
||||
"""
|
||||
返回动作:在回放模式下返回空动作
|
||||
实际状态由环境直接设置
|
||||
"""
|
||||
return [0.0, 0.0]
|
||||
|
||||
def get_target_state(self, step):
|
||||
"""
|
||||
获取指定时间步的目标状态
|
||||
|
||||
Args:
|
||||
step: 时间步
|
||||
|
||||
Returns:
|
||||
dict: 包含 position, heading, velocity 的字典,如果无效则返回 None
|
||||
"""
|
||||
if step >= len(self.trajectory['valid']):
|
||||
return None
|
||||
|
||||
if not self.trajectory['valid'][step]:
|
||||
return None
|
||||
|
||||
return {
|
||||
'position': self.trajectory['positions'][step],
|
||||
'heading': self.trajectory['headings'][step],
|
||||
'velocity': self.trajectory['velocities'][step]
|
||||
}
|
||||
|
||||
def is_finished(self, step):
|
||||
"""
|
||||
判断轨迹是否已经播放完毕
|
||||
|
||||
Args:
|
||||
step: 当前时间步
|
||||
|
||||
Returns:
|
||||
bool: 如果轨迹已播放完或当前步无效,返回 True
|
||||
"""
|
||||
# 超出轨迹长度
|
||||
if step >= len(self.trajectory['valid']):
|
||||
return True
|
||||
|
||||
# 当前步及之后都无效
|
||||
return not any(self.trajectory['valid'][step:])
|
||||
@@ -1,40 +1,362 @@
|
||||
import argparse
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from Env.simple_idm_policy import ConstantVelocityPolicy
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from replay_policy import ReplayPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/zhy/桌面/MAGAIL_TR/Env"
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted"
|
||||
|
||||
def main():
|
||||
|
||||
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||||
scenario_id=None, use_scenario_duration=False,
|
||||
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
||||
"""
|
||||
回放模式:严格按照专家轨迹回放
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录
|
||||
num_episodes: 回合数(如果指定scenario_id,则忽略)
|
||||
horizon: 最大步数(如果use_scenario_duration=True,则自动设置)
|
||||
render: 是否渲染
|
||||
debug: 是否调试模式
|
||||
scenario_id: 指定场景ID(可选)
|
||||
use_scenario_duration: 是否使用场景原始时长
|
||||
spawn_vehicles: 是否生成车辆(默认True)
|
||||
spawn_pedestrians: 是否生成行人(默认True)
|
||||
spawn_cyclists: 是否生成自行车(默认True)
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("运行模式: 专家轨迹回放 (Replay Mode)")
|
||||
if scenario_id is not None:
|
||||
print(f"指定场景ID: {scenario_id}")
|
||||
if use_scenario_duration:
|
||||
print("使用场景原始时长")
|
||||
print("=" * 50)
|
||||
|
||||
# 如果指定了场景ID,只运行1个回合
|
||||
if scenario_id is not None:
|
||||
num_episodes = 1
|
||||
|
||||
# ✅ 环境创建移到循环外面,避免重复创建
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
# "data_directory": AssetLoader.file_path(AssetLoader.asset_path, "waymo", unix_style=False),
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"horizon": horizon,
|
||||
"use_render": render,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": False, # 回放模式下不需要反应式交通
|
||||
"manual_control": False,
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0,
|
||||
"replay_mode": True, # 标记为回放模式
|
||||
"debug": debug,
|
||||
"specific_scenario_id": scenario_id, # 指定场景ID
|
||||
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
||||
# 对象类型过滤
|
||||
"spawn_vehicles": spawn_vehicles,
|
||||
"spawn_pedestrians": spawn_pedestrians,
|
||||
"spawn_cyclists": spawn_cyclists,
|
||||
},
|
||||
agent2policy=None # 回放模式不需要统一策略
|
||||
)
|
||||
|
||||
try:
|
||||
for episode in range(num_episodes):
|
||||
print(f"\n{'='*50}")
|
||||
print(f"回合 {episode + 1}/{num_episodes}")
|
||||
if scenario_id is not None:
|
||||
print(f"场景ID: {scenario_id}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
# ✅ 如果不是指定场景,使用seed来遍历不同场景
|
||||
seed = scenario_id if scenario_id is not None else episode
|
||||
obs = env.reset(seed=seed)
|
||||
|
||||
# 为每个车辆分配 ReplayPolicy
|
||||
replay_policies = {}
|
||||
for agent_id, vehicle in env.controlled_agents.items():
|
||||
vehicle_id = vehicle.expert_vehicle_id
|
||||
if vehicle_id in env.expert_trajectories:
|
||||
replay_policy = ReplayPolicy(
|
||||
env.expert_trajectories[vehicle_id],
|
||||
vehicle_id
|
||||
)
|
||||
vehicle.set_policy(replay_policy)
|
||||
replay_policies[agent_id] = replay_policy
|
||||
|
||||
# 输出场景信息
|
||||
actual_horizon = env.config["horizon"]
|
||||
print(f"初始化完成:")
|
||||
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
||||
print(f" 专家轨迹数: {len(env.expert_trajectories)}")
|
||||
print(f" 场景时长: {env.scenario_max_duration} 步")
|
||||
print(f" 实际Horizon: {actual_horizon} 步")
|
||||
|
||||
step_count = 0
|
||||
active_vehicles_count = []
|
||||
|
||||
while True:
|
||||
# 在回放模式下,直接使用专家轨迹设置车辆状态
|
||||
for agent_id, vehicle in list(env.controlled_agents.items()):
|
||||
vehicle_id = vehicle.expert_vehicle_id
|
||||
if vehicle_id in env.expert_trajectories and agent_id in replay_policies:
|
||||
target_state = replay_policies[agent_id].get_target_state(env.round)
|
||||
if target_state is not None:
|
||||
# 直接设置车辆状态(绕过物理引擎)
|
||||
# 只使用xy坐标,保持车辆在地面上
|
||||
position_2d = target_state['position'][:2]
|
||||
vehicle.set_position(position_2d)
|
||||
vehicle.set_heading_theta(target_state['heading'])
|
||||
vehicle.set_velocity(target_state['velocity'][:2] if len(target_state['velocity']) > 2 else target_state['velocity'])
|
||||
|
||||
# 使用空动作进行步进
|
||||
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
if render:
|
||||
env.render(mode="topdown")
|
||||
|
||||
step_count += 1
|
||||
active_vehicles_count.append(len(env.controlled_agents))
|
||||
|
||||
# 每50步打印一次状态
|
||||
if step_count % 50 == 0:
|
||||
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
||||
|
||||
# 调试模式下打印车辆高度信息
|
||||
if debug and len(env.controlled_agents) > 0:
|
||||
sample_vehicle = list(env.controlled_agents.values())[0]
|
||||
z_pos = sample_vehicle.position[2] if len(sample_vehicle.position) > 2 else 0
|
||||
print(f" [DEBUG] 示例车辆高度: z={z_pos:.3f}m")
|
||||
|
||||
if dones["__all__"]:
|
||||
print(f"\n回合结束统计:")
|
||||
print(f" 总步数: {step_count}")
|
||||
print(f" 最大同时车辆数: {max(active_vehicles_count) if active_vehicles_count else 0}")
|
||||
print(f" 平均车辆数: {sum(active_vehicles_count) / len(active_vehicles_count) if active_vehicles_count else 0:.1f}")
|
||||
if use_scenario_duration:
|
||||
print(f" 场景完整回放: {'是' if step_count >= env.scenario_max_duration else '否'}")
|
||||
break
|
||||
finally:
|
||||
# ✅ 确保环境被正确关闭
|
||||
env.close()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("回放完成!")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||||
scenario_id=None, use_scenario_duration=False,
|
||||
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
||||
"""
|
||||
仿真模式:使用自定义策略控制车辆
|
||||
车辆根据专家数据的初始位姿生成,然后由策略控制
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录
|
||||
num_episodes: 回合数
|
||||
horizon: 最大步数
|
||||
render: 是否渲染
|
||||
debug: 是否调试模式
|
||||
scenario_id: 指定场景ID(可选)
|
||||
use_scenario_duration: 是否使用场景原始时长
|
||||
spawn_vehicles: 是否生成车辆(默认True)
|
||||
spawn_pedestrians: 是否生成行人(默认True)
|
||||
spawn_cyclists: 是否生成自行车(默认True)
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("运行模式: 策略仿真 (Simulation Mode)")
|
||||
if scenario_id is not None:
|
||||
print(f"指定场景ID: {scenario_id}")
|
||||
if use_scenario_duration:
|
||||
print("使用场景原始时长")
|
||||
print("=" * 50)
|
||||
|
||||
# 如果指定了场景ID,只运行1个回合
|
||||
if scenario_id is not None:
|
||||
num_episodes = 1
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 300,
|
||||
"use_render": True,
|
||||
"horizon": horizon,
|
||||
"use_render": render,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
"manual_control": True,
|
||||
"manual_control": False,
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0,
|
||||
"replay_mode": False, # 仿真模式
|
||||
"debug": debug,
|
||||
"specific_scenario_id": scenario_id, # 指定场景ID
|
||||
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
||||
# 对象类型过滤
|
||||
"spawn_vehicles": spawn_vehicles,
|
||||
"spawn_pedestrians": spawn_pedestrians,
|
||||
"spawn_cyclists": spawn_cyclists,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
obs = env.reset(0
|
||||
)
|
||||
for step in range(10000):
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
try:
|
||||
for episode in range(num_episodes):
|
||||
print(f"\n{'='*50}")
|
||||
print(f"回合 {episode + 1}/{num_episodes}")
|
||||
if scenario_id is not None:
|
||||
print(f"场景ID: {scenario_id}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
env.render(mode="topdown")
|
||||
seed = scenario_id if scenario_id is not None else episode
|
||||
obs = env.reset(seed=seed)
|
||||
|
||||
if dones["__all__"]:
|
||||
break
|
||||
actual_horizon = env.config["horizon"]
|
||||
print(f"初始化完成:")
|
||||
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
||||
print(f" 场景时长: {env.scenario_max_duration} 步")
|
||||
print(f" 实际Horizon: {actual_horizon} 步")
|
||||
|
||||
env.close()
|
||||
step_count = 0
|
||||
total_reward = 0.0
|
||||
|
||||
while True:
|
||||
# 使用策略生成动作
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
if render:
|
||||
env.render(mode="topdown")
|
||||
|
||||
step_count += 1
|
||||
total_reward += sum(rewards.values())
|
||||
|
||||
# 每50步打印一次状态
|
||||
if step_count % 50 == 0:
|
||||
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
||||
|
||||
if dones["__all__"]:
|
||||
print(f"\n回合结束统计:")
|
||||
print(f" 总步数: {step_count}")
|
||||
print(f" 总奖励: {total_reward:.2f}")
|
||||
break
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("仿真完成!")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MetaDrive 多智能体环境运行脚本")
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["replay", "simulation"],
|
||||
default="simulation",
|
||||
help="运行模式: replay=专家轨迹回放, simulation=策略仿真 (默认: simulation)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default=WAYMO_DATA_DIR,
|
||||
help=f"数据目录路径 (默认: {WAYMO_DATA_DIR})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="运行回合数 (默认: 1)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--horizon",
|
||||
type=int,
|
||||
default=300,
|
||||
help="每回合最大步数 (默认: 300,如果启用 --use_scenario_duration 则自动设置)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_render",
|
||||
action="store_true",
|
||||
help="禁用渲染(加速运行)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="启用调试模式(显示详细日志)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scenario_id",
|
||||
type=int,
|
||||
default=None,
|
||||
help="指定场景ID(可选,如指定则只运行该场景)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_scenario_duration",
|
||||
action="store_true",
|
||||
help="使用场景原始时长作为horizon(自动停止)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_vehicles",
|
||||
action="store_true",
|
||||
help="禁止生成车辆"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_pedestrians",
|
||||
action="store_true",
|
||||
help="禁止生成行人"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_cyclists",
|
||||
action="store_true",
|
||||
help="禁止生成自行车"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "replay":
|
||||
run_replay_mode(
|
||||
data_dir=args.data_dir,
|
||||
num_episodes=args.episodes,
|
||||
horizon=args.horizon,
|
||||
render=not args.no_render,
|
||||
debug=args.debug,
|
||||
scenario_id=args.scenario_id,
|
||||
use_scenario_duration=args.use_scenario_duration,
|
||||
spawn_vehicles=not args.no_vehicles,
|
||||
spawn_pedestrians=not args.no_pedestrians,
|
||||
spawn_cyclists=not args.no_cyclists
|
||||
)
|
||||
else:
|
||||
run_simulation_mode(
|
||||
data_dir=args.data_dir,
|
||||
num_episodes=args.episodes,
|
||||
horizon=args.horizon,
|
||||
render=not args.no_render,
|
||||
debug=args.debug,
|
||||
scenario_id=args.scenario_id,
|
||||
use_scenario_duration=args.use_scenario_duration,
|
||||
spawn_vehicles=not args.no_vehicles,
|
||||
spawn_pedestrians=not args.no_pedestrians,
|
||||
spawn_cyclists=not args.no_cyclists
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -15,6 +15,7 @@ class PolicyVehicle(DefaultVehicle):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.policy = None
|
||||
self.destination = None
|
||||
self.expert_vehicle_id = None # 关联专家车辆ID
|
||||
|
||||
def set_policy(self, policy):
|
||||
self.policy = policy
|
||||
@@ -22,6 +23,9 @@ class PolicyVehicle(DefaultVehicle):
|
||||
def set_destination(self, des):
|
||||
self.destination = des
|
||||
|
||||
def set_expert_vehicle_id(self, vid):
|
||||
self.expert_vehicle_id = vid
|
||||
|
||||
def act(self, observation, policy=None):
|
||||
if self.policy is not None:
|
||||
return self.policy.act(observation)
|
||||
@@ -53,6 +57,15 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
data_directory=None,
|
||||
num_controlled_agents=3,
|
||||
horizon=1000,
|
||||
filter_offroad_vehicles=True, # 车道过滤开关
|
||||
lane_tolerance=3.0, # 车道检测容差(米)
|
||||
replay_mode=False, # 回放模式开关
|
||||
specific_scenario_id=None, # 新增:指定场景ID(仅回放模式)
|
||||
use_scenario_duration=False, # 新增:使用场景原始时长作为horizon
|
||||
# 对象类型过滤选项
|
||||
spawn_vehicles=True, # 是否生成车辆
|
||||
spawn_pedestrians=True, # 是否生成行人
|
||||
spawn_cyclists=True, # 是否生成自行车
|
||||
))
|
||||
return config
|
||||
|
||||
@@ -62,50 +75,180 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
self.controlled_agent_ids = []
|
||||
self.obs_list = []
|
||||
self.round = 0
|
||||
self.expert_trajectories = {} # 存储完整专家轨迹
|
||||
self.replay_mode = config.get("replay_mode", False)
|
||||
self.scenario_max_duration = 0 # 场景实际最大时长
|
||||
super().__init__(config)
|
||||
|
||||
def reset(self, seed: Union[None, int] = None):
|
||||
self.round = 0
|
||||
|
||||
if self.logger is None:
|
||||
self.logger = get_logger()
|
||||
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
|
||||
set_log_level(log_level)
|
||||
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
|
||||
set_log_level(log_level)
|
||||
|
||||
# ✅ 关键修复:在每次 reset 前清理所有自定义生成的对象
|
||||
if hasattr(self, 'engine') and self.engine is not None:
|
||||
if hasattr(self, 'controlled_agents') and self.controlled_agents:
|
||||
# 先从 agent_manager 中移除
|
||||
if hasattr(self.engine, 'agent_manager'):
|
||||
for agent_id in list(self.controlled_agents.keys()):
|
||||
if agent_id in self.engine.agent_manager.active_agents:
|
||||
self.engine.agent_manager.active_agents.pop(agent_id)
|
||||
|
||||
# 然后清理对象
|
||||
for agent_id, vehicle in list(self.controlled_agents.items()):
|
||||
try:
|
||||
self.engine.clear_objects([vehicle.id])
|
||||
except:
|
||||
pass
|
||||
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
|
||||
self.lazy_init()
|
||||
self._reset_global_seed(seed)
|
||||
|
||||
if self.engine is None:
|
||||
raise ValueError("Broken MetaDrive instance.")
|
||||
|
||||
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
|
||||
_obj_to_clean_this_frame = []
|
||||
self.car_birth_info_list = []
|
||||
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
||||
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
|
||||
continue
|
||||
else:
|
||||
if track["type"] == MetaDriveType.VEHICLE:
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
valid = track['state']['valid']
|
||||
first_show = np.argmax(valid) if valid.any() else -1
|
||||
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
||||
# id,出现时间,出生点坐标,出生朝向,目的地
|
||||
self.car_birth_info_list.append({
|
||||
'id': track['metadata']['object_id'],
|
||||
'show_time': first_show,
|
||||
'begin': (track['state']['position'][first_show, 0], track['state']['position'][first_show, 1]),
|
||||
'heading': track['state']['heading'][first_show],
|
||||
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
|
||||
})
|
||||
|
||||
for scenario_id in _obj_to_clean_this_frame:
|
||||
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
|
||||
# 如果指定了场景ID,修改start_scenario_index
|
||||
if self.config.get("specific_scenario_id") is not None:
|
||||
scenario_id = self.config.get("specific_scenario_id")
|
||||
self.config["start_scenario_index"] = scenario_id
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"Using specific scenario ID: {scenario_id}")
|
||||
|
||||
# ✅ 先初始化引擎和 lanes
|
||||
self.engine.reset()
|
||||
self.reset_sensors()
|
||||
self.engine.taskMgr.step()
|
||||
|
||||
self.lanes = self.engine.map_manager.current_map.road_network.graph
|
||||
|
||||
# 记录专家数据(现在 self.lanes 已经初始化)
|
||||
_obj_to_clean_this_frame = []
|
||||
self.car_birth_info_list = []
|
||||
self.expert_trajectories.clear()
|
||||
total_vehicles = 0
|
||||
total_pedestrians = 0
|
||||
total_cyclists = 0
|
||||
filtered_vehicles = 0
|
||||
filtered_by_type = 0
|
||||
self.scenario_max_duration = 0 # 重置场景时长
|
||||
|
||||
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
||||
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
|
||||
continue
|
||||
|
||||
# 对象类型过滤
|
||||
obj_type = track["type"]
|
||||
|
||||
# 统计对象类型
|
||||
if obj_type == MetaDriveType.VEHICLE:
|
||||
total_vehicles += 1
|
||||
elif obj_type == MetaDriveType.PEDESTRIAN:
|
||||
total_pedestrians += 1
|
||||
elif obj_type == MetaDriveType.CYCLIST:
|
||||
total_cyclists += 1
|
||||
|
||||
# 根据配置过滤对象类型
|
||||
if obj_type == MetaDriveType.VEHICLE and not self.config.get("spawn_vehicles", True):
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
filtered_by_type += 1
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Filtering VEHICLE {track['metadata']['object_id']} - spawn_vehicles=False")
|
||||
continue
|
||||
|
||||
if obj_type == MetaDriveType.PEDESTRIAN and not self.config.get("spawn_pedestrians", True):
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
filtered_by_type += 1
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Filtering PEDESTRIAN {track['metadata']['object_id']} - spawn_pedestrians=False")
|
||||
continue
|
||||
|
||||
if obj_type == MetaDriveType.CYCLIST and not self.config.get("spawn_cyclists", True):
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
filtered_by_type += 1
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Filtering CYCLIST {track['metadata']['object_id']} - spawn_cyclists=False")
|
||||
continue
|
||||
|
||||
# 只处理车辆类型(行人和自行车暂时只做过滤)
|
||||
if track["type"] == MetaDriveType.VEHICLE:
|
||||
valid = track['state']['valid']
|
||||
first_show = np.argmax(valid) if valid.any() else -1
|
||||
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
||||
|
||||
if first_show == -1 or last_show == -1:
|
||||
continue
|
||||
|
||||
# 更新场景最大时长
|
||||
self.scenario_max_duration = max(self.scenario_max_duration, last_show + 1)
|
||||
|
||||
# 获取车辆初始位置
|
||||
initial_position = (
|
||||
track['state']['position'][first_show, 0],
|
||||
track['state']['position'][first_show, 1]
|
||||
)
|
||||
|
||||
# 车道过滤
|
||||
if self.config.get("filter_offroad_vehicles", True):
|
||||
if not self._is_position_on_lane(initial_position):
|
||||
filtered_vehicles += 1
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(
|
||||
f"Filtering vehicle {track['metadata']['object_id']} - "
|
||||
f"not on lane at position {initial_position}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 存储完整专家轨迹(只使用2D位置,避免高度问题)
|
||||
object_id = track['metadata']['object_id']
|
||||
positions_2d = track['state']['position'].copy()
|
||||
positions_2d[:, 2] = 0 # 将z坐标设为0,让MetaDrive自动处理高度
|
||||
|
||||
self.expert_trajectories[object_id] = {
|
||||
'positions': positions_2d,
|
||||
'headings': track['state']['heading'].copy(),
|
||||
'velocities': track['state']['velocity'].copy(),
|
||||
'valid': track['state']['valid'].copy(),
|
||||
}
|
||||
|
||||
# 保存车辆生成信息
|
||||
self.car_birth_info_list.append({
|
||||
'id': object_id,
|
||||
'show_time': first_show,
|
||||
'begin': initial_position,
|
||||
'heading': track['state']['heading'][first_show],
|
||||
'velocity': track['state']['velocity'][first_show] if self.config.get("inherit_expert_velocity", False) else None,
|
||||
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
|
||||
})
|
||||
|
||||
# 在回放和仿真模式下都清除原始专家车辆
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
|
||||
# 清除专家车辆和过滤的对象
|
||||
for scenario_id in _obj_to_clean_this_frame:
|
||||
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
|
||||
|
||||
# 输出统计信息
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"=== 对象统计 ===")
|
||||
self.logger.info(f"车辆 (VEHICLE): 总数={total_vehicles}, 车道过滤={filtered_vehicles}, 保留={total_vehicles - filtered_vehicles}")
|
||||
self.logger.info(f"行人 (PEDESTRIAN): 总数={total_pedestrians}")
|
||||
self.logger.info(f"自行车 (CYCLIST): 总数={total_cyclists}")
|
||||
self.logger.info(f"类型过滤: {filtered_by_type} 个对象")
|
||||
self.logger.info(f"场景时长: {self.scenario_max_duration} 步")
|
||||
|
||||
# 如果启用场景时长控制,更新horizon
|
||||
if self.config.get("use_scenario_duration", False) and self.scenario_max_duration > 0:
|
||||
original_horizon = self.config["horizon"]
|
||||
self.config["horizon"] = self.scenario_max_duration
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"Horizon updated from {original_horizon} to {self.scenario_max_duration} (scenario duration)")
|
||||
|
||||
if self.top_down_renderer is not None:
|
||||
self.top_down_renderer.clear()
|
||||
self.engine.top_down_renderer = None
|
||||
@@ -113,7 +256,6 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
self.dones = {}
|
||||
self.episode_rewards = defaultdict(float)
|
||||
self.episode_lengths = defaultdict(int)
|
||||
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
|
||||
@@ -122,37 +264,92 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
|
||||
return self._get_all_obs()
|
||||
|
||||
def _is_position_on_lane(self, position, tolerance=None):
|
||||
if tolerance is None:
|
||||
tolerance = self.config.get("lane_tolerance", 3.0)
|
||||
|
||||
# 确保 self.lanes 已初始化
|
||||
if not hasattr(self, 'lanes') or self.lanes is None:
|
||||
if self.config.get("debug", False):
|
||||
self.logger.warning("Lanes not initialized, skipping lane check")
|
||||
return True
|
||||
|
||||
position_2d = np.array(position[:2]) if len(position) > 2 else np.array(position)
|
||||
|
||||
try:
|
||||
for lane in self.lanes.values():
|
||||
if lane.lane.point_on_lane(position_2d):
|
||||
return True
|
||||
|
||||
lane_start = np.array(lane.lane.start)[:2]
|
||||
lane_end = np.array(lane.lane.end)[:2]
|
||||
lane_vec = lane_end - lane_start
|
||||
lane_length = np.linalg.norm(lane_vec)
|
||||
|
||||
if lane_length < 1e-6:
|
||||
continue
|
||||
|
||||
lane_vec_normalized = lane_vec / lane_length
|
||||
point_vec = position_2d - lane_start
|
||||
projection = np.dot(point_vec, lane_vec_normalized)
|
||||
|
||||
if 0 <= projection <= lane_length:
|
||||
closest_point = lane_start + projection * lane_vec_normalized
|
||||
distance = np.linalg.norm(position_2d - closest_point)
|
||||
if distance <= tolerance:
|
||||
return True
|
||||
except Exception as e:
|
||||
if self.config.get("debug", False):
|
||||
self.logger.warning(f"Lane check error: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _spawn_controlled_agents(self):
|
||||
# ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent")
|
||||
# ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0])
|
||||
for car in self.car_birth_info_list:
|
||||
if car['show_time'] == self.round:
|
||||
agent_id = f"controlled_{car['id']}"
|
||||
|
||||
vehicle_config = {}
|
||||
vehicle = self.engine.spawn_object(
|
||||
PolicyVehicle,
|
||||
vehicle_config={},
|
||||
vehicle_config=vehicle_config,
|
||||
position=car['begin'],
|
||||
heading=car['heading']
|
||||
)
|
||||
vehicle.reset(position=car['begin'], heading=car['heading'])
|
||||
|
||||
# 重置车辆状态
|
||||
reset_kwargs = {
|
||||
'position': car['begin'],
|
||||
'heading': car['heading']
|
||||
}
|
||||
|
||||
# 如果启用速度继承,设置初始速度
|
||||
if car.get('velocity') is not None:
|
||||
reset_kwargs['velocity'] = car['velocity']
|
||||
|
||||
vehicle.reset(**reset_kwargs)
|
||||
|
||||
# 设置策略和目的地
|
||||
vehicle.set_policy(self.policy)
|
||||
vehicle.set_destination(car['end'])
|
||||
vehicle.set_expert_vehicle_id(car['id'])
|
||||
|
||||
self.controlled_agents[agent_id] = vehicle
|
||||
self.controlled_agent_ids.append(agent_id)
|
||||
|
||||
# ✅ 关键:注册到引擎的 active_agents,才能参与物理更新
|
||||
# 注册到引擎的 active_agents
|
||||
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
||||
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Spawned vehicle {agent_id} at round {self.round}, position {car['begin']}")
|
||||
|
||||
def _get_all_obs(self):
|
||||
# position, velocity, heading, lidar, navigation, TODO: trafficlight -> list
|
||||
self.obs_list = []
|
||||
|
||||
for agent_id, vehicle in self.controlled_agents.items():
|
||||
state = vehicle.get_state()
|
||||
|
||||
traffic_light = 0
|
||||
|
||||
for lane in self.lanes.values():
|
||||
if lane.lane.point_on_lane(state['position'][:2]):
|
||||
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
||||
@@ -168,37 +365,69 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
break
|
||||
|
||||
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.dynamic_world)
|
||||
physics_world=self.engine.physics_world.dynamic_world)
|
||||
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
|
||||
obs = (state['position'][:2] + list(state['velocity']) + [state['heading_theta']]
|
||||
obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']]
|
||||
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
|
||||
+ list(vehicle.destination))
|
||||
|
||||
self.obs_list.append(obs)
|
||||
|
||||
return self.obs_list
|
||||
|
||||
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
|
||||
self.round += 1
|
||||
|
||||
# 应用动作
|
||||
for agent_id, action in action_dict.items():
|
||||
if agent_id in self.controlled_agents:
|
||||
self.controlled_agents[agent_id].before_step(action)
|
||||
|
||||
# 物理引擎步进
|
||||
self.engine.step()
|
||||
|
||||
# 后处理
|
||||
for agent_id in action_dict:
|
||||
if agent_id in self.controlled_agents:
|
||||
self.controlled_agents[agent_id].after_step()
|
||||
|
||||
# 生成新车辆
|
||||
self._spawn_controlled_agents()
|
||||
|
||||
# 获取观测
|
||||
obs = self._get_all_obs()
|
||||
|
||||
rewards = {aid: 0.0 for aid in self.controlled_agents}
|
||||
dones = {aid: False for aid in self.controlled_agents}
|
||||
dones["__all__"] = self.episode_step >= self.config["horizon"]
|
||||
|
||||
# ✅ 修复:添加回放模式的完成检查
|
||||
replay_finished = False
|
||||
if self.replay_mode and self.config.get("use_scenario_duration", False):
|
||||
# 检查是否所有专家轨迹都已播放完毕
|
||||
if self.round >= self.scenario_max_duration:
|
||||
replay_finished = True
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"Replay finished at step {self.round}/{self.scenario_max_duration}")
|
||||
|
||||
dones["__all__"] = self.episode_step >= self.config["horizon"] or replay_finished
|
||||
|
||||
infos = {aid: {} for aid in self.controlled_agents}
|
||||
|
||||
return obs, rewards, dones, infos
|
||||
|
||||
def close(self):
|
||||
# ✅ 清理所有生成的车辆
|
||||
if hasattr(self, 'controlled_agents') and self.controlled_agents:
|
||||
for agent_id, vehicle in list(self.controlled_agents.items()):
|
||||
if vehicle in self.engine.get_objects():
|
||||
self.engine.clear_objects([vehicle.id])
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
|
||||
super().close()
|
||||
Reference in New Issue
Block a user