Files
MAGAIL4AutoDrive/Env/scenario_env.py

503 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.component.vehicle.vehicle_type import DefaultVehicle, vehicle_class_to_type
import math
import logging
from collections import defaultdict
from typing import Union, Dict, AnyStr
from metadrive.engine.logger import get_logger, set_log_level
from metadrive.type import MetaDriveType
class PolicyVehicle(DefaultVehicle):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.policy = None
self.destination = None
def set_policy(self, policy):
self.policy = policy
def set_destination(self, des):
self.destination = des
def act(self, observation, policy=None):
if self.policy is not None:
return self.policy.act(observation)
else:
return self.action_space.sample()
def before_step(self, action):
self.last_position = self.position # 2D vector
self.last_velocity = self.velocity # 2D vector
self.last_speed = self.speed # Scalar
self.last_heading_dir = self.heading
if action is not None:
self.last_current_action.append(action)
self._set_action(action)
def is_done(self):
# arrive or crash
pass
vehicle_class_to_type[PolicyVehicle] = "default"
class MultiAgentScenarioEnv(ScenarioEnv):
@classmethod
def default_config(cls):
config = super().default_config()
config.update(dict(
data_directory=None,
num_controlled_agents=3,
horizon=1000,
# 车道检测与过滤配置
filter_offroad_vehicles=True, # 是否过滤非车道区域的车辆
lane_tolerance=3.0, # 车道检测容差(米),用于放宽边界条件
max_controlled_vehicles=None, # 最大可控车辆数限制None表示不限制
# 调试模式配置
debug_traffic_light=False, # 是否启用红绿灯检测调试输出
debug_lane_filter=False, # 是否启用车道过滤调试输出
))
return config
def __init__(self, config, agent2policy):
self.policy = agent2policy
self.controlled_agents = {}
self.controlled_agent_ids = []
self.obs_list = []
self.round = 0
# 调试模式配置
self.debug_traffic_light = config.get("debug_traffic_light", False)
self.debug_lane_filter = config.get("debug_lane_filter", False)
super().__init__(config)
def reset(self, seed: Union[None, int] = None):
self.round = 0
if self.logger is None:
self.logger = get_logger()
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
set_log_level(log_level)
self.lazy_init()
self._reset_global_seed(seed)
if self.engine is None:
raise ValueError("Broken MetaDrive instance.")
# 在engine.reset()之前清理对象
self.before_reset()
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
_obj_to_clean_this_frame = []
self.car_birth_info_list = []
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
continue
else:
if track["type"] == MetaDriveType.VEHICLE:
_obj_to_clean_this_frame.append(scenario_id)
valid = track['state']['valid']
first_show = np.argmax(valid) if valid.any() else -1
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
# id出现时间出生点坐标出生朝向目的地
self.car_birth_info_list.append({
'id': track['metadata']['object_id'],
'show_time': first_show,
'begin': (track['state']['position'][first_show, 0], track['state']['position'][first_show, 1]),
'heading': track['state']['heading'][first_show],
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
})
for scenario_id in _obj_to_clean_this_frame:
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
self.engine.reset()
self.reset_sensors()
self.engine.taskMgr.step()
self.lanes = self.engine.map_manager.current_map.road_network.graph
# 调试:场景信息统计
if self.debug_lane_filter or self.debug_traffic_light:
print(f"\n📍 场景信息统计:")
print(f" - 总车道数: {len(self.lanes)}")
# 统计红绿灯数量
if self.debug_traffic_light:
traffic_light_lanes = []
for lane in self.lanes.values():
if self.engine.light_manager.has_traffic_light(lane.lane.index):
traffic_light_lanes.append(lane.lane.index)
print(f" - 有红绿灯的车道数: {len(traffic_light_lanes)}")
if len(traffic_light_lanes) > 0:
print(f" 车道索引: {traffic_light_lanes[:5]}" +
(f" ... 共{len(traffic_light_lanes)}" if len(traffic_light_lanes) > 5 else ""))
else:
print(f" ⚠️ 场景中没有红绿灯!")
# 在获取车道信息后,进行车道区域过滤
total_cars_before = len(self.car_birth_info_list)
valid_count, filtered_count, filtered_list = self._filter_valid_spawn_positions()
# 输出过滤信息
if filtered_count > 0:
self.logger.warning(f"车辆生成位置过滤: 原始 {total_cars_before} 辆, "
f"有效 {valid_count} 辆, 过滤 {filtered_count}")
for filtered_car in filtered_list[:5]: # 只显示前5个
self.logger.debug(f" - 过滤车辆 ID={filtered_car['id']}, "
f"位置={filtered_car['position']}, "
f"原因={filtered_car['reason']}")
if filtered_count > 5:
self.logger.debug(f" - ... 还有 {filtered_count - 5} 辆车被过滤")
# 限制最大车辆数(在过滤后应用)
max_vehicles = self.config.get("max_controlled_vehicles", None)
if max_vehicles is not None and len(self.car_birth_info_list) > max_vehicles:
self.car_birth_info_list = self.car_birth_info_list[:max_vehicles]
self.logger.info(f"限制最大车辆数为 {max_vehicles}")
self.logger.info(f"最终生成 {len(self.car_birth_info_list)} 辆可控车辆")
if self.top_down_renderer is not None:
self.top_down_renderer.clear()
self.engine.top_down_renderer = None
self.dones = {}
self.episode_rewards = defaultdict(float)
self.episode_lengths = defaultdict(int)
# 调用父类reset会清理场景
super().reset(seed) # 初始化场景
# 重新生成车辆
self._spawn_controlled_agents()
return self._get_all_obs()
def _is_position_on_lane(self, position, tolerance=None):
"""
检测给定位置是否在有效车道范围内
Args:
position: (x, y) 车辆位置坐标
tolerance: 容差范围用于放宽检测条件。None时使用配置中的默认值
Returns:
bool: True表示在车道上False表示在非车道区域如草坪、停车场等
"""
if not hasattr(self, 'lanes') or self.lanes is None:
if self.debug_lane_filter:
print(f" ⚠️ 车道信息未初始化,默认允许")
return True # 如果车道信息未初始化,默认允许生成
if tolerance is None:
tolerance = self.config.get("lane_tolerance", 3.0)
position_2d = (position[0], position[1])
if self.debug_lane_filter:
print(f" 🔍 检测位置 ({position_2d[0]:.2f}, {position_2d[1]:.2f}), 容差={tolerance}m")
# 方法1直接检测是否在任一车道上
checked_lanes = 0
for lane in self.lanes.values():
try:
checked_lanes += 1
if lane.lane.point_on_lane(position_2d):
if self.debug_lane_filter:
print(f" ✅ 在车道上 (车道{lane.lane.index}, 检查了{checked_lanes}条)")
return True
except:
continue
if self.debug_lane_filter:
print(f" ❌ 不在任何车道上 (检查了{checked_lanes}条车道)")
# 方法2如果严格检测失败使用容差范围检测考虑车道边缘
# 注释:此方法已被禁用,如需启用请取消注释
# if tolerance > 0:
# for lane in self.lanes.values():
# try:
# # 计算点到车道中心线的距离
# lane_obj = lane.lane
# # 获取车道长度并检测最近点
# s, lateral = lane_obj.local_coordinates(position_2d)
# # 如果横向距离在容差范围内,认为是有效的
# if abs(lateral) <= tolerance and 0 <= s <= lane_obj.length:
# return True
# except:
# continue
return False
def _filter_valid_spawn_positions(self):
"""
过滤掉生成位置不在有效车道上的车辆信息
根据配置决定是否执行过滤
Returns:
tuple: (有效车辆数量, 被过滤车辆数量, 被过滤车辆ID列表)
"""
# 如果配置中禁用了过滤,直接返回
if not self.config.get("filter_offroad_vehicles", True):
if self.debug_lane_filter:
print(f"🚫 车道过滤已禁用")
return len(self.car_birth_info_list), 0, []
if self.debug_lane_filter:
print(f"\n🔍 开始车道过滤: 共 {len(self.car_birth_info_list)} 辆车待检测")
valid_cars = []
filtered_cars = []
tolerance = self.config.get("lane_tolerance", 3.0)
for idx, car in enumerate(self.car_birth_info_list):
if self.debug_lane_filter:
print(f"\n车辆 {idx+1}/{len(self.car_birth_info_list)}: ID={car['id']}")
if self._is_position_on_lane(car['begin'], tolerance=tolerance):
valid_cars.append(car)
if self.debug_lane_filter:
print(f" ✅ 保留")
else:
filtered_cars.append({
'id': car['id'],
'position': car['begin'],
'reason': '生成位置不在有效车道上(可能在草坪/停车场等区域)'
})
if self.debug_lane_filter:
print(f" ❌ 过滤 (原因: 不在车道上)")
self.car_birth_info_list = valid_cars
if self.debug_lane_filter:
print(f"\n📊 过滤结果: 保留 {len(valid_cars)} 辆, 过滤 {len(filtered_cars)}")
return len(valid_cars), len(filtered_cars), filtered_cars
def _spawn_controlled_agents(self):
# ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent")
# ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0])
for car in self.car_birth_info_list:
if car['show_time'] == self.round:
agent_id = f"controlled_{car['id']}"
vehicle = self.engine.spawn_object(
PolicyVehicle,
vehicle_config={},
position=car['begin'],
heading=car['heading']
)
vehicle.reset(position=car['begin'], heading=car['heading'])
vehicle.set_policy(self.policy)
vehicle.set_destination(car['end'])
self.controlled_agents[agent_id] = vehicle
self.controlled_agent_ids.append(agent_id)
# ✅ 关键:注册到引擎的 active_agents才能参与物理更新
self.engine.agent_manager.active_agents[agent_id] = vehicle
def before_reset(self):
"""在reset之前清理对象"""
# 清理所有可控车辆
if hasattr(self, 'controlled_agents') and hasattr(self, 'engine'):
# 使用MetaDrive的clear_objects方法清理
if hasattr(self.engine, 'clear_objects'):
try:
self.engine.clear_objects(list(self.controlled_agents.keys()))
except:
pass
# 从agent_manager中移除
if hasattr(self.engine, 'agent_manager'):
for agent_id in list(self.controlled_agents.keys()):
if agent_id in self.engine.agent_manager.active_agents:
self.engine.agent_manager.active_agents.pop(agent_id)
self.controlled_agents.clear()
self.controlled_agent_ids.clear()
def _get_traffic_light_state(self, vehicle):
"""
获取车辆当前位置的红绿灯状态(优化版)
解决问题:
1. 部分红绿灯状态为None的问题 - 添加异常处理和默认值
2. 车道分段导致无法获取红绿灯的问题 - 优先使用导航模块,失败时回退到遍历
Returns:
int: 0=无红绿灯, 1=绿灯, 2=黄灯, 3=红灯
"""
traffic_light = 0
state = vehicle.get_state()
position_2d = state['position'][:2]
if self.debug_traffic_light:
print(f"\n🚦 检测车辆红绿灯 - 位置: ({position_2d[0]:.1f}, {position_2d[1]:.1f})")
try:
# 方法1优先尝试从车辆导航模块获取当前车道更高效
if hasattr(vehicle, 'navigation') and vehicle.navigation is not None:
current_lane = vehicle.navigation.current_lane
if self.debug_traffic_light:
print(f" 方法1-导航模块:")
print(f" current_lane = {current_lane}")
print(f" lane_index = {current_lane.index if current_lane else 'None'}")
if current_lane:
has_light = self.engine.light_manager.has_traffic_light(current_lane.index)
if self.debug_traffic_light:
print(f" has_traffic_light = {has_light}")
if has_light:
status = self.engine.light_manager._lane_index_to_obj[current_lane.index].status
if self.debug_traffic_light:
print(f" status = {status}")
if status == 'TRAFFIC_LIGHT_GREEN':
if self.debug_traffic_light:
print(f" ✅ 方法1成功: 绿灯")
return 1
elif status == 'TRAFFIC_LIGHT_YELLOW':
if self.debug_traffic_light:
print(f" ✅ 方法1成功: 黄灯")
return 2
elif status == 'TRAFFIC_LIGHT_RED':
if self.debug_traffic_light:
print(f" ✅ 方法1成功: 红灯")
return 3
elif status is None:
if self.debug_traffic_light:
print(f" ⚠️ 方法1: 红绿灯状态为None")
return 0
else:
if self.debug_traffic_light:
print(f" 该车道没有红绿灯")
else:
if self.debug_traffic_light:
print(f" 导航模块current_lane为None")
else:
if self.debug_traffic_light:
has_nav = hasattr(vehicle, 'navigation')
nav_not_none = vehicle.navigation is not None if has_nav else False
print(f" 方法1-导航模块: 不可用 (hasattr={has_nav}, not_none={nav_not_none})")
except Exception as e:
if self.debug_traffic_light:
print(f" ❌ 方法1异常: {type(e).__name__}: {e}")
pass
try:
# 方法2遍历所有车道查找兜底方案处理车道分段问题
if self.debug_traffic_light:
print(f" 方法2-遍历车道: 开始遍历 {len(self.lanes)} 条车道")
found_lane = False
checked_lanes = 0
for lane in self.lanes.values():
try:
checked_lanes += 1
if lane.lane.point_on_lane(position_2d):
found_lane = True
if self.debug_traffic_light:
print(f" ✓ 找到车辆所在车道: {lane.lane.index} (检查了{checked_lanes}条)")
has_light = self.engine.light_manager.has_traffic_light(lane.lane.index)
if self.debug_traffic_light:
print(f" has_traffic_light = {has_light}")
if has_light:
status = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
if self.debug_traffic_light:
print(f" status = {status}")
if status == 'TRAFFIC_LIGHT_GREEN':
if self.debug_traffic_light:
print(f" ✅ 方法2成功: 绿灯")
return 1
elif status == 'TRAFFIC_LIGHT_YELLOW':
if self.debug_traffic_light:
print(f" ✅ 方法2成功: 黄灯")
return 2
elif status == 'TRAFFIC_LIGHT_RED':
if self.debug_traffic_light:
print(f" ✅ 方法2成功: 红灯")
return 3
elif status is None:
if self.debug_traffic_light:
print(f" ⚠️ 方法2: 红绿灯状态为None")
return 0
else:
if self.debug_traffic_light:
print(f" 该车道没有红绿灯")
break
except:
continue
if self.debug_traffic_light and not found_lane:
print(f" ⚠️ 未找到车辆所在车道 (检查了{checked_lanes}条)")
except Exception as e:
if self.debug_traffic_light:
print(f" ❌ 方法2异常: {type(e).__name__}: {e}")
pass
if self.debug_traffic_light:
print(f" 结果: 返回 {traffic_light} (无红绿灯/未知)")
return traffic_light
def _get_all_obs(self):
# position, velocity, heading, lidar, navigation, TODO: trafficlight -> list
self.obs_list = []
for agent_id, vehicle in self.controlled_agents.items():
state = vehicle.get_state()
# 使用优化后的红绿灯检测方法
traffic_light = self._get_traffic_light_state(vehicle)
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
physics_world=self.engine.physics_world.dynamic_world)
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
obs = (state['position'][:2] + list(state['velocity']) + [state['heading_theta']]
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
+ list(vehicle.destination))
self.obs_list.append(obs)
return self.obs_list
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
self.round += 1
for agent_id, action in action_dict.items():
if agent_id in self.controlled_agents:
self.controlled_agents[agent_id].before_step(action)
self.engine.step()
for agent_id in action_dict:
if agent_id in self.controlled_agents:
self.controlled_agents[agent_id].after_step()
self._spawn_controlled_agents()
obs = self._get_all_obs()
rewards = {aid: 0.0 for aid in self.controlled_agents}
dones = {aid: False for aid in self.controlled_agents}
dones["__all__"] = self.episode_step >= self.config["horizon"]
infos = {aid: {} for aid in self.controlled_agents}
return obs, rewards, dones, infos