2025-09-28 18:57:04 +08:00
|
|
|
|
import numpy as np
|
|
|
|
|
|
from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation
|
|
|
|
|
|
from metadrive.envs.scenario_env import ScenarioEnv
|
|
|
|
|
|
from metadrive.component.vehicle.vehicle_type import DefaultVehicle, vehicle_class_to_type
|
|
|
|
|
|
import math
|
|
|
|
|
|
import logging
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
from typing import Union, Dict, AnyStr
|
|
|
|
|
|
from metadrive.engine.logger import get_logger, set_log_level
|
|
|
|
|
|
from metadrive.type import MetaDriveType
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PolicyVehicle(DefaultVehicle):
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
self.policy = None
|
|
|
|
|
|
self.destination = None
|
2025-10-27 00:27:47 +08:00
|
|
|
|
self.expert_vehicle_id = None # 关联专家车辆ID
|
2025-09-28 18:57:04 +08:00
|
|
|
|
|
|
|
|
|
|
def set_policy(self, policy):
|
|
|
|
|
|
self.policy = policy
|
|
|
|
|
|
|
|
|
|
|
|
def set_destination(self, des):
|
|
|
|
|
|
self.destination = des
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
def set_expert_vehicle_id(self, vid):
|
|
|
|
|
|
self.expert_vehicle_id = vid
|
|
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
def act(self, observation, policy=None):
|
|
|
|
|
|
if self.policy is not None:
|
|
|
|
|
|
return self.policy.act(observation)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return self.action_space.sample()
|
|
|
|
|
|
|
|
|
|
|
|
def before_step(self, action):
|
|
|
|
|
|
self.last_position = self.position # 2D vector
|
|
|
|
|
|
self.last_velocity = self.velocity # 2D vector
|
|
|
|
|
|
self.last_speed = self.speed # Scalar
|
|
|
|
|
|
self.last_heading_dir = self.heading
|
|
|
|
|
|
if action is not None:
|
|
|
|
|
|
self.last_current_action.append(action)
|
|
|
|
|
|
self._set_action(action)
|
|
|
|
|
|
|
|
|
|
|
|
def is_done(self):
|
|
|
|
|
|
# arrive or crash
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vehicle_class_to_type[PolicyVehicle] = "default"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiAgentScenarioEnv(ScenarioEnv):
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def default_config(cls):
|
|
|
|
|
|
config = super().default_config()
|
|
|
|
|
|
config.update(dict(
|
|
|
|
|
|
data_directory=None,
|
|
|
|
|
|
num_controlled_agents=3,
|
|
|
|
|
|
horizon=1000,
|
2025-10-27 00:27:47 +08:00
|
|
|
|
filter_offroad_vehicles=True, # 车道过滤开关
|
|
|
|
|
|
lane_tolerance=3.0, # 车道检测容差(米)
|
|
|
|
|
|
replay_mode=False, # 回放模式开关
|
|
|
|
|
|
specific_scenario_id=None, # 新增:指定场景ID(仅回放模式)
|
|
|
|
|
|
use_scenario_duration=False, # 新增:使用场景原始时长作为horizon
|
|
|
|
|
|
# 对象类型过滤选项
|
|
|
|
|
|
spawn_vehicles=True, # 是否生成车辆
|
|
|
|
|
|
spawn_pedestrians=True, # 是否生成行人
|
|
|
|
|
|
spawn_cyclists=True, # 是否生成自行车
|
2025-09-28 18:57:04 +08:00
|
|
|
|
))
|
|
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config, agent2policy):
|
|
|
|
|
|
self.policy = agent2policy
|
|
|
|
|
|
self.controlled_agents = {}
|
|
|
|
|
|
self.controlled_agent_ids = []
|
|
|
|
|
|
self.obs_list = []
|
|
|
|
|
|
self.round = 0
|
2025-10-27 00:27:47 +08:00
|
|
|
|
self.expert_trajectories = {} # 存储完整专家轨迹
|
|
|
|
|
|
self.replay_mode = config.get("replay_mode", False)
|
|
|
|
|
|
self.scenario_max_duration = 0 # 场景实际最大时长
|
2025-09-28 18:57:04 +08:00
|
|
|
|
super().__init__(config)
|
|
|
|
|
|
|
|
|
|
|
|
def reset(self, seed: Union[None, int] = None):
|
|
|
|
|
|
self.round = 0
|
2025-10-27 00:27:47 +08:00
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
if self.logger is None:
|
|
|
|
|
|
self.logger = get_logger()
|
2025-10-27 00:27:47 +08:00
|
|
|
|
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
|
|
|
|
|
|
set_log_level(log_level)
|
|
|
|
|
|
|
|
|
|
|
|
# ✅ 关键修复:在每次 reset 前清理所有自定义生成的对象
|
|
|
|
|
|
if hasattr(self, 'engine') and self.engine is not None:
|
|
|
|
|
|
if hasattr(self, 'controlled_agents') and self.controlled_agents:
|
|
|
|
|
|
# 先从 agent_manager 中移除
|
|
|
|
|
|
if hasattr(self.engine, 'agent_manager'):
|
|
|
|
|
|
for agent_id in list(self.controlled_agents.keys()):
|
|
|
|
|
|
if agent_id in self.engine.agent_manager.active_agents:
|
|
|
|
|
|
self.engine.agent_manager.active_agents.pop(agent_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 然后清理对象
|
|
|
|
|
|
for agent_id, vehicle in list(self.controlled_agents.items()):
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.engine.clear_objects([vehicle.id])
|
|
|
|
|
|
except:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
self.controlled_agents.clear()
|
|
|
|
|
|
self.controlled_agent_ids.clear()
|
|
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
self.lazy_init()
|
|
|
|
|
|
self._reset_global_seed(seed)
|
2025-10-27 00:27:47 +08:00
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
if self.engine is None:
|
|
|
|
|
|
raise ValueError("Broken MetaDrive instance.")
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
# 如果指定了场景ID,修改start_scenario_index
|
|
|
|
|
|
if self.config.get("specific_scenario_id") is not None:
|
|
|
|
|
|
scenario_id = self.config.get("specific_scenario_id")
|
|
|
|
|
|
self.config["start_scenario_index"] = scenario_id
|
|
|
|
|
|
if self.config.get("debug", False):
|
|
|
|
|
|
self.logger.info(f"Using specific scenario ID: {scenario_id}")
|
|
|
|
|
|
|
|
|
|
|
|
# ✅ 先初始化引擎和 lanes
|
|
|
|
|
|
self.engine.reset()
|
|
|
|
|
|
self.reset_sensors()
|
|
|
|
|
|
self.engine.taskMgr.step()
|
|
|
|
|
|
self.lanes = self.engine.map_manager.current_map.road_network.graph
|
|
|
|
|
|
|
|
|
|
|
|
# 记录专家数据(现在 self.lanes 已经初始化)
|
2025-09-28 18:57:04 +08:00
|
|
|
|
_obj_to_clean_this_frame = []
|
|
|
|
|
|
self.car_birth_info_list = []
|
2025-10-27 00:27:47 +08:00
|
|
|
|
self.expert_trajectories.clear()
|
|
|
|
|
|
total_vehicles = 0
|
|
|
|
|
|
total_pedestrians = 0
|
|
|
|
|
|
total_cyclists = 0
|
|
|
|
|
|
filtered_vehicles = 0
|
|
|
|
|
|
filtered_by_type = 0
|
|
|
|
|
|
self.scenario_max_duration = 0 # 重置场景时长
|
|
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
|
|
|
|
|
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
# 对象类型过滤
|
|
|
|
|
|
obj_type = track["type"]
|
|
|
|
|
|
|
|
|
|
|
|
# 统计对象类型
|
|
|
|
|
|
if obj_type == MetaDriveType.VEHICLE:
|
|
|
|
|
|
total_vehicles += 1
|
|
|
|
|
|
elif obj_type == MetaDriveType.PEDESTRIAN:
|
|
|
|
|
|
total_pedestrians += 1
|
|
|
|
|
|
elif obj_type == MetaDriveType.CYCLIST:
|
|
|
|
|
|
total_cyclists += 1
|
|
|
|
|
|
|
|
|
|
|
|
# 根据配置过滤对象类型
|
|
|
|
|
|
if obj_type == MetaDriveType.VEHICLE and not self.config.get("spawn_vehicles", True):
|
|
|
|
|
|
_obj_to_clean_this_frame.append(scenario_id)
|
|
|
|
|
|
filtered_by_type += 1
|
|
|
|
|
|
if self.config.get("debug", False):
|
|
|
|
|
|
self.logger.debug(f"Filtering VEHICLE {track['metadata']['object_id']} - spawn_vehicles=False")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
if obj_type == MetaDriveType.PEDESTRIAN and not self.config.get("spawn_pedestrians", True):
|
|
|
|
|
|
_obj_to_clean_this_frame.append(scenario_id)
|
|
|
|
|
|
filtered_by_type += 1
|
|
|
|
|
|
if self.config.get("debug", False):
|
|
|
|
|
|
self.logger.debug(f"Filtering PEDESTRIAN {track['metadata']['object_id']} - spawn_pedestrians=False")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
if obj_type == MetaDriveType.CYCLIST and not self.config.get("spawn_cyclists", True):
|
|
|
|
|
|
_obj_to_clean_this_frame.append(scenario_id)
|
|
|
|
|
|
filtered_by_type += 1
|
|
|
|
|
|
if self.config.get("debug", False):
|
|
|
|
|
|
self.logger.debug(f"Filtering CYCLIST {track['metadata']['object_id']} - spawn_cyclists=False")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 只处理车辆类型(行人和自行车暂时只做过滤)
|
|
|
|
|
|
if track["type"] == MetaDriveType.VEHICLE:
|
|
|
|
|
|
valid = track['state']['valid']
|
|
|
|
|
|
first_show = np.argmax(valid) if valid.any() else -1
|
|
|
|
|
|
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
|
|
|
|
|
|
|
|
|
|
|
if first_show == -1 or last_show == -1:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 更新场景最大时长
|
|
|
|
|
|
self.scenario_max_duration = max(self.scenario_max_duration, last_show + 1)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取车辆初始位置
|
|
|
|
|
|
initial_position = (
|
|
|
|
|
|
track['state']['position'][first_show, 0],
|
|
|
|
|
|
track['state']['position'][first_show, 1]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 车道过滤
|
|
|
|
|
|
if self.config.get("filter_offroad_vehicles", True):
|
|
|
|
|
|
if not self._is_position_on_lane(initial_position):
|
|
|
|
|
|
filtered_vehicles += 1
|
|
|
|
|
|
_obj_to_clean_this_frame.append(scenario_id)
|
|
|
|
|
|
if self.config.get("debug", False):
|
|
|
|
|
|
self.logger.debug(
|
|
|
|
|
|
f"Filtering vehicle {track['metadata']['object_id']} - "
|
|
|
|
|
|
f"not on lane at position {initial_position}"
|
|
|
|
|
|
)
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 存储完整专家轨迹(只使用2D位置,避免高度问题)
|
|
|
|
|
|
object_id = track['metadata']['object_id']
|
|
|
|
|
|
positions_2d = track['state']['position'].copy()
|
|
|
|
|
|
positions_2d[:, 2] = 0 # 将z坐标设为0,让MetaDrive自动处理高度
|
|
|
|
|
|
|
|
|
|
|
|
self.expert_trajectories[object_id] = {
|
|
|
|
|
|
'positions': positions_2d,
|
|
|
|
|
|
'headings': track['state']['heading'].copy(),
|
|
|
|
|
|
'velocities': track['state']['velocity'].copy(),
|
|
|
|
|
|
'valid': track['state']['valid'].copy(),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 保存车辆生成信息
|
|
|
|
|
|
self.car_birth_info_list.append({
|
|
|
|
|
|
'id': object_id,
|
|
|
|
|
|
'show_time': first_show,
|
|
|
|
|
|
'begin': initial_position,
|
|
|
|
|
|
'heading': track['state']['heading'][first_show],
|
|
|
|
|
|
'velocity': track['state']['velocity'][first_show] if self.config.get("inherit_expert_velocity", False) else None,
|
|
|
|
|
|
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 在回放和仿真模式下都清除原始专家车辆
|
|
|
|
|
|
_obj_to_clean_this_frame.append(scenario_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 清除专家车辆和过滤的对象
|
2025-09-28 18:57:04 +08:00
|
|
|
|
for scenario_id in _obj_to_clean_this_frame:
|
|
|
|
|
|
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
# 输出统计信息
|
|
|
|
|
|
if self.config.get("debug", False):
|
|
|
|
|
|
self.logger.info(f"=== 对象统计 ===")
|
|
|
|
|
|
self.logger.info(f"车辆 (VEHICLE): 总数={total_vehicles}, 车道过滤={filtered_vehicles}, 保留={total_vehicles - filtered_vehicles}")
|
|
|
|
|
|
self.logger.info(f"行人 (PEDESTRIAN): 总数={total_pedestrians}")
|
|
|
|
|
|
self.logger.info(f"自行车 (CYCLIST): 总数={total_cyclists}")
|
|
|
|
|
|
self.logger.info(f"类型过滤: {filtered_by_type} 个对象")
|
|
|
|
|
|
self.logger.info(f"场景时长: {self.scenario_max_duration} 步")
|
|
|
|
|
|
|
|
|
|
|
|
# 如果启用场景时长控制,更新horizon
|
|
|
|
|
|
if self.config.get("use_scenario_duration", False) and self.scenario_max_duration > 0:
|
|
|
|
|
|
original_horizon = self.config["horizon"]
|
|
|
|
|
|
self.config["horizon"] = self.scenario_max_duration
|
|
|
|
|
|
if self.config.get("debug", False):
|
|
|
|
|
|
self.logger.info(f"Horizon updated from {original_horizon} to {self.scenario_max_duration} (scenario duration)")
|
2025-09-28 18:57:04 +08:00
|
|
|
|
|
|
|
|
|
|
if self.top_down_renderer is not None:
|
|
|
|
|
|
self.top_down_renderer.clear()
|
|
|
|
|
|
self.engine.top_down_renderer = None
|
|
|
|
|
|
|
|
|
|
|
|
self.dones = {}
|
|
|
|
|
|
self.episode_rewards = defaultdict(float)
|
|
|
|
|
|
self.episode_lengths = defaultdict(int)
|
|
|
|
|
|
self.controlled_agents.clear()
|
|
|
|
|
|
self.controlled_agent_ids.clear()
|
|
|
|
|
|
|
|
|
|
|
|
super().reset(seed) # 初始化场景
|
|
|
|
|
|
self._spawn_controlled_agents()
|
|
|
|
|
|
|
|
|
|
|
|
return self._get_all_obs()
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
def _is_position_on_lane(self, position, tolerance=None):
|
|
|
|
|
|
if tolerance is None:
|
|
|
|
|
|
tolerance = self.config.get("lane_tolerance", 3.0)
|
|
|
|
|
|
|
|
|
|
|
|
# 确保 self.lanes 已初始化
|
|
|
|
|
|
if not hasattr(self, 'lanes') or self.lanes is None:
|
|
|
|
|
|
if self.config.get("debug", False):
|
|
|
|
|
|
self.logger.warning("Lanes not initialized, skipping lane check")
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
position_2d = np.array(position[:2]) if len(position) > 2 else np.array(position)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
for lane in self.lanes.values():
|
|
|
|
|
|
if lane.lane.point_on_lane(position_2d):
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
lane_start = np.array(lane.lane.start)[:2]
|
|
|
|
|
|
lane_end = np.array(lane.lane.end)[:2]
|
|
|
|
|
|
lane_vec = lane_end - lane_start
|
|
|
|
|
|
lane_length = np.linalg.norm(lane_vec)
|
|
|
|
|
|
|
|
|
|
|
|
if lane_length < 1e-6:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
lane_vec_normalized = lane_vec / lane_length
|
|
|
|
|
|
point_vec = position_2d - lane_start
|
|
|
|
|
|
projection = np.dot(point_vec, lane_vec_normalized)
|
|
|
|
|
|
|
|
|
|
|
|
if 0 <= projection <= lane_length:
|
|
|
|
|
|
closest_point = lane_start + projection * lane_vec_normalized
|
|
|
|
|
|
distance = np.linalg.norm(position_2d - closest_point)
|
|
|
|
|
|
if distance <= tolerance:
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
if self.config.get("debug", False):
|
|
|
|
|
|
self.logger.warning(f"Lane check error: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
def _spawn_controlled_agents(self):
|
|
|
|
|
|
for car in self.car_birth_info_list:
|
|
|
|
|
|
if car['show_time'] == self.round:
|
|
|
|
|
|
agent_id = f"controlled_{car['id']}"
|
2025-10-27 00:27:47 +08:00
|
|
|
|
vehicle_config = {}
|
2025-09-28 18:57:04 +08:00
|
|
|
|
vehicle = self.engine.spawn_object(
|
|
|
|
|
|
PolicyVehicle,
|
2025-10-27 00:27:47 +08:00
|
|
|
|
vehicle_config=vehicle_config,
|
2025-09-28 18:57:04 +08:00
|
|
|
|
position=car['begin'],
|
|
|
|
|
|
heading=car['heading']
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
# 重置车辆状态
|
|
|
|
|
|
reset_kwargs = {
|
|
|
|
|
|
'position': car['begin'],
|
|
|
|
|
|
'heading': car['heading']
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 如果启用速度继承,设置初始速度
|
|
|
|
|
|
if car.get('velocity') is not None:
|
|
|
|
|
|
reset_kwargs['velocity'] = car['velocity']
|
|
|
|
|
|
|
|
|
|
|
|
vehicle.reset(**reset_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
# 设置策略和目的地
|
2025-09-28 18:57:04 +08:00
|
|
|
|
vehicle.set_policy(self.policy)
|
|
|
|
|
|
vehicle.set_destination(car['end'])
|
2025-10-27 00:27:47 +08:00
|
|
|
|
vehicle.set_expert_vehicle_id(car['id'])
|
2025-09-28 18:57:04 +08:00
|
|
|
|
|
|
|
|
|
|
self.controlled_agents[agent_id] = vehicle
|
|
|
|
|
|
self.controlled_agent_ids.append(agent_id)
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
# 注册到引擎的 active_agents
|
2025-09-28 18:57:04 +08:00
|
|
|
|
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
if self.config.get("debug", False):
|
|
|
|
|
|
self.logger.debug(f"Spawned vehicle {agent_id} at round {self.round}, position {car['begin']}")
|
|
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
def _get_all_obs(self):
|
|
|
|
|
|
self.obs_list = []
|
2025-10-27 00:27:47 +08:00
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
for agent_id, vehicle in self.controlled_agents.items():
|
|
|
|
|
|
state = vehicle.get_state()
|
|
|
|
|
|
traffic_light = 0
|
2025-10-27 00:27:47 +08:00
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
for lane in self.lanes.values():
|
|
|
|
|
|
if lane.lane.point_on_lane(state['position'][:2]):
|
|
|
|
|
|
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
|
|
|
|
|
traffic_light = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
|
|
|
|
|
|
if traffic_light == 'TRAFFIC_LIGHT_GREEN':
|
|
|
|
|
|
traffic_light = 1
|
|
|
|
|
|
elif traffic_light == 'TRAFFIC_LIGHT_YELLOW':
|
|
|
|
|
|
traffic_light = 2
|
|
|
|
|
|
elif traffic_light == 'TRAFFIC_LIGHT_RED':
|
|
|
|
|
|
traffic_light = 3
|
|
|
|
|
|
else:
|
|
|
|
|
|
traffic_light = 0
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
|
2025-10-27 00:27:47 +08:00
|
|
|
|
physics_world=self.engine.physics_world.dynamic_world)
|
2025-09-28 18:57:04 +08:00
|
|
|
|
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
|
2025-10-27 00:27:47 +08:00
|
|
|
|
base_vehicle=vehicle,
|
|
|
|
|
|
physics_world=self.engine.physics_world.static_world)
|
2025-09-28 18:57:04 +08:00
|
|
|
|
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3,
|
2025-10-27 00:27:47 +08:00
|
|
|
|
base_vehicle=vehicle,
|
|
|
|
|
|
physics_world=self.engine.physics_world.static_world)
|
2025-09-28 18:57:04 +08:00
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']]
|
2025-09-28 18:57:04 +08:00
|
|
|
|
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
|
|
|
|
|
|
+ list(vehicle.destination))
|
2025-10-27 00:27:47 +08:00
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
self.obs_list.append(obs)
|
2025-10-27 00:27:47 +08:00
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
return self.obs_list
|
|
|
|
|
|
|
|
|
|
|
|
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
|
|
|
|
|
|
self.round += 1
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
# 应用动作
|
2025-09-28 18:57:04 +08:00
|
|
|
|
for agent_id, action in action_dict.items():
|
|
|
|
|
|
if agent_id in self.controlled_agents:
|
|
|
|
|
|
self.controlled_agents[agent_id].before_step(action)
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
# 物理引擎步进
|
2025-09-28 18:57:04 +08:00
|
|
|
|
self.engine.step()
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
# 后处理
|
2025-09-28 18:57:04 +08:00
|
|
|
|
for agent_id in action_dict:
|
|
|
|
|
|
if agent_id in self.controlled_agents:
|
|
|
|
|
|
self.controlled_agents[agent_id].after_step()
|
|
|
|
|
|
|
2025-10-27 00:27:47 +08:00
|
|
|
|
# 生成新车辆
|
2025-09-28 18:57:04 +08:00
|
|
|
|
self._spawn_controlled_agents()
|
2025-10-27 00:27:47 +08:00
|
|
|
|
|
|
|
|
|
|
# 获取观测
|
2025-09-28 18:57:04 +08:00
|
|
|
|
obs = self._get_all_obs()
|
2025-10-27 00:27:47 +08:00
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
rewards = {aid: 0.0 for aid in self.controlled_agents}
|
|
|
|
|
|
dones = {aid: False for aid in self.controlled_agents}
|
2025-10-27 00:27:47 +08:00
|
|
|
|
|
|
|
|
|
|
# ✅ 修复:添加回放模式的完成检查
|
|
|
|
|
|
replay_finished = False
|
|
|
|
|
|
if self.replay_mode and self.config.get("use_scenario_duration", False):
|
|
|
|
|
|
# 检查是否所有专家轨迹都已播放完毕
|
|
|
|
|
|
if self.round >= self.scenario_max_duration:
|
|
|
|
|
|
replay_finished = True
|
|
|
|
|
|
if self.config.get("debug", False):
|
|
|
|
|
|
self.logger.info(f"Replay finished at step {self.round}/{self.scenario_max_duration}")
|
|
|
|
|
|
|
|
|
|
|
|
dones["__all__"] = self.episode_step >= self.config["horizon"] or replay_finished
|
|
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
infos = {aid: {} for aid in self.controlled_agents}
|
2025-10-27 00:27:47 +08:00
|
|
|
|
|
2025-09-28 18:57:04 +08:00
|
|
|
|
return obs, rewards, dones, infos
|
2025-10-27 00:27:47 +08:00
|
|
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
|
|
# ✅ 清理所有生成的车辆
|
|
|
|
|
|
if hasattr(self, 'controlled_agents') and self.controlled_agents:
|
|
|
|
|
|
for agent_id, vehicle in list(self.controlled_agents.items()):
|
|
|
|
|
|
if vehicle in self.engine.get_objects():
|
|
|
|
|
|
self.engine.clear_objects([vehicle.id])
|
|
|
|
|
|
self.controlled_agents.clear()
|
|
|
|
|
|
self.controlled_agent_ids.clear()
|
|
|
|
|
|
|
|
|
|
|
|
super().close()
|