import numpy as np from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation from metadrive.envs.scenario_env import ScenarioEnv from metadrive.component.vehicle.vehicle_type import DefaultVehicle, vehicle_class_to_type import math import logging from collections import defaultdict from typing import Union, Dict, AnyStr from metadrive.engine.logger import get_logger, set_log_level from metadrive.type import MetaDriveType class PolicyVehicle(DefaultVehicle): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.policy = None self.destination = None self.expert_vehicle_id = None # 关联专家车辆ID def set_policy(self, policy): self.policy = policy def set_destination(self, des): self.destination = des def set_expert_vehicle_id(self, vid): self.expert_vehicle_id = vid def act(self, observation, policy=None): if self.policy is not None: return self.policy.act(observation) else: return self.action_space.sample() def before_step(self, action): self.last_position = self.position # 2D vector self.last_velocity = self.velocity # 2D vector self.last_speed = self.speed # Scalar self.last_heading_dir = self.heading if action is not None: self.last_current_action.append(action) self._set_action(action) def is_done(self): # arrive or crash pass vehicle_class_to_type[PolicyVehicle] = "default" class MultiAgentScenarioEnv(ScenarioEnv): @classmethod def default_config(cls): config = super().default_config() config.update(dict( data_directory=None, num_controlled_agents=3, horizon=1000, filter_offroad_vehicles=True, # 车道过滤开关 lane_tolerance=3.0, # 车道检测容差(米) replay_mode=False, # 回放模式开关 specific_scenario_id=None, # 新增:指定场景ID(仅回放模式) use_scenario_duration=False, # 新增:使用场景原始时长作为horizon # 对象类型过滤选项 spawn_vehicles=True, # 是否生成车辆 spawn_pedestrians=True, # 是否生成行人 spawn_cyclists=True, # 是否生成自行车 )) return config def __init__(self, config, agent2policy): self.policy = agent2policy self.controlled_agents = {} self.controlled_agent_ids = [] self.obs_list = [] self.round = 0 self.expert_trajectories = {} # 存储完整专家轨迹 self.replay_mode = config.get("replay_mode", False) self.scenario_max_duration = 0 # 场景实际最大时长 super().__init__(config) def reset(self, seed: Union[None, int] = None): self.round = 0 if self.logger is None: self.logger = get_logger() log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO) set_log_level(log_level) # ✅ 关键修复:在每次 reset 前清理所有自定义生成的对象 if hasattr(self, 'engine') and self.engine is not None: if hasattr(self, 'controlled_agents') and self.controlled_agents: # 先从 agent_manager 中移除 if hasattr(self.engine, 'agent_manager'): for agent_id in list(self.controlled_agents.keys()): if agent_id in self.engine.agent_manager.active_agents: self.engine.agent_manager.active_agents.pop(agent_id) # 然后清理对象 for agent_id, vehicle in list(self.controlled_agents.items()): try: self.engine.clear_objects([vehicle.id]) except: pass self.controlled_agents.clear() self.controlled_agent_ids.clear() self.lazy_init() self._reset_global_seed(seed) if self.engine is None: raise ValueError("Broken MetaDrive instance.") # 如果指定了场景ID,修改start_scenario_index if self.config.get("specific_scenario_id") is not None: scenario_id = self.config.get("specific_scenario_id") self.config["start_scenario_index"] = scenario_id if self.config.get("debug", False): self.logger.info(f"Using specific scenario ID: {scenario_id}") # ✅ 先初始化引擎和 lanes self.engine.reset() self.reset_sensors() self.engine.taskMgr.step() self.lanes = self.engine.map_manager.current_map.road_network.graph # 记录专家数据(现在 self.lanes 已经初始化) _obj_to_clean_this_frame = [] self.car_birth_info_list = [] self.expert_trajectories.clear() total_vehicles = 0 total_pedestrians = 0 total_cyclists = 0 filtered_vehicles = 0 filtered_by_type = 0 self.scenario_max_duration = 0 # 重置场景时长 for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items(): if scenario_id == self.engine.traffic_manager.sdc_scenario_id: continue # 对象类型过滤 obj_type = track["type"] # 统计对象类型 if obj_type == MetaDriveType.VEHICLE: total_vehicles += 1 elif obj_type == MetaDriveType.PEDESTRIAN: total_pedestrians += 1 elif obj_type == MetaDriveType.CYCLIST: total_cyclists += 1 # 根据配置过滤对象类型 if obj_type == MetaDriveType.VEHICLE and not self.config.get("spawn_vehicles", True): _obj_to_clean_this_frame.append(scenario_id) filtered_by_type += 1 if self.config.get("debug", False): self.logger.debug(f"Filtering VEHICLE {track['metadata']['object_id']} - spawn_vehicles=False") continue if obj_type == MetaDriveType.PEDESTRIAN and not self.config.get("spawn_pedestrians", True): _obj_to_clean_this_frame.append(scenario_id) filtered_by_type += 1 if self.config.get("debug", False): self.logger.debug(f"Filtering PEDESTRIAN {track['metadata']['object_id']} - spawn_pedestrians=False") continue if obj_type == MetaDriveType.CYCLIST and not self.config.get("spawn_cyclists", True): _obj_to_clean_this_frame.append(scenario_id) filtered_by_type += 1 if self.config.get("debug", False): self.logger.debug(f"Filtering CYCLIST {track['metadata']['object_id']} - spawn_cyclists=False") continue # 只处理车辆类型(行人和自行车暂时只做过滤) if track["type"] == MetaDriveType.VEHICLE: valid = track['state']['valid'] first_show = np.argmax(valid) if valid.any() else -1 last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1 if first_show == -1 or last_show == -1: continue # 更新场景最大时长 self.scenario_max_duration = max(self.scenario_max_duration, last_show + 1) # 获取车辆初始位置 initial_position = ( track['state']['position'][first_show, 0], track['state']['position'][first_show, 1] ) # 车道过滤 if self.config.get("filter_offroad_vehicles", True): if not self._is_position_on_lane(initial_position): filtered_vehicles += 1 _obj_to_clean_this_frame.append(scenario_id) if self.config.get("debug", False): self.logger.debug( f"Filtering vehicle {track['metadata']['object_id']} - " f"not on lane at position {initial_position}" ) continue # 存储完整专家轨迹(只使用2D位置,避免高度问题) object_id = track['metadata']['object_id'] positions_2d = track['state']['position'].copy() positions_2d[:, 2] = 0 # 将z坐标设为0,让MetaDrive自动处理高度 self.expert_trajectories[object_id] = { 'positions': positions_2d, 'headings': track['state']['heading'].copy(), 'velocities': track['state']['velocity'].copy(), 'valid': track['state']['valid'].copy(), } # 保存车辆生成信息 self.car_birth_info_list.append({ 'id': object_id, 'show_time': first_show, 'begin': initial_position, 'heading': track['state']['heading'][first_show], 'velocity': track['state']['velocity'][first_show] if self.config.get("inherit_expert_velocity", False) else None, 'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1]) }) # 在回放和仿真模式下都清除原始专家车辆 _obj_to_clean_this_frame.append(scenario_id) # 清除专家车辆和过滤的对象 for scenario_id in _obj_to_clean_this_frame: self.engine.traffic_manager.current_traffic_data.pop(scenario_id) # 输出统计信息 if self.config.get("debug", False): self.logger.info(f"=== 对象统计 ===") self.logger.info(f"车辆 (VEHICLE): 总数={total_vehicles}, 车道过滤={filtered_vehicles}, 保留={total_vehicles - filtered_vehicles}") self.logger.info(f"行人 (PEDESTRIAN): 总数={total_pedestrians}") self.logger.info(f"自行车 (CYCLIST): 总数={total_cyclists}") self.logger.info(f"类型过滤: {filtered_by_type} 个对象") self.logger.info(f"场景时长: {self.scenario_max_duration} 步") # 如果启用场景时长控制,更新horizon if self.config.get("use_scenario_duration", False) and self.scenario_max_duration > 0: original_horizon = self.config["horizon"] self.config["horizon"] = self.scenario_max_duration if self.config.get("debug", False): self.logger.info(f"Horizon updated from {original_horizon} to {self.scenario_max_duration} (scenario duration)") if self.top_down_renderer is not None: self.top_down_renderer.clear() self.engine.top_down_renderer = None self.dones = {} self.episode_rewards = defaultdict(float) self.episode_lengths = defaultdict(int) self.controlled_agents.clear() self.controlled_agent_ids.clear() super().reset(seed) # 初始化场景 self._spawn_controlled_agents() return self._get_all_obs() def _is_position_on_lane(self, position, tolerance=None): if tolerance is None: tolerance = self.config.get("lane_tolerance", 3.0) # 确保 self.lanes 已初始化 if not hasattr(self, 'lanes') or self.lanes is None: if self.config.get("debug", False): self.logger.warning("Lanes not initialized, skipping lane check") return True position_2d = np.array(position[:2]) if len(position) > 2 else np.array(position) try: for lane in self.lanes.values(): if lane.lane.point_on_lane(position_2d): return True lane_start = np.array(lane.lane.start)[:2] lane_end = np.array(lane.lane.end)[:2] lane_vec = lane_end - lane_start lane_length = np.linalg.norm(lane_vec) if lane_length < 1e-6: continue lane_vec_normalized = lane_vec / lane_length point_vec = position_2d - lane_start projection = np.dot(point_vec, lane_vec_normalized) if 0 <= projection <= lane_length: closest_point = lane_start + projection * lane_vec_normalized distance = np.linalg.norm(position_2d - closest_point) if distance <= tolerance: return True except Exception as e: if self.config.get("debug", False): self.logger.warning(f"Lane check error: {e}") return False return False def _spawn_controlled_agents(self): for car in self.car_birth_info_list: if car['show_time'] == self.round: agent_id = f"controlled_{car['id']}" vehicle_config = {} vehicle = self.engine.spawn_object( PolicyVehicle, vehicle_config=vehicle_config, position=car['begin'], heading=car['heading'] ) # 重置车辆状态 reset_kwargs = { 'position': car['begin'], 'heading': car['heading'] } # 如果启用速度继承,设置初始速度 if car.get('velocity') is not None: reset_kwargs['velocity'] = car['velocity'] vehicle.reset(**reset_kwargs) # 设置策略和目的地 vehicle.set_policy(self.policy) vehicle.set_destination(car['end']) vehicle.set_expert_vehicle_id(car['id']) self.controlled_agents[agent_id] = vehicle self.controlled_agent_ids.append(agent_id) # 注册到引擎的 active_agents self.engine.agent_manager.active_agents[agent_id] = vehicle if self.config.get("debug", False): self.logger.debug(f"Spawned vehicle {agent_id} at round {self.round}, position {car['begin']}") def _get_all_obs(self): self.obs_list = [] for agent_id, vehicle in self.controlled_agents.items(): state = vehicle.get_state() traffic_light = 0 for lane in self.lanes.values(): if lane.lane.point_on_lane(state['position'][:2]): if self.engine.light_manager.has_traffic_light(lane.lane.index): traffic_light = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status if traffic_light == 'TRAFFIC_LIGHT_GREEN': traffic_light = 1 elif traffic_light == 'TRAFFIC_LIGHT_YELLOW': traffic_light = 2 elif traffic_light == 'TRAFFIC_LIGHT_RED': traffic_light = 3 else: traffic_light = 0 break lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle, physics_world=self.engine.physics_world.dynamic_world) side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8, base_vehicle=vehicle, physics_world=self.engine.physics_world.static_world) lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3, base_vehicle=vehicle, physics_world=self.engine.physics_world.static_world) obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']] + lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light] + list(vehicle.destination)) self.obs_list.append(obs) return self.obs_list def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]): self.round += 1 # 应用动作 for agent_id, action in action_dict.items(): if agent_id in self.controlled_agents: self.controlled_agents[agent_id].before_step(action) # 物理引擎步进 self.engine.step() # 后处理 for agent_id in action_dict: if agent_id in self.controlled_agents: self.controlled_agents[agent_id].after_step() # 生成新车辆 self._spawn_controlled_agents() # 获取观测 obs = self._get_all_obs() rewards = {aid: 0.0 for aid in self.controlled_agents} dones = {aid: False for aid in self.controlled_agents} # ✅ 修复:添加回放模式的完成检查 replay_finished = False if self.replay_mode and self.config.get("use_scenario_duration", False): # 检查是否所有专家轨迹都已播放完毕 if self.round >= self.scenario_max_duration: replay_finished = True if self.config.get("debug", False): self.logger.info(f"Replay finished at step {self.round}/{self.scenario_max_duration}") dones["__all__"] = self.episode_step >= self.config["horizon"] or replay_finished infos = {aid: {} for aid in self.controlled_agents} return obs, rewards, dones, infos def close(self): # ✅ 清理所有生成的车辆 if hasattr(self, 'controlled_agents') and self.controlled_agents: for agent_id, vehicle in list(self.controlled_agents.items()): if vehicle in self.engine.get_objects(): self.engine.clear_objects([vehicle.id]) self.controlled_agents.clear() self.controlled_agent_ids.clear() super().close()