Compare commits
1 Commits
train_not_
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| 113e86bda2 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,3 +0,0 @@
|
||||
# 日志文件
|
||||
Env/logs/
|
||||
*.log
|
||||
@@ -1,27 +0,0 @@
|
||||
"""
|
||||
MAGAIL Algorithm Package
|
||||
|
||||
多智能体生成对抗模仿学习算法实现
|
||||
"""
|
||||
|
||||
from .magail import MAGAIL
|
||||
from .ppo import PPO
|
||||
from .disc import GAILDiscrim
|
||||
from .bert import Bert
|
||||
from .policy import StateIndependentPolicy
|
||||
from .buffer import RolloutBuffer
|
||||
from .utils import Normalizer, build_mlp, reparameterize, evaluate_lop_pi
|
||||
|
||||
__all__ = [
|
||||
'MAGAIL',
|
||||
'PPO',
|
||||
'GAILDiscrim',
|
||||
'Bert',
|
||||
'StateIndependentPolicy',
|
||||
'RolloutBuffer',
|
||||
'Normalizer',
|
||||
'build_mlp',
|
||||
'reparameterize',
|
||||
'evaluate_lop_pi',
|
||||
]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -28,26 +28,17 @@ class Bert(nn.Module):
|
||||
self.classifier.train()
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
# x可以是2D (batch_size, input_dim) 或 3D (batch_size, seq_len, feature_dim)
|
||||
is_2d_input = (x.dim() == 2)
|
||||
|
||||
if is_2d_input:
|
||||
# 如果输入是2D,添加一个序列维度
|
||||
x = x.unsqueeze(1) # (batch_size, 1, input_dim)
|
||||
|
||||
# x: (batch_size, seq_len, input_dim)
|
||||
# 线性投影
|
||||
x = self.projection(x) # (batch_size, seq_len, embed_dim)
|
||||
x = self.projection(x) # (batch_size, input_dim, embed_dim)
|
||||
|
||||
batch_size = x.size(0)
|
||||
if self.CLS:
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, seq_len+1, embed_dim)
|
||||
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, 29, embed_dim)
|
||||
|
||||
# 添加位置编码(截断或扩展以匹配序列长度)
|
||||
seq_len = x.size(1)
|
||||
pos_embed = self.pos_embed[:, :seq_len, :]
|
||||
x = x + pos_embed
|
||||
# 添加位置编码
|
||||
x = x + self.pos_embed
|
||||
|
||||
# 转置为(seq_len, batch_size, embed_dim)
|
||||
x = x.permute(1, 0, 2)
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
try:
|
||||
from .bert import Bert
|
||||
except ImportError:
|
||||
from bert import Bert
|
||||
from .bert import Bert
|
||||
|
||||
|
||||
DISC_LOGIT_INIT_SCALE = 1.0
|
||||
|
||||
@@ -2,30 +2,21 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
try:
|
||||
from .disc import GAILDiscrim
|
||||
from .ppo import PPO
|
||||
from .utils import Normalizer
|
||||
except ImportError:
|
||||
from disc import GAILDiscrim
|
||||
from ppo import PPO
|
||||
from utils import Normalizer
|
||||
from .disc import GAILDiscrim
|
||||
from .ppo import PPO
|
||||
from .utils import Normalizer
|
||||
|
||||
|
||||
class MAGAIL(PPO):
|
||||
def __init__(self, buffer_exp, input_dim, device, action_shape=(2,),
|
||||
def __init__(self, buffer_exp, input_dim, device,
|
||||
disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005,
|
||||
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True,
|
||||
**kwargs # 接受其他PPO参数
|
||||
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True
|
||||
):
|
||||
super().__init__(state_shape=input_dim, device=device, action_shape=action_shape, **kwargs)
|
||||
super().__init__(state_shape=input_dim, device=device)
|
||||
self.learning_steps = 0
|
||||
self.learning_steps_disc = 0
|
||||
|
||||
# 如果input_dim是元组,提取第一个元素
|
||||
state_dim = input_dim[0] if isinstance(input_dim, tuple) else input_dim
|
||||
# 判别器输入是state+next_state拼接,所以维度是state_dim*2
|
||||
self.disc = GAILDiscrim(input_dim=state_dim*2).to(device) # 移动到指定设备
|
||||
self.disc = GAILDiscrim(input_dim=input_dim)
|
||||
self.disc_grad_penalty = disc_grad_penalty
|
||||
self.disc_coef = disc_coef
|
||||
self.disc_logit_reg = disc_logit_reg
|
||||
@@ -36,9 +27,7 @@ class MAGAIL(PPO):
|
||||
|
||||
self.normalizer = None
|
||||
if use_gail_norm:
|
||||
# state_shape已经是元组形式
|
||||
state_dim = self.state_shape[0] if isinstance(self.state_shape, tuple) else self.state_shape
|
||||
self.normalizer = Normalizer(state_dim*2)
|
||||
self.normalizer = Normalizer(self.state_shape[0]*2)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.buffer_exp = buffer_exp
|
||||
@@ -63,7 +52,7 @@ class MAGAIL(PPO):
|
||||
# grad penalty
|
||||
sample_expert = states_exp_cp
|
||||
sample_expert.requires_grad = True
|
||||
disc = self.disc(sample_expert) # 直接调用forward方法
|
||||
disc = self.disc.linear(self.disc.trunk(sample_expert))
|
||||
ones = torch.ones(disc.size(), device=disc.device)
|
||||
disc_demo_grad = torch.autograd.grad(disc, sample_expert,
|
||||
grad_outputs=ones,
|
||||
@@ -102,8 +91,7 @@ class MAGAIL(PPO):
|
||||
|
||||
# Samples from current policy trajectories.
|
||||
samples_policy = self.buffer.sample(self.batch_size)
|
||||
# samples_policy返回: (states, actions, rewards, dones, tm_dones, log_pis, next_states, means, stds)
|
||||
states, next_states = samples_policy[0], samples_policy[6] # 修正: 使用states而不是actions
|
||||
states, next_states = samples_policy[1], samples_policy[-3]
|
||||
states = torch.cat([states, next_states], dim=-1)
|
||||
|
||||
# Samples from expert demonstrations.
|
||||
@@ -141,8 +129,6 @@ class MAGAIL(PPO):
|
||||
return rewards_t.mean().item() + rewards_i.mean().item()
|
||||
|
||||
def save_models(self, path):
|
||||
# 确保目录存在
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save({
|
||||
'actor': self.actor.state_dict(),
|
||||
'critic': self.critic.state_dict(),
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
try:
|
||||
from .utils import build_mlp, reparameterize, evaluate_lop_pi
|
||||
except ImportError:
|
||||
from utils import build_mlp, reparameterize, evaluate_lop_pi
|
||||
from .utils import build_mlp, reparameterize, evaluate_lop_pi
|
||||
|
||||
class StateIndependentPolicy(nn.Module):
|
||||
|
||||
|
||||
@@ -3,14 +3,9 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
try:
|
||||
from .buffer import RolloutBuffer
|
||||
from .bert import Bert
|
||||
from .policy import StateIndependentPolicy
|
||||
except ImportError:
|
||||
from buffer import RolloutBuffer
|
||||
from bert import Bert
|
||||
from policy import StateIndependentPolicy
|
||||
from buffer import RolloutBuffer
|
||||
from bert import Bert
|
||||
from policy import StateIndependentPolicy
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
@@ -60,7 +55,7 @@ class Algorithm(ABC):
|
||||
|
||||
class PPO(Algorithm):
|
||||
|
||||
def __init__(self, state_shape, device, action_shape=(2,), gamma=0.995, rollout_length=2048,
|
||||
def __init__(self, state_shape, device, gamma=0.995, rollout_length=2048,
|
||||
units_actor=(64, 64), epoch_ppo=10, clip_eps=0.2,
|
||||
lambd=0.97, max_grad_norm=1.0, desired_kl=0.01, surrogate_loss_coef=2.,
|
||||
value_loss_coef=5., entropy_coef=0., bounds_loss_coef=10., lr_actor=1e-3, lr_critic=1e-3,
|
||||
@@ -71,7 +66,6 @@ class PPO(Algorithm):
|
||||
self.lr_critic = lr_critic
|
||||
self.lr_disc = lr_disc
|
||||
self.auto_lr = auto_lr
|
||||
self.action_shape = action_shape
|
||||
|
||||
self.use_adv_norm = use_adv_norm
|
||||
|
||||
@@ -92,10 +86,8 @@ class PPO(Algorithm):
|
||||
).to(device)
|
||||
|
||||
# Critic.
|
||||
# 如果state_shape是元组,提取第一个元素
|
||||
state_dim = state_shape[0] if isinstance(state_shape, tuple) else state_shape
|
||||
self.critic = Bert(
|
||||
input_dim=state_dim,
|
||||
input_dim=state_shape,
|
||||
output_dim=1
|
||||
).to(device)
|
||||
|
||||
@@ -153,12 +145,14 @@ class PPO(Algorithm):
|
||||
targets, gaes = self.calculate_gae(
|
||||
values, rewards, dones, tm_dones, next_values, self.gamma, self.lambd)
|
||||
|
||||
# 处理批量数据(不需要按智能体分组,因为buffer中已经混合了所有智能体的数据)
|
||||
state_list = states.permute(1, 0, 2)
|
||||
action_list = actions.permute(1, 0, 2)
|
||||
|
||||
for i in range(self.epoch_ppo):
|
||||
self.learning_steps_ppo += 1
|
||||
self.update_critic(states, targets, writer)
|
||||
# 直接使用整个batch进行actor更新
|
||||
self.update_actor(states, actions, log_pi_list, gaes, mus, sigmas, writer)
|
||||
for state, action, log_pi in state_list, action_list, log_pi_list:
|
||||
self.update_actor(state, action, log_pi, gaes, mus, sigmas, writer)
|
||||
|
||||
# self.lr_decay(total_steps, writer)
|
||||
|
||||
|
||||
136
CHANGELOG.md
136
CHANGELOG.md
@@ -1,136 +0,0 @@
|
||||
# 更新日志
|
||||
|
||||
## 2025-01-20 问题修复与优化
|
||||
|
||||
### ✅ 已解决的问题
|
||||
|
||||
#### 1. 车辆生成位置偏差问题
|
||||
**问题描述:** 部分车辆生成于草坪、停车场等非车道区域
|
||||
|
||||
**解决方案:**
|
||||
- 实现 `_is_position_on_lane()` 方法:检测位置是否在有效车道上
|
||||
- 实现 `_filter_valid_spawn_positions()` 方法:自动过滤非车道区域车辆
|
||||
- 支持容差参数(默认3米)处理边界情况
|
||||
- 在 `reset()` 时自动执行过滤,并输出统计信息
|
||||
|
||||
**配置参数:**
|
||||
```python
|
||||
"filter_offroad_vehicles": True, # 启用/禁用过滤
|
||||
"lane_tolerance": 3.0, # 容差范围(米)
|
||||
"max_controlled_vehicles": 10, # 最大车辆数限制
|
||||
```
|
||||
|
||||
#### 2. 红绿灯信息采集问题
|
||||
**问题描述:**
|
||||
- 部分红绿灯状态为 None
|
||||
- 车道分段时部分车辆无法获取红绿灯状态
|
||||
|
||||
**解决方案:**
|
||||
- 实现 `_get_traffic_light_state()` 方法,采用双重检测策略
|
||||
- 方法1(优先):从导航模块获取当前车道,直接查询(高效)
|
||||
- 方法2(兜底):遍历所有车道匹配位置(处理特殊情况)
|
||||
- 完善异常处理,None 状态返回 0(无红绿灯)
|
||||
- 返回值:0=无/未知, 1=绿灯, 2=黄灯, 3=红灯
|
||||
|
||||
#### 3. 性能优化问题
|
||||
**问题描述:** FPS只有15帧,CPU利用率不到20%
|
||||
|
||||
**解决方案:**
|
||||
- 创建 `run_multiagent_env_fast.py`:激光雷达优化版(30-60 FPS)
|
||||
- 创建 `run_multiagent_env_parallel.py`:多进程并行版(300-600 steps/s)
|
||||
- 提供详细的性能优化文档
|
||||
|
||||
### 📝 修改的文件
|
||||
|
||||
1. **Env/scenario_env.py**
|
||||
- 新增 `_is_position_on_lane()` 方法
|
||||
- 新增 `_filter_valid_spawn_positions()` 方法
|
||||
- 新增 `_get_traffic_light_state()` 方法
|
||||
- 更新 `default_config()` 添加配置参数
|
||||
- 更新 `reset()` 调用过滤逻辑
|
||||
- 更新 `_get_all_obs()` 使用新的红绿灯检测方法
|
||||
|
||||
2. **Env/run_multiagent_env.py**
|
||||
- 添加车道过滤配置参数
|
||||
|
||||
3. **Env/run_multiagent_env_fast.py**
|
||||
- 添加车道过滤配置
|
||||
- 性能优化配置
|
||||
|
||||
4. **Env/run_multiagent_env_parallel.py**
|
||||
- 添加车道过滤配置
|
||||
- 多进程并行实现
|
||||
|
||||
5. **README.md**
|
||||
- 更新问题说明,添加解决方案
|
||||
- 添加配置示例和测试方法
|
||||
- 添加问题解决总结
|
||||
|
||||
6. **新增文件**
|
||||
- `Env/test_lane_filter.py`:功能测试脚本
|
||||
|
||||
### 🧪 测试方法
|
||||
|
||||
```bash
|
||||
# 测试车道过滤和红绿灯检测功能
|
||||
python Env/test_lane_filter.py
|
||||
|
||||
# 运行标准版本(带过滤和可视化)
|
||||
python Env/run_multiagent_env.py
|
||||
|
||||
# 运行高性能版本(适合训练)
|
||||
python Env/run_multiagent_env_fast.py
|
||||
|
||||
# 运行多进程并行版本(最高吞吐量)
|
||||
python Env/run_multiagent_env_parallel.py
|
||||
```
|
||||
|
||||
### 💡 使用建议
|
||||
|
||||
1. **调试阶段**:使用 `run_multiagent_env.py`,启用渲染和车道过滤
|
||||
2. **训练阶段**:使用 `run_multiagent_env_fast.py`,关闭渲染,启用所有优化
|
||||
3. **大规模训练**:使用 `run_multiagent_env_parallel.py`,充分利用多核CPU
|
||||
|
||||
### ⚙️ 配置说明
|
||||
|
||||
所有配置参数都可以在创建环境时通过 `config` 字典传递:
|
||||
|
||||
```python
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
# 基础配置
|
||||
"data_directory": "...",
|
||||
"is_multi_agent": True,
|
||||
"horizon": 300,
|
||||
|
||||
# 车道过滤(新增)
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0, # 容差3米
|
||||
"max_controlled_vehicles": 10, # 最多10辆车
|
||||
|
||||
# 性能优化
|
||||
"use_render": False,
|
||||
"decision_repeat": 5,
|
||||
...
|
||||
},
|
||||
agent2policy=your_policy
|
||||
)
|
||||
```
|
||||
|
||||
### 🔍 技术细节
|
||||
|
||||
**车道检测逻辑:**
|
||||
1. 使用 `lane.lane.point_on_lane()` 精确检测
|
||||
2. 使用 `lane.local_coordinates()` 计算横向距离
|
||||
3. 支持容差参数处理边界情况
|
||||
|
||||
**红绿灯检测逻辑:**
|
||||
1. 优先从 `vehicle.navigation.current_lane` 获取
|
||||
2. 失败时遍历所有车道查找
|
||||
3. 所有异常均有保护,确保稳定性
|
||||
|
||||
**性能优化原理:**
|
||||
- 减少激光束数量降低计算量
|
||||
- 多进程绕过Python GIL限制
|
||||
- 充分利用多核CPU
|
||||
|
||||
@@ -1,339 +0,0 @@
|
||||
# 调试功能使用指南
|
||||
|
||||
## 📋 概述
|
||||
|
||||
已为车道过滤和红绿灯检测功能添加了详细的调试输出,帮助您诊断和理解代码行为。
|
||||
|
||||
---
|
||||
|
||||
## 🎛️ 调试开关
|
||||
|
||||
### 1. 配置参数
|
||||
|
||||
在创建环境时,可以通过 `config` 参数启用调试模式:
|
||||
|
||||
```python
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
# ... 其他配置 ...
|
||||
|
||||
# 🔥 调试开关
|
||||
"debug_lane_filter": True, # 启用车道过滤调试
|
||||
"debug_traffic_light": True, # 启用红绿灯检测调试
|
||||
},
|
||||
agent2policy=your_policy
|
||||
)
|
||||
```
|
||||
|
||||
### 2. 默认值
|
||||
|
||||
两个调试开关默认都是 `False`(关闭),避免正常运行时产生大量日志。
|
||||
|
||||
---
|
||||
|
||||
## 📊 车道过滤调试 (`debug_lane_filter=True`)
|
||||
|
||||
### 输出内容
|
||||
|
||||
```
|
||||
📍 场景信息统计:
|
||||
- 总车道数: 123
|
||||
|
||||
🔍 开始车道过滤: 共 51 辆车待检测
|
||||
|
||||
车辆 1/51: ID=128
|
||||
🔍 检测位置 (-4.11, 46.76), 容差=3.0m
|
||||
✅ 在车道上 (车道184, 检查了32条)
|
||||
✅ 保留
|
||||
|
||||
车辆 7/51: ID=134
|
||||
🔍 检测位置 (-51.34, -3.77), 容差=3.0m
|
||||
❌ 不在任何车道上 (检查了123条车道)
|
||||
❌ 过滤 (原因: 不在车道上)
|
||||
|
||||
... (所有车辆)
|
||||
|
||||
📊 过滤结果: 保留 45 辆, 过滤 6 辆
|
||||
```
|
||||
|
||||
### 调试信息说明
|
||||
|
||||
| 信息 | 含义 |
|
||||
|------|------|
|
||||
| 📍 场景信息统计 | 场景的基本信息(车道数、红绿灯数) |
|
||||
| 🔍 开始车道过滤 | 开始过滤,显示待检测车辆总数 |
|
||||
| 🔍 检测位置 | 车辆的坐标和使用的容差值 |
|
||||
| ✅ 在车道上 | 找到了车辆所在的车道,显示车道ID和检查次数 |
|
||||
| ❌ 不在任何车道上 | 所有车道都检查完了,未找到匹配的车道 |
|
||||
| 📊 过滤结果 | 最终统计:保留多少辆,过滤多少辆 |
|
||||
|
||||
### 典型输出案例
|
||||
|
||||
**情况1:车辆在正常车道上**
|
||||
```
|
||||
车辆 1/51: ID=128
|
||||
🔍 检测位置 (-4.11, 46.76), 容差=3.0m
|
||||
✅ 在车道上 (车道184, 检查了32条)
|
||||
✅ 保留
|
||||
```
|
||||
→ 检查了32条车道后找到匹配的车道184
|
||||
|
||||
**情况2:车辆在草坪/停车场**
|
||||
```
|
||||
车辆 7/51: ID=134
|
||||
🔍 检测位置 (-51.34, -3.77), 容差=3.0m
|
||||
❌ 不在任何车道上 (检查了123条车道)
|
||||
❌ 过滤 (原因: 不在车道上)
|
||||
```
|
||||
→ 检查了所有123条车道都不匹配,该车辆被过滤
|
||||
|
||||
---
|
||||
|
||||
## 🚦 红绿灯检测调试 (`debug_traffic_light=True`)
|
||||
|
||||
### 输出内容
|
||||
|
||||
```
|
||||
📍 场景信息统计:
|
||||
- 总车道数: 123
|
||||
- 有红绿灯的车道数: 0
|
||||
⚠️ 场景中没有红绿灯!
|
||||
|
||||
🚦 检测车辆红绿灯 - 位置: (-4.1, 46.8)
|
||||
方法1-导航模块:
|
||||
current_lane = <metadrive.component.lane.straight_lane.StraightLane object>
|
||||
lane_index = 184
|
||||
has_traffic_light = False
|
||||
该车道没有红绿灯
|
||||
方法2-遍历车道: 开始遍历 123 条车道
|
||||
✓ 找到车辆所在车道: 184 (检查了32条)
|
||||
has_traffic_light = False
|
||||
该车道没有红绿灯
|
||||
结果: 返回 0 (无红绿灯/未知)
|
||||
```
|
||||
|
||||
### 调试信息说明
|
||||
|
||||
| 信息 | 含义 |
|
||||
|------|------|
|
||||
| 有红绿灯的车道数 | 统计场景中有多少个红绿灯 |
|
||||
| ⚠️ 场景中没有红绿灯 | 如果数量为0,会特别提示 |
|
||||
| 方法1-导航模块 | 尝试从导航系统获取 |
|
||||
| current_lane | 导航系统返回的当前车道对象 |
|
||||
| lane_index | 车道的唯一标识符 |
|
||||
| has_traffic_light | 该车道是否有红绿灯 |
|
||||
| status | 红绿灯的状态(GREEN/YELLOW/RED/None) |
|
||||
| 方法2-遍历车道 | 兜底方案,遍历所有车道查找 |
|
||||
| ✓ 找到车辆所在车道 | 遍历找到了匹配的车道 |
|
||||
|
||||
### 典型输出案例
|
||||
|
||||
**情况1:场景没有红绿灯**
|
||||
```
|
||||
📍 场景信息统计:
|
||||
- 有红绿灯的车道数: 0
|
||||
⚠️ 场景中没有红绿灯!
|
||||
|
||||
🚦 检测车辆红绿灯 - 位置: (-4.1, 46.8)
|
||||
方法1-导航模块:
|
||||
...
|
||||
has_traffic_light = False
|
||||
该车道没有红绿灯
|
||||
结果: 返回 0 (无红绿灯/未知)
|
||||
```
|
||||
→ 所有车辆都会返回0,这是正常的
|
||||
|
||||
**情况2:有红绿灯且状态正常**
|
||||
```
|
||||
🚦 检测车辆红绿灯 - 位置: (10.5, 20.3)
|
||||
方法1-导航模块:
|
||||
current_lane = <...>
|
||||
lane_index = 205
|
||||
has_traffic_light = True
|
||||
status = TRAFFIC_LIGHT_GREEN
|
||||
✅ 方法1成功: 绿灯
|
||||
```
|
||||
→ 方法1直接成功,返回1(绿灯)
|
||||
|
||||
**情况3:红绿灯状态为None**
|
||||
```
|
||||
🚦 检测车辆红绿灯 - 位置: (10.5, 20.3)
|
||||
方法1-导航模块:
|
||||
current_lane = <...>
|
||||
lane_index = 205
|
||||
has_traffic_light = True
|
||||
status = None
|
||||
⚠️ 方法1: 红绿灯状态为None
|
||||
```
|
||||
→ 有红绿灯,但状态异常,返回0
|
||||
|
||||
**情况4:导航失败,方法2兜底**
|
||||
```
|
||||
🚦 检测车辆红绿灯 - 位置: (15.2, 30.5)
|
||||
方法1-导航模块: 不可用 (hasattr=True, not_none=False)
|
||||
方法2-遍历车道: 开始遍历 123 条车道
|
||||
✓ 找到车辆所在车道: 210 (检查了45条)
|
||||
has_traffic_light = True
|
||||
status = TRAFFIC_LIGHT_RED
|
||||
✅ 方法2成功: 红灯
|
||||
```
|
||||
→ 方法1失败,方法2兜底成功,返回3(红灯)
|
||||
|
||||
---
|
||||
|
||||
## 🧪 测试方法
|
||||
|
||||
### 方式1:使用测试脚本
|
||||
|
||||
```bash
|
||||
# 标准测试(无详细调试)
|
||||
python Env/test_lane_filter.py
|
||||
|
||||
# 调试模式(详细输出)
|
||||
python Env/test_lane_filter.py --debug
|
||||
```
|
||||
|
||||
### 方式2:在代码中直接启用
|
||||
|
||||
```python
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": "...",
|
||||
"use_render": False,
|
||||
|
||||
# 启用调试
|
||||
"debug_lane_filter": True,
|
||||
"debug_traffic_light": True,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
obs = env.reset(0)
|
||||
# 调试信息会自动输出
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📝 调试输出控制
|
||||
|
||||
### 场景1:只想看车道过滤
|
||||
|
||||
```python
|
||||
config = {
|
||||
"debug_lane_filter": True,
|
||||
"debug_traffic_light": False, # 关闭红绿灯调试
|
||||
}
|
||||
```
|
||||
|
||||
### 场景2:只想看红绿灯检测
|
||||
|
||||
```python
|
||||
config = {
|
||||
"debug_lane_filter": False,
|
||||
"debug_traffic_light": True, # 只看红绿灯
|
||||
}
|
||||
```
|
||||
|
||||
### 场景3:生产环境(关闭所有调试)
|
||||
|
||||
```python
|
||||
config = {
|
||||
"debug_lane_filter": False,
|
||||
"debug_traffic_light": False,
|
||||
}
|
||||
# 或者直接不设置这两个参数,默认就是False
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 💡 常见问题诊断
|
||||
|
||||
### 问题1:所有红绿灯状态都是0
|
||||
|
||||
**检查调试输出:**
|
||||
```
|
||||
📍 场景信息统计:
|
||||
- 有红绿灯的车道数: 0
|
||||
⚠️ 场景中没有红绿灯!
|
||||
```
|
||||
|
||||
**结论:** 场景本身没有红绿灯,返回0是正常的
|
||||
|
||||
---
|
||||
|
||||
### 问题2:车辆被过滤但不应该过滤
|
||||
|
||||
**检查调试输出:**
|
||||
```
|
||||
车辆 X: ID=XXX
|
||||
🔍 检测位置 (x, y), 容差=3.0m
|
||||
❌ 不在任何车道上 (检查了123条车道)
|
||||
❌ 过滤 (原因: 不在车道上)
|
||||
```
|
||||
|
||||
**可能原因:**
|
||||
1. 车辆确实在非车道区域(草坪/停车场)
|
||||
2. 容差值太小,可以尝试增大 `lane_tolerance`
|
||||
3. 车道数据有问题
|
||||
|
||||
**解决方案:**
|
||||
```python
|
||||
config = {
|
||||
"lane_tolerance": 5.0, # 增大容差到5米
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 问题3:性能下降
|
||||
|
||||
启用调试模式会有大量输出,影响性能:
|
||||
|
||||
**解决方案:**
|
||||
- 只在开发/调试时启用
|
||||
- 生产环境关闭所有调试开关
|
||||
- 或者只测试少量车辆:
|
||||
```python
|
||||
config = {
|
||||
"max_controlled_vehicles": 5, # 只测试5辆车
|
||||
"debug_traffic_light": True,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📌 最佳实践
|
||||
|
||||
1. **开发阶段**:启用调试,理解代码行为
|
||||
2. **调试问题**:根据需要选择性启用调试
|
||||
3. **性能测试**:关闭所有调试
|
||||
4. **生产运行**:永久关闭调试
|
||||
|
||||
---
|
||||
|
||||
## 🔧 调试输出示例
|
||||
|
||||
完整的调试运行示例:
|
||||
|
||||
```bash
|
||||
cd /home/huangfukk/MAGAIL4AutoDrive
|
||||
python Env/test_lane_filter.py --debug
|
||||
```
|
||||
|
||||
输出会包含:
|
||||
- 场景统计信息
|
||||
- 每辆车的详细检测过程
|
||||
- 最终的过滤/检测结果
|
||||
- 性能统计
|
||||
|
||||
---
|
||||
|
||||
## 📖 相关文档
|
||||
|
||||
- `README.md` - 项目总览和问题解决
|
||||
- `CHANGELOG.md` - 更新日志
|
||||
- `PERFORMANCE_OPTIMIZATION.md` - 性能优化指南
|
||||
|
||||
@@ -1,221 +0,0 @@
|
||||
# GPU加速指南
|
||||
|
||||
## 当前性能瓶颈分析
|
||||
|
||||
从测试结果看,即使关闭渲染,FPS仍然只有15-20左右,主要瓶颈是:
|
||||
|
||||
### 计算量分析(51辆车)
|
||||
```
|
||||
激光雷达计算:
|
||||
- 前向雷达:80束 × 51车 = 4,080次射线检测
|
||||
- 侧向雷达:10束 × 51车 = 510次射线检测
|
||||
- 车道线雷达:10束 × 51车 = 510次射线检测
|
||||
合计:5,100次射线检测/帧
|
||||
|
||||
红绿灯检测:
|
||||
- 遍历所有车道 × 51车 = 数千次几何计算
|
||||
```
|
||||
|
||||
**关键问题**:这些计算都是CPU单线程串行的,无法利用多核和GPU!
|
||||
|
||||
---
|
||||
|
||||
## GPU加速方案
|
||||
|
||||
### 方案1:优化激光雷达计算(已实现)✅
|
||||
|
||||
**优化内容:**
|
||||
1. 减少激光束数量:100束 → 52束(减少48%)
|
||||
2. 优化红绿灯检测:避免遍历所有车道
|
||||
3. 激光雷达缓存:每N帧才重新计算一次
|
||||
|
||||
**预期提升:** 2-4倍(30-60 FPS)
|
||||
|
||||
**使用方法:**
|
||||
```bash
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 方案2:MetaDrive GPU渲染(有限支持)
|
||||
|
||||
**说明:**
|
||||
MetaDrive基于Panda3D引擎,理论上支持GPU渲染,但:
|
||||
- GPU主要用于**图形渲染**,不是物理计算
|
||||
- 激光雷达的射线检测仍在CPU上
|
||||
- GPU渲染主要加速可视化,不加速训练
|
||||
|
||||
**启用方法:**
|
||||
```python
|
||||
config = {
|
||||
"use_render": True,
|
||||
"render_mode": "onscreen", # 或 "offscreen"
|
||||
# Panda3D会自动尝试使用GPU
|
||||
}
|
||||
```
|
||||
|
||||
**限制:**
|
||||
- 需要显示器或虚拟显示(Xvfb)
|
||||
- WSL2环境需要配置X11转发
|
||||
- 对无渲染训练无帮助
|
||||
|
||||
---
|
||||
|
||||
### 方案3:使用GPU加速的物理引擎(推荐但需要迁移)
|
||||
|
||||
**选项A:Isaac Gym (NVIDIA)**
|
||||
- 完全在GPU上运行物理模拟和渲染
|
||||
- 可同时模拟数千个环境
|
||||
- **缺点**:需要完全重写环境代码,迁移成本高
|
||||
|
||||
**选项B:IsaacSim/Omniverse**
|
||||
- NVIDIA的高级仿真平台
|
||||
- 支持GPU加速的激光雷达
|
||||
- **缺点**:学习曲线陡峭,环境配置复杂
|
||||
|
||||
**选项C:Brax (Google)**
|
||||
- JAX驱动,完全在GPU/TPU上运行
|
||||
- **缺点**:功能有限,不支持复杂场景
|
||||
|
||||
---
|
||||
|
||||
### 方案4:策略网络GPU加速(推荐)✅
|
||||
|
||||
虽然环境仿真在CPU,但可以让**策略网络在GPU上运行**:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# 创建GPU上的策略模型
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
policy = PolicyNetwork().to(device)
|
||||
|
||||
# 批量处理观测
|
||||
obs_batch = torch.tensor(obs_list).to(device)
|
||||
with torch.no_grad():
|
||||
actions = policy(obs_batch)
|
||||
actions = actions.cpu().numpy()
|
||||
```
|
||||
|
||||
**优势:**
|
||||
- 51辆车的推理可以并行
|
||||
- 如果使用RL训练,GPU加速训练过程
|
||||
- 不需要修改环境代码
|
||||
|
||||
---
|
||||
|
||||
### 方案5:多进程并行(最实用)✅
|
||||
|
||||
既然单个环境受限于CPU单线程,可以**并行运行多个环境**:
|
||||
|
||||
```python
|
||||
from multiprocessing import Pool
|
||||
import os
|
||||
|
||||
def run_single_env(seed):
|
||||
"""运行单个环境实例"""
|
||||
env = MultiAgentScenarioEnv(config=...)
|
||||
obs = env.reset(seed)
|
||||
|
||||
for step in range(1000):
|
||||
actions = {...}
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
env.close()
|
||||
return results
|
||||
|
||||
# 使用进程池并行运行
|
||||
if __name__ == "__main__":
|
||||
num_processes = os.cpu_count() # 12600KF有10核20线程
|
||||
seeds = list(range(num_processes))
|
||||
|
||||
with Pool(processes=num_processes) as pool:
|
||||
results = pool.map(run_single_env, seeds)
|
||||
```
|
||||
|
||||
**预期提升:** 接近线性(10核 ≈ 10倍吞吐量)
|
||||
|
||||
**CPU利用率:** 可达80-100%
|
||||
|
||||
---
|
||||
|
||||
## 推荐的完整优化方案
|
||||
|
||||
### 1. 立即可用(已实现)
|
||||
```bash
|
||||
# 使用优化版本,激光束减少+缓存
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
**预期:** 30-60 FPS(2-4倍提升)
|
||||
|
||||
### 2. 短期优化(1-2小时)
|
||||
- 实现多进程并行
|
||||
- 策略网络迁移到GPU
|
||||
|
||||
**预期:** 300-600 FPS(总吞吐量)
|
||||
|
||||
### 3. 中期优化(1-2天)
|
||||
- 使用NumPy矢量化批量处理观测
|
||||
- 优化Python代码热点(用Cython/Numba)
|
||||
|
||||
**预期:** 额外20-30%提升
|
||||
|
||||
### 4. 长期方案(1-2周)
|
||||
- 迁移到Isaac Gym等GPU加速仿真器
|
||||
- 或使用分布式训练框架(Ray/RLlib)
|
||||
|
||||
**预期:** 10-100倍提升
|
||||
|
||||
---
|
||||
|
||||
## 为什么MetaDrive无法直接使用GPU?
|
||||
|
||||
### 架构限制:
|
||||
1. **物理引擎**:使用Bullet/Panda3D的CPU物理引擎
|
||||
2. **射线检测**:串行CPU计算,无法并行
|
||||
3. **Python GIL**:全局解释器锁限制多线程
|
||||
4. **设计目标**:MetaDrive设计时主要考虑灵活性而非极致性能
|
||||
|
||||
### GPU在仿真中的作用:
|
||||
- ✅ **图形渲染**:绘制画面(但我们训练时不需要)
|
||||
- ✅ **神经网络推理/训练**:策略模型计算
|
||||
- ❌ **物理计算**:MetaDrive的物理引擎在CPU
|
||||
- ❌ **传感器模拟**:激光雷达等在CPU
|
||||
|
||||
---
|
||||
|
||||
## 检查GPU是否可用
|
||||
|
||||
```bash
|
||||
# 检查NVIDIA GPU
|
||||
nvidia-smi
|
||||
|
||||
# 检查PyTorch GPU支持
|
||||
python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
|
||||
|
||||
# 检查MetaDrive渲染设备
|
||||
python -c "from panda3d.core import GraphicsPipeSelection; print(GraphicsPipeSelection.get_global_ptr().get_default_pipe())"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
| 方案 | 实现难度 | 性能提升 | GPU使用 | 推荐度 |
|
||||
|------|----------|----------|---------|--------|
|
||||
| 减少激光束 | ⭐ | 2-4x | ❌ | ⭐⭐⭐⭐⭐ |
|
||||
| 激光雷达缓存 | ⭐ | 1.5-3x | ❌ | ⭐⭐⭐⭐⭐ |
|
||||
| 多进程并行 | ⭐⭐ | 5-10x | ❌ | ⭐⭐⭐⭐⭐ |
|
||||
| 策略GPU加速 | ⭐⭐ | 2-5x | ✅ | ⭐⭐⭐⭐ |
|
||||
| GPU渲染 | ⭐⭐⭐ | 1.2x | ✅ | ⭐⭐ |
|
||||
| 迁移Isaac Gym | ⭐⭐⭐⭐⭐ | 10-100x | ✅ | ⭐⭐⭐ |
|
||||
|
||||
**结论:**
|
||||
1. 先用已实现的优化(减少激光束+缓存)
|
||||
2. 再实现多进程并行
|
||||
3. 策略网络用GPU训练
|
||||
4. 如果还不够,考虑迁移到GPU仿真器
|
||||
|
||||
@@ -1,413 +0,0 @@
|
||||
# 日志记录功能使用指南
|
||||
|
||||
## 📋 概述
|
||||
|
||||
为所有运行脚本添加了日志记录功能,可以将终端输出同时保存到文本文件,方便后续分析和问题排查。
|
||||
|
||||
---
|
||||
|
||||
## 🎯 功能特点
|
||||
|
||||
1. **双向输出**:同时输出到终端和文件,不影响实时查看
|
||||
2. **自动管理**:使用上下文管理器,自动处理文件开启/关闭
|
||||
3. **灵活配置**:支持自定义文件名和日志目录
|
||||
4. **时间戳命名**:默认使用时间戳生成唯一文件名
|
||||
5. **无缝集成**:只需添加命令行参数,无需修改代码
|
||||
|
||||
---
|
||||
|
||||
## 🚀 快速使用
|
||||
|
||||
### 1. 基础用法
|
||||
|
||||
```bash
|
||||
# 不启用日志(默认)
|
||||
python Env/run_multiagent_env.py
|
||||
|
||||
# 启用日志记录
|
||||
python Env/run_multiagent_env.py --log
|
||||
|
||||
# 或使用短选项
|
||||
python Env/run_multiagent_env.py -l
|
||||
```
|
||||
|
||||
### 2. 自定义文件名
|
||||
|
||||
```bash
|
||||
# 使用自定义日志文件名
|
||||
python Env/run_multiagent_env.py --log --log-file=my_test.log
|
||||
|
||||
# 测试脚本也支持
|
||||
python Env/test_lane_filter.py --log --log-file=test_results.log
|
||||
```
|
||||
|
||||
### 3. 组合使用调试和日志
|
||||
|
||||
```bash
|
||||
# 测试脚本:调试模式 + 日志记录
|
||||
python Env/test_lane_filter.py --debug --log
|
||||
|
||||
# 会生成类似:test_debug_20251021_123456.log
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📁 日志文件位置
|
||||
|
||||
默认日志目录:`Env/logs/`
|
||||
|
||||
### 文件命名规则
|
||||
|
||||
| 脚本 | 默认文件名格式 | 示例 |
|
||||
|------|---------------|------|
|
||||
| `run_multiagent_env.py` | `run_YYYYMMDD_HHMMSS.log` | `run_20251021_143022.log` |
|
||||
| `run_multiagent_env_fast.py` | `run_fast.log` | `run_fast.log` |
|
||||
| `test_lane_filter.py` | `test_{mode}_YYYYMMDD_HHMMSS.log` | `test_debug_20251021_143500.log` |
|
||||
|
||||
**说明**:
|
||||
- `YYYYMMDD_HHMMSS` 是时间戳(年月日_时分秒)
|
||||
- `{mode}` 是测试模式(`standard` 或 `debug`)
|
||||
|
||||
---
|
||||
|
||||
## 📝 所有支持的脚本
|
||||
|
||||
### 1. run_multiagent_env.py(标准运行脚本)
|
||||
|
||||
```bash
|
||||
# 不启用日志
|
||||
python Env/run_multiagent_env.py
|
||||
|
||||
# 启用日志(自动生成时间戳文件名)
|
||||
python Env/run_multiagent_env.py --log
|
||||
|
||||
# 自定义文件名
|
||||
python Env/run_multiagent_env.py --log --log-file=run_test1.log
|
||||
```
|
||||
|
||||
**日志位置**:`Env/logs/run_YYYYMMDD_HHMMSS.log`
|
||||
|
||||
---
|
||||
|
||||
### 2. run_multiagent_env_fast.py(高性能版本)
|
||||
|
||||
```bash
|
||||
# 启用日志
|
||||
python Env/run_multiagent_env_fast.py --log
|
||||
|
||||
# 自定义文件名
|
||||
python Env/run_multiagent_env_fast.py --log --log-file=fast_test.log
|
||||
```
|
||||
|
||||
**日志位置**:`Env/logs/run_fast.log`(默认)
|
||||
|
||||
---
|
||||
|
||||
### 3. test_lane_filter.py(测试脚本)
|
||||
|
||||
```bash
|
||||
# 标准测试 + 日志
|
||||
python Env/test_lane_filter.py --log
|
||||
|
||||
# 调试测试 + 日志
|
||||
python Env/test_lane_filter.py --debug --log
|
||||
|
||||
# 自定义文件名
|
||||
python Env/test_lane_filter.py --log --log-file=my_test.log
|
||||
|
||||
# 组合使用
|
||||
python Env/test_lane_filter.py --debug --log --log-file=debug_run.log
|
||||
```
|
||||
|
||||
**日志位置**:
|
||||
- 标准模式:`Env/logs/test_standard_YYYYMMDD_HHMMSS.log`
|
||||
- 调试模式:`Env/logs/test_debug_YYYYMMDD_HHMMSS.log`
|
||||
|
||||
---
|
||||
|
||||
## 💻 编程接口
|
||||
|
||||
如果您想在代码中直接使用日志功能:
|
||||
|
||||
```python
|
||||
from logger_utils import setup_logger
|
||||
|
||||
# 方式1:使用上下文管理器(推荐)
|
||||
with setup_logger(log_file="my_log.log", log_dir="logs"):
|
||||
print("这条消息会同时输出到终端和文件")
|
||||
# 运行您的代码
|
||||
# ...
|
||||
|
||||
# 方式2:手动管理
|
||||
from logger_utils import LoggerContext
|
||||
|
||||
logger = LoggerContext(log_file="custom.log", log_dir="output")
|
||||
logger.__enter__() # 开启日志
|
||||
print("输出消息")
|
||||
logger.__exit__(None, None, None) # 关闭日志
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📊 日志内容示例
|
||||
|
||||
### 标准运行
|
||||
|
||||
```
|
||||
📝 日志记录已启用
|
||||
📁 日志文件: Env/logs/run_20251021_143022.log
|
||||
------------------------------------------------------------
|
||||
💡 提示: 使用 --log 或 -l 参数启用日志记录
|
||||
示例: python run_multiagent_env.py --log
|
||||
自定义文件名: python run_multiagent_env.py --log --log-file=my_run.log
|
||||
------------------------------------------------------------
|
||||
[INFO] Environment: MultiAgentScenarioEnv
|
||||
[INFO] MetaDrive version: 0.4.3
|
||||
...
|
||||
------------------------------------------------------------
|
||||
✅ 日志已保存到: Env/logs/run_20251021_143022.log
|
||||
```
|
||||
|
||||
### 调试模式
|
||||
|
||||
```
|
||||
📝 日志记录已启用
|
||||
📁 日志文件: Env/logs/test_debug_20251021_143500.log
|
||||
------------------------------------------------------------
|
||||
🐛 调试模式启用
|
||||
============================================================
|
||||
|
||||
📍 场景信息统计:
|
||||
- 总车道数: 123
|
||||
- 有红绿灯的车道数: 0
|
||||
⚠️ 场景中没有红绿灯!
|
||||
|
||||
🔍 开始车道过滤: 共 51 辆车待检测
|
||||
...
|
||||
------------------------------------------------------------
|
||||
✅ 日志已保存到: Env/logs/test_debug_20251021_143500.log
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔧 高级配置
|
||||
|
||||
### 自定义日志目录
|
||||
|
||||
```python
|
||||
from logger_utils import setup_logger
|
||||
|
||||
# 指定不同的日志目录
|
||||
with setup_logger(log_file="test.log", log_dir="my_logs"):
|
||||
print("日志会保存到 my_logs/test.log")
|
||||
```
|
||||
|
||||
### 追加模式
|
||||
|
||||
```python
|
||||
from logger_utils import setup_logger
|
||||
|
||||
# 追加到现有文件(而不是覆盖)
|
||||
with setup_logger(log_file="test.log", mode='a'): # mode='a' 表示追加
|
||||
print("这条消息会追加到文件末尾")
|
||||
```
|
||||
|
||||
### 只重定向特定输出
|
||||
|
||||
```python
|
||||
from logger_utils import LoggerContext
|
||||
|
||||
# 只重定向stdout,不重定向stderr
|
||||
logger = LoggerContext(
|
||||
log_file="test.log",
|
||||
redirect_stdout=True, # 重定向标准输出
|
||||
redirect_stderr=False # 不重定向错误输出
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📋 命令行参数总结
|
||||
|
||||
| 参数 | 短选项 | 说明 | 示例 |
|
||||
|------|--------|------|------|
|
||||
| `--log` | `-l` | 启用日志记录 | `--log` |
|
||||
| `--log-file=NAME` | 无 | 指定日志文件名 | `--log-file=test.log` |
|
||||
| `--debug` | `-d` | 启用调试模式(test_lane_filter.py) | `--debug` |
|
||||
|
||||
### 参数组合
|
||||
|
||||
```bash
|
||||
# 示例1:标准模式 + 日志
|
||||
python Env/test_lane_filter.py --log
|
||||
|
||||
# 示例2:调试模式 + 日志
|
||||
python Env/test_lane_filter.py --debug --log
|
||||
|
||||
# 示例3:调试 + 自定义文件名
|
||||
python Env/test_lane_filter.py -d --log --log-file=my_debug.log
|
||||
|
||||
# 示例4:所有参数
|
||||
python Env/test_lane_filter.py --debug --log --log-file=full_test.log
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ 常见问题
|
||||
|
||||
### Q1: 日志文件在哪里?
|
||||
|
||||
**A**: 默认在 `Env/logs/` 目录下。如果目录不存在,会自动创建。
|
||||
|
||||
```bash
|
||||
# 查看所有日志文件
|
||||
ls -lh Env/logs/
|
||||
|
||||
# 查看最新的日志
|
||||
ls -lt Env/logs/ | head -5
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Q2: 如何查看日志内容?
|
||||
|
||||
**A**: 使用任何文本编辑器或命令行工具:
|
||||
|
||||
```bash
|
||||
# 方式1:使用cat
|
||||
cat Env/logs/run_20251021_143022.log
|
||||
|
||||
# 方式2:使用less(可翻页)
|
||||
less Env/logs/run_20251021_143022.log
|
||||
|
||||
# 方式3:查看末尾内容
|
||||
tail -n 50 Env/logs/run_20251021_143022.log
|
||||
|
||||
# 方式4:实时监控(适合长时间运行)
|
||||
tail -f Env/logs/run_20251021_143022.log
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Q3: 日志文件太多怎么办?
|
||||
|
||||
**A**: 可以定期清理旧日志:
|
||||
|
||||
```bash
|
||||
# 删除7天前的日志
|
||||
find Env/logs/ -name "*.log" -mtime +7 -delete
|
||||
|
||||
# 只保留最新的10个日志
|
||||
cd Env/logs && ls -t *.log | tail -n +11 | xargs rm -f
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Q4: 日志会影响性能吗?
|
||||
|
||||
**A**: 影响很小,因为:
|
||||
1. 文件I/O是异步的
|
||||
2. 使用了缓冲区
|
||||
3. 立即刷新确保数据不丢失
|
||||
|
||||
如果追求极致性能,建议训练时不启用日志,只在需要分析时启用。
|
||||
|
||||
---
|
||||
|
||||
### Q5: 可以同时记录多个脚本的日志吗?
|
||||
|
||||
**A**: 可以,每个脚本使用不同的日志文件:
|
||||
|
||||
```bash
|
||||
# 终端1
|
||||
python Env/run_multiagent_env.py --log --log-file=script1.log
|
||||
|
||||
# 终端2(同时运行)
|
||||
python Env/test_lane_filter.py --log --log-file=script2.log
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 💡 最佳实践
|
||||
|
||||
### 1. 开发阶段
|
||||
|
||||
```bash
|
||||
# 使用调试模式 + 日志,方便排查问题
|
||||
python Env/test_lane_filter.py --debug --log
|
||||
```
|
||||
|
||||
### 2. 长时间运行
|
||||
|
||||
```bash
|
||||
# 启用日志,避免输出丢失
|
||||
nohup python Env/run_multiagent_env.py --log > /dev/null 2>&1 &
|
||||
|
||||
# 查看实时输出
|
||||
tail -f Env/logs/run_*.log
|
||||
```
|
||||
|
||||
### 3. 批量实验
|
||||
|
||||
```bash
|
||||
# 为每次实验使用不同的日志文件
|
||||
for i in {1..5}; do
|
||||
python Env/run_multiagent_env.py --log --log-file=exp_${i}.log
|
||||
done
|
||||
```
|
||||
|
||||
### 4. 性能测试
|
||||
|
||||
```bash
|
||||
# 不启用日志,获得最佳性能
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📖 相关文档
|
||||
|
||||
- `README.md` - 项目总览
|
||||
- `DEBUG_GUIDE.md` - 调试功能使用指南
|
||||
- `CHANGELOG.md` - 更新日志
|
||||
|
||||
---
|
||||
|
||||
## 🔍 技术细节
|
||||
|
||||
### 实现原理
|
||||
|
||||
1. **TeeLogger类**:实现同时写入终端和文件
|
||||
2. **上下文管理器**:自动管理资源(文件打开/关闭)
|
||||
3. **sys.stdout重定向**:拦截所有print输出
|
||||
4. **即时刷新**:每次写入后立即刷新,确保数据不丢失
|
||||
|
||||
### 源代码
|
||||
|
||||
详见 `Env/logger_utils.py`
|
||||
|
||||
```python
|
||||
# 简化示例
|
||||
class TeeLogger:
|
||||
def write(self, message):
|
||||
self.terminal.write(message) # 输出到终端
|
||||
self.log_file.write(message) # 写入文件
|
||||
self.log_file.flush() # 立即刷新
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ✅ 总结
|
||||
|
||||
- ✅ 简单易用:只需添加 `--log` 参数
|
||||
- ✅ 不影响输出:终端仍可实时查看
|
||||
- ✅ 自动管理:文件自动开启/关闭
|
||||
- ✅ 灵活配置:支持自定义文件名和目录
|
||||
- ✅ 完整记录:包含所有调试信息
|
||||
|
||||
立即开始使用:
|
||||
|
||||
```bash
|
||||
python Env/test_lane_filter.py --debug --log
|
||||
```
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
# MetaDrive 性能优化指南
|
||||
|
||||
## 为什么帧率只有15FPS且CPU利用率不高?
|
||||
|
||||
### 主要原因:
|
||||
|
||||
1. **渲染瓶颈(最主要)**
|
||||
- `use_render: True` + 每帧调用 `env.render()` 会严重限制帧率
|
||||
- MetaDrive 使用 Panda3D 渲染引擎,渲染是**同步阻塞**的
|
||||
- 即使CPU有余力,也要等待渲染完成才能继续下一步
|
||||
- 这就是为什么CPU利用率低但帧率也低的原因
|
||||
|
||||
2. **激光雷达计算开销**
|
||||
- 每帧对每辆车进行3次激光雷达扫描(100个激光束)
|
||||
- 需要进行物理射线检测,计算量较大
|
||||
|
||||
3. **物理引擎同步**
|
||||
- 默认物理步长很小(0.02s),需要频繁计算
|
||||
|
||||
4. **Python GIL限制**
|
||||
- Python全局解释器锁限制了多核并行
|
||||
- 即使是多核CPU,Python单线程性能才是瓶颈
|
||||
|
||||
## 性能优化方案
|
||||
|
||||
### 方案1:关闭渲染(推荐用于训练)
|
||||
**预期提升:10-20倍(150-300+ FPS)**
|
||||
|
||||
```python
|
||||
config = {
|
||||
"use_render": False, # 关闭渲染
|
||||
"render_pipeline": False,
|
||||
"image_observation": False,
|
||||
"interface_panel": [],
|
||||
"manual_control": False,
|
||||
}
|
||||
```
|
||||
|
||||
### 方案2:降低物理计算频率
|
||||
**预期提升:2-3倍**
|
||||
|
||||
```python
|
||||
config = {
|
||||
"physics_world_step_size": 0.05, # 默认0.02,增大步长
|
||||
"decision_repeat": 5, # 每5个物理步执行一次决策
|
||||
}
|
||||
```
|
||||
|
||||
### 方案3:优化激光雷达
|
||||
**预期提升:1.5-2倍**
|
||||
|
||||
修改 `scenario_env.py` 中的 `_get_all_obs()` 函数:
|
||||
|
||||
```python
|
||||
# 减少激光束数量
|
||||
lidar = self.engine.get_sensor("lidar").perceive(
|
||||
num_lasers=40, # 从80减到40
|
||||
distance=30,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.dynamic_world
|
||||
)
|
||||
|
||||
# 或者降低扫描频率(每N步才扫描一次)
|
||||
if self.round % 5 == 0:
|
||||
lidar = self.engine.get_sensor("lidar").perceive(...)
|
||||
else:
|
||||
lidar = self.last_lidar[agent_id] # 使用缓存
|
||||
```
|
||||
|
||||
### 方案4:间歇性渲染
|
||||
**适用场景:既需要可视化又想提升性能**
|
||||
|
||||
```python
|
||||
# 每10步渲染一次,而不是每步都渲染
|
||||
if step % 10 == 0:
|
||||
env.render(mode="topdown")
|
||||
```
|
||||
|
||||
### 方案5:使用多进程并行(高级)
|
||||
**预期提升:接近线性(取决于进程数)**
|
||||
|
||||
```python
|
||||
from multiprocessing import Pool
|
||||
|
||||
def run_env(seed):
|
||||
env = MultiAgentScenarioEnv(config=...)
|
||||
# 运行仿真
|
||||
return results
|
||||
|
||||
# 使用进程池并行运行多个环境
|
||||
with Pool(processes=8) as pool:
|
||||
results = pool.map(run_env, range(8))
|
||||
```
|
||||
|
||||
## 文件说明
|
||||
|
||||
- `run_multiagent_env.py` - **标准版本**(无渲染,基础优化)
|
||||
- `run_multiagent_env_fast.py` - **极速版本**(激光雷达优化+缓存)⭐推荐
|
||||
- `run_multiagent_env_parallel.py` - **并行版本**(多进程,最高吞吐量)⭐⭐推荐
|
||||
- `run_multiagent_env_visual.py` - **可视化版本**(有渲染,适合调试)
|
||||
|
||||
## 性能对比
|
||||
|
||||
| 配置 | 单环境FPS | 总吞吐量 | CPU利用率 | 文件 | 适用场景 |
|
||||
|------|-----------|----------|-----------|------|----------|
|
||||
| 原始配置(有渲染) | 15-20 | 15-20 | 15-20% | visual | 实时可视化调试 |
|
||||
| 关闭渲染 | 20-25 | 20-25 | 20-30% | 标准版 | 基础训练 |
|
||||
| 激光雷达优化+缓存 | 30-60 | 30-60 | 30-50% | fast | 快速训练⭐ |
|
||||
| 多进程并行(10核) | 30-60 | 300-600 | 90-100% | parallel | 大规模训练⭐⭐ |
|
||||
|
||||
**说明:**
|
||||
- **单环境FPS**:单个环境实例的帧率
|
||||
- **总吞吐量**:所有进程合计的 steps/second
|
||||
- 12600KF(10核20线程)推荐使用并行版本
|
||||
|
||||
## 建议
|
||||
|
||||
1. **训练时**:使用高性能版本(关闭渲染)
|
||||
2. **调试时**:使用可视化版本,或间歇性渲染
|
||||
3. **大规模实验**:使用多进程并行
|
||||
4. **如果需要GPU加速**:考虑使用GPU渲染或将策略网络部署到GPU上
|
||||
|
||||
## 为什么CPU利用率低?
|
||||
|
||||
- **渲染阻塞**:CPU在等待渲染完成
|
||||
- **Python GIL**:限制了多核利用
|
||||
- **I/O等待**:可能在等待磁盘读取数据
|
||||
- **单线程瓶颈**:MetaDrive主循环是单线程的
|
||||
|
||||
解决方法:关闭渲染 + 多进程并行
|
||||
|
||||
@@ -1,241 +0,0 @@
|
||||
# 快速使用指南
|
||||
|
||||
## 🚀 已实现的性能优化
|
||||
|
||||
根据您的测试结果,原始版本FPS只有15左右,现已进行了全面优化。
|
||||
|
||||
---
|
||||
|
||||
## 📊 性能瓶颈分析
|
||||
|
||||
您的CPU是12600KF(10核20线程),但利用率不到20%,原因是:
|
||||
|
||||
1. **激光雷达计算瓶颈**:51辆车 × 100个激光束 = 每帧5100次射线检测
|
||||
2. **红绿灯检测低效**:遍历所有车道进行几何计算
|
||||
3. **Python GIL限制**:单线程执行,无法利用多核
|
||||
4. **计算串行化**:所有车辆依次处理,没有并行
|
||||
|
||||
---
|
||||
|
||||
## 🎯 推荐使用方案
|
||||
|
||||
### 方案1:极速单环境(推荐新手)⭐
|
||||
```bash
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
|
||||
**优化内容:**
|
||||
- ✅ 激光束:100束 → 52束(减少48%计算量)
|
||||
- ✅ 激光雷达缓存:每3帧才重新计算
|
||||
- ✅ 红绿灯检测优化:避免遍历所有车道
|
||||
- ✅ 关闭所有渲染和调试
|
||||
|
||||
**预期性能:** 30-60 FPS(2-4倍提升)
|
||||
|
||||
---
|
||||
|
||||
### 方案2:多进程并行(推荐训练)⭐⭐
|
||||
```bash
|
||||
python Env/run_multiagent_env_parallel.py
|
||||
```
|
||||
|
||||
**优化内容:**
|
||||
- ✅ 同时运行10个独立环境(充分利用10核CPU)
|
||||
- ✅ 每个环境应用所有单环境优化
|
||||
- ✅ CPU利用率可达90-100%
|
||||
|
||||
**预期性能:** 300-600 steps/s(20-40倍总吞吐量)
|
||||
|
||||
---
|
||||
|
||||
### 方案3:可视化调试
|
||||
```bash
|
||||
python Env/run_multiagent_env_visual.py
|
||||
```
|
||||
|
||||
**说明:** 保留渲染功能,FPS约15,仅用于调试
|
||||
|
||||
---
|
||||
|
||||
## 🔧 关于GPU加速
|
||||
|
||||
### GPU能否加速MetaDrive?
|
||||
|
||||
**简短回答:有限支持,主要瓶颈不在GPU**
|
||||
|
||||
**详细说明:**
|
||||
|
||||
1. **物理计算(主要瓶颈)** ❌ 不支持GPU
|
||||
- MetaDrive使用Bullet物理引擎,只在CPU运行
|
||||
- 激光雷达射线检测也在CPU
|
||||
- 这是FPS低的主要原因
|
||||
|
||||
2. **图形渲染** ✅ 支持GPU
|
||||
- Panda3D会自动使用GPU渲染
|
||||
- 但我们训练时关闭了渲染,所以GPU无用武之地
|
||||
|
||||
3. **策略网络** ✅ 支持GPU
|
||||
- 可以把Policy模型放到GPU上
|
||||
- 但环境本身仍在CPU
|
||||
|
||||
### GPU渲染配置(可选)
|
||||
```python
|
||||
config = {
|
||||
"use_render": True,
|
||||
# GPU会自动用于渲染
|
||||
}
|
||||
```
|
||||
|
||||
### 策略网络GPU加速(推荐)
|
||||
```python
|
||||
import torch
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
policy_model = PolicyNet().to(device)
|
||||
|
||||
# 批量推理
|
||||
obs_tensor = torch.tensor(obs_list).to(device)
|
||||
actions = policy_model(obs_tensor)
|
||||
```
|
||||
|
||||
**详细说明请看:** `GPU_ACCELERATION.md`
|
||||
|
||||
---
|
||||
|
||||
## 📈 性能对比
|
||||
|
||||
| 版本 | FPS | CPU利用率 | 改进 |
|
||||
|------|-----|-----------|------|
|
||||
| 原始版本 | 15 | 20% | - |
|
||||
| 极速版本 | 30-60 | 30-50% | 2-4x |
|
||||
| 并行版本 | 30-60/env | 90-100% | 总吞吐20-40x |
|
||||
|
||||
---
|
||||
|
||||
## 💡 使用建议
|
||||
|
||||
### 场景1:快速测试环境
|
||||
```bash
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
单环境,快速验证功能
|
||||
|
||||
### 场景2:大规模数据收集
|
||||
```bash
|
||||
python Env/run_multiagent_env_parallel.py
|
||||
```
|
||||
多进程,最大化数据收集速度
|
||||
|
||||
### 场景3:RL训练
|
||||
```bash
|
||||
# 推荐使用Ray RLlib等框架,它们内置了并行环境管理
|
||||
# 或者修改parallel版本,保存经验到replay buffer
|
||||
```
|
||||
|
||||
### 场景4:调试/可视化
|
||||
```bash
|
||||
python Env/run_multiagent_env_visual.py
|
||||
```
|
||||
带渲染,可以看到车辆运行
|
||||
|
||||
---
|
||||
|
||||
## 🔍 性能监控
|
||||
|
||||
所有版本都内置了性能统计,运行时会显示:
|
||||
```
|
||||
Step 100: FPS = 45.23, 车辆数 = 51, 平均步时间 = 22.10ms
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ 高级优化选项
|
||||
|
||||
### 调整激光雷达缓存频率
|
||||
|
||||
编辑 `run_multiagent_env_fast.py`:
|
||||
```python
|
||||
env.lidar_cache_interval = 3 # 改为5可进一步提速(但观测会更旧)
|
||||
```
|
||||
|
||||
### 调整并行进程数
|
||||
|
||||
编辑 `run_multiagent_env_parallel.py`:
|
||||
```python
|
||||
num_workers = 10 # 改为更少的进程数(如果内存不足)
|
||||
```
|
||||
|
||||
### 进一步减少激光束
|
||||
|
||||
编辑 `scenario_env.py` 的 `_get_all_obs()` 函数:
|
||||
```python
|
||||
lidar = self.engine.get_sensor("lidar").perceive(
|
||||
num_lasers=20, # 从40进一步减少到20
|
||||
distance=20, # 从30减少到20米
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎓 为什么CPU利用率低?
|
||||
|
||||
### 原因分析:
|
||||
|
||||
1. **单线程瓶颈**
|
||||
- Python GIL限制
|
||||
- MetaDrive主循环是单线程的
|
||||
- 即使有10个核心,也只用1个
|
||||
|
||||
2. **I/O等待**
|
||||
- 等待渲染完成(如果开启)
|
||||
- 等待磁盘读取数据
|
||||
|
||||
3. **计算不均衡**
|
||||
- 某些计算很重(激光雷达),某些很轻
|
||||
- CPU在重计算之间有空闲
|
||||
|
||||
### 解决方案:
|
||||
|
||||
✅ **已实现:** 多进程并行(`run_multiagent_env_parallel.py`)
|
||||
- 每个进程占用1个核心
|
||||
- 10个进程可充分利用10核CPU
|
||||
- CPU利用率可达90-100%
|
||||
|
||||
---
|
||||
|
||||
## 📚 相关文档
|
||||
|
||||
- `PERFORMANCE_OPTIMIZATION.md` - 详细的性能优化指南
|
||||
- `GPU_ACCELERATION.md` - GPU加速的完整说明
|
||||
|
||||
---
|
||||
|
||||
## ❓ 常见问题
|
||||
|
||||
### Q: 为什么关闭渲染后FPS还是只有20?
|
||||
A: 主要瓶颈是激光雷达计算,不是渲染。请使用 `run_multiagent_env_fast.py`。
|
||||
|
||||
### Q: GPU能加速训练吗?
|
||||
A: 环境模拟在CPU,但策略网络可以在GPU上训练。
|
||||
|
||||
### Q: 如何最大化CPU利用率?
|
||||
A: 使用 `run_multiagent_env_parallel.py` 多进程版本。
|
||||
|
||||
### Q: 会影响观测精度吗?
|
||||
A: 激光束减少会略微降低精度,但实践中影响很小。缓存会让观测滞后1-2帧。
|
||||
|
||||
### Q: 如何恢复原始配置?
|
||||
A: 使用 `run_multiagent_env_visual.py` 或修改配置文件中的参数。
|
||||
|
||||
---
|
||||
|
||||
## 🚦 下一步
|
||||
|
||||
1. 先测试 `run_multiagent_env_fast.py`,验证性能提升
|
||||
2. 如果满意,用于日常训练
|
||||
3. 需要大规模训练时,使用 `run_multiagent_env_parallel.py`
|
||||
4. 考虑将策略网络迁移到GPU
|
||||
|
||||
祝训练顺利!🎉
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
"""
|
||||
Multi-Agent Scenario Environment
|
||||
|
||||
多智能体场景环境
|
||||
"""
|
||||
|
||||
from .scenario_env import MultiAgentScenarioEnv, PolicyVehicle
|
||||
from .simple_idm_policy import ConstantVelocityPolicy
|
||||
|
||||
__all__ = [
|
||||
'MultiAgentScenarioEnv',
|
||||
'PolicyVehicle',
|
||||
'ConstantVelocityPolicy',
|
||||
]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Env/__pycache__/replay_policy.cpython-310.pyc
Normal file
BIN
Env/__pycache__/replay_policy.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,116 +0,0 @@
|
||||
"""
|
||||
日志记录功能示例
|
||||
演示如何在自定义脚本中使用日志功能
|
||||
"""
|
||||
from logger_utils import setup_logger
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
def example_without_logging():
|
||||
"""示例1:不使用日志"""
|
||||
print("=" * 60)
|
||||
print("示例1:普通输出(不记录日志)")
|
||||
print("=" * 60)
|
||||
|
||||
print("这是普通的print输出")
|
||||
print("只会显示在终端")
|
||||
print("不会保存到文件")
|
||||
print()
|
||||
|
||||
|
||||
def example_with_logging():
|
||||
"""示例2:使用日志记录"""
|
||||
print("=" * 60)
|
||||
print("示例2:使用日志记录")
|
||||
print("=" * 60)
|
||||
|
||||
# 使用with语句,自动管理日志文件
|
||||
with setup_logger(log_file="example_demo.log", log_dir="logs"):
|
||||
print("✅ 这条消息会同时输出到终端和文件")
|
||||
print("✅ 运行一些计算...")
|
||||
|
||||
for i in range(5):
|
||||
print(f" 步骤 {i+1}/5: 处理中...")
|
||||
time.sleep(0.1)
|
||||
|
||||
print("✅ 计算完成!")
|
||||
|
||||
print("日志文件已关闭")
|
||||
print()
|
||||
|
||||
|
||||
def example_custom_filename():
|
||||
"""示例3:使用时间戳命名"""
|
||||
print("=" * 60)
|
||||
print("示例3:自动生成时间戳文件名")
|
||||
print("=" * 60)
|
||||
|
||||
# log_file=None 会自动生成时间戳文件名
|
||||
with setup_logger(log_file=None, log_dir="logs"):
|
||||
print("文件名会自动包含时间戳")
|
||||
print("适合批量实验,避免覆盖")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def example_append_mode():
|
||||
"""示例4:追加模式"""
|
||||
print("=" * 60)
|
||||
print("示例4:追加到现有文件")
|
||||
print("=" * 60)
|
||||
|
||||
# 第一次写入
|
||||
with setup_logger(log_file="append_test.log", log_dir="logs", mode='w'):
|
||||
print("第一次写入:这会覆盖文件")
|
||||
|
||||
# 第二次写入(追加)
|
||||
with setup_logger(log_file="append_test.log", log_dir="logs", mode='a'):
|
||||
print("第二次写入:这会追加到文件末尾")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def example_complex_output():
|
||||
"""示例5:复杂输出(包含颜色、格式)"""
|
||||
print("=" * 60)
|
||||
print("示例5:复杂输出格式")
|
||||
print("=" * 60)
|
||||
|
||||
with setup_logger(log_file="complex_output.log", log_dir="logs"):
|
||||
# 模拟多种输出格式
|
||||
print("\n📊 实验统计:")
|
||||
print(" - 实验名称:车道过滤测试")
|
||||
print(" - 开始时间:", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
print(" - 车辆总数:51")
|
||||
print(" - 过滤后:45")
|
||||
print("\n🚦 红绿灯检测:")
|
||||
print(" ✅ 方法1成功:3辆")
|
||||
print(" ✅ 方法2成功:2辆")
|
||||
print(" ⚠️ 未检测到:40辆")
|
||||
print("\n" + "="*50)
|
||||
print("实验完成!")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""运行所有示例"""
|
||||
print("\n" + "🎯 " + "="*56)
|
||||
print("日志记录功能完整示例")
|
||||
print("="*60 + "\n")
|
||||
|
||||
example_without_logging()
|
||||
example_with_logging()
|
||||
example_custom_filename()
|
||||
example_append_mode()
|
||||
example_complex_output()
|
||||
|
||||
print("="*60)
|
||||
print("✅ 所有示例运行完成!")
|
||||
print("📁 查看日志文件:ls -lh logs/")
|
||||
print("="*60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,170 +0,0 @@
|
||||
"""
|
||||
日志工具模块
|
||||
提供将终端输出同时保存到文件的功能
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class TeeLogger:
|
||||
"""
|
||||
双向输出类:同时输出到终端和文件
|
||||
"""
|
||||
def __init__(self, filename, mode='w', terminal=None):
|
||||
"""
|
||||
Args:
|
||||
filename: 日志文件路径
|
||||
mode: 文件打开模式 ('w'=覆盖, 'a'=追加)
|
||||
terminal: 原始输出流(通常是sys.stdout或sys.stderr)
|
||||
"""
|
||||
self.terminal = terminal or sys.stdout
|
||||
self.log_file = open(filename, mode, encoding='utf-8')
|
||||
|
||||
def write(self, message):
|
||||
"""写入消息到终端和文件"""
|
||||
self.terminal.write(message)
|
||||
self.log_file.write(message)
|
||||
self.log_file.flush() # 立即写入磁盘
|
||||
|
||||
def flush(self):
|
||||
"""刷新缓冲区"""
|
||||
self.terminal.flush()
|
||||
self.log_file.flush()
|
||||
|
||||
def close(self):
|
||||
"""关闭日志文件"""
|
||||
if self.log_file:
|
||||
self.log_file.close()
|
||||
|
||||
|
||||
class LoggerContext:
|
||||
"""
|
||||
日志上下文管理器
|
||||
使用with语句自动管理日志的开启和关闭
|
||||
"""
|
||||
def __init__(self, log_file=None, log_dir="logs", mode='w',
|
||||
redirect_stdout=True, redirect_stderr=True):
|
||||
"""
|
||||
Args:
|
||||
log_file: 日志文件名(None则自动生成时间戳文件名)
|
||||
log_dir: 日志目录
|
||||
mode: 文件打开模式 ('w'=覆盖, 'a'=追加)
|
||||
redirect_stdout: 是否重定向标准输出
|
||||
redirect_stderr: 是否重定向标准错误
|
||||
"""
|
||||
self.log_dir = log_dir
|
||||
self.mode = mode
|
||||
self.redirect_stdout = redirect_stdout
|
||||
self.redirect_stderr = redirect_stderr
|
||||
|
||||
# 创建日志目录
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 生成日志文件名
|
||||
if log_file is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
log_file = f"run_{timestamp}.log"
|
||||
|
||||
self.log_path = os.path.join(log_dir, log_file)
|
||||
|
||||
# 保存原始的stdout和stderr
|
||||
self.original_stdout = sys.stdout
|
||||
self.original_stderr = sys.stderr
|
||||
|
||||
# 日志对象
|
||||
self.stdout_logger = None
|
||||
self.stderr_logger = None
|
||||
|
||||
def __enter__(self):
|
||||
"""进入上下文:开启日志"""
|
||||
print(f"📝 日志记录已启用")
|
||||
print(f"📁 日志文件: {self.log_path}")
|
||||
print("-" * 60)
|
||||
|
||||
# 创建TeeLogger对象
|
||||
if self.redirect_stdout:
|
||||
self.stdout_logger = TeeLogger(
|
||||
self.log_path,
|
||||
mode=self.mode,
|
||||
terminal=self.original_stdout
|
||||
)
|
||||
sys.stdout = self.stdout_logger
|
||||
|
||||
if self.redirect_stderr:
|
||||
self.stderr_logger = TeeLogger(
|
||||
self.log_path,
|
||||
mode='a', # stderr总是追加模式
|
||||
terminal=self.original_stderr
|
||||
)
|
||||
sys.stderr = self.stderr_logger
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""退出上下文:关闭日志"""
|
||||
# 恢复原始输出
|
||||
sys.stdout = self.original_stdout
|
||||
sys.stderr = self.original_stderr
|
||||
|
||||
# 关闭日志文件
|
||||
if self.stdout_logger:
|
||||
self.stdout_logger.close()
|
||||
if self.stderr_logger:
|
||||
self.stderr_logger.close()
|
||||
|
||||
print("-" * 60)
|
||||
print(f"✅ 日志已保存到: {self.log_path}")
|
||||
|
||||
# 返回False表示不抑制异常
|
||||
return False
|
||||
|
||||
|
||||
def setup_logger(log_file=None, log_dir="logs", mode='w'):
|
||||
"""
|
||||
快速设置日志记录
|
||||
|
||||
Args:
|
||||
log_file: 日志文件名(None则自动生成)
|
||||
log_dir: 日志目录
|
||||
mode: 文件模式 ('w'=覆盖, 'a'=追加)
|
||||
|
||||
Returns:
|
||||
LoggerContext对象
|
||||
|
||||
Example:
|
||||
with setup_logger("my_test.log"):
|
||||
print("这条消息会同时输出到终端和文件")
|
||||
"""
|
||||
return LoggerContext(log_file=log_file, log_dir=log_dir, mode=mode)
|
||||
|
||||
|
||||
def get_default_log_filename(prefix="run"):
|
||||
"""
|
||||
生成默认的日志文件名(带时间戳)
|
||||
|
||||
Args:
|
||||
prefix: 文件名前缀
|
||||
|
||||
Returns:
|
||||
str: 格式为 "prefix_YYYYMMDD_HHMMSS.log"
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return f"{prefix}_{timestamp}.log"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
print("测试1: 使用默认配置")
|
||||
with setup_logger():
|
||||
print("这是测试消息1")
|
||||
print("这是测试消息2")
|
||||
print("日志记录已结束\n")
|
||||
|
||||
print("测试2: 使用自定义文件名")
|
||||
with setup_logger(log_file="test_custom.log"):
|
||||
print("自定义文件名测试")
|
||||
for i in range(3):
|
||||
print(f" 消息 {i+1}")
|
||||
print("完成")
|
||||
|
||||
62
Env/replay_policy.py
Normal file
62
Env/replay_policy.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
|
||||
class ReplayPolicy:
|
||||
"""
|
||||
严格回放策略:根据专家轨迹数据,逐帧回放车辆状态
|
||||
"""
|
||||
|
||||
def __init__(self, expert_trajectory, vehicle_id):
|
||||
"""
|
||||
Args:
|
||||
expert_trajectory: 专家轨迹字典,包含 positions, headings, velocities, valid
|
||||
vehicle_id: 车辆ID(用于调试)
|
||||
"""
|
||||
self.trajectory = expert_trajectory
|
||||
self.vehicle_id = vehicle_id
|
||||
self.current_step = 0
|
||||
|
||||
def act(self, observation=None):
|
||||
"""
|
||||
返回动作:在回放模式下返回空动作
|
||||
实际状态由环境直接设置
|
||||
"""
|
||||
return [0.0, 0.0]
|
||||
|
||||
def get_target_state(self, step):
|
||||
"""
|
||||
获取指定时间步的目标状态
|
||||
|
||||
Args:
|
||||
step: 时间步
|
||||
|
||||
Returns:
|
||||
dict: 包含 position, heading, velocity 的字典,如果无效则返回 None
|
||||
"""
|
||||
if step >= len(self.trajectory['valid']):
|
||||
return None
|
||||
|
||||
if not self.trajectory['valid'][step]:
|
||||
return None
|
||||
|
||||
return {
|
||||
'position': self.trajectory['positions'][step],
|
||||
'heading': self.trajectory['headings'][step],
|
||||
'velocity': self.trajectory['velocities'][step]
|
||||
}
|
||||
|
||||
def is_finished(self, step):
|
||||
"""
|
||||
判断轨迹是否已经播放完毕
|
||||
|
||||
Args:
|
||||
step: 当前时间步
|
||||
|
||||
Returns:
|
||||
bool: 如果轨迹已播放完或当前步无效,返回 True
|
||||
"""
|
||||
# 超出轨迹长度
|
||||
if step >= len(self.trajectory['valid']):
|
||||
return True
|
||||
|
||||
# 当前步及之后都无效
|
||||
return not any(self.trajectory['valid'][step:])
|
||||
@@ -1,78 +1,363 @@
|
||||
import argparse
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from replay_policy import ReplayPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
from logger_utils import setup_logger
|
||||
import sys
|
||||
import os
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/MAGAIL4AutoDrive/Env"
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted"
|
||||
|
||||
def main(enable_logging=False, log_file=None):
|
||||
|
||||
def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||||
scenario_id=None, use_scenario_duration=False,
|
||||
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
||||
"""
|
||||
主函数
|
||||
|
||||
回放模式:严格按照专家轨迹回放
|
||||
|
||||
Args:
|
||||
enable_logging: 是否启用日志记录到文件
|
||||
log_file: 日志文件名(None则自动生成时间戳文件名)
|
||||
data_dir: 数据目录
|
||||
num_episodes: 回合数(如果指定scenario_id,则忽略)
|
||||
horizon: 最大步数(如果use_scenario_duration=True,则自动设置)
|
||||
render: 是否渲染
|
||||
debug: 是否调试模式
|
||||
scenario_id: 指定场景ID(可选)
|
||||
use_scenario_duration: 是否使用场景原始时长
|
||||
spawn_vehicles: 是否生成车辆(默认True)
|
||||
spawn_pedestrians: 是否生成行人(默认True)
|
||||
spawn_cyclists: 是否生成自行车(默认True)
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("运行模式: 专家轨迹回放 (Replay Mode)")
|
||||
if scenario_id is not None:
|
||||
print(f"指定场景ID: {scenario_id}")
|
||||
if use_scenario_duration:
|
||||
print("使用场景原始时长")
|
||||
print("=" * 50)
|
||||
|
||||
# 如果指定了场景ID,只运行1个回合
|
||||
if scenario_id is not None:
|
||||
num_episodes = 1
|
||||
|
||||
# ✅ 环境创建移到循环外面,避免重复创建
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
# "data_directory": AssetLoader.file_path(AssetLoader.asset_path, "waymo", unix_style=False),
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"horizon": horizon,
|
||||
"use_render": render,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": False, # 回放模式下不需要反应式交通
|
||||
"manual_control": False,
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0,
|
||||
"replay_mode": True, # 标记为回放模式
|
||||
"debug": debug,
|
||||
"specific_scenario_id": scenario_id, # 指定场景ID
|
||||
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
||||
# 对象类型过滤
|
||||
"spawn_vehicles": spawn_vehicles,
|
||||
"spawn_pedestrians": spawn_pedestrians,
|
||||
"spawn_cyclists": spawn_cyclists,
|
||||
},
|
||||
agent2policy=None # 回放模式不需要统一策略
|
||||
)
|
||||
|
||||
try:
|
||||
for episode in range(num_episodes):
|
||||
print(f"\n{'='*50}")
|
||||
print(f"回合 {episode + 1}/{num_episodes}")
|
||||
if scenario_id is not None:
|
||||
print(f"场景ID: {scenario_id}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
# ✅ 如果不是指定场景,使用seed来遍历不同场景
|
||||
seed = scenario_id if scenario_id is not None else episode
|
||||
obs = env.reset(seed=seed)
|
||||
|
||||
# 为每个车辆分配 ReplayPolicy
|
||||
replay_policies = {}
|
||||
for agent_id, vehicle in env.controlled_agents.items():
|
||||
vehicle_id = vehicle.expert_vehicle_id
|
||||
if vehicle_id in env.expert_trajectories:
|
||||
replay_policy = ReplayPolicy(
|
||||
env.expert_trajectories[vehicle_id],
|
||||
vehicle_id
|
||||
)
|
||||
vehicle.set_policy(replay_policy)
|
||||
replay_policies[agent_id] = replay_policy
|
||||
|
||||
# 输出场景信息
|
||||
actual_horizon = env.config["horizon"]
|
||||
print(f"初始化完成:")
|
||||
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
||||
print(f" 专家轨迹数: {len(env.expert_trajectories)}")
|
||||
print(f" 场景时长: {env.scenario_max_duration} 步")
|
||||
print(f" 实际Horizon: {actual_horizon} 步")
|
||||
|
||||
step_count = 0
|
||||
active_vehicles_count = []
|
||||
|
||||
while True:
|
||||
# 在回放模式下,直接使用专家轨迹设置车辆状态
|
||||
for agent_id, vehicle in list(env.controlled_agents.items()):
|
||||
vehicle_id = vehicle.expert_vehicle_id
|
||||
if vehicle_id in env.expert_trajectories and agent_id in replay_policies:
|
||||
target_state = replay_policies[agent_id].get_target_state(env.round)
|
||||
if target_state is not None:
|
||||
# 直接设置车辆状态(绕过物理引擎)
|
||||
# 只使用xy坐标,保持车辆在地面上
|
||||
position_2d = target_state['position'][:2]
|
||||
vehicle.set_position(position_2d)
|
||||
vehicle.set_heading_theta(target_state['heading'])
|
||||
vehicle.set_velocity(target_state['velocity'][:2] if len(target_state['velocity']) > 2 else target_state['velocity'])
|
||||
|
||||
# 使用空动作进行步进
|
||||
actions = {aid: [0.0, 0.0] for aid in env.controlled_agents}
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
if render:
|
||||
env.render(mode="topdown")
|
||||
|
||||
step_count += 1
|
||||
active_vehicles_count.append(len(env.controlled_agents))
|
||||
|
||||
# 每50步打印一次状态
|
||||
if step_count % 50 == 0:
|
||||
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
||||
|
||||
# 调试模式下打印车辆高度信息
|
||||
if debug and len(env.controlled_agents) > 0:
|
||||
sample_vehicle = list(env.controlled_agents.values())[0]
|
||||
z_pos = sample_vehicle.position[2] if len(sample_vehicle.position) > 2 else 0
|
||||
print(f" [DEBUG] 示例车辆高度: z={z_pos:.3f}m")
|
||||
|
||||
if dones["__all__"]:
|
||||
print(f"\n回合结束统计:")
|
||||
print(f" 总步数: {step_count}")
|
||||
print(f" 最大同时车辆数: {max(active_vehicles_count) if active_vehicles_count else 0}")
|
||||
print(f" 平均车辆数: {sum(active_vehicles_count) / len(active_vehicles_count) if active_vehicles_count else 0:.1f}")
|
||||
if use_scenario_duration:
|
||||
print(f" 场景完整回放: {'是' if step_count >= env.scenario_max_duration else '否'}")
|
||||
break
|
||||
finally:
|
||||
# ✅ 确保环境被正确关闭
|
||||
env.close()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("回放完成!")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False,
|
||||
scenario_id=None, use_scenario_duration=False,
|
||||
spawn_vehicles=True, spawn_pedestrians=True, spawn_cyclists=True):
|
||||
"""
|
||||
仿真模式:使用自定义策略控制车辆
|
||||
车辆根据专家数据的初始位姿生成,然后由策略控制
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录
|
||||
num_episodes: 回合数
|
||||
horizon: 最大步数
|
||||
render: 是否渲染
|
||||
debug: 是否调试模式
|
||||
scenario_id: 指定场景ID(可选)
|
||||
use_scenario_duration: 是否使用场景原始时长
|
||||
spawn_vehicles: 是否生成车辆(默认True)
|
||||
spawn_pedestrians: 是否生成行人(默认True)
|
||||
spawn_cyclists: 是否生成自行车(默认True)
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("运行模式: 策略仿真 (Simulation Mode)")
|
||||
if scenario_id is not None:
|
||||
print(f"指定场景ID: {scenario_id}")
|
||||
if use_scenario_duration:
|
||||
print("使用场景原始时长")
|
||||
print("=" * 50)
|
||||
|
||||
# 如果指定了场景ID,只运行1个回合
|
||||
if scenario_id is not None:
|
||||
num_episodes = 1
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 300,
|
||||
"use_render": True,
|
||||
"horizon": horizon,
|
||||
"use_render": render,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
"manual_control": True,
|
||||
|
||||
# 车道检测与过滤配置
|
||||
"filter_offroad_vehicles": True, # 启用车道区域过滤,过滤草坪等非车道区域的车辆
|
||||
"lane_tolerance": 3.0, # 车道检测容差(米),可根据需要调整
|
||||
"max_controlled_vehicles": 2, # 限制最大车辆数(可选,None表示不限制)
|
||||
"debug_lane_filter": True,
|
||||
"debug_traffic_light": True,
|
||||
"manual_control": False,
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0,
|
||||
"replay_mode": False, # 仿真模式
|
||||
"debug": debug,
|
||||
"specific_scenario_id": scenario_id, # 指定场景ID
|
||||
"use_scenario_duration": use_scenario_duration, # 使用场景时长
|
||||
# 对象类型过滤
|
||||
"spawn_vehicles": spawn_vehicles,
|
||||
"spawn_pedestrians": spawn_pedestrians,
|
||||
"spawn_cyclists": spawn_cyclists,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
obs = env.reset(0)
|
||||
for step in range(10000):
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
try:
|
||||
for episode in range(num_episodes):
|
||||
print(f"\n{'='*50}")
|
||||
print(f"回合 {episode + 1}/{num_episodes}")
|
||||
if scenario_id is not None:
|
||||
print(f"场景ID: {scenario_id}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
env.render(mode="topdown")
|
||||
seed = scenario_id if scenario_id is not None else episode
|
||||
obs = env.reset(seed=seed)
|
||||
|
||||
if dones["__all__"]:
|
||||
break
|
||||
actual_horizon = env.config["horizon"]
|
||||
print(f"初始化完成:")
|
||||
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
||||
print(f" 场景时长: {env.scenario_max_duration} 步")
|
||||
print(f" 实际Horizon: {actual_horizon} 步")
|
||||
|
||||
env.close()
|
||||
step_count = 0
|
||||
total_reward = 0.0
|
||||
|
||||
while True:
|
||||
# 使用策略生成动作
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
if render:
|
||||
env.render(mode="topdown")
|
||||
|
||||
step_count += 1
|
||||
total_reward += sum(rewards.values())
|
||||
|
||||
# 每50步打印一次状态
|
||||
if step_count % 50 == 0:
|
||||
print(f"Step {step_count}: {len(env.controlled_agents)} 辆活跃车辆")
|
||||
|
||||
if dones["__all__"]:
|
||||
print(f"\n回合结束统计:")
|
||||
print(f" 总步数: {step_count}")
|
||||
print(f" 总奖励: {total_reward:.2f}")
|
||||
break
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("仿真完成!")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MetaDrive 多智能体环境运行脚本")
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["replay", "simulation"],
|
||||
default="simulation",
|
||||
help="运行模式: replay=专家轨迹回放, simulation=策略仿真 (默认: simulation)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default=WAYMO_DATA_DIR,
|
||||
help=f"数据目录路径 (默认: {WAYMO_DATA_DIR})"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="运行回合数 (默认: 1)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--horizon",
|
||||
type=int,
|
||||
default=300,
|
||||
help="每回合最大步数 (默认: 300,如果启用 --use_scenario_duration 则自动设置)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_render",
|
||||
action="store_true",
|
||||
help="禁用渲染(加速运行)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="启用调试模式(显示详细日志)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scenario_id",
|
||||
type=int,
|
||||
default=None,
|
||||
help="指定场景ID(可选,如指定则只运行该场景)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_scenario_duration",
|
||||
action="store_true",
|
||||
help="使用场景原始时长作为horizon(自动停止)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_vehicles",
|
||||
action="store_true",
|
||||
help="禁止生成车辆"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_pedestrians",
|
||||
action="store_true",
|
||||
help="禁止生成行人"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_cyclists",
|
||||
action="store_true",
|
||||
help="禁止生成自行车"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "replay":
|
||||
run_replay_mode(
|
||||
data_dir=args.data_dir,
|
||||
num_episodes=args.episodes,
|
||||
horizon=args.horizon,
|
||||
render=not args.no_render,
|
||||
debug=args.debug,
|
||||
scenario_id=args.scenario_id,
|
||||
use_scenario_duration=args.use_scenario_duration,
|
||||
spawn_vehicles=not args.no_vehicles,
|
||||
spawn_pedestrians=not args.no_pedestrians,
|
||||
spawn_cyclists=not args.no_cyclists
|
||||
)
|
||||
else:
|
||||
run_simulation_mode(
|
||||
data_dir=args.data_dir,
|
||||
num_episodes=args.episodes,
|
||||
horizon=args.horizon,
|
||||
render=not args.no_render,
|
||||
debug=args.debug,
|
||||
scenario_id=args.scenario_id,
|
||||
use_scenario_duration=args.use_scenario_duration,
|
||||
spawn_vehicles=not args.no_vehicles,
|
||||
spawn_pedestrians=not args.no_pedestrians,
|
||||
spawn_cyclists=not args.no_cyclists
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 解析命令行参数
|
||||
enable_logging = "--log" in sys.argv or "-l" in sys.argv
|
||||
|
||||
# 提取自定义日志文件名
|
||||
log_file = None
|
||||
for arg in sys.argv:
|
||||
if arg.startswith("--log-file="):
|
||||
log_file = arg.split("=")[1]
|
||||
break
|
||||
|
||||
if enable_logging:
|
||||
# 使用日志记录
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
with setup_logger(log_file=log_file, log_dir=log_dir):
|
||||
main(enable_logging=True, log_file=log_file)
|
||||
else:
|
||||
# 普通运行(只输出到终端)
|
||||
print("💡 提示: 使用 --log 或 -l 参数启用日志记录")
|
||||
print(" 示例: python run_multiagent_env.py --log")
|
||||
print(" 自定义文件名: python run_multiagent_env.py --log --log-file=my_run.log")
|
||||
print("-" * 60)
|
||||
main(enable_logging=False)
|
||||
main()
|
||||
@@ -1,115 +0,0 @@
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
from logger_utils import setup_logger
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/MAGAIL4AutoDrive/Env"
|
||||
|
||||
def main(enable_logging=False):
|
||||
"""极致性能优化版本 - 启用所有优化选项"""
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 300,
|
||||
|
||||
# 关闭所有渲染
|
||||
"use_render": False,
|
||||
"render_pipeline": False,
|
||||
"image_observation": False,
|
||||
"interface_panel": [],
|
||||
"manual_control": False,
|
||||
"show_fps": False,
|
||||
"debug": False,
|
||||
|
||||
# 物理引擎优化
|
||||
"physics_world_step_size": 0.02,
|
||||
"decision_repeat": 5,
|
||||
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
|
||||
# 车道检测与过滤配置
|
||||
"filter_offroad_vehicles": True, # 过滤非车道区域的车辆
|
||||
"lane_tolerance": 3.0,
|
||||
"max_controlled_vehicles": 15, # 限制车辆数以提升性能
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
# 【关键优化】启用激光雷达缓存
|
||||
# 每3帧才重新计算激光雷达,其余帧使用缓存
|
||||
# 可将激光雷达计算量减少到原来的1/3
|
||||
env.lidar_cache_interval = 3
|
||||
|
||||
obs = env.reset(0)
|
||||
|
||||
# 性能统计
|
||||
start_time = time.time()
|
||||
total_steps = 0
|
||||
|
||||
print("=" * 60)
|
||||
print("极致性能模式")
|
||||
print("激光雷达优化:80→40束 (前向), 10→6束 (侧向+车道线)")
|
||||
print("激光雷达缓存:每3帧计算一次,中间帧使用缓存")
|
||||
print("预期性能提升:3-5倍")
|
||||
print("=" * 60)
|
||||
|
||||
for step in range(10000):
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
total_steps += 1
|
||||
|
||||
# 每100步输出一次性能统计
|
||||
if step % 100 == 0 and step > 0:
|
||||
elapsed = time.time() - start_time
|
||||
fps = total_steps / elapsed
|
||||
print(f"Step {step:4d}: FPS = {fps:6.2f}, 车辆数 = {len(env.controlled_agents):3d}, "
|
||||
f"平均步时间 = {1000/fps:.2f}ms")
|
||||
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
# 最终统计
|
||||
elapsed = time.time() - start_time
|
||||
fps = total_steps / elapsed
|
||||
print("\n" + "=" * 60)
|
||||
print(f"总计: {total_steps} 步")
|
||||
print(f"耗时: {elapsed:.2f}s")
|
||||
print(f"平均FPS: {fps:.2f}")
|
||||
print(f"单步平均耗时: {1000/fps:.2f}ms")
|
||||
print("=" * 60)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 解析命令行参数
|
||||
enable_logging = "--log" in sys.argv or "-l" in sys.argv
|
||||
|
||||
# 提取自定义日志文件名
|
||||
log_file = None
|
||||
for arg in sys.argv:
|
||||
if arg.startswith("--log-file="):
|
||||
log_file = arg.split("=")[1]
|
||||
break
|
||||
|
||||
if enable_logging:
|
||||
# 使用日志记录
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
with setup_logger(log_file=log_file or "run_fast.log", log_dir=log_dir):
|
||||
main(enable_logging=True)
|
||||
else:
|
||||
# 普通运行(只输出到终端)
|
||||
print("💡 提示: 使用 --log 或 -l 参数启用日志记录")
|
||||
print("-" * 60)
|
||||
main(enable_logging=False)
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
"""
|
||||
多进程并行版本 - 充分利用多核CPU
|
||||
适合大规模数据收集和训练
|
||||
"""
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
import time
|
||||
import os
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/MAGAIL4AutoDrive/Env"
|
||||
|
||||
|
||||
def run_single_env(args):
|
||||
"""在单个进程中运行一个环境实例"""
|
||||
seed, num_steps, worker_id = args
|
||||
|
||||
# 创建环境(每个进程独立)
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 300,
|
||||
|
||||
# 性能优化
|
||||
"use_render": False,
|
||||
"render_pipeline": False,
|
||||
"image_observation": False,
|
||||
"interface_panel": [],
|
||||
"manual_control": False,
|
||||
"show_fps": False,
|
||||
"debug": False,
|
||||
|
||||
"physics_world_step_size": 0.02,
|
||||
"decision_repeat": 5,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
|
||||
# 车道检测与过滤配置
|
||||
"filter_offroad_vehicles": True,
|
||||
"lane_tolerance": 3.0,
|
||||
"max_controlled_vehicles": 15,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
# 启用激光雷达缓存
|
||||
env.lidar_cache_interval = 3
|
||||
|
||||
# 运行仿真
|
||||
start_time = time.time()
|
||||
obs = env.reset(seed)
|
||||
total_steps = 0
|
||||
total_agents = 0
|
||||
|
||||
for step in range(num_steps):
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
total_steps += 1
|
||||
total_agents += len(env.controlled_agents)
|
||||
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
fps = total_steps / elapsed if elapsed > 0 else 0
|
||||
avg_agents = total_agents / total_steps if total_steps > 0 else 0
|
||||
|
||||
env.close()
|
||||
|
||||
return {
|
||||
'worker_id': worker_id,
|
||||
'seed': seed,
|
||||
'steps': total_steps,
|
||||
'elapsed': elapsed,
|
||||
'fps': fps,
|
||||
'avg_agents': avg_agents,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:协调多个并行环境"""
|
||||
# 获取CPU核心数
|
||||
num_cores = cpu_count()
|
||||
# 建议使用物理核心数(12600KF是10核20线程,使用10个进程)
|
||||
num_workers = min(10, num_cores)
|
||||
|
||||
print("=" * 80)
|
||||
print(f"多进程并行模式")
|
||||
print(f"CPU核心数: {num_cores}")
|
||||
print(f"并行进程数: {num_workers}")
|
||||
print(f"每个环境运行: 1000步")
|
||||
print("=" * 80)
|
||||
|
||||
# 准备任务参数
|
||||
num_steps_per_env = 1000
|
||||
tasks = [(seed, num_steps_per_env, worker_id)
|
||||
for worker_id, seed in enumerate(range(num_workers))]
|
||||
|
||||
# 启动多进程池
|
||||
start_time = time.time()
|
||||
|
||||
with Pool(processes=num_workers) as pool:
|
||||
results = pool.map(run_single_env, tasks)
|
||||
|
||||
total_elapsed = time.time() - start_time
|
||||
|
||||
# 统计结果
|
||||
print("\n" + "=" * 80)
|
||||
print("各进程执行结果:")
|
||||
print("-" * 80)
|
||||
print(f"{'Worker':<8} {'Seed':<6} {'Steps':<8} {'Time(s)':<10} {'FPS':<8} {'平均车辆数':<12}")
|
||||
print("-" * 80)
|
||||
|
||||
total_steps = 0
|
||||
total_fps = 0
|
||||
|
||||
for result in results:
|
||||
print(f"{result['worker_id']:<8} "
|
||||
f"{result['seed']:<6} "
|
||||
f"{result['steps']:<8} "
|
||||
f"{result['elapsed']:<10.2f} "
|
||||
f"{result['fps']:<8.2f} "
|
||||
f"{result['avg_agents']:<12.1f}")
|
||||
total_steps += result['steps']
|
||||
total_fps += result['fps']
|
||||
|
||||
print("-" * 80)
|
||||
avg_fps_per_env = total_fps / len(results)
|
||||
total_throughput = total_steps / total_elapsed
|
||||
|
||||
print(f"\n总体统计:")
|
||||
print(f" 总步数: {total_steps}")
|
||||
print(f" 总耗时: {total_elapsed:.2f}s")
|
||||
print(f" 单环境平均FPS: {avg_fps_per_env:.2f}")
|
||||
print(f" 总吞吐量: {total_throughput:.2f} steps/s")
|
||||
print(f" 并行效率: {total_throughput / avg_fps_per_env:.1f}x")
|
||||
print("=" * 80)
|
||||
|
||||
# 与单进程对比
|
||||
print(f"\n性能对比:")
|
||||
print(f" 单进程FPS (预估): ~30 FPS")
|
||||
print(f" 多进程吞吐量: {total_throughput:.2f} steps/s")
|
||||
print(f" 性能提升: {total_throughput / 30:.1f}x")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
import time
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/MAGAIL4AutoDrive/Env"
|
||||
|
||||
def main():
|
||||
"""带可视化的版本(低FPS,约15帧)"""
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 300,
|
||||
|
||||
# 可视化设置(牺牲性能)
|
||||
"use_render": True,
|
||||
"manual_control": False,
|
||||
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
obs = env.reset(0)
|
||||
|
||||
start_time = time.time()
|
||||
total_steps = 0
|
||||
|
||||
for step in range(10000):
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
env.render(mode="topdown") # 实时渲染
|
||||
|
||||
total_steps += 1
|
||||
|
||||
if step % 100 == 0 and step > 0:
|
||||
elapsed = time.time() - start_time
|
||||
fps = total_steps / elapsed
|
||||
print(f"Step {step}: FPS = {fps:.2f}, 车辆数 = {len(env.controlled_agents)}")
|
||||
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
fps = total_steps / elapsed
|
||||
print(f"\n总计: {total_steps} 步,耗时 {elapsed:.2f}s,平均FPS = {fps:.2f}")
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -15,6 +15,7 @@ class PolicyVehicle(DefaultVehicle):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.policy = None
|
||||
self.destination = None
|
||||
self.expert_vehicle_id = None # 关联专家车辆ID
|
||||
|
||||
def set_policy(self, policy):
|
||||
self.policy = policy
|
||||
@@ -22,6 +23,9 @@ class PolicyVehicle(DefaultVehicle):
|
||||
def set_destination(self, des):
|
||||
self.destination = des
|
||||
|
||||
def set_expert_vehicle_id(self, vid):
|
||||
self.expert_vehicle_id = vid
|
||||
|
||||
def act(self, observation, policy=None):
|
||||
if self.policy is not None:
|
||||
return self.policy.act(observation)
|
||||
@@ -53,13 +57,15 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
data_directory=None,
|
||||
num_controlled_agents=3,
|
||||
horizon=1000,
|
||||
# 车道检测与过滤配置
|
||||
filter_offroad_vehicles=True, # 是否过滤非车道区域的车辆
|
||||
lane_tolerance=3.0, # 车道检测容差(米),用于放宽边界条件
|
||||
max_controlled_vehicles=None, # 最大可控车辆数限制(None表示不限制)
|
||||
# 调试模式配置
|
||||
debug_traffic_light=False, # 是否启用红绿灯检测调试输出
|
||||
debug_lane_filter=False, # 是否启用车道过滤调试输出
|
||||
filter_offroad_vehicles=True, # 车道过滤开关
|
||||
lane_tolerance=3.0, # 车道检测容差(米)
|
||||
replay_mode=False, # 回放模式开关
|
||||
specific_scenario_id=None, # 新增:指定场景ID(仅回放模式)
|
||||
use_scenario_duration=False, # 新增:使用场景原始时长作为horizon
|
||||
# 对象类型过滤选项
|
||||
spawn_vehicles=True, # 是否生成车辆
|
||||
spawn_pedestrians=True, # 是否生成行人
|
||||
spawn_cyclists=True, # 是否生成自行车
|
||||
))
|
||||
return config
|
||||
|
||||
@@ -69,96 +75,179 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
self.controlled_agent_ids = []
|
||||
self.obs_list = []
|
||||
self.round = 0
|
||||
# 调试模式配置
|
||||
self.debug_traffic_light = config.get("debug_traffic_light", False)
|
||||
self.debug_lane_filter = config.get("debug_lane_filter", False)
|
||||
self.expert_trajectories = {} # 存储完整专家轨迹
|
||||
self.replay_mode = config.get("replay_mode", False)
|
||||
self.scenario_max_duration = 0 # 场景实际最大时长
|
||||
super().__init__(config)
|
||||
|
||||
def reset(self, seed: Union[None, int] = None):
|
||||
self.round = 0
|
||||
|
||||
if self.logger is None:
|
||||
self.logger = get_logger()
|
||||
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
|
||||
set_log_level(log_level)
|
||||
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
|
||||
set_log_level(log_level)
|
||||
|
||||
# ✅ 关键修复:在每次 reset 前清理所有自定义生成的对象
|
||||
if hasattr(self, 'engine') and self.engine is not None:
|
||||
if hasattr(self, 'controlled_agents') and self.controlled_agents:
|
||||
# 先从 agent_manager 中移除
|
||||
if hasattr(self.engine, 'agent_manager'):
|
||||
for agent_id in list(self.controlled_agents.keys()):
|
||||
if agent_id in self.engine.agent_manager.active_agents:
|
||||
self.engine.agent_manager.active_agents.pop(agent_id)
|
||||
|
||||
# 然后清理对象
|
||||
for agent_id, vehicle in list(self.controlled_agents.items()):
|
||||
try:
|
||||
self.engine.clear_objects([vehicle.id])
|
||||
except:
|
||||
pass
|
||||
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
|
||||
self.lazy_init()
|
||||
self._reset_global_seed(seed)
|
||||
|
||||
if self.engine is None:
|
||||
raise ValueError("Broken MetaDrive instance.")
|
||||
|
||||
# 在engine.reset()之前清理对象
|
||||
self.before_reset()
|
||||
|
||||
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
|
||||
_obj_to_clean_this_frame = []
|
||||
self.car_birth_info_list = []
|
||||
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
||||
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
|
||||
continue
|
||||
else:
|
||||
if track["type"] == MetaDriveType.VEHICLE:
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
valid = track['state']['valid']
|
||||
first_show = np.argmax(valid) if valid.any() else -1
|
||||
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
||||
# id,出现时间,出生点坐标,出生朝向,目的地
|
||||
self.car_birth_info_list.append({
|
||||
'id': track['metadata']['object_id'],
|
||||
'show_time': first_show,
|
||||
'begin': (track['state']['position'][first_show, 0], track['state']['position'][first_show, 1]),
|
||||
'heading': track['state']['heading'][first_show],
|
||||
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
|
||||
})
|
||||
|
||||
for scenario_id in _obj_to_clean_this_frame:
|
||||
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
|
||||
# 如果指定了场景ID,修改start_scenario_index
|
||||
if self.config.get("specific_scenario_id") is not None:
|
||||
scenario_id = self.config.get("specific_scenario_id")
|
||||
self.config["start_scenario_index"] = scenario_id
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"Using specific scenario ID: {scenario_id}")
|
||||
|
||||
# ✅ 先初始化引擎和 lanes
|
||||
self.engine.reset()
|
||||
self.reset_sensors()
|
||||
self.engine.taskMgr.step()
|
||||
|
||||
self.lanes = self.engine.map_manager.current_map.road_network.graph
|
||||
|
||||
# 调试:场景信息统计
|
||||
if self.debug_lane_filter or self.debug_traffic_light:
|
||||
print(f"\n📍 场景信息统计:")
|
||||
print(f" - 总车道数: {len(self.lanes)}")
|
||||
|
||||
# 记录专家数据(现在 self.lanes 已经初始化)
|
||||
_obj_to_clean_this_frame = []
|
||||
self.car_birth_info_list = []
|
||||
self.expert_trajectories.clear()
|
||||
total_vehicles = 0
|
||||
total_pedestrians = 0
|
||||
total_cyclists = 0
|
||||
filtered_vehicles = 0
|
||||
filtered_by_type = 0
|
||||
self.scenario_max_duration = 0 # 重置场景时长
|
||||
|
||||
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
||||
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
|
||||
continue
|
||||
|
||||
# 对象类型过滤
|
||||
obj_type = track["type"]
|
||||
|
||||
# 统计红绿灯数量
|
||||
if self.debug_traffic_light:
|
||||
traffic_light_lanes = []
|
||||
for lane in self.lanes.values():
|
||||
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
||||
traffic_light_lanes.append(lane.lane.index)
|
||||
print(f" - 有红绿灯的车道数: {len(traffic_light_lanes)}")
|
||||
if len(traffic_light_lanes) > 0:
|
||||
print(f" 车道索引: {traffic_light_lanes[:5]}" +
|
||||
(f" ... 共{len(traffic_light_lanes)}个" if len(traffic_light_lanes) > 5 else ""))
|
||||
else:
|
||||
print(f" ⚠️ 场景中没有红绿灯!")
|
||||
|
||||
# 在获取车道信息后,进行车道区域过滤
|
||||
total_cars_before = len(self.car_birth_info_list)
|
||||
valid_count, filtered_count, filtered_list = self._filter_valid_spawn_positions()
|
||||
|
||||
# 输出过滤信息
|
||||
if filtered_count > 0:
|
||||
self.logger.warning(f"车辆生成位置过滤: 原始 {total_cars_before} 辆, "
|
||||
f"有效 {valid_count} 辆, 过滤 {filtered_count} 辆")
|
||||
for filtered_car in filtered_list[:5]: # 只显示前5个
|
||||
self.logger.debug(f" - 过滤车辆 ID={filtered_car['id']}, "
|
||||
f"位置={filtered_car['position']}, "
|
||||
f"原因={filtered_car['reason']}")
|
||||
if filtered_count > 5:
|
||||
self.logger.debug(f" - ... 还有 {filtered_count - 5} 辆车被过滤")
|
||||
|
||||
# 限制最大车辆数(在过滤后应用)
|
||||
max_vehicles = self.config.get("max_controlled_vehicles", None)
|
||||
if max_vehicles is not None and len(self.car_birth_info_list) > max_vehicles:
|
||||
self.car_birth_info_list = self.car_birth_info_list[:max_vehicles]
|
||||
self.logger.info(f"限制最大车辆数为 {max_vehicles} 辆")
|
||||
|
||||
self.logger.info(f"最终生成 {len(self.car_birth_info_list)} 辆可控车辆")
|
||||
# 统计对象类型
|
||||
if obj_type == MetaDriveType.VEHICLE:
|
||||
total_vehicles += 1
|
||||
elif obj_type == MetaDriveType.PEDESTRIAN:
|
||||
total_pedestrians += 1
|
||||
elif obj_type == MetaDriveType.CYCLIST:
|
||||
total_cyclists += 1
|
||||
|
||||
# 根据配置过滤对象类型
|
||||
if obj_type == MetaDriveType.VEHICLE and not self.config.get("spawn_vehicles", True):
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
filtered_by_type += 1
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Filtering VEHICLE {track['metadata']['object_id']} - spawn_vehicles=False")
|
||||
continue
|
||||
|
||||
if obj_type == MetaDriveType.PEDESTRIAN and not self.config.get("spawn_pedestrians", True):
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
filtered_by_type += 1
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Filtering PEDESTRIAN {track['metadata']['object_id']} - spawn_pedestrians=False")
|
||||
continue
|
||||
|
||||
if obj_type == MetaDriveType.CYCLIST and not self.config.get("spawn_cyclists", True):
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
filtered_by_type += 1
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Filtering CYCLIST {track['metadata']['object_id']} - spawn_cyclists=False")
|
||||
continue
|
||||
|
||||
# 只处理车辆类型(行人和自行车暂时只做过滤)
|
||||
if track["type"] == MetaDriveType.VEHICLE:
|
||||
valid = track['state']['valid']
|
||||
first_show = np.argmax(valid) if valid.any() else -1
|
||||
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
||||
|
||||
if first_show == -1 or last_show == -1:
|
||||
continue
|
||||
|
||||
# 更新场景最大时长
|
||||
self.scenario_max_duration = max(self.scenario_max_duration, last_show + 1)
|
||||
|
||||
# 获取车辆初始位置
|
||||
initial_position = (
|
||||
track['state']['position'][first_show, 0],
|
||||
track['state']['position'][first_show, 1]
|
||||
)
|
||||
|
||||
# 车道过滤
|
||||
if self.config.get("filter_offroad_vehicles", True):
|
||||
if not self._is_position_on_lane(initial_position):
|
||||
filtered_vehicles += 1
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(
|
||||
f"Filtering vehicle {track['metadata']['object_id']} - "
|
||||
f"not on lane at position {initial_position}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 存储完整专家轨迹(只使用2D位置,避免高度问题)
|
||||
object_id = track['metadata']['object_id']
|
||||
positions_2d = track['state']['position'].copy()
|
||||
positions_2d[:, 2] = 0 # 将z坐标设为0,让MetaDrive自动处理高度
|
||||
|
||||
self.expert_trajectories[object_id] = {
|
||||
'positions': positions_2d,
|
||||
'headings': track['state']['heading'].copy(),
|
||||
'velocities': track['state']['velocity'].copy(),
|
||||
'valid': track['state']['valid'].copy(),
|
||||
}
|
||||
|
||||
# 保存车辆生成信息
|
||||
self.car_birth_info_list.append({
|
||||
'id': object_id,
|
||||
'show_time': first_show,
|
||||
'begin': initial_position,
|
||||
'heading': track['state']['heading'][first_show],
|
||||
'velocity': track['state']['velocity'][first_show] if self.config.get("inherit_expert_velocity", False) else None,
|
||||
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
|
||||
})
|
||||
|
||||
# 在回放和仿真模式下都清除原始专家车辆
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
|
||||
# 清除专家车辆和过滤的对象
|
||||
for scenario_id in _obj_to_clean_this_frame:
|
||||
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
|
||||
|
||||
# 输出统计信息
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"=== 对象统计 ===")
|
||||
self.logger.info(f"车辆 (VEHICLE): 总数={total_vehicles}, 车道过滤={filtered_vehicles}, 保留={total_vehicles - filtered_vehicles}")
|
||||
self.logger.info(f"行人 (PEDESTRIAN): 总数={total_pedestrians}")
|
||||
self.logger.info(f"自行车 (CYCLIST): 总数={total_cyclists}")
|
||||
self.logger.info(f"类型过滤: {filtered_by_type} 个对象")
|
||||
self.logger.info(f"场景时长: {self.scenario_max_duration} 步")
|
||||
|
||||
# 如果启用场景时长控制,更新horizon
|
||||
if self.config.get("use_scenario_duration", False) and self.scenario_max_duration > 0:
|
||||
original_horizon = self.config["horizon"]
|
||||
self.config["horizon"] = self.scenario_max_duration
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"Horizon updated from {original_horizon} to {self.scenario_max_duration} (scenario duration)")
|
||||
|
||||
if self.top_down_renderer is not None:
|
||||
self.top_down_renderer.clear()
|
||||
@@ -167,336 +256,178 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
self.dones = {}
|
||||
self.episode_rewards = defaultdict(float)
|
||||
self.episode_lengths = defaultdict(int)
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
|
||||
# 调用父类reset会清理场景
|
||||
super().reset(seed) # 初始化场景
|
||||
|
||||
# 重新生成车辆
|
||||
self._spawn_controlled_agents()
|
||||
|
||||
return self._get_all_obs()
|
||||
|
||||
def _is_position_on_lane(self, position, tolerance=None):
|
||||
"""
|
||||
检测给定位置是否在有效车道范围内
|
||||
|
||||
Args:
|
||||
position: (x, y) 车辆位置坐标
|
||||
tolerance: 容差范围(米),用于放宽检测条件。None时使用配置中的默认值
|
||||
|
||||
Returns:
|
||||
bool: True表示在车道上,False表示在非车道区域(如草坪、停车场等)
|
||||
"""
|
||||
if not hasattr(self, 'lanes') or self.lanes is None:
|
||||
if self.debug_lane_filter:
|
||||
print(f" ⚠️ 车道信息未初始化,默认允许")
|
||||
return True # 如果车道信息未初始化,默认允许生成
|
||||
|
||||
if tolerance is None:
|
||||
tolerance = self.config.get("lane_tolerance", 3.0)
|
||||
|
||||
position_2d = (position[0], position[1])
|
||||
|
||||
if self.debug_lane_filter:
|
||||
print(f" 🔍 检测位置 ({position_2d[0]:.2f}, {position_2d[1]:.2f}), 容差={tolerance}m")
|
||||
|
||||
# 方法1:直接检测是否在任一车道上
|
||||
checked_lanes = 0
|
||||
for lane in self.lanes.values():
|
||||
try:
|
||||
checked_lanes += 1
|
||||
|
||||
# 确保 self.lanes 已初始化
|
||||
if not hasattr(self, 'lanes') or self.lanes is None:
|
||||
if self.config.get("debug", False):
|
||||
self.logger.warning("Lanes not initialized, skipping lane check")
|
||||
return True
|
||||
|
||||
position_2d = np.array(position[:2]) if len(position) > 2 else np.array(position)
|
||||
|
||||
try:
|
||||
for lane in self.lanes.values():
|
||||
if lane.lane.point_on_lane(position_2d):
|
||||
if self.debug_lane_filter:
|
||||
print(f" ✅ 在车道上 (车道{lane.lane.index}, 检查了{checked_lanes}条)")
|
||||
return True
|
||||
except:
|
||||
continue
|
||||
|
||||
if self.debug_lane_filter:
|
||||
print(f" ❌ 不在任何车道上 (检查了{checked_lanes}条车道)")
|
||||
|
||||
# 方法2:如果严格检测失败,使用容差范围检测(考虑车道边缘)
|
||||
# 注释:此方法已被禁用,如需启用请取消注释
|
||||
# if tolerance > 0:
|
||||
# for lane in self.lanes.values():
|
||||
# try:
|
||||
# # 计算点到车道中心线的距离
|
||||
# lane_obj = lane.lane
|
||||
# # 获取车道长度并检测最近点
|
||||
# s, lateral = lane_obj.local_coordinates(position_2d)
|
||||
|
||||
# # 如果横向距离在容差范围内,认为是有效的
|
||||
# if abs(lateral) <= tolerance and 0 <= s <= lane_obj.length:
|
||||
# return True
|
||||
# except:
|
||||
# continue
|
||||
|
||||
|
||||
lane_start = np.array(lane.lane.start)[:2]
|
||||
lane_end = np.array(lane.lane.end)[:2]
|
||||
lane_vec = lane_end - lane_start
|
||||
lane_length = np.linalg.norm(lane_vec)
|
||||
|
||||
if lane_length < 1e-6:
|
||||
continue
|
||||
|
||||
lane_vec_normalized = lane_vec / lane_length
|
||||
point_vec = position_2d - lane_start
|
||||
projection = np.dot(point_vec, lane_vec_normalized)
|
||||
|
||||
if 0 <= projection <= lane_length:
|
||||
closest_point = lane_start + projection * lane_vec_normalized
|
||||
distance = np.linalg.norm(position_2d - closest_point)
|
||||
if distance <= tolerance:
|
||||
return True
|
||||
except Exception as e:
|
||||
if self.config.get("debug", False):
|
||||
self.logger.warning(f"Lane check error: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _filter_valid_spawn_positions(self):
|
||||
"""
|
||||
过滤掉生成位置不在有效车道上的车辆信息
|
||||
根据配置决定是否执行过滤
|
||||
|
||||
Returns:
|
||||
tuple: (有效车辆数量, 被过滤车辆数量, 被过滤车辆ID列表)
|
||||
"""
|
||||
# 如果配置中禁用了过滤,直接返回
|
||||
if not self.config.get("filter_offroad_vehicles", True):
|
||||
if self.debug_lane_filter:
|
||||
print(f"🚫 车道过滤已禁用")
|
||||
return len(self.car_birth_info_list), 0, []
|
||||
|
||||
if self.debug_lane_filter:
|
||||
print(f"\n🔍 开始车道过滤: 共 {len(self.car_birth_info_list)} 辆车待检测")
|
||||
|
||||
valid_cars = []
|
||||
filtered_cars = []
|
||||
tolerance = self.config.get("lane_tolerance", 3.0)
|
||||
|
||||
for idx, car in enumerate(self.car_birth_info_list):
|
||||
if self.debug_lane_filter:
|
||||
print(f"\n车辆 {idx+1}/{len(self.car_birth_info_list)}: ID={car['id']}")
|
||||
|
||||
if self._is_position_on_lane(car['begin'], tolerance=tolerance):
|
||||
valid_cars.append(car)
|
||||
if self.debug_lane_filter:
|
||||
print(f" ✅ 保留")
|
||||
else:
|
||||
filtered_cars.append({
|
||||
'id': car['id'],
|
||||
'position': car['begin'],
|
||||
'reason': '生成位置不在有效车道上(可能在草坪/停车场等区域)'
|
||||
})
|
||||
if self.debug_lane_filter:
|
||||
print(f" ❌ 过滤 (原因: 不在车道上)")
|
||||
|
||||
self.car_birth_info_list = valid_cars
|
||||
|
||||
if self.debug_lane_filter:
|
||||
print(f"\n📊 过滤结果: 保留 {len(valid_cars)} 辆, 过滤 {len(filtered_cars)} 辆")
|
||||
|
||||
return len(valid_cars), len(filtered_cars), filtered_cars
|
||||
|
||||
|
||||
def _spawn_controlled_agents(self):
|
||||
# ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent")
|
||||
# ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0])
|
||||
for car in self.car_birth_info_list:
|
||||
if car['show_time'] == self.round:
|
||||
agent_id = f"controlled_{car['id']}"
|
||||
|
||||
vehicle_config = {}
|
||||
vehicle = self.engine.spawn_object(
|
||||
PolicyVehicle,
|
||||
vehicle_config={},
|
||||
vehicle_config=vehicle_config,
|
||||
position=car['begin'],
|
||||
heading=car['heading']
|
||||
)
|
||||
vehicle.reset(position=car['begin'], heading=car['heading'])
|
||||
|
||||
# 重置车辆状态
|
||||
reset_kwargs = {
|
||||
'position': car['begin'],
|
||||
'heading': car['heading']
|
||||
}
|
||||
|
||||
# 如果启用速度继承,设置初始速度
|
||||
if car.get('velocity') is not None:
|
||||
reset_kwargs['velocity'] = car['velocity']
|
||||
|
||||
vehicle.reset(**reset_kwargs)
|
||||
|
||||
# 设置策略和目的地
|
||||
vehicle.set_policy(self.policy)
|
||||
vehicle.set_destination(car['end'])
|
||||
vehicle.set_expert_vehicle_id(car['id'])
|
||||
|
||||
self.controlled_agents[agent_id] = vehicle
|
||||
self.controlled_agent_ids.append(agent_id)
|
||||
|
||||
# ✅ 关键:注册到引擎的 active_agents,才能参与物理更新
|
||||
# 注册到引擎的 active_agents
|
||||
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
||||
|
||||
def before_reset(self):
|
||||
"""在reset之前清理对象"""
|
||||
# 清理所有可控车辆
|
||||
if hasattr(self, 'controlled_agents') and hasattr(self, 'engine'):
|
||||
# 使用MetaDrive的clear_objects方法清理
|
||||
if hasattr(self.engine, 'clear_objects'):
|
||||
try:
|
||||
self.engine.clear_objects(list(self.controlled_agents.keys()))
|
||||
except:
|
||||
pass
|
||||
|
||||
# 从agent_manager中移除
|
||||
if hasattr(self.engine, 'agent_manager'):
|
||||
for agent_id in list(self.controlled_agents.keys()):
|
||||
if agent_id in self.engine.agent_manager.active_agents:
|
||||
self.engine.agent_manager.active_agents.pop(agent_id)
|
||||
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
|
||||
def _get_traffic_light_state(self, vehicle):
|
||||
"""
|
||||
获取车辆当前位置的红绿灯状态(优化版)
|
||||
|
||||
解决问题:
|
||||
1. 部分红绿灯状态为None的问题 - 添加异常处理和默认值
|
||||
2. 车道分段导致无法获取红绿灯的问题 - 优先使用导航模块,失败时回退到遍历
|
||||
|
||||
Returns:
|
||||
int: 0=无红绿灯, 1=绿灯, 2=黄灯, 3=红灯
|
||||
"""
|
||||
traffic_light = 0
|
||||
state = vehicle.get_state()
|
||||
position_2d = state['position'][:2]
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f"\n🚦 检测车辆红绿灯 - 位置: ({position_2d[0]:.1f}, {position_2d[1]:.1f})")
|
||||
|
||||
try:
|
||||
# 方法1:优先尝试从车辆导航模块获取当前车道(更高效)
|
||||
if hasattr(vehicle, 'navigation') and vehicle.navigation is not None:
|
||||
current_lane = vehicle.navigation.current_lane
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" 方法1-导航模块:")
|
||||
print(f" current_lane = {current_lane}")
|
||||
print(f" lane_index = {current_lane.index if current_lane else 'None'}")
|
||||
|
||||
if current_lane:
|
||||
has_light = self.engine.light_manager.has_traffic_light(current_lane.index)
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" has_traffic_light = {has_light}")
|
||||
|
||||
if has_light:
|
||||
status = self.engine.light_manager._lane_index_to_obj[current_lane.index].status
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" status = {status}")
|
||||
|
||||
if status == 'TRAFFIC_LIGHT_GREEN':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法1成功: 绿灯")
|
||||
return 1
|
||||
elif status == 'TRAFFIC_LIGHT_YELLOW':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法1成功: 黄灯")
|
||||
return 2
|
||||
elif status == 'TRAFFIC_LIGHT_RED':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法1成功: 红灯")
|
||||
return 3
|
||||
elif status is None:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ⚠️ 方法1: 红绿灯状态为None")
|
||||
return 0
|
||||
else:
|
||||
if self.debug_traffic_light:
|
||||
print(f" 该车道没有红绿灯")
|
||||
else:
|
||||
if self.debug_traffic_light:
|
||||
print(f" 导航模块current_lane为None")
|
||||
else:
|
||||
if self.debug_traffic_light:
|
||||
has_nav = hasattr(vehicle, 'navigation')
|
||||
nav_not_none = vehicle.navigation is not None if has_nav else False
|
||||
print(f" 方法1-导航模块: 不可用 (hasattr={has_nav}, not_none={nav_not_none})")
|
||||
|
||||
except Exception as e:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ❌ 方法1异常: {type(e).__name__}: {e}")
|
||||
pass
|
||||
|
||||
try:
|
||||
# 方法2:遍历所有车道查找(兜底方案,处理车道分段问题)
|
||||
if self.debug_traffic_light:
|
||||
print(f" 方法2-遍历车道: 开始遍历 {len(self.lanes)} 条车道")
|
||||
|
||||
found_lane = False
|
||||
checked_lanes = 0
|
||||
|
||||
for lane in self.lanes.values():
|
||||
try:
|
||||
checked_lanes += 1
|
||||
if lane.lane.point_on_lane(position_2d):
|
||||
found_lane = True
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✓ 找到车辆所在车道: {lane.lane.index} (检查了{checked_lanes}条)")
|
||||
|
||||
has_light = self.engine.light_manager.has_traffic_light(lane.lane.index)
|
||||
if self.debug_traffic_light:
|
||||
print(f" has_traffic_light = {has_light}")
|
||||
|
||||
if has_light:
|
||||
status = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
|
||||
if self.debug_traffic_light:
|
||||
print(f" status = {status}")
|
||||
|
||||
if status == 'TRAFFIC_LIGHT_GREEN':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法2成功: 绿灯")
|
||||
return 1
|
||||
elif status == 'TRAFFIC_LIGHT_YELLOW':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法2成功: 黄灯")
|
||||
return 2
|
||||
elif status == 'TRAFFIC_LIGHT_RED':
|
||||
if self.debug_traffic_light:
|
||||
print(f" ✅ 方法2成功: 红灯")
|
||||
return 3
|
||||
elif status is None:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ⚠️ 方法2: 红绿灯状态为None")
|
||||
return 0
|
||||
else:
|
||||
if self.debug_traffic_light:
|
||||
print(f" 该车道没有红绿灯")
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
if self.debug_traffic_light and not found_lane:
|
||||
print(f" ⚠️ 未找到车辆所在车道 (检查了{checked_lanes}条)")
|
||||
|
||||
except Exception as e:
|
||||
if self.debug_traffic_light:
|
||||
print(f" ❌ 方法2异常: {type(e).__name__}: {e}")
|
||||
pass
|
||||
|
||||
if self.debug_traffic_light:
|
||||
print(f" 结果: 返回 {traffic_light} (无红绿灯/未知)")
|
||||
|
||||
return traffic_light
|
||||
|
||||
if self.config.get("debug", False):
|
||||
self.logger.debug(f"Spawned vehicle {agent_id} at round {self.round}, position {car['begin']}")
|
||||
|
||||
def _get_all_obs(self):
|
||||
# position, velocity, heading, lidar, navigation, TODO: trafficlight -> list
|
||||
self.obs_list = []
|
||||
|
||||
for agent_id, vehicle in self.controlled_agents.items():
|
||||
state = vehicle.get_state()
|
||||
traffic_light = 0
|
||||
|
||||
# 使用优化后的红绿灯检测方法
|
||||
traffic_light = self._get_traffic_light_state(vehicle)
|
||||
for lane in self.lanes.values():
|
||||
if lane.lane.point_on_lane(state['position'][:2]):
|
||||
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
||||
traffic_light = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
|
||||
if traffic_light == 'TRAFFIC_LIGHT_GREEN':
|
||||
traffic_light = 1
|
||||
elif traffic_light == 'TRAFFIC_LIGHT_YELLOW':
|
||||
traffic_light = 2
|
||||
elif traffic_light == 'TRAFFIC_LIGHT_RED':
|
||||
traffic_light = 3
|
||||
else:
|
||||
traffic_light = 0
|
||||
break
|
||||
|
||||
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.dynamic_world)
|
||||
physics_world=self.engine.physics_world.dynamic_world)
|
||||
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
|
||||
obs = (state['position'][:2] + list(state['velocity']) + [state['heading_theta']]
|
||||
obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']]
|
||||
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
|
||||
+ list(vehicle.destination))
|
||||
|
||||
self.obs_list.append(obs)
|
||||
|
||||
return self.obs_list
|
||||
|
||||
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
|
||||
self.round += 1
|
||||
|
||||
# 应用动作
|
||||
for agent_id, action in action_dict.items():
|
||||
if agent_id in self.controlled_agents:
|
||||
self.controlled_agents[agent_id].before_step(action)
|
||||
|
||||
# 物理引擎步进
|
||||
self.engine.step()
|
||||
|
||||
# 后处理
|
||||
for agent_id in action_dict:
|
||||
if agent_id in self.controlled_agents:
|
||||
self.controlled_agents[agent_id].after_step()
|
||||
|
||||
# 生成新车辆
|
||||
self._spawn_controlled_agents()
|
||||
|
||||
# 获取观测
|
||||
obs = self._get_all_obs()
|
||||
|
||||
rewards = {aid: 0.0 for aid in self.controlled_agents}
|
||||
dones = {aid: False for aid in self.controlled_agents}
|
||||
dones["__all__"] = self.episode_step >= self.config["horizon"]
|
||||
|
||||
# ✅ 修复:添加回放模式的完成检查
|
||||
replay_finished = False
|
||||
if self.replay_mode and self.config.get("use_scenario_duration", False):
|
||||
# 检查是否所有专家轨迹都已播放完毕
|
||||
if self.round >= self.scenario_max_duration:
|
||||
replay_finished = True
|
||||
if self.config.get("debug", False):
|
||||
self.logger.info(f"Replay finished at step {self.round}/{self.scenario_max_duration}")
|
||||
|
||||
dones["__all__"] = self.episode_step >= self.config["horizon"] or replay_finished
|
||||
|
||||
infos = {aid: {} for aid in self.controlled_agents}
|
||||
|
||||
return obs, rewards, dones, infos
|
||||
|
||||
def close(self):
|
||||
# ✅ 清理所有生成的车辆
|
||||
if hasattr(self, 'controlled_agents') and self.controlled_agents:
|
||||
for agent_id, vehicle in list(self.controlled_agents.items()):
|
||||
if vehicle in self.engine.get_objects():
|
||||
self.engine.clear_objects([vehicle.id])
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
|
||||
super().close()
|
||||
@@ -6,8 +6,13 @@ class ConstantVelocityPolicy:
|
||||
|
||||
def act(self):
|
||||
self.step_num += 1
|
||||
# 简单的前进策略:直行 + 较大油门
|
||||
steering = 0.0 # 直行
|
||||
throttle = 0.5 # 中等油门,让车辆有明显运动
|
||||
|
||||
return [steering, throttle]
|
||||
if self.step_num % 30 < 15:
|
||||
throttle = 1.0
|
||||
else:
|
||||
throttle = 1.0
|
||||
|
||||
steering = 0.1
|
||||
|
||||
# return [steering, throttle]
|
||||
|
||||
return [0.0,0.05]
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
"""
|
||||
测试车道过滤和红绿灯检测功能
|
||||
"""
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from simple_idm_policy import ConstantVelocityPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
from logger_utils import setup_logger
|
||||
import os
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/MAGAIL4AutoDrive/Env"
|
||||
|
||||
def test_lane_filter():
|
||||
"""测试车道过滤功能(基础版)"""
|
||||
print("=" * 60)
|
||||
print("测试1:车道过滤功能(基础)")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建启用过滤的环境
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 100,
|
||||
"use_render": False,
|
||||
|
||||
# 车道过滤配置
|
||||
"filter_offroad_vehicles": True,
|
||||
"lane_tolerance": 3.0,
|
||||
"max_controlled_vehicles": 10,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
print("\n启用车道过滤...")
|
||||
obs = env.reset(0)
|
||||
print(f"生成车辆数: {len(env.controlled_agents)}")
|
||||
print(f"观测数据长度: {len(obs)}")
|
||||
|
||||
# 运行几步
|
||||
for step in range(5):
|
||||
actions = {aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents}
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
|
||||
env.close()
|
||||
print("✓ 车道过滤测试通过\n")
|
||||
|
||||
|
||||
def test_lane_filter_debug():
|
||||
"""测试车道过滤功能(详细调试)"""
|
||||
print("=" * 60)
|
||||
print("测试1b:车道过滤功能(详细调试模式)")
|
||||
print("=" * 60)
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 100,
|
||||
"use_render": False,
|
||||
|
||||
# 车道过滤配置
|
||||
"filter_offroad_vehicles": True,
|
||||
"lane_tolerance": 3.0,
|
||||
"max_controlled_vehicles": 5, # 只看前5辆车
|
||||
|
||||
# 🔥 启用调试模式
|
||||
"debug_lane_filter": True, # 启用车道过滤调试
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
print("\n启用车道过滤调试...")
|
||||
obs = env.reset(0)
|
||||
|
||||
env.close()
|
||||
print("\n✓ 车道过滤调试测试完成\n")
|
||||
|
||||
|
||||
def test_traffic_light():
|
||||
"""测试红绿灯检测功能"""
|
||||
print("=" * 60)
|
||||
print("测试2:红绿灯检测功能(启用详细调试)")
|
||||
print("=" * 60)
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 100,
|
||||
"use_render": False,
|
||||
"filter_offroad_vehicles": True,
|
||||
"max_controlled_vehicles": 3, # 只测试3辆车
|
||||
|
||||
# 🔥 启用调试模式
|
||||
"debug_traffic_light": True, # 启用红绿灯调试
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
obs = env.reset(0)
|
||||
|
||||
# 测试红绿灯检测(调试模式会自动输出详细信息)
|
||||
print(f"\n" + "="*60)
|
||||
print(f"开始逐车检测红绿灯状态(共 {len(env.controlled_agents)} 辆车)")
|
||||
print("="*60)
|
||||
|
||||
for idx, (aid, vehicle) in enumerate(list(env.controlled_agents.items())[:3]): # 只测试前3辆
|
||||
print(f"\n【车辆 {idx+1}/3】 ID={aid}")
|
||||
traffic_light = env._get_traffic_light_state(vehicle)
|
||||
state = vehicle.get_state()
|
||||
|
||||
status_text = {0: '无/未知', 1: '绿灯', 2: '黄灯', 3: '红灯'}[traffic_light]
|
||||
print(f"最终结果: 红绿灯状态={traffic_light} ({status_text})\n")
|
||||
|
||||
env.close()
|
||||
print("="*60)
|
||||
print("✓ 红绿灯检测测试完成")
|
||||
print("="*60 + "\n")
|
||||
|
||||
|
||||
def test_without_filter():
|
||||
"""测试禁用过滤的情况"""
|
||||
print("=" * 60)
|
||||
print("测试3:禁用过滤(对比测试)")
|
||||
print("=" * 60)
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 100,
|
||||
"use_render": False,
|
||||
|
||||
# 禁用过滤
|
||||
"filter_offroad_vehicles": False,
|
||||
"max_controlled_vehicles": None,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
print("\n禁用车道过滤...")
|
||||
obs = env.reset(0)
|
||||
print(f"生成车辆数(未过滤): {len(env.controlled_agents)}")
|
||||
|
||||
env.close()
|
||||
print("✓ 禁用过滤测试通过\n")
|
||||
|
||||
|
||||
def run_tests(debug_mode=False):
|
||||
"""运行测试的主函数"""
|
||||
try:
|
||||
if debug_mode:
|
||||
print("🐛 调试模式启用")
|
||||
print("=" * 60 + "\n")
|
||||
test_lane_filter_debug()
|
||||
test_traffic_light()
|
||||
else:
|
||||
print("⚡ 标准测试模式(使用 --debug 参数启用详细调试)")
|
||||
print("=" * 60 + "\n")
|
||||
test_lane_filter()
|
||||
test_traffic_light()
|
||||
test_without_filter()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ 所有测试通过!")
|
||||
print("=" * 60)
|
||||
print("\n功能说明:")
|
||||
print("1. 车道过滤功能已启用,自动过滤非车道区域车辆")
|
||||
print("2. 红绿灯检测采用双重策略,确保稳定获取状态")
|
||||
print("3. 可通过配置参数灵活启用/禁用功能")
|
||||
print("\n使用方法:")
|
||||
print(" python Env/test_lane_filter.py # 标准测试")
|
||||
print(" python Env/test_lane_filter.py --debug # 详细调试")
|
||||
print(" python Env/test_lane_filter.py --log # 保存日志")
|
||||
print(" python Env/test_lane_filter.py --debug --log # 调试+日志")
|
||||
print("\n请运行 run_multiagent_env.py 查看完整效果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
# 解析命令行参数
|
||||
debug_mode = "--debug" in sys.argv or "-d" in sys.argv
|
||||
enable_logging = "--log" in sys.argv or "-l" in sys.argv
|
||||
|
||||
# 提取自定义日志文件名
|
||||
log_file = None
|
||||
for arg in sys.argv:
|
||||
if arg.startswith("--log-file="):
|
||||
log_file = arg.split("=")[1]
|
||||
break
|
||||
|
||||
if enable_logging:
|
||||
# 启用日志记录
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
|
||||
# 生成默认日志文件名
|
||||
if log_file is None:
|
||||
mode_suffix = "debug" if debug_mode else "standard"
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
log_file = f"test_{mode_suffix}_{timestamp}.log"
|
||||
|
||||
with setup_logger(log_file=log_file, log_dir=log_dir):
|
||||
run_tests(debug_mode=debug_mode)
|
||||
else:
|
||||
# 不启用日志,直接运行
|
||||
run_tests(debug_mode=debug_mode)
|
||||
|
||||
543
MAGAIL算法应用指南.md
543
MAGAIL算法应用指南.md
@@ -1,543 +0,0 @@
|
||||
# MAGAIL算法应用指南
|
||||
|
||||
## 目录
|
||||
1. [Algorithm模块概览](#algorithm模块概览)
|
||||
2. [如何应用到环境](#如何应用到环境)
|
||||
3. [完整训练流程](#完整训练流程)
|
||||
4. [当前实现状态](#当前实现状态)
|
||||
5. [需要完善的部分](#需要完善的部分)
|
||||
|
||||
---
|
||||
|
||||
## Algorithm模块概览
|
||||
|
||||
### 📁 模块文件说明
|
||||
|
||||
```
|
||||
Algorithm/
|
||||
├── bert.py # BERT判别器/价值网络
|
||||
├── disc.py # GAIL判别器(继承BERT)
|
||||
├── policy.py # 策略网络(Actor)
|
||||
├── ppo.py # PPO算法基类
|
||||
├── magail.py # MAGAIL主算法(继承PPO)
|
||||
├── buffer.py # 经验回放缓冲区
|
||||
└── utils.py # 工具函数(标准化等)
|
||||
```
|
||||
|
||||
### 🔗 模块依赖关系
|
||||
|
||||
```
|
||||
MAGAIL (magail.py)
|
||||
├─ 继承 PPO (ppo.py)
|
||||
│ ├─ 使用 RolloutBuffer (buffer.py)
|
||||
│ ├─ 使用 StateIndependentPolicy (policy.py)
|
||||
│ └─ 使用 Bert作为Critic (bert.py)
|
||||
│
|
||||
├─ 使用 GAILDiscrim (disc.py)
|
||||
│ └─ 继承 Bert (bert.py)
|
||||
│
|
||||
└─ 使用 Normalizer (utils.py)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 如何应用到环境
|
||||
|
||||
### ✅ 已完成的准备工作
|
||||
|
||||
我已经为您:
|
||||
|
||||
1. **修复了PPO代码bug**:添加了缺失的`action_shape`参数
|
||||
2. **创建了训练脚本**:`train_magail.py`
|
||||
3. **提供了完整框架**:包含环境初始化、训练循环、模型保存等
|
||||
|
||||
### 🚀 快速开始
|
||||
|
||||
#### 方法1:使用训练脚本(推荐)
|
||||
|
||||
```bash
|
||||
# 基本训练(使用默认参数)
|
||||
python train_magail.py
|
||||
|
||||
# 自定义参数
|
||||
python train_magail.py \
|
||||
--data-dir /path/to/waymo/data \
|
||||
--episodes 1000 \
|
||||
--horizon 300 \
|
||||
--batch-size 256 \
|
||||
--lr-actor 3e-4 \
|
||||
--render # 可视化
|
||||
|
||||
# 查看所有参数
|
||||
python train_magail.py --help
|
||||
```
|
||||
|
||||
#### 方法2:在Jupyter Notebook中使用
|
||||
|
||||
```python
|
||||
import sys
|
||||
sys.path.append('Algorithm')
|
||||
sys.path.append('Env')
|
||||
|
||||
from Algorithm.magail import MAGAIL
|
||||
from Env.scenario_env import MultiAgentScenarioEnv
|
||||
|
||||
# 初始化环境
|
||||
env = MultiAgentScenarioEnv(config={...})
|
||||
|
||||
# 初始化MAGAIL
|
||||
magail = MAGAIL(
|
||||
buffer_exp=expert_buffer,
|
||||
input_dim=(108,), # 观测维度
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# 训练循环
|
||||
for episode in range(1000):
|
||||
obs = env.reset()
|
||||
for step in range(300):
|
||||
actions, log_pis = magail.explore(obs)
|
||||
next_obs, rewards, dones, infos = env.step(actions)
|
||||
# ... 更新逻辑
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 完整训练流程
|
||||
|
||||
### 📊 数据流程图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ MAGAIL训练流程 │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
第1步: 初始化
|
||||
├─ 加载Waymo专家数据 → ExpertBuffer
|
||||
├─ 创建MAGAIL算法实例
|
||||
│ ├─ Actor (policy.py)
|
||||
│ ├─ Critic (bert.py)
|
||||
│ ├─ Discriminator (disc.py)
|
||||
│ └─ Buffers (buffer.py)
|
||||
└─ 创建多智能体环境
|
||||
|
||||
第2步: 训练循环
|
||||
for episode in range(episodes):
|
||||
├─ env.reset() → 重置环境,生成车辆
|
||||
│
|
||||
for step in range(horizon):
|
||||
├─ obs = env._get_all_obs() # 收集观测
|
||||
│
|
||||
├─ actions = magail.explore(obs) # 策略采样
|
||||
│
|
||||
├─ next_obs, rewards, dones = env.step(actions)
|
||||
│
|
||||
├─ buffer.append(obs, actions, rewards, ...) # 存储经验
|
||||
│
|
||||
└─ if step % rollout_length == 0:
|
||||
├─ 更新判别器
|
||||
│ ├─ 采样策略数据: buffer.sample()
|
||||
│ ├─ 采样专家数据: expert_buffer.sample()
|
||||
│ └─ update_disc(policy_data, expert_data)
|
||||
│
|
||||
├─ 计算GAIL奖励
|
||||
│ └─ reward = -log(1 - D(s, s'))
|
||||
│
|
||||
└─ 更新PPO
|
||||
├─ 计算GAE优势
|
||||
├─ update_actor()
|
||||
└─ update_critic()
|
||||
|
||||
第3步: 评估与保存
|
||||
└─ 保存模型、记录指标
|
||||
```
|
||||
|
||||
### 🔑 关键代码段
|
||||
|
||||
#### 1. 初始化MAGAIL
|
||||
|
||||
```python
|
||||
from Algorithm.magail import MAGAIL
|
||||
|
||||
magail = MAGAIL(
|
||||
buffer_exp=expert_buffer, # 专家数据缓冲区
|
||||
input_dim=(obs_dim,), # 观测维度 (108,)
|
||||
device=device, # cuda/cpu
|
||||
action_shape=(2,), # 动作维度 [转向, 油门]
|
||||
|
||||
# 判别器参数
|
||||
disc_coef=20.0, # 判别器损失系数
|
||||
disc_grad_penalty=0.1, # 梯度惩罚系数
|
||||
disc_logit_reg=0.25, # Logit正则化
|
||||
disc_weight_decay=0.0005, # 权重衰减
|
||||
lr_disc=3e-4, # 判别器学习率
|
||||
epoch_disc=5, # 判别器更新轮数
|
||||
|
||||
# PPO参数
|
||||
rollout_length=2048, # 更新间隔
|
||||
lr_actor=3e-4, # Actor学习率
|
||||
lr_critic=3e-4, # Critic学习率
|
||||
epoch_ppo=10, # PPO更新轮数
|
||||
batch_size=256, # 批次大小
|
||||
gamma=0.995, # 折扣因子
|
||||
lambd=0.97, # GAE lambda
|
||||
|
||||
# 其他
|
||||
use_gail_norm=True, # 使用数据标准化
|
||||
)
|
||||
```
|
||||
|
||||
#### 2. 环境交互
|
||||
|
||||
```python
|
||||
# 重置环境
|
||||
obs_list = env.reset(episode)
|
||||
|
||||
# 收集观测(所有车辆)
|
||||
obs_array = np.array(env.obs_list) # shape: (n_agents, 108)
|
||||
|
||||
# 策略采样
|
||||
actions, log_pis = magail.explore(obs_array)
|
||||
# actions: list of [转向, 油门] for each agent
|
||||
# log_pis: list of log probabilities
|
||||
|
||||
# 构建动作字典
|
||||
action_dict = {
|
||||
agent_id: actions[i]
|
||||
for i, agent_id in enumerate(env.controlled_agents.keys())
|
||||
}
|
||||
|
||||
# 环境步进
|
||||
next_obs, rewards, dones, infos = env.step(action_dict)
|
||||
```
|
||||
|
||||
#### 3. 模型更新
|
||||
|
||||
```python
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
writer = SummaryWriter('logs')
|
||||
|
||||
# 更新判别器和策略
|
||||
if total_steps % rollout_length == 0:
|
||||
# MAGAIL会自动:
|
||||
# 1. 从buffer采样策略数据
|
||||
# 2. 从expert_buffer采样专家数据
|
||||
# 3. 更新判别器
|
||||
# 4. 计算GAIL奖励
|
||||
# 5. 更新PPO(Actor + Critic)
|
||||
|
||||
reward = magail.update(writer, total_steps)
|
||||
|
||||
print(f"Step {total_steps}, Reward: {reward:.4f}")
|
||||
```
|
||||
|
||||
#### 4. 保存和加载模型
|
||||
|
||||
```python
|
||||
# 保存
|
||||
magail.save_models('outputs/models/checkpoint_100')
|
||||
|
||||
# 加载
|
||||
magail.load_models('outputs/models/checkpoint_100/model.pth')
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 当前实现状态
|
||||
|
||||
### ✅ 已实现功能
|
||||
|
||||
| 模块 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| BERT判别器 | ✅ 完整 | 支持动态车辆数量 |
|
||||
| GAIL判别器 | ✅ 完整 | 包含梯度惩罚、正则化 |
|
||||
| 策略网络 | ✅ 完整 | 高斯策略,重参数化 |
|
||||
| PPO算法 | ✅ 完整 | GAE、裁剪目标、自适应LR |
|
||||
| MAGAIL | ✅ 完整 | 判别器+PPO整合 |
|
||||
| 缓冲区 | ✅ 完整 | 经验存储和采样 |
|
||||
| 数据标准化 | ✅ 完整 | 运行时统计量 |
|
||||
| 环境接口 | ✅ 完整 | 多智能体场景环境 |
|
||||
|
||||
### ⚠️ 需要注意的问题
|
||||
|
||||
#### 1. 多智能体适配问题
|
||||
|
||||
**当前状态:** Algorithm模块设计为单智能体,但环境是多智能体
|
||||
|
||||
**影响:**
|
||||
- `buffer.append()` 接受单个状态-动作对
|
||||
- 但环境返回多个智能体的数据
|
||||
|
||||
**解决方案A:** 将所有智能体视为一个整体
|
||||
```python
|
||||
# 拼接所有智能体的观测
|
||||
all_obs = np.concatenate([obs for obs in obs_list])
|
||||
all_actions = np.concatenate([actions for actions in action_list])
|
||||
```
|
||||
|
||||
**解决方案B:** 为每个智能体独立存储
|
||||
```python
|
||||
for i, agent_id in enumerate(env.controlled_agents):
|
||||
buffer.append(obs_list[i], actions[i], rewards[i], ...)
|
||||
```
|
||||
|
||||
**推荐:** 解决方案B,因为MAGAIL的设计就是处理多智能体的
|
||||
|
||||
#### 2. 专家数据加载
|
||||
|
||||
**当前状态:** `ExpertBuffer` 类只有框架,未实现实际加载
|
||||
|
||||
**需要完善:**
|
||||
```python
|
||||
def _extract_trajectories(self, scenario_data):
|
||||
"""
|
||||
需要根据Waymo数据格式实现
|
||||
|
||||
示例结构:
|
||||
scenario_data = {
|
||||
'tracks': {
|
||||
'vehicle_id': {
|
||||
'states': [...], # 状态序列
|
||||
'actions': [...], # 动作序列(如果有)
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
# TODO: 提取state和next_state对
|
||||
for track_id, track_data in scenario_data['tracks'].items():
|
||||
states = track_data['states']
|
||||
for i in range(len(states) - 1):
|
||||
self.states.append(states[i])
|
||||
self.next_states.append(states[i+1])
|
||||
```
|
||||
|
||||
#### 3. 观测维度对齐
|
||||
|
||||
**当前假设:** 观测维度为108
|
||||
- 位置(2) + 速度(2) + 朝向(1) + 激光雷达(80) + 侧向(10) + 车道线(10) + 红绿灯(1) + 目标点(2) = 108
|
||||
|
||||
**需要验证:** 实际运行时打印观测shape
|
||||
```python
|
||||
obs = env.reset()
|
||||
print(f"观测维度: {len(obs[0]) if obs else 0}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 需要完善的部分
|
||||
|
||||
### 🔨 短期TODO
|
||||
|
||||
#### 1. 修复多智能体buffer问题
|
||||
|
||||
**创建文件:** `Algorithm/multi_agent_buffer.py`
|
||||
|
||||
```python
|
||||
class MultiAgentRolloutBuffer:
|
||||
"""
|
||||
多智能体经验缓冲区
|
||||
|
||||
支持动态数量的智能体
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size, state_shape, action_shape, device):
|
||||
self.buffer_size = buffer_size
|
||||
self.state_shape = state_shape
|
||||
self.action_shape = action_shape
|
||||
self.device = device
|
||||
|
||||
# 使用列表存储,支持动态智能体数量
|
||||
self.episodes = []
|
||||
self.current_episode = {
|
||||
'states': [],
|
||||
'actions': [],
|
||||
'rewards': [],
|
||||
'dones': [],
|
||||
'log_pis': [],
|
||||
'next_states': [],
|
||||
}
|
||||
|
||||
def append(self, state, action, reward, done, log_pi, next_state):
|
||||
"""添加单步经验"""
|
||||
self.current_episode['states'].append(state)
|
||||
self.current_episode['actions'].append(action)
|
||||
self.current_episode['rewards'].append(reward)
|
||||
self.current_episode['dones'].append(done)
|
||||
self.current_episode['log_pis'].append(log_pi)
|
||||
self.current_episode['next_states'].append(next_state)
|
||||
|
||||
def finish_episode(self):
|
||||
"""完成一个episode"""
|
||||
self.episodes.append(self.current_episode)
|
||||
self.current_episode = {
|
||||
'states': [],
|
||||
'actions': [],
|
||||
'rewards': [],
|
||||
'dones': [],
|
||||
'log_pis': [],
|
||||
'next_states': [],
|
||||
}
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""采样批次"""
|
||||
# 从所有episode中随机采样
|
||||
all_states = []
|
||||
all_next_states = []
|
||||
|
||||
for episode in self.episodes:
|
||||
all_states.extend(episode['states'])
|
||||
all_next_states.extend(episode['next_states'])
|
||||
|
||||
indices = np.random.choice(len(all_states), batch_size, replace=False)
|
||||
|
||||
states = torch.tensor([all_states[i] for i in indices], device=self.device)
|
||||
next_states = torch.tensor([all_next_states[i] for i in indices], device=self.device)
|
||||
|
||||
return states, next_states
|
||||
```
|
||||
|
||||
#### 2. 实现专家数据加载
|
||||
|
||||
**需要了解:** Waymo数据的实际格式
|
||||
|
||||
```python
|
||||
# 示例:读取一个pkl文件并打印结构
|
||||
import pickle
|
||||
|
||||
with open('Env/exp_converted/exp_converted_0/sd_waymo_*.pkl', 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
print(type(data))
|
||||
print(data.keys() if isinstance(data, dict) else len(data))
|
||||
# 根据实际结构调整加载代码
|
||||
```
|
||||
|
||||
#### 3. 完善训练循环
|
||||
|
||||
**在 `train_magail.py` 中添加:**
|
||||
|
||||
```python
|
||||
# 完整的buffer存储逻辑
|
||||
for i, agent_id in enumerate(env.controlled_agents.keys()):
|
||||
if i < len(obs_array) and i < len(actions):
|
||||
magail.buffer.append(
|
||||
state=obs_array[i],
|
||||
action=actions[i],
|
||||
reward=rewards.get(agent_id, 0.0),
|
||||
done=dones.get(agent_id, False),
|
||||
tm_done=dones.get(agent_id, False),
|
||||
log_pi=log_pis[i],
|
||||
next_state=next_obs_array[i] if i < len(next_obs_array) else obs_array[i],
|
||||
next_state_gail=next_obs_array[i] if i < len(next_obs_array) else obs_array[i],
|
||||
means=magail.actor.means[i].detach().cpu().numpy(),
|
||||
stds=magail.actor.log_stds.exp()[0].detach().cpu().numpy()
|
||||
)
|
||||
```
|
||||
|
||||
### 🎯 中期TODO
|
||||
|
||||
1. **实现多智能体BERT**:当前BERT接受(batch, N, obs_dim),需要确保正确处理
|
||||
2. **奖励设计**:当前环境奖励为0,需要设计合理的任务奖励
|
||||
3. **评估脚本**:创建评估脚本,可视化训练好的策略
|
||||
4. **超参数调优**:使用wandb或tensorboard进行超参数搜索
|
||||
|
||||
---
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 示例1:简单训练
|
||||
|
||||
```bash
|
||||
# 1. 确保环境正常
|
||||
python Env/run_multiagent_env.py
|
||||
|
||||
# 2. 开始训练(不渲染,快速训练)
|
||||
python train_magail.py \
|
||||
--episodes 100 \
|
||||
--horizon 200 \
|
||||
--rollout-length 1024 \
|
||||
--batch-size 128
|
||||
|
||||
# 3. 查看训练日志
|
||||
tensorboard --logdir outputs/magail_*/logs
|
||||
```
|
||||
|
||||
### 示例2:调试模式
|
||||
|
||||
```bash
|
||||
# 少量episode,启用渲染
|
||||
python train_magail.py \
|
||||
--episodes 5 \
|
||||
--horizon 100 \
|
||||
--render
|
||||
```
|
||||
|
||||
### 示例3:在代码中使用
|
||||
|
||||
```python
|
||||
# test_algorithm.py
|
||||
import sys
|
||||
sys.path.append('Algorithm')
|
||||
|
||||
from Algorithm.magail import MAGAIL
|
||||
import torch
|
||||
|
||||
# 创建虚拟数据测试
|
||||
class DummyExpertBuffer:
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
|
||||
def sample(self, batch_size):
|
||||
states = torch.randn(batch_size, 108, device=self.device)
|
||||
next_states = torch.randn(batch_size, 108, device=self.device)
|
||||
return states, next_states
|
||||
|
||||
# 初始化
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
expert_buffer = DummyExpertBuffer(device)
|
||||
|
||||
magail = MAGAIL(
|
||||
buffer_exp=expert_buffer,
|
||||
input_dim=(108,),
|
||||
device=device,
|
||||
action_shape=(2,),
|
||||
)
|
||||
|
||||
# 测试前向传播
|
||||
test_obs = torch.randn(5, 108, device=device) # 5个智能体
|
||||
actions, log_pis = magail.explore(test_obs)
|
||||
|
||||
print(f"观测形状: {test_obs.shape}")
|
||||
print(f"动作数量: {len(actions)}")
|
||||
print(f"单个动作形状: {actions[0].shape}")
|
||||
print(f"测试成功!✅")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
### ✅ 现在可以做什么
|
||||
|
||||
1. **运行环境测试**:`run_multiagent_env.py` 已经可以正常运行
|
||||
2. **测试算法模块**:Algorithm中的所有模块都已实现
|
||||
3. **开始初步训练**:使用 `train_magail.py`(但需要完善buffer逻辑)
|
||||
|
||||
### ⚠️ 需要您完成的
|
||||
|
||||
1. **调试多智能体buffer**:确保经验正确存储
|
||||
2. **实现专家数据加载**:根据实际数据格式调整
|
||||
3. **验证观测维度**:确认实际观测是否为108维
|
||||
4. **调整训练参数**:根据训练效果调优
|
||||
|
||||
### 🎯 最终目标
|
||||
|
||||
```
|
||||
环境 (Env/) + 算法 (Algorithm/) = 完整的MAGAIL训练系统
|
||||
↓
|
||||
训练出能够模仿专家行为的
|
||||
多智能体自动驾驶策略
|
||||
```
|
||||
|
||||
祝训练顺利!🚀
|
||||
|
||||
456
README.md
456
README.md
@@ -1,85 +1,401 @@
|
||||
# MAGAIL4AutoDrive
|
||||
### 1.1 环境搭建
|
||||
环境核心代码封装于`Env`文件夹,通过运行`run_multiagent_env.py`即可启动多智能体交互环境,该脚本的核心功能为读取各智能体(车辆)的动作指令,并将其传入`env.step()`方法中完成仿真执行。
|
||||
# MAGAIL4AutoDrive - 多智能体自动驾驶环境
|
||||
|
||||
**性能优化版本:** 针对原始版本FPS低(15帧)和CPU利用率不足的问题,已提供多个优化版本:
|
||||
- `run_multiagent_env_fast.py` - 激光雷达优化版(30-60 FPS,2-4倍提升)⭐推荐
|
||||
- `run_multiagent_env_parallel.py` - 多进程并行版(300-600 steps/s总吞吐量,充分利用多核CPU)⭐⭐推荐
|
||||
- 详见 `Env/QUICK_START.md` 快速使用指南
|
||||
基于 MetaDrive 的多智能体自动驾驶仿真与回放环境,支持 Waymo Open Dataset 的专家轨迹回放和自定义策略仿真。
|
||||
|
||||
当前已初步实现`Env.senario_env.MultiAgentScenarioEnv.reset()`车辆生成函数,具体逻辑如下:首先读取专家数据集中各车辆的初始位姿信息;随后对原始数据进行清洗,剔除车辆 Agent 实例信息,记录核心参数(车辆 ID、初始生成位置、朝向角、生成时间戳、目标终点坐标);最后调用`_spawn_controlled_agents()`函数,依据清洗后的参数在指定时间、指定位置生成搭载自动驾驶算法的可控车辆。
|
||||
## 📋 目录
|
||||
|
||||
**✅ 已解决:车辆生成位置偏差问题**
|
||||
- **问题描述**:部分车辆生成于草坪、停车场等非车道区域,原因是专家数据记录误差或停车场特殊标注
|
||||
- **解决方案**:实现了`_is_position_on_lane()`车道区域检测机制和`_filter_valid_spawn_positions()`过滤函数
|
||||
- 检测逻辑:通过`point_on_lane()`判断位置是否在车道上,支持容差参数(默认3米)处理边界情况
|
||||
- 双重检测:优先使用精确检测,失败时使用容差范围检测,确保车道边缘车辆不被误过滤
|
||||
- 自动过滤:在`reset()`时自动过滤非车道区域车辆,并输出过滤统计信息
|
||||
- **配置参数**:
|
||||
- `filter_offroad_vehicles=True`:启用/禁用车道过滤功能
|
||||
- `lane_tolerance=3.0`:车道检测容差(米),可根据场景调整
|
||||
- `max_controlled_vehicles=10`:限制最大车辆数(可选)
|
||||
- **使用示例**:在环境配置中设置上述参数即可自动启用,运行时会显示过滤信息(如"过滤5辆,保留45辆")
|
||||
- [项目简介](#项目简介)
|
||||
- [功能特性](#功能特性)
|
||||
- [环境要求](#环境要求)
|
||||
- [安装步骤](#安装步骤)
|
||||
- [快速开始](#快速开始)
|
||||
- [使用指南](#使用指南)
|
||||
- [项目结构](#项目结构)
|
||||
- [配置说明](#配置说明)
|
||||
- [常见问题](#常见问题)
|
||||
|
||||
## 项目简介
|
||||
|
||||
MAGAIL4AutoDrive 是一个基于 MetaDrive 0.4.3 的多智能体自动驾驶环境,专为模仿学习(Imitation Learning)和强化学习(Reinforcement Learning)研究设计。项目支持从真实世界数据集(如 Waymo Open Dataset)中加载场景,并提供两种核心运行模式:
|
||||
|
||||
- **回放模式(Replay Mode)**:严格按照专家轨迹回放,用于数据可视化和验证
|
||||
- **仿真模式(Simulation Mode)**:使用自定义策略控制车辆,用于算法训练和测试
|
||||
|
||||
## 功能特性
|
||||
|
||||
### 核心功能
|
||||
- ✅ **多智能体支持**:同时控制多辆车辆进行协同仿真
|
||||
- ✅ **专家轨迹回放**:精确回放 Waymo 数据集中的专家驾驶行为
|
||||
- ✅ **自定义策略接口**:灵活接入各种控制策略(IDM、RL 等)
|
||||
- ✅ **智能车道过滤**:自动过滤不在车道上的异常车辆
|
||||
- ✅ **场景时长控制**:支持使用数据集原始场景时长或自定义 horizon
|
||||
- ✅ **丰富的传感器**:LiDAR、侧向检测器、车道线检测器、相机、仪表盘
|
||||
|
||||
### 高级特性
|
||||
- 🎯 指定场景 ID 运行
|
||||
- 🔄 自动场景切换(修复版)
|
||||
- 📊 详细的调试日志输出
|
||||
- 🚗 车辆动态生成与管理
|
||||
- 🎮 支持可视化渲染和无头运行
|
||||
|
||||
## 环境要求
|
||||
|
||||
### 系统要求
|
||||
- **操作系统**:Ubuntu 18.04+ / macOS 10.14+ / Windows 10+
|
||||
- **Python 版本**:3.8 - 3.10
|
||||
- **GPU**:可选,但推荐使用(用于加速渲染)
|
||||
|
||||
### 依赖库
|
||||
```
|
||||
|
||||
metadrive-simulator==0.4.3
|
||||
numpy>=1.19.0
|
||||
pygame>=2.0.0
|
||||
|
||||
```
|
||||
|
||||
## 安装步骤
|
||||
|
||||
### 1. 创建 Conda 环境
|
||||
```
|
||||
|
||||
conda create -n metadrive python=3.10
|
||||
conda activate metadrive
|
||||
|
||||
```
|
||||
|
||||
### 2. 安装 MetaDrive
|
||||
```
|
||||
|
||||
pip install metadrive-simulator==0.4.3
|
||||
|
||||
```
|
||||
|
||||
### 3. 克隆项目
|
||||
```
|
||||
|
||||
git clone https://github.com/your-username/MAGAIL4AutoDrive.git
|
||||
cd MAGAIL4AutoDrive/Env
|
||||
|
||||
```
|
||||
|
||||
### 4. 准备数据集
|
||||
将 Waymo 数据集转换为 MetaDrive 格式并放置在项目目录下:
|
||||
```
|
||||
|
||||
MAGAIL4AutoDrive/Env/
|
||||
├── exp_converted/
|
||||
│ ├── scenario_0/
|
||||
│ ├── scenario_1/
|
||||
│ └── ...
|
||||
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 回放模式(推荐先尝试)
|
||||
```
|
||||
|
||||
|
||||
### 1.2 观测获取
|
||||
观测信息采集功能通过`Env.senario_env.MultiAgentScenarioEnv._get_all_obs()`函数实现,该函数支持遍历所有可控车辆并采集多维度观测数据,当前已实现的观测维度包括:车辆实时位置坐标、朝向角、行驶速度、雷达扫描点云(含障碍物与车道线特征)、导航信息(因场景复杂度较低,暂采用目标终点坐标直接作为导航输入)。
|
||||
# 使用场景原始时长回放第一个场景
|
||||
|
||||
**✅ 已解决:红绿灯信息采集问题**
|
||||
- **问题描述**:
|
||||
- 问题1:部分红绿灯状态值为`None`,导致异常或错误判断
|
||||
- 问题2:车道分段设计时,部分区域车辆无法匹配到红绿灯
|
||||
- **解决方案**:实现了`_get_traffic_light_state()`优化方法,采用多级检测策略
|
||||
- **方法1(优先)**:从车辆导航模块`vehicle.navigation.current_lane`获取当前车道,直接查询红绿灯状态(高效,自动处理车道分段)
|
||||
- **方法2(兜底)**:遍历所有车道,通过`point_on_lane()`判断车辆位置,查找对应红绿灯(处理导航失败情况)
|
||||
- **异常处理**:对状态为`None`的情况返回0(无红绿灯),所有异常均有try-except保护,确保不会中断程序
|
||||
- **返回值规范**:0=无红绿灯/未知, 1=绿灯, 2=黄灯, 3=红灯
|
||||
- **优势**:双重保障机制,优先用高效方法,失败时自动切换到兜底方案,确保所有场景都能正确获取红绿灯信息
|
||||
python run_multiagent_env.py --mode replay --episodes 1 --use_scenario_duration
|
||||
|
||||
# 回放指定场景
|
||||
|
||||
python run_multiagent_env.py --mode replay --scenario_id 0 --use_scenario_duration
|
||||
|
||||
# 回放多个场景
|
||||
|
||||
python run_multiagent_env.py --mode replay --episodes 3 --use_scenario_duration
|
||||
|
||||
```
|
||||
|
||||
### 仿真模式
|
||||
```
|
||||
|
||||
|
||||
### 1.3 算法模块
|
||||
本方案的核心创新点在于对 GAIL 算法的判别器进行改进,使其适配多智能体场景下 “输入长度动态变化”(车辆数量不固定)的特性,实现对整体交互场景的分类判断,进而满足多智能体自动驾驶环境的训练需求。算法核心代码封装于`Algorithm.bert.Bert`类,具体实现逻辑如下:
|
||||
# 使用默认策略运行仿真
|
||||
|
||||
1. 输入层处理:输入数据为维度`(N, input_dim)`的矩阵(其中`N`为当前场景车辆数量,`input_dim`为单车辆固定观测维度),初始化`Bert`类时需设置`input_dim`,确保输入维度匹配;
|
||||
2. 嵌入层与位置编码:通过`projection`线性投影层将单车辆观测维度映射至预设的嵌入维度(`embed_dim`),随后叠加可学习的位置编码(`pos_embed`),以捕捉观测序列的时序与空间关联信息;
|
||||
3. Transformer 特征提取:嵌入后的特征向量输入至多层`Transformer`网络(层数由`num_layers`参数控制),完成高阶特征交互与抽象;
|
||||
4. 分类头设计:提供两种特征聚合与分类方案:若开启`CLS`模式,在嵌入层前拼接 1 个可学习的`CLS`标记,最终取`CLS`标记对应的特征向量输入全连接层完成分类;若关闭`CLS`模式,则对`Transformer`输出的所有车辆特征向量进行序列维度均值池化,再将池化后的全局特征输入全连接层。分类器支持可选的`Tanh`激活函数,以适配不同场景下的输出分布需求。
|
||||
python run_multiagent_env.py --mode simulation --episodes 1
|
||||
|
||||
# 无渲染运行(加速训练)
|
||||
|
||||
### 1.4 动作执行
|
||||
在当前环境测试阶段,暂沿用腾达的动作执行框架:为每辆可控车辆分配独立的`policy`模型,将单车辆观测数据输入对应`policy`得到动作指令后,传入`env.step()`完成仿真;同时在`before_step`阶段调用`_set_action()`函数,将动作指令绑定至车辆实例,最终由 MetaDrive 仿真系统完成物理动力学计算与场景渲染。
|
||||
python run_multiagent_env.py --mode simulation --episodes 5 --no_render
|
||||
|
||||
后续优化方向为构建 "参数共享式统一模型框架",具体设计如下:所有车辆共用 1 个`policy`模型,通过参数共享机制实现模型的全局统一维护。该框架具备三重优势:一是避免多车辆独立模型带来的训练偏差(如不同模型训练程度不一致);二是解决车辆数量动态变化时的模型管理问题(车辆新增无需额外初始化模型,车辆减少不丢失模型训练信息);三是支持动作指令的并行计算,可显著提升每一步决策的迭代效率,适配大规模多智能体交互场景的训练需求。
|
||||
```
|
||||
|
||||
## 使用指南
|
||||
|
||||
### 命令行参数
|
||||
|
||||
| 参数 | 类型 | 默认值 | 说明 |
|
||||
|------|------|--------|------|
|
||||
| `--mode` | str | simulation | 运行模式:`replay` 或 `simulation` |
|
||||
| `--data_dir` | str | 当前目录 | Waymo 数据目录路径 |
|
||||
| `--episodes` | int | 1 | 运行回合数 |
|
||||
| `--horizon` | int | 300 | 每回合最大步数 |
|
||||
| `--no_render` | flag | False | 禁用渲染(加速运行) |
|
||||
| `--debug` | flag | False | 启用调试模式 |
|
||||
| `--scenario_id` | int | None | 指定场景 ID |
|
||||
| `--use_scenario_duration` | flag | False | 使用场景原始时长 |
|
||||
| `--no_vehicles` | flag | False | 禁止生成车辆 |
|
||||
| `--no_pedestrians` | flag | False | 禁止生成行人 |
|
||||
| `--no_cyclists` | flag | False | 禁止生成自行车 |
|
||||
|
||||
### 回放模式详解
|
||||
|
||||
回放模式严格按照专家轨迹回放车辆状态,不涉及物理引擎控制。主要用途:
|
||||
- 数据集可视化
|
||||
- 验证数据质量
|
||||
- 生成演示视频
|
||||
|
||||
```bash
|
||||
# 完整参数示例
|
||||
python run_multiagent_env.py \
|
||||
--mode replay \
|
||||
--episodes 1 \
|
||||
--use_scenario_duration \
|
||||
--debug
|
||||
|
||||
# 仅回放车辆,禁止行人和自行车
|
||||
python run_multiagent_env.py \
|
||||
--mode replay \
|
||||
--use_scenario_duration \
|
||||
--no_pedestrians \
|
||||
--no_cyclists
|
||||
```
|
||||
|
||||
**重要提示**:回放模式建议始终启用 `--use_scenario_duration`,否则会出现场景播放完后继续运行的问题。
|
||||
|
||||
### 仿真模式详解
|
||||
|
||||
仿真模式使用自定义策略控制车辆,适合算法开发和测试:
|
||||
|
||||
```bash
|
||||
# 基础仿真
|
||||
python run_multiagent_env.py --mode simulation
|
||||
|
||||
# 长时间训练(无渲染)
|
||||
python run_multiagent_env.py \
|
||||
--mode simulation \
|
||||
--episodes 100 \
|
||||
--horizon 500 \
|
||||
--no_render
|
||||
|
||||
# 仅车辆仿真(用于专注车车交互场景)
|
||||
python run_multiagent_env.py \
|
||||
--mode simulation \
|
||||
--no_pedestrians \
|
||||
--no_cyclists
|
||||
```
|
||||
|
||||
### 自定义策略
|
||||
|
||||
修改 `simple_idm_policy.py` 或创建新的策略类:
|
||||
|
||||
```python
|
||||
class CustomPolicy:
|
||||
def __init__(self, **kwargs):
|
||||
# 初始化策略参数
|
||||
pass
|
||||
|
||||
def act(self, observation=None):
|
||||
# 返回动作 [steering, acceleration]
|
||||
# steering: [-1, 1]
|
||||
# acceleration: [-1, 1]
|
||||
return [0.0, 0.5]
|
||||
```
|
||||
|
||||
在 `run_multiagent_env.py` 中使用:
|
||||
```
|
||||
|
||||
from custom_policy import CustomPolicy
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={...},
|
||||
agent2policy=CustomPolicy()
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
|
||||
MAGAIL4AutoDrive/Env/
|
||||
├── run_multiagent_env.py \# 主运行脚本
|
||||
├── scenario_env.py \# 多智能体场景环境
|
||||
├── replay_policy.py \# 专家轨迹回放策略
|
||||
├── simple_idm_policy.py \# IDM 策略实现
|
||||
├── utils.py \# 工具函数
|
||||
├── ENHANCED_USAGE_GUIDE.md \# 详细使用指南
|
||||
├── README.md \# 本文档
|
||||
└── exp_converted/ \# Waymo 数据集(需自行准备)
|
||||
├── scenario_0/
|
||||
├── scenario_1/
|
||||
└── ...
|
||||
|
||||
```
|
||||
|
||||
### 核心文件说明
|
||||
|
||||
**run_multiagent_env.py**
|
||||
- 主入口脚本
|
||||
- 处理命令行参数
|
||||
- 管理回放和仿真两种模式的运行逻辑
|
||||
|
||||
**scenario_env.py**
|
||||
- 自定义多智能体环境类
|
||||
- 车辆生成与管理
|
||||
- 车道过滤逻辑
|
||||
- 观测空间定义
|
||||
|
||||
**replay_policy.py**
|
||||
- 专家轨迹回放策略
|
||||
- 逐帧状态查询
|
||||
- 轨迹完成判断
|
||||
|
||||
**simple_idm_policy.py**
|
||||
- 简单的恒速策略示例
|
||||
- 可作为自定义策略的模板
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 环境配置参数
|
||||
|
||||
在 `scenario_env.py` 的 `default_config()` 中可修改:
|
||||
|
||||
```python
|
||||
config.update(dict(
|
||||
data_directory=None, # 数据目录
|
||||
num_controlled_agents=3, # 可控车辆数量(仅仿真模式)
|
||||
horizon=1000, # 最大步数
|
||||
filter_offroad_vehicles=True, # 是否过滤车道外车辆
|
||||
lane_tolerance=3.0, # 车道容差(米)
|
||||
replay_mode=False, # 是否为回放模式
|
||||
specific_scenario_id=None, # 指定场景 ID
|
||||
use_scenario_duration=False, # 使用场景原始时长
|
||||
# 对象类型过滤选项
|
||||
spawn_vehicles=True, # 是否生成车辆
|
||||
spawn_pedestrians=True, # 是否生成行人
|
||||
spawn_cyclists=True, # 是否生成自行车
|
||||
))
|
||||
```
|
||||
|
||||
### 传感器配置
|
||||
|
||||
默认启用的传感器(可在环境初始化时修改):
|
||||
- **LiDAR**:80 条激光,探测距离 30 米
|
||||
- **侧向检测器**:10 条激光,探测距离 8 米
|
||||
- **车道线检测器**:10 条激光,探测距离 3 米
|
||||
- **主相机**:分辨率 1200x900
|
||||
- **仪表盘**:车辆状态信息
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q1: 回放模式为什么超出数据集的最大帧数还在继续?
|
||||
**A**: 需要添加 `--use_scenario_duration` 参数。修复版本已在 `scenario_env.py` 中添加了自动检测机制。
|
||||
|
||||
### Q2: 如何切换不同的场景?
|
||||
**A**:
|
||||
- 方法一:使用 `--scenario_id` 指定场景
|
||||
- 方法二:使用 `--episodes N` 自动遍历 N 个场景
|
||||
|
||||
### Q3: 为什么有些车辆没有出现?
|
||||
**A**: 启用了车道过滤功能(`filter_offroad_vehicles=True`),不在车道上的车辆会被过滤。可以通过设置 `lane_tolerance` 调整容差或关闭此功能。
|
||||
|
||||
### Q4: 如何提高运行速度?
|
||||
**A**:
|
||||
- 使用 `--no_render` 禁用可视化
|
||||
- 减少 `num_controlled_agents` 数量
|
||||
- 使用 GPU 加速
|
||||
|
||||
### Q5: 如何控制场景中的对象类型?
|
||||
**A**: 使用对象过滤参数:
|
||||
```bash
|
||||
# 仅车辆,无行人和自行车
|
||||
python run_multiagent_env.py --mode replay --no_pedestrians --no_cyclists
|
||||
|
||||
# 仅行人和自行车,无车辆(特殊场景)
|
||||
python run_multiagent_env.py --mode replay --no_vehicles
|
||||
|
||||
# 调试模式查看过滤统计
|
||||
python run_multiagent_env.py --mode replay --debug --no_pedestrians
|
||||
```
|
||||
|
||||
### Q6: 为什么有些车辆生成在空中?
|
||||
**A**: 已在 v1.2.0 中修复。现在所有车辆位置都只使用 2D 坐标(x, y),z 坐标设为 0,让 MetaDrive 自动处理高度,确保车辆贴在地面上。
|
||||
|
||||
### Q7: 如何导出观测数据?
|
||||
**A**: 在 `run_multiagent_env.py` 中添加数据保存逻辑:
|
||||
```python
|
||||
import pickle
|
||||
|
||||
obs_data = []
|
||||
while True:
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
obs_data.append(obs)
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
with open('observations.pkl', 'wb') as f:
|
||||
pickle.dump(obs_data, f)
|
||||
```
|
||||
|
||||
## 更新日志
|
||||
|
||||
### v1.2.0 (2025-10-26)
|
||||
- ✅ 修复车辆生成高度问题(车辆悬空)
|
||||
- ✅ 添加对象类型过滤功能(车辆/行人/自行车)
|
||||
- ✅ 新增命令行参数:`--no_vehicles`、`--no_pedestrians`、`--no_cyclists`
|
||||
- ✅ 改进调试信息输出,显示各类型对象统计
|
||||
- ✅ 优化位置处理逻辑,只使用 2D 坐标避免高度问题
|
||||
|
||||
### v1.1.0 (2025-10-26)
|
||||
- ✅ 修复回放模式超出场景时长问题
|
||||
- ✅ 添加场景自动切换功能
|
||||
- ✅ 改进 `replay_policy.py`,新增 `is_finished()` 方法
|
||||
- ✅ 优化 `scenario_env.py` 的 done 判断逻辑
|
||||
- ✅ 修复多回合运行时的对象清理问题
|
||||
|
||||
### v1.0.0 (初始版本)
|
||||
- 基础多智能体环境实现
|
||||
- 回放和仿真两种模式
|
||||
- 车道过滤功能
|
||||
- Waymo 数据集支持
|
||||
|
||||
## 贡献指南
|
||||
|
||||
欢迎提交 Issue 和 Pull Request!
|
||||
|
||||
### 提交 Issue
|
||||
- 请详细描述问题和复现步骤
|
||||
- 附上运行日志和错误信息
|
||||
- 说明运行环境(OS、Python 版本等)
|
||||
|
||||
### 提交 PR
|
||||
- Fork 本项目
|
||||
- 创建特性分支:`git checkout -b feature/your-feature`
|
||||
- 提交更改:`git commit -m 'Add some feature'`
|
||||
- 推送分支:`git push origin feature/your-feature`
|
||||
- 提交 Pull Request
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目基于 MIT 许可证开源。
|
||||
|
||||
## 致谢
|
||||
|
||||
- [MetaDrive](https://github.com/metadriverse/metadrive) - 优秀的驾驶仿真平台
|
||||
- [Waymo Open Dataset](https://waymo.com/open/) - 高质量的自动驾驶数据集
|
||||
|
||||
## 联系方式
|
||||
|
||||
如有问题或建议,请通过以下方式联系:
|
||||
- GitHub Issues: [项目 Issues 页面]
|
||||
- Email: huangfukk@xxx.com
|
||||
|
||||
---
|
||||
|
||||
## 问题解决总结
|
||||
|
||||
### ✅ 已完成的优化
|
||||
|
||||
1. **车辆生成位置偏差** - 实现车道区域检测和自动过滤,配置参数:`filter_offroad_vehicles`, `lane_tolerance`, `max_controlled_vehicles`
|
||||
2. **红绿灯信息采集** - 采用双重检测策略(导航模块+遍历兜底),处理None状态和车道分段问题
|
||||
3. **性能优化** - 提供多个优化版本(fast/parallel),FPS从15提升到30-60,支持多进程充分利用CPU
|
||||
|
||||
### 🧪 测试方法
|
||||
```bash
|
||||
# 测试车道过滤和红绿灯检测
|
||||
python Env/test_lane_filter.py
|
||||
|
||||
# 运行标准版本(带过滤)
|
||||
python Env/run_multiagent_env.py
|
||||
|
||||
# 运行高性能版本
|
||||
python Env/run_multiagent_env_fast.py
|
||||
```
|
||||
|
||||
### 📝 配置示例
|
||||
```python
|
||||
config = {
|
||||
# 车道过滤
|
||||
"filter_offroad_vehicles": True, # 启用车道过滤
|
||||
"lane_tolerance": 3.0, # 容差范围(米)
|
||||
"max_controlled_vehicles": 10, # 最大车辆数
|
||||
# 其他配置...
|
||||
}
|
||||
```
|
||||
**Happy Driving! 🚗💨**
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
"""
|
||||
分析Waymo专家数据的结构
|
||||
|
||||
运行: python analyze_expert_data.py
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
def analyze_pkl_file(filepath):
|
||||
"""分析单个pkl文件的结构"""
|
||||
print(f"\n{'='*80}")
|
||||
print(f"分析文件: {os.path.basename(filepath)}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
with open(filepath, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(f"\n1. 数据类型: {type(data)}")
|
||||
print(f" 文件大小: {os.path.getsize(filepath) / 1024:.1f} KB")
|
||||
|
||||
if isinstance(data, dict):
|
||||
print(f"\n2. 字典结构:")
|
||||
print(f" 键数量: {len(data)}")
|
||||
print(f" 键列表: {list(data.keys())[:10]}")
|
||||
|
||||
# 详细分析每个键
|
||||
for i, (key, value) in enumerate(list(data.items())[:5]):
|
||||
print(f"\n 键 [{i+1}]: '{key}'")
|
||||
print(f" 类型: {type(value)}")
|
||||
|
||||
if isinstance(value, dict):
|
||||
print(f" 子键: {list(value.keys())}")
|
||||
|
||||
# 分析子字典
|
||||
for subkey, subvalue in list(value.items())[:3]:
|
||||
print(f" - {subkey}: {type(subvalue)}", end="")
|
||||
if isinstance(subvalue, np.ndarray):
|
||||
print(f" shape={subvalue.shape}, dtype={subvalue.dtype}")
|
||||
elif isinstance(subvalue, dict):
|
||||
print(f" keys={list(subvalue.keys())[:5]}")
|
||||
elif isinstance(subvalue, (list, tuple)):
|
||||
print(f" len={len(subvalue)}")
|
||||
else:
|
||||
print(f" = {subvalue}")
|
||||
|
||||
elif isinstance(value, np.ndarray):
|
||||
print(f" Shape: {value.shape}, dtype: {value.dtype}")
|
||||
print(f" 示例: {value.flatten()[:5]}")
|
||||
elif isinstance(value, (list, tuple)):
|
||||
print(f" 长度: {len(value)}")
|
||||
if len(value) > 0:
|
||||
print(f" 第一个元素: {type(value[0])}")
|
||||
|
||||
elif isinstance(data, (list, tuple)):
|
||||
print(f"\n2. 列表/元组结构:")
|
||||
print(f" 长度: {len(data)}")
|
||||
if len(data) > 0:
|
||||
print(f" 第一个元素类型: {type(data[0])}")
|
||||
if isinstance(data[0], dict):
|
||||
print(f" 第一个元素的键: {list(data[0].keys())}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def find_trajectory_data(data, max_depth=3, current_depth=0, path=""):
|
||||
"""递归查找可能包含轨迹数据的字段"""
|
||||
if current_depth > max_depth:
|
||||
return
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
new_path = f"{path}.{key}" if path else key
|
||||
|
||||
# 查找可能是轨迹的数据(通常是时间序列数组)
|
||||
if isinstance(value, np.ndarray):
|
||||
if len(value.shape) >= 2 and value.shape[0] > 10: # 可能是时间序列
|
||||
print(f" 🎯 可能的轨迹数据: {new_path}")
|
||||
print(f" Shape: {value.shape}, dtype: {value.dtype}")
|
||||
print(f" 前3个值: {value[:3]}")
|
||||
|
||||
# 继续递归
|
||||
elif isinstance(value, dict):
|
||||
find_trajectory_data(value, max_depth, current_depth + 1, new_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 分析第一个数据文件
|
||||
data_dir = "Env/exp_converted/exp_converted_0"
|
||||
pkl_files = [f for f in os.listdir(data_dir) if f.startswith('sd_waymo')]
|
||||
|
||||
if pkl_files:
|
||||
filepath = os.path.join(data_dir, pkl_files[0])
|
||||
data = analyze_pkl_file(filepath)
|
||||
|
||||
print(f"\n\n{'='*80}")
|
||||
print("查找可能的轨迹数据...")
|
||||
print(f"{'='*80}")
|
||||
find_trajectory_data(data)
|
||||
else:
|
||||
print("未找到数据文件!")
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user