Files
MAGAIL4AutoDrive/Env/scenario_env.py

503 lines
22 KiB
Python
Raw Normal View History

2025-09-28 18:57:04 +08:00
import numpy as np
from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.component.vehicle.vehicle_type import DefaultVehicle, vehicle_class_to_type
import math
import logging
from collections import defaultdict
from typing import Union, Dict, AnyStr
from metadrive.engine.logger import get_logger, set_log_level
from metadrive.type import MetaDriveType
class PolicyVehicle(DefaultVehicle):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.policy = None
self.destination = None
def set_policy(self, policy):
self.policy = policy
def set_destination(self, des):
self.destination = des
def act(self, observation, policy=None):
if self.policy is not None:
return self.policy.act(observation)
else:
return self.action_space.sample()
def before_step(self, action):
self.last_position = self.position # 2D vector
self.last_velocity = self.velocity # 2D vector
self.last_speed = self.speed # Scalar
self.last_heading_dir = self.heading
if action is not None:
self.last_current_action.append(action)
self._set_action(action)
def is_done(self):
# arrive or crash
pass
vehicle_class_to_type[PolicyVehicle] = "default"
class MultiAgentScenarioEnv(ScenarioEnv):
@classmethod
def default_config(cls):
config = super().default_config()
config.update(dict(
data_directory=None,
num_controlled_agents=3,
horizon=1000,
# 车道检测与过滤配置
filter_offroad_vehicles=True, # 是否过滤非车道区域的车辆
lane_tolerance=3.0, # 车道检测容差(米),用于放宽边界条件
max_controlled_vehicles=None, # 最大可控车辆数限制None表示不限制
# 调试模式配置
debug_traffic_light=False, # 是否启用红绿灯检测调试输出
debug_lane_filter=False, # 是否启用车道过滤调试输出
2025-09-28 18:57:04 +08:00
))
return config
def __init__(self, config, agent2policy):
self.policy = agent2policy
self.controlled_agents = {}
self.controlled_agent_ids = []
self.obs_list = []
self.round = 0
# 调试模式配置
self.debug_traffic_light = config.get("debug_traffic_light", False)
self.debug_lane_filter = config.get("debug_lane_filter", False)
2025-09-28 18:57:04 +08:00
super().__init__(config)
def reset(self, seed: Union[None, int] = None):
self.round = 0
if self.logger is None:
self.logger = get_logger()
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
set_log_level(log_level)
self.lazy_init()
self._reset_global_seed(seed)
if self.engine is None:
raise ValueError("Broken MetaDrive instance.")
# 在engine.reset()之前清理对象
self.before_reset()
2025-09-28 18:57:04 +08:00
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
_obj_to_clean_this_frame = []
self.car_birth_info_list = []
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
continue
else:
if track["type"] == MetaDriveType.VEHICLE:
_obj_to_clean_this_frame.append(scenario_id)
valid = track['state']['valid']
first_show = np.argmax(valid) if valid.any() else -1
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
# id出现时间出生点坐标出生朝向目的地
self.car_birth_info_list.append({
'id': track['metadata']['object_id'],
'show_time': first_show,
'begin': (track['state']['position'][first_show, 0], track['state']['position'][first_show, 1]),
'heading': track['state']['heading'][first_show],
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
})
for scenario_id in _obj_to_clean_this_frame:
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
self.engine.reset()
self.reset_sensors()
self.engine.taskMgr.step()
self.lanes = self.engine.map_manager.current_map.road_network.graph
# 调试:场景信息统计
if self.debug_lane_filter or self.debug_traffic_light:
print(f"\n📍 场景信息统计:")
print(f" - 总车道数: {len(self.lanes)}")
# 统计红绿灯数量
if self.debug_traffic_light:
traffic_light_lanes = []
for lane in self.lanes.values():
if self.engine.light_manager.has_traffic_light(lane.lane.index):
traffic_light_lanes.append(lane.lane.index)
print(f" - 有红绿灯的车道数: {len(traffic_light_lanes)}")
if len(traffic_light_lanes) > 0:
print(f" 车道索引: {traffic_light_lanes[:5]}" +
(f" ... 共{len(traffic_light_lanes)}" if len(traffic_light_lanes) > 5 else ""))
else:
print(f" ⚠️ 场景中没有红绿灯!")
# 在获取车道信息后,进行车道区域过滤
total_cars_before = len(self.car_birth_info_list)
valid_count, filtered_count, filtered_list = self._filter_valid_spawn_positions()
# 输出过滤信息
if filtered_count > 0:
self.logger.warning(f"车辆生成位置过滤: 原始 {total_cars_before} 辆, "
f"有效 {valid_count} 辆, 过滤 {filtered_count}")
for filtered_car in filtered_list[:5]: # 只显示前5个
self.logger.debug(f" - 过滤车辆 ID={filtered_car['id']}, "
f"位置={filtered_car['position']}, "
f"原因={filtered_car['reason']}")
if filtered_count > 5:
self.logger.debug(f" - ... 还有 {filtered_count - 5} 辆车被过滤")
# 限制最大车辆数(在过滤后应用)
max_vehicles = self.config.get("max_controlled_vehicles", None)
if max_vehicles is not None and len(self.car_birth_info_list) > max_vehicles:
self.car_birth_info_list = self.car_birth_info_list[:max_vehicles]
self.logger.info(f"限制最大车辆数为 {max_vehicles}")
self.logger.info(f"最终生成 {len(self.car_birth_info_list)} 辆可控车辆")
2025-09-28 18:57:04 +08:00
if self.top_down_renderer is not None:
self.top_down_renderer.clear()
self.engine.top_down_renderer = None
self.dones = {}
self.episode_rewards = defaultdict(float)
self.episode_lengths = defaultdict(int)
# 调用父类reset会清理场景
2025-09-28 18:57:04 +08:00
super().reset(seed) # 初始化场景
# 重新生成车辆
2025-09-28 18:57:04 +08:00
self._spawn_controlled_agents()
return self._get_all_obs()
def _is_position_on_lane(self, position, tolerance=None):
"""
检测给定位置是否在有效车道范围内
Args:
position: (x, y) 车辆位置坐标
tolerance: 容差范围用于放宽检测条件None时使用配置中的默认值
Returns:
bool: True表示在车道上False表示在非车道区域如草坪停车场等
"""
if not hasattr(self, 'lanes') or self.lanes is None:
if self.debug_lane_filter:
print(f" ⚠️ 车道信息未初始化,默认允许")
return True # 如果车道信息未初始化,默认允许生成
if tolerance is None:
tolerance = self.config.get("lane_tolerance", 3.0)
position_2d = (position[0], position[1])
if self.debug_lane_filter:
print(f" 🔍 检测位置 ({position_2d[0]:.2f}, {position_2d[1]:.2f}), 容差={tolerance}m")
# 方法1直接检测是否在任一车道上
checked_lanes = 0
for lane in self.lanes.values():
try:
checked_lanes += 1
if lane.lane.point_on_lane(position_2d):
if self.debug_lane_filter:
print(f" ✅ 在车道上 (车道{lane.lane.index}, 检查了{checked_lanes}条)")
return True
except:
continue
if self.debug_lane_filter:
print(f" ❌ 不在任何车道上 (检查了{checked_lanes}条车道)")
# 方法2如果严格检测失败使用容差范围检测考虑车道边缘
# 注释:此方法已被禁用,如需启用请取消注释
# if tolerance > 0:
# for lane in self.lanes.values():
# try:
# # 计算点到车道中心线的距离
# lane_obj = lane.lane
# # 获取车道长度并检测最近点
# s, lateral = lane_obj.local_coordinates(position_2d)
# # 如果横向距离在容差范围内,认为是有效的
# if abs(lateral) <= tolerance and 0 <= s <= lane_obj.length:
# return True
# except:
# continue
return False
def _filter_valid_spawn_positions(self):
"""
过滤掉生成位置不在有效车道上的车辆信息
根据配置决定是否执行过滤
Returns:
tuple: (有效车辆数量, 被过滤车辆数量, 被过滤车辆ID列表)
"""
# 如果配置中禁用了过滤,直接返回
if not self.config.get("filter_offroad_vehicles", True):
if self.debug_lane_filter:
print(f"🚫 车道过滤已禁用")
return len(self.car_birth_info_list), 0, []
if self.debug_lane_filter:
print(f"\n🔍 开始车道过滤: 共 {len(self.car_birth_info_list)} 辆车待检测")
valid_cars = []
filtered_cars = []
tolerance = self.config.get("lane_tolerance", 3.0)
for idx, car in enumerate(self.car_birth_info_list):
if self.debug_lane_filter:
print(f"\n车辆 {idx+1}/{len(self.car_birth_info_list)}: ID={car['id']}")
if self._is_position_on_lane(car['begin'], tolerance=tolerance):
valid_cars.append(car)
if self.debug_lane_filter:
print(f" ✅ 保留")
else:
filtered_cars.append({
'id': car['id'],
'position': car['begin'],
'reason': '生成位置不在有效车道上(可能在草坪/停车场等区域)'
})
if self.debug_lane_filter:
print(f" ❌ 过滤 (原因: 不在车道上)")
self.car_birth_info_list = valid_cars
if self.debug_lane_filter:
print(f"\n📊 过滤结果: 保留 {len(valid_cars)} 辆, 过滤 {len(filtered_cars)}")
return len(valid_cars), len(filtered_cars), filtered_cars
2025-09-28 18:57:04 +08:00
def _spawn_controlled_agents(self):
# ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent")
# ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0])
for car in self.car_birth_info_list:
if car['show_time'] == self.round:
agent_id = f"controlled_{car['id']}"
vehicle = self.engine.spawn_object(
PolicyVehicle,
vehicle_config={},
position=car['begin'],
heading=car['heading']
)
vehicle.reset(position=car['begin'], heading=car['heading'])
vehicle.set_policy(self.policy)
vehicle.set_destination(car['end'])
self.controlled_agents[agent_id] = vehicle
self.controlled_agent_ids.append(agent_id)
# ✅ 关键:注册到引擎的 active_agents才能参与物理更新
self.engine.agent_manager.active_agents[agent_id] = vehicle
def before_reset(self):
"""在reset之前清理对象"""
# 清理所有可控车辆
if hasattr(self, 'controlled_agents') and hasattr(self, 'engine'):
# 使用MetaDrive的clear_objects方法清理
if hasattr(self.engine, 'clear_objects'):
try:
self.engine.clear_objects(list(self.controlled_agents.keys()))
except:
pass
# 从agent_manager中移除
if hasattr(self.engine, 'agent_manager'):
for agent_id in list(self.controlled_agents.keys()):
if agent_id in self.engine.agent_manager.active_agents:
self.engine.agent_manager.active_agents.pop(agent_id)
self.controlled_agents.clear()
self.controlled_agent_ids.clear()
2025-09-28 18:57:04 +08:00
def _get_traffic_light_state(self, vehicle):
"""
获取车辆当前位置的红绿灯状态优化版
解决问题
1. 部分红绿灯状态为None的问题 - 添加异常处理和默认值
2. 车道分段导致无法获取红绿灯的问题 - 优先使用导航模块失败时回退到遍历
Returns:
int: 0=无红绿灯, 1=绿灯, 2=黄灯, 3=红灯
"""
traffic_light = 0
state = vehicle.get_state()
position_2d = state['position'][:2]
if self.debug_traffic_light:
print(f"\n🚦 检测车辆红绿灯 - 位置: ({position_2d[0]:.1f}, {position_2d[1]:.1f})")
try:
# 方法1优先尝试从车辆导航模块获取当前车道更高效
if hasattr(vehicle, 'navigation') and vehicle.navigation is not None:
current_lane = vehicle.navigation.current_lane
if self.debug_traffic_light:
print(f" 方法1-导航模块:")
print(f" current_lane = {current_lane}")
print(f" lane_index = {current_lane.index if current_lane else 'None'}")
if current_lane:
has_light = self.engine.light_manager.has_traffic_light(current_lane.index)
if self.debug_traffic_light:
print(f" has_traffic_light = {has_light}")
if has_light:
status = self.engine.light_manager._lane_index_to_obj[current_lane.index].status
if self.debug_traffic_light:
print(f" status = {status}")
if status == 'TRAFFIC_LIGHT_GREEN':
if self.debug_traffic_light:
print(f" ✅ 方法1成功: 绿灯")
return 1
elif status == 'TRAFFIC_LIGHT_YELLOW':
if self.debug_traffic_light:
print(f" ✅ 方法1成功: 黄灯")
return 2
elif status == 'TRAFFIC_LIGHT_RED':
if self.debug_traffic_light:
print(f" ✅ 方法1成功: 红灯")
return 3
elif status is None:
if self.debug_traffic_light:
print(f" ⚠️ 方法1: 红绿灯状态为None")
return 0
else:
if self.debug_traffic_light:
print(f" 该车道没有红绿灯")
else:
if self.debug_traffic_light:
print(f" 导航模块current_lane为None")
else:
if self.debug_traffic_light:
has_nav = hasattr(vehicle, 'navigation')
nav_not_none = vehicle.navigation is not None if has_nav else False
print(f" 方法1-导航模块: 不可用 (hasattr={has_nav}, not_none={nav_not_none})")
except Exception as e:
if self.debug_traffic_light:
print(f" ❌ 方法1异常: {type(e).__name__}: {e}")
pass
try:
# 方法2遍历所有车道查找兜底方案处理车道分段问题
if self.debug_traffic_light:
print(f" 方法2-遍历车道: 开始遍历 {len(self.lanes)} 条车道")
found_lane = False
checked_lanes = 0
for lane in self.lanes.values():
try:
checked_lanes += 1
if lane.lane.point_on_lane(position_2d):
found_lane = True
if self.debug_traffic_light:
print(f" ✓ 找到车辆所在车道: {lane.lane.index} (检查了{checked_lanes}条)")
has_light = self.engine.light_manager.has_traffic_light(lane.lane.index)
if self.debug_traffic_light:
print(f" has_traffic_light = {has_light}")
if has_light:
status = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
if self.debug_traffic_light:
print(f" status = {status}")
if status == 'TRAFFIC_LIGHT_GREEN':
if self.debug_traffic_light:
print(f" ✅ 方法2成功: 绿灯")
return 1
elif status == 'TRAFFIC_LIGHT_YELLOW':
if self.debug_traffic_light:
print(f" ✅ 方法2成功: 黄灯")
return 2
elif status == 'TRAFFIC_LIGHT_RED':
if self.debug_traffic_light:
print(f" ✅ 方法2成功: 红灯")
return 3
elif status is None:
if self.debug_traffic_light:
print(f" ⚠️ 方法2: 红绿灯状态为None")
return 0
else:
if self.debug_traffic_light:
print(f" 该车道没有红绿灯")
break
except:
continue
if self.debug_traffic_light and not found_lane:
print(f" ⚠️ 未找到车辆所在车道 (检查了{checked_lanes}条)")
except Exception as e:
if self.debug_traffic_light:
print(f" ❌ 方法2异常: {type(e).__name__}: {e}")
pass
if self.debug_traffic_light:
print(f" 结果: 返回 {traffic_light} (无红绿灯/未知)")
return traffic_light
2025-09-28 18:57:04 +08:00
def _get_all_obs(self):
# position, velocity, heading, lidar, navigation, TODO: trafficlight -> list
self.obs_list = []
for agent_id, vehicle in self.controlled_agents.items():
state = vehicle.get_state()
# 使用优化后的红绿灯检测方法
traffic_light = self._get_traffic_light_state(vehicle)
2025-09-28 18:57:04 +08:00
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
physics_world=self.engine.physics_world.dynamic_world)
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
obs = (state['position'][:2] + list(state['velocity']) + [state['heading_theta']]
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
+ list(vehicle.destination))
self.obs_list.append(obs)
return self.obs_list
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
self.round += 1
for agent_id, action in action_dict.items():
if agent_id in self.controlled_agents:
self.controlled_agents[agent_id].before_step(action)
self.engine.step()
for agent_id in action_dict:
if agent_id in self.controlled_agents:
self.controlled_agents[agent_id].after_step()
self._spawn_controlled_agents()
obs = self._get_all_obs()
rewards = {aid: 0.0 for aid in self.controlled_agents}
dones = {aid: False for aid in self.controlled_agents}
dones["__all__"] = self.episode_step >= self.config["horizon"]
infos = {aid: {} for aid in self.controlled_agents}
return obs, rewards, dones, infos