Files
MAGAIL4AutoDrive/Env/scenario_env.py

433 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.component.vehicle.vehicle_type import DefaultVehicle, vehicle_class_to_type
import math
import logging
from collections import defaultdict
from typing import Union, Dict, AnyStr
from metadrive.engine.logger import get_logger, set_log_level
from metadrive.type import MetaDriveType
class PolicyVehicle(DefaultVehicle):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.policy = None
self.destination = None
self.expert_vehicle_id = None # 关联专家车辆ID
def set_policy(self, policy):
self.policy = policy
def set_destination(self, des):
self.destination = des
def set_expert_vehicle_id(self, vid):
self.expert_vehicle_id = vid
def act(self, observation, policy=None):
if self.policy is not None:
return self.policy.act(observation)
else:
return self.action_space.sample()
def before_step(self, action):
self.last_position = self.position # 2D vector
self.last_velocity = self.velocity # 2D vector
self.last_speed = self.speed # Scalar
self.last_heading_dir = self.heading
if action is not None:
self.last_current_action.append(action)
self._set_action(action)
def is_done(self):
# arrive or crash
pass
vehicle_class_to_type[PolicyVehicle] = "default"
class MultiAgentScenarioEnv(ScenarioEnv):
@classmethod
def default_config(cls):
config = super().default_config()
config.update(dict(
data_directory=None,
num_controlled_agents=3,
horizon=1000,
filter_offroad_vehicles=True, # 车道过滤开关
lane_tolerance=3.0, # 车道检测容差(米)
replay_mode=False, # 回放模式开关
specific_scenario_id=None, # 新增指定场景ID仅回放模式
use_scenario_duration=False, # 新增使用场景原始时长作为horizon
# 对象类型过滤选项
spawn_vehicles=True, # 是否生成车辆
spawn_pedestrians=True, # 是否生成行人
spawn_cyclists=True, # 是否生成自行车
))
return config
def __init__(self, config, agent2policy):
self.policy = agent2policy
self.controlled_agents = {}
self.controlled_agent_ids = []
self.obs_list = []
self.round = 0
self.expert_trajectories = {} # 存储完整专家轨迹
self.replay_mode = config.get("replay_mode", False)
self.scenario_max_duration = 0 # 场景实际最大时长
super().__init__(config)
def reset(self, seed: Union[None, int] = None):
self.round = 0
if self.logger is None:
self.logger = get_logger()
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
set_log_level(log_level)
# ✅ 关键修复:在每次 reset 前清理所有自定义生成的对象
if hasattr(self, 'engine') and self.engine is not None:
if hasattr(self, 'controlled_agents') and self.controlled_agents:
# 先从 agent_manager 中移除
if hasattr(self.engine, 'agent_manager'):
for agent_id in list(self.controlled_agents.keys()):
if agent_id in self.engine.agent_manager.active_agents:
self.engine.agent_manager.active_agents.pop(agent_id)
# 然后清理对象
for agent_id, vehicle in list(self.controlled_agents.items()):
try:
self.engine.clear_objects([vehicle.id])
except:
pass
self.controlled_agents.clear()
self.controlled_agent_ids.clear()
self.lazy_init()
self._reset_global_seed(seed)
if self.engine is None:
raise ValueError("Broken MetaDrive instance.")
# 如果指定了场景ID修改start_scenario_index
if self.config.get("specific_scenario_id") is not None:
scenario_id = self.config.get("specific_scenario_id")
self.config["start_scenario_index"] = scenario_id
if self.config.get("debug", False):
self.logger.info(f"Using specific scenario ID: {scenario_id}")
# ✅ 先初始化引擎和 lanes
self.engine.reset()
self.reset_sensors()
self.engine.taskMgr.step()
self.lanes = self.engine.map_manager.current_map.road_network.graph
# 记录专家数据(现在 self.lanes 已经初始化)
_obj_to_clean_this_frame = []
self.car_birth_info_list = []
self.expert_trajectories.clear()
total_vehicles = 0
total_pedestrians = 0
total_cyclists = 0
filtered_vehicles = 0
filtered_by_type = 0
self.scenario_max_duration = 0 # 重置场景时长
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
continue
# 对象类型过滤
obj_type = track["type"]
# 统计对象类型
if obj_type == MetaDriveType.VEHICLE:
total_vehicles += 1
elif obj_type == MetaDriveType.PEDESTRIAN:
total_pedestrians += 1
elif obj_type == MetaDriveType.CYCLIST:
total_cyclists += 1
# 根据配置过滤对象类型
if obj_type == MetaDriveType.VEHICLE and not self.config.get("spawn_vehicles", True):
_obj_to_clean_this_frame.append(scenario_id)
filtered_by_type += 1
if self.config.get("debug", False):
self.logger.debug(f"Filtering VEHICLE {track['metadata']['object_id']} - spawn_vehicles=False")
continue
if obj_type == MetaDriveType.PEDESTRIAN and not self.config.get("spawn_pedestrians", True):
_obj_to_clean_this_frame.append(scenario_id)
filtered_by_type += 1
if self.config.get("debug", False):
self.logger.debug(f"Filtering PEDESTRIAN {track['metadata']['object_id']} - spawn_pedestrians=False")
continue
if obj_type == MetaDriveType.CYCLIST and not self.config.get("spawn_cyclists", True):
_obj_to_clean_this_frame.append(scenario_id)
filtered_by_type += 1
if self.config.get("debug", False):
self.logger.debug(f"Filtering CYCLIST {track['metadata']['object_id']} - spawn_cyclists=False")
continue
# 只处理车辆类型(行人和自行车暂时只做过滤)
if track["type"] == MetaDriveType.VEHICLE:
valid = track['state']['valid']
first_show = np.argmax(valid) if valid.any() else -1
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
if first_show == -1 or last_show == -1:
continue
# 更新场景最大时长
self.scenario_max_duration = max(self.scenario_max_duration, last_show + 1)
# 获取车辆初始位置
initial_position = (
track['state']['position'][first_show, 0],
track['state']['position'][first_show, 1]
)
# 车道过滤
if self.config.get("filter_offroad_vehicles", True):
if not self._is_position_on_lane(initial_position):
filtered_vehicles += 1
_obj_to_clean_this_frame.append(scenario_id)
if self.config.get("debug", False):
self.logger.debug(
f"Filtering vehicle {track['metadata']['object_id']} - "
f"not on lane at position {initial_position}"
)
continue
# 存储完整专家轨迹只使用2D位置避免高度问题
object_id = track['metadata']['object_id']
positions_2d = track['state']['position'].copy()
positions_2d[:, 2] = 0 # 将z坐标设为0让MetaDrive自动处理高度
self.expert_trajectories[object_id] = {
'positions': positions_2d,
'headings': track['state']['heading'].copy(),
'velocities': track['state']['velocity'].copy(),
'valid': track['state']['valid'].copy(),
}
# 保存车辆生成信息
self.car_birth_info_list.append({
'id': object_id,
'show_time': first_show,
'begin': initial_position,
'heading': track['state']['heading'][first_show],
'velocity': track['state']['velocity'][first_show] if self.config.get("inherit_expert_velocity", False) else None,
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
})
# 在回放和仿真模式下都清除原始专家车辆
_obj_to_clean_this_frame.append(scenario_id)
# 清除专家车辆和过滤的对象
for scenario_id in _obj_to_clean_this_frame:
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
# 输出统计信息
if self.config.get("debug", False):
self.logger.info(f"=== 对象统计 ===")
self.logger.info(f"车辆 (VEHICLE): 总数={total_vehicles}, 车道过滤={filtered_vehicles}, 保留={total_vehicles - filtered_vehicles}")
self.logger.info(f"行人 (PEDESTRIAN): 总数={total_pedestrians}")
self.logger.info(f"自行车 (CYCLIST): 总数={total_cyclists}")
self.logger.info(f"类型过滤: {filtered_by_type} 个对象")
self.logger.info(f"场景时长: {self.scenario_max_duration}")
# 如果启用场景时长控制更新horizon
if self.config.get("use_scenario_duration", False) and self.scenario_max_duration > 0:
original_horizon = self.config["horizon"]
self.config["horizon"] = self.scenario_max_duration
if self.config.get("debug", False):
self.logger.info(f"Horizon updated from {original_horizon} to {self.scenario_max_duration} (scenario duration)")
if self.top_down_renderer is not None:
self.top_down_renderer.clear()
self.engine.top_down_renderer = None
self.dones = {}
self.episode_rewards = defaultdict(float)
self.episode_lengths = defaultdict(int)
self.controlled_agents.clear()
self.controlled_agent_ids.clear()
super().reset(seed) # 初始化场景
self._spawn_controlled_agents()
return self._get_all_obs()
def _is_position_on_lane(self, position, tolerance=None):
if tolerance is None:
tolerance = self.config.get("lane_tolerance", 3.0)
# 确保 self.lanes 已初始化
if not hasattr(self, 'lanes') or self.lanes is None:
if self.config.get("debug", False):
self.logger.warning("Lanes not initialized, skipping lane check")
return True
position_2d = np.array(position[:2]) if len(position) > 2 else np.array(position)
try:
for lane in self.lanes.values():
if lane.lane.point_on_lane(position_2d):
return True
lane_start = np.array(lane.lane.start)[:2]
lane_end = np.array(lane.lane.end)[:2]
lane_vec = lane_end - lane_start
lane_length = np.linalg.norm(lane_vec)
if lane_length < 1e-6:
continue
lane_vec_normalized = lane_vec / lane_length
point_vec = position_2d - lane_start
projection = np.dot(point_vec, lane_vec_normalized)
if 0 <= projection <= lane_length:
closest_point = lane_start + projection * lane_vec_normalized
distance = np.linalg.norm(position_2d - closest_point)
if distance <= tolerance:
return True
except Exception as e:
if self.config.get("debug", False):
self.logger.warning(f"Lane check error: {e}")
return False
return False
def _spawn_controlled_agents(self):
for car in self.car_birth_info_list:
if car['show_time'] == self.round:
agent_id = f"controlled_{car['id']}"
vehicle_config = {}
vehicle = self.engine.spawn_object(
PolicyVehicle,
vehicle_config=vehicle_config,
position=car['begin'],
heading=car['heading']
)
# 重置车辆状态
reset_kwargs = {
'position': car['begin'],
'heading': car['heading']
}
# 如果启用速度继承,设置初始速度
if car.get('velocity') is not None:
reset_kwargs['velocity'] = car['velocity']
vehicle.reset(**reset_kwargs)
# 设置策略和目的地
vehicle.set_policy(self.policy)
vehicle.set_destination(car['end'])
vehicle.set_expert_vehicle_id(car['id'])
self.controlled_agents[agent_id] = vehicle
self.controlled_agent_ids.append(agent_id)
# 注册到引擎的 active_agents
self.engine.agent_manager.active_agents[agent_id] = vehicle
if self.config.get("debug", False):
self.logger.debug(f"Spawned vehicle {agent_id} at round {self.round}, position {car['begin']}")
def _get_all_obs(self):
self.obs_list = []
for agent_id, vehicle in self.controlled_agents.items():
state = vehicle.get_state()
traffic_light = 0
for lane in self.lanes.values():
if lane.lane.point_on_lane(state['position'][:2]):
if self.engine.light_manager.has_traffic_light(lane.lane.index):
traffic_light = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
if traffic_light == 'TRAFFIC_LIGHT_GREEN':
traffic_light = 1
elif traffic_light == 'TRAFFIC_LIGHT_YELLOW':
traffic_light = 2
elif traffic_light == 'TRAFFIC_LIGHT_RED':
traffic_light = 3
else:
traffic_light = 0
break
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
physics_world=self.engine.physics_world.dynamic_world)
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']]
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
+ list(vehicle.destination))
self.obs_list.append(obs)
return self.obs_list
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
self.round += 1
# 应用动作
for agent_id, action in action_dict.items():
if agent_id in self.controlled_agents:
self.controlled_agents[agent_id].before_step(action)
# 物理引擎步进
self.engine.step()
# 后处理
for agent_id in action_dict:
if agent_id in self.controlled_agents:
self.controlled_agents[agent_id].after_step()
# 生成新车辆
self._spawn_controlled_agents()
# 获取观测
obs = self._get_all_obs()
rewards = {aid: 0.0 for aid in self.controlled_agents}
dones = {aid: False for aid in self.controlled_agents}
# ✅ 修复:添加回放模式的完成检查
replay_finished = False
if self.replay_mode and self.config.get("use_scenario_duration", False):
# 检查是否所有专家轨迹都已播放完毕
if self.round >= self.scenario_max_duration:
replay_finished = True
if self.config.get("debug", False):
self.logger.info(f"Replay finished at step {self.round}/{self.scenario_max_duration}")
dones["__all__"] = self.episode_step >= self.config["horizon"] or replay_finished
infos = {aid: {} for aid in self.controlled_agents}
return obs, rewards, dones, infos
def close(self):
# ✅ 清理所有生成的车辆
if hasattr(self, 'controlled_agents') and self.controlled_agents:
for agent_id, vehicle in list(self.controlled_agents.items()):
if vehicle in self.engine.get_objects():
self.engine.clear_objects([vehicle.id])
self.controlled_agents.clear()
self.controlled_agent_ids.clear()
super().close()