优化车辆生成位置偏差问题,新增红绿灯信息采集方法
This commit is contained in:
@@ -53,6 +53,13 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
data_directory=None,
|
||||
num_controlled_agents=3,
|
||||
horizon=1000,
|
||||
# 车道检测与过滤配置
|
||||
filter_offroad_vehicles=True, # 是否过滤非车道区域的车辆
|
||||
lane_tolerance=3.0, # 车道检测容差(米),用于放宽边界条件
|
||||
max_controlled_vehicles=None, # 最大可控车辆数限制(None表示不限制)
|
||||
# 调试模式配置
|
||||
debug_traffic_light=False, # 是否启用红绿灯检测调试输出
|
||||
debug_lane_filter=False, # 是否启用车道过滤调试输出
|
||||
))
|
||||
return config
|
||||
|
||||
@@ -62,6 +69,9 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
self.controlled_agent_ids = []
|
||||
self.obs_list = []
|
||||
self.round = 0
|
||||
# 调试模式配置
|
||||
self.debug_traffic_light = config.get("debug_traffic_light", False)
|
||||
self.debug_lane_filter = config.get("debug_lane_filter", False)
|
||||
super().__init__(config)
|
||||
|
||||
def reset(self, seed: Union[None, int] = None):
|
||||
@@ -105,6 +115,47 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
self.engine.taskMgr.step()
|
||||
|
||||
self.lanes = self.engine.map_manager.current_map.road_network.graph
|
||||
|
||||
# 调试:场景信息统计
|
||||
if self.debug_lane_filter or self.debug_traffic_light:
|
||||
print(f"\n📍 场景信息统计:")
|
||||
print(f" - 总车道数: {len(self.lanes)}")
|
||||
|
||||
# 统计红绿灯数量
|
||||
if self.debug_traffic_light:
|
||||
traffic_light_lanes = []
|
||||
for lane in self.lanes.values():
|
||||
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
||||
traffic_light_lanes.append(lane.lane.index)
|
||||
print(f" - 有红绿灯的车道数: {len(traffic_light_lanes)}")
|
||||
if len(traffic_light_lanes) > 0:
|
||||
print(f" 车道索引: {traffic_light_lanes[:5]}" +
|
||||
(f" ... 共{len(traffic_light_lanes)}个" if len(traffic_light_lanes) > 5 else ""))
|
||||
else:
|
||||
print(f" ⚠️ 场景中没有红绿灯!")
|
||||
|
||||
# 在获取车道信息后,进行车道区域过滤
|
||||
total_cars_before = len(self.car_birth_info_list)
|
||||
valid_count, filtered_count, filtered_list = self._filter_valid_spawn_positions()
|
||||
|
||||
# 输出过滤信息
|
||||
if filtered_count > 0:
|
||||
self.logger.warning(f"车辆生成位置过滤: 原始 {total_cars_before} 辆, "
|
||||
f"有效 {valid_count} 辆, 过滤 {filtered_count} 辆")
|
||||
for filtered_car in filtered_list[:5]: # 只显示前5个
|
||||
self.logger.debug(f" - 过滤车辆 ID={filtered_car['id']}, "
|
||||
f"位置={filtered_car['position']}, "
|
||||
f"原因={filtered_car['reason']}")
|
||||
if filtered_count > 5:
|
||||
self.logger.debug(f" - ... 还有 {filtered_count - 5} 辆车被过滤")
|
||||
|
||||
# 限制最大车辆数(在过滤后应用)
|
||||
max_vehicles = self.config.get("max_controlled_vehicles", None)
|
||||
if max_vehicles is not None and len(self.car_birth_info_list) > max_vehicles:
|
||||
self.car_birth_info_list = self.car_birth_info_list[:max_vehicles]
|
||||
self.logger.info(f"限制最大车辆数为 {max_vehicles} 辆")
|
||||
|
||||
self.logger.info(f"最终生成 {len(self.car_birth_info_list)} 辆可控车辆")
|
||||
|
||||
if self.top_down_renderer is not None:
|
||||
self.top_down_renderer.clear()
|
||||
@@ -122,6 +173,108 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
|
||||
return self._get_all_obs()
|
||||
|
||||
def _is_position_on_lane(self, position, tolerance=None):
|
||||
"""
|
||||
检测给定位置是否在有效车道范围内
|
||||
|
||||
Args:
|
||||
position: (x, y) 车辆位置坐标
|
||||
tolerance: 容差范围(米),用于放宽检测条件。None时使用配置中的默认值
|
||||
|
||||
Returns:
|
||||
bool: True表示在车道上,False表示在非车道区域(如草坪、停车场等)
|
||||
"""
|
||||
if not hasattr(self, 'lanes') or self.lanes is None:
|
||||
if self.debug_lane_filter:
|
||||
print(f" ⚠️ 车道信息未初始化,默认允许")
|
||||
return True # 如果车道信息未初始化,默认允许生成
|
||||
|
||||
if tolerance is None:
|
||||
tolerance = self.config.get("lane_tolerance", 3.0)
|
||||
|
||||
position_2d = (position[0], position[1])
|
||||
|
||||
if self.debug_lane_filter:
|
||||
print(f" 🔍 检测位置 ({position_2d[0]:.2f}, {position_2d[1]:.2f}), 容差={tolerance}m")
|
||||
|
||||
# 方法1:直接检测是否在任一车道上
|
||||
checked_lanes = 0
|
||||
for lane in self.lanes.values():
|
||||
try:
|
||||
checked_lanes += 1
|
||||
if lane.lane.point_on_lane(position_2d):
|
||||
if self.debug_lane_filter:
|
||||
print(f" ✅ 在车道上 (车道{lane.lane.index}, 检查了{checked_lanes}条)")
|
||||
return True
|
||||
except:
|
||||
continue
|
||||
|
||||
if self.debug_lane_filter:
|
||||
print(f" ❌ 不在任何车道上 (检查了{checked_lanes}条车道)")
|
||||
|
||||
# 方法2:如果严格检测失败,使用容差范围检测(考虑车道边缘)
|
||||
# 注释:此方法已被禁用,如需启用请取消注释
|
||||
# if tolerance > 0:
|
||||
# for lane in self.lanes.values():
|
||||
# try:
|
||||
# # 计算点到车道中心线的距离
|
||||
# lane_obj = lane.lane
|
||||
# # 获取车道长度并检测最近点
|
||||
# s, lateral = lane_obj.local_coordinates(position_2d)
|
||||
|
||||
# # 如果横向距离在容差范围内,认为是有效的
|
||||
# if abs(lateral) <= tolerance and 0 <= s <= lane_obj.length:
|
||||
# return True
|
||||
# except:
|
||||
# continue
|
||||
|
||||
return False
|
||||
|
||||
def _filter_valid_spawn_positions(self):
|
||||
"""
|
||||
过滤掉生成位置不在有效车道上的车辆信息
|
||||
根据配置决定是否执行过滤
|
||||
|
||||
Returns:
|
||||
tuple: (有效车辆数量, 被过滤车辆数量, 被过滤车辆ID列表)
|
||||
"""
|
||||
# 如果配置中禁用了过滤,直接返回
|
||||
if not self.config.get("filter_offroad_vehicles", True):
|
||||
if self.debug_lane_filter:
|
||||
print(f"🚫 车道过滤已禁用")
|
||||
return len(self.car_birth_info_list), 0, []
|
||||
|
||||
if self.debug_lane_filter:
|
||||
print(f"\n🔍 开始车道过滤: 共 {len(self.car_birth_info_list)} 辆车待检测")
|
||||
|
||||
valid_cars = []
|
||||
filtered_cars = []
|
||||
tolerance = self.config.get("lane_tolerance", 3.0)
|
||||
|
||||
for idx, car in enumerate(self.car_birth_info_list):
|
||||
if self.debug_lane_filter:
|
||||
print(f"\n车辆 {idx+1}/{len(self.car_birth_info_list)}: ID={car['id']}")
|
||||
|
||||
if self._is_position_on_lane(car['begin'], tolerance=tolerance):
|
||||
valid_cars.append(car)
|
||||
if self.debug_lane_filter:
|
||||
print(f" ✅ 保留")
|
||||
else:
|
||||
filtered_cars.append({
|
||||
'id': car['id'],
|
||||
'position': car['begin'],
|
||||
'reason': '生成位置不在有效车道上(可能在草坪/停车场等区域)'
|
||||
})
|
||||
if self.debug_lane_filter:
|
||||
print(f" ❌ 过滤 (原因: 不在车道上)")
|
||||
|
||||
self.car_birth_info_list = valid_cars
|
||||
|
||||
if self.debug_lane_filter:
|
||||
print(f"\n📊 过滤结果: 保留 {len(valid_cars)} 辆, 过滤 {len(filtered_cars)} 辆")
|
||||
|
||||
return len(valid_cars), len(filtered_cars), filtered_cars
|
||||
|
||||
def _spawn_controlled_agents(self):
|
||||
# ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent")
|
||||
# ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0])
|
||||
@@ -146,26 +299,148 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
# ✅ 关键:注册到引擎的 active_agents,才能参与物理更新
|
||||
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
||||
|
||||
def _get_traffic_light_state(self, vehicle):
|
||||
"""
|
||||
获取车辆当前位置的红绿灯状态(优化版)
|
||||
|
||||
解决问题:
|
||||
1. 部分红绿灯状态为None的问题 - 添加异常处理和默认值
|
||||
2. 车道分段导致无法获取红绿灯的问题 - 优先使用导航模块,失败时回退到遍历
|
||||
|
||||
Returns:
|
||||
int: 0=无红绿灯, 1=绿灯, 2=黄灯, 3=红灯
|
||||
"""
|
||||
traffic_light = 0
|
||||
state = vehicle.get_state()
|
||||
position_2d = state['position'][:2]
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f"\n🚦 检测车辆红绿灯 - 位置: ({position_2d[0]:.1f}, {position_2d[1]:.1f})")
|
||||
|
||||
try:
|
||||
# 方法1:优先尝试从车辆导航模块获取当前车道(更高效)
|
||||
if hasattr(vehicle, 'navigation') and vehicle.navigation is not None:
|
||||
current_lane = vehicle.navigation.current_lane
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" 方法1-导航模块:")
|
||||
print(f" current_lane = {current_lane}")
|
||||
print(f" lane_index = {current_lane.index if current_lane else 'None'}")
|
||||
|
||||
if current_lane:
|
||||
has_light = self.engine.light_manager.has_traffic_light(current_lane.index)
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" has_traffic_light = {has_light}")
|
||||
|
||||
if has_light:
|
||||
status = self.engine.light_manager._lane_index_to_obj[current_lane.index].status
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" status = {status}")
|
||||
|
||||
if status == 'TRAFFIC_LIGHT_GREEN':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法1成功: 绿灯")
|
||||
return 1
|
||||
elif status == 'TRAFFIC_LIGHT_YELLOW':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法1成功: 黄灯")
|
||||
return 2
|
||||
elif status == 'TRAFFIC_LIGHT_RED':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法1成功: 红灯")
|
||||
return 3
|
||||
elif status is None:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ⚠️ 方法1: 红绿灯状态为None")
|
||||
return 0
|
||||
else:
|
||||
if self.debug_traffic_light:
|
||||
print(f" 该车道没有红绿灯")
|
||||
else:
|
||||
if self.debug_traffic_light:
|
||||
print(f" 导航模块current_lane为None")
|
||||
else:
|
||||
if self.debug_traffic_light:
|
||||
has_nav = hasattr(vehicle, 'navigation')
|
||||
nav_not_none = vehicle.navigation is not None if has_nav else False
|
||||
print(f" 方法1-导航模块: 不可用 (hasattr={has_nav}, not_none={nav_not_none})")
|
||||
|
||||
except Exception as e:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ❌ 方法1异常: {type(e).__name__}: {e}")
|
||||
pass
|
||||
|
||||
try:
|
||||
# 方法2:遍历所有车道查找(兜底方案,处理车道分段问题)
|
||||
if self.debug_traffic_light:
|
||||
print(f" 方法2-遍历车道: 开始遍历 {len(self.lanes)} 条车道")
|
||||
|
||||
found_lane = False
|
||||
checked_lanes = 0
|
||||
|
||||
for lane in self.lanes.values():
|
||||
try:
|
||||
checked_lanes += 1
|
||||
if lane.lane.point_on_lane(position_2d):
|
||||
found_lane = True
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✓ 找到车辆所在车道: {lane.lane.index} (检查了{checked_lanes}条)")
|
||||
|
||||
has_light = self.engine.light_manager.has_traffic_light(lane.lane.index)
|
||||
if self.debug_traffic_light:
|
||||
print(f" has_traffic_light = {has_light}")
|
||||
|
||||
if has_light:
|
||||
status = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
|
||||
if self.debug_traffic_light:
|
||||
print(f" status = {status}")
|
||||
|
||||
if status == 'TRAFFIC_LIGHT_GREEN':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法2成功: 绿灯")
|
||||
return 1
|
||||
elif status == 'TRAFFIC_LIGHT_YELLOW':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法2成功: 黄灯")
|
||||
return 2
|
||||
elif status == 'TRAFFIC_LIGHT_RED':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法2成功: 红灯")
|
||||
return 3
|
||||
elif status is None:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ⚠️ 方法2: 红绿灯状态为None")
|
||||
return 0
|
||||
else:
|
||||
if self.debug_traffic_light:
|
||||
print(f" 该车道没有红绿灯")
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
if self.debug_traffic_light and not found_lane:
|
||||
print(f" ⚠️ 未找到车辆所在车道 (检查了{checked_lanes}条)")
|
||||
|
||||
except Exception as e:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ❌ 方法2异常: {type(e).__name__}: {e}")
|
||||
pass
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" 结果: 返回 {traffic_light} (无红绿灯/未知)")
|
||||
|
||||
return traffic_light
|
||||
|
||||
def _get_all_obs(self):
|
||||
# position, velocity, heading, lidar, navigation, TODO: trafficlight -> list
|
||||
self.obs_list = []
|
||||
for agent_id, vehicle in self.controlled_agents.items():
|
||||
state = vehicle.get_state()
|
||||
|
||||
traffic_light = 0
|
||||
for lane in self.lanes.values():
|
||||
if lane.lane.point_on_lane(state['position'][:2]):
|
||||
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
||||
traffic_light = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
|
||||
if traffic_light == 'TRAFFIC_LIGHT_GREEN':
|
||||
traffic_light = 1
|
||||
elif traffic_light == 'TRAFFIC_LIGHT_YELLOW':
|
||||
traffic_light = 2
|
||||
elif traffic_light == 'TRAFFIC_LIGHT_RED':
|
||||
traffic_light = 3
|
||||
else:
|
||||
traffic_light = 0
|
||||
break
|
||||
# 使用优化后的红绿灯检测方法
|
||||
traffic_light = self._get_traffic_light_state(vehicle)
|
||||
|
||||
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.dynamic_world)
|
||||
|
||||
Reference in New Issue
Block a user