修改了算法代码,并建立了一个简单的训练脚本.修改bert处理二维输入,移除PPO的permute参数
This commit is contained in:
298
专家数据说明.md
Normal file
298
专家数据说明.md
Normal file
@@ -0,0 +1,298 @@
|
||||
# 专家数据说明
|
||||
|
||||
## 📍 数据位置
|
||||
|
||||
您的专家数据已经存在于项目中了!
|
||||
|
||||
```
|
||||
MAGAIL4AutoDrive/
|
||||
└── Env/
|
||||
└── exp_converted/
|
||||
├── dataset_mapping.pkl # 数据集映射文件
|
||||
├── dataset_summary.pkl # 数据集摘要(468KB)
|
||||
└── exp_converted_0/ # 场景数据目录
|
||||
├── dataset_mapping.pkl
|
||||
├── dataset_summary.pkl
|
||||
└── sd_waymo_v1.2_*.pkl # 75个Waymo场景文件(共87MB)
|
||||
```
|
||||
|
||||
## 📊 数据概况
|
||||
|
||||
- **数据来源**: Waymo Open Motion Dataset
|
||||
- **场景数量**: 75个场景
|
||||
- **数据格式**: MetaDrive转换后的pkl格式
|
||||
- **总大小**: 约87MB
|
||||
- **单个场景**: 300KB - 3.4MB不等
|
||||
|
||||
## 🔍 数据内容
|
||||
|
||||
这些是**Waymo自动驾驶数据集**,包含:
|
||||
|
||||
### 场景信息
|
||||
- 道路网络(车道、路口)
|
||||
- 交通信号灯
|
||||
- 静态障碍物
|
||||
|
||||
### 车辆轨迹(专家数据)
|
||||
- 每辆车的完整行驶轨迹
|
||||
- 位置序列(x, y, z)
|
||||
- 速度、加速度
|
||||
- 朝向角
|
||||
- 时间戳
|
||||
|
||||
### 用途
|
||||
这些轨迹数据可以作为**GAIL算法的专家演示**:
|
||||
1. 判别器学习区分策略行为和专家行为
|
||||
2. 策略网络学习模仿专家的驾驶方式
|
||||
|
||||
## ⚠️ 当前问题
|
||||
|
||||
### 为什么显示"未找到专家数据"?
|
||||
|
||||
在 `train_magail.py` 的 `ExpertBuffer` 类中,`_extract_trajectories()` 方法是空的:
|
||||
|
||||
```python
|
||||
def _extract_trajectories(self, scenario_data):
|
||||
"""
|
||||
从场景数据中提取轨迹
|
||||
|
||||
TODO: 需要根据实际数据格式调整
|
||||
这里提供一个基本框架
|
||||
"""
|
||||
# 由于实际数据格式未知,这里只是占位
|
||||
# 实际使用时需要根据Waymo数据格式提取state和next_state
|
||||
pass # ← 这里什么都没做!
|
||||
```
|
||||
|
||||
所以虽然找到了75个文件,但没有提取出任何轨迹数据。
|
||||
|
||||
## 🔧 如何解析专家数据
|
||||
|
||||
### 方法1: 分析数据结构(推荐)
|
||||
|
||||
运行分析脚本查看数据格式:
|
||||
|
||||
```bash
|
||||
# 在conda环境中运行
|
||||
conda activate metadrive
|
||||
python analyze_expert_data.py
|
||||
```
|
||||
|
||||
这会告诉您:
|
||||
- pkl文件的内部结构
|
||||
- 哪些字段包含轨迹数据
|
||||
- 数据的shape和类型
|
||||
|
||||
### 方法2: 查看MetaDrive文档
|
||||
|
||||
MetaDrive的Waymo数据格式文档:
|
||||
- GitHub: https://github.com/metadriverse/metadrive
|
||||
- 文档: https://metadrive-simulator.readthedocs.io/
|
||||
|
||||
通常包含的字段:
|
||||
```python
|
||||
{
|
||||
'metadata': {...}, # 场景元数据
|
||||
'tracks': { # 车辆轨迹
|
||||
'vehicle_id': {
|
||||
'state': {
|
||||
'position': np.array([...]), # (T, 3) 位置序列
|
||||
'velocity': np.array([...]), # (T, 2) 速度序列
|
||||
'heading': np.array([...]), # (T,) 朝向序列
|
||||
'valid': np.array([...]), # (T,) 有效标记
|
||||
},
|
||||
'type': 'VEHICLE',
|
||||
...
|
||||
},
|
||||
...
|
||||
},
|
||||
'map_features': {...}, # 地图特征
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
### 方法3: 参考环境代码
|
||||
|
||||
在 `scenario_env.py` 中,reset() 方法已经展示了如何读取这些数据:
|
||||
|
||||
```python
|
||||
# scenario_env.py 第92-108行
|
||||
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
||||
if track["type"] == MetaDriveType.VEHICLE:
|
||||
valid = track['state']['valid']
|
||||
first_show = np.argmax(valid) if valid.any() else -1
|
||||
|
||||
# 提取位置
|
||||
position = track['state']['position'][first_show]
|
||||
|
||||
# 提取朝向
|
||||
heading = track['state']['heading'][first_show]
|
||||
```
|
||||
|
||||
## 💡 实现专家数据加载
|
||||
|
||||
### 完整实现示例
|
||||
|
||||
修改 `train_magail.py` 中的 `_extract_trajectories()` 方法:
|
||||
|
||||
```python
|
||||
def _extract_trajectories(self, scenario_data):
|
||||
"""
|
||||
从MetaDrive场景数据中提取车辆轨迹
|
||||
|
||||
Args:
|
||||
scenario_data: MetaDrive格式的场景数据字典
|
||||
"""
|
||||
try:
|
||||
# 方法1: 如果是字典且有'tracks'键
|
||||
if isinstance(scenario_data, dict) and 'tracks' in scenario_data:
|
||||
for vehicle_id, track_data in scenario_data['tracks'].items():
|
||||
if track_data.get('type') == 'VEHICLE':
|
||||
states = track_data.get('state', {})
|
||||
|
||||
# 获取有效帧
|
||||
valid = states.get('valid', [])
|
||||
if not hasattr(valid, 'any') or not valid.any():
|
||||
continue
|
||||
|
||||
# 提取位置、速度、朝向等
|
||||
positions = states.get('position', [])
|
||||
velocities = states.get('velocity', [])
|
||||
headings = states.get('heading', [])
|
||||
|
||||
# 构建state序列
|
||||
for t in range(len(positions) - 1):
|
||||
if valid[t] and valid[t+1]:
|
||||
# 当前状态
|
||||
state = np.concatenate([
|
||||
positions[t][:2], # x, y
|
||||
velocities[t], # vx, vy
|
||||
[headings[t]], # heading
|
||||
# ... 其他观测维度(激光雷达等暂时用0填充)
|
||||
])
|
||||
|
||||
# 下一状态
|
||||
next_state = np.concatenate([
|
||||
positions[t+1][:2],
|
||||
velocities[t+1],
|
||||
[headings[t+1]],
|
||||
])
|
||||
|
||||
# 补齐到108维(匹配实际观测维度)
|
||||
state = np.pad(state, (0, 108 - len(state)))
|
||||
next_state = np.pad(next_state, (0, 108 - len(next_state)))
|
||||
|
||||
self.states.append(state)
|
||||
self.next_states.append(next_state)
|
||||
|
||||
if len(self.states) >= self.max_samples:
|
||||
return
|
||||
|
||||
# 方法2: 其他可能的格式
|
||||
# ...
|
||||
|
||||
except Exception as e:
|
||||
print(f" ⚠️ 提取轨迹失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
```
|
||||
|
||||
### 简化版本(快速测试)
|
||||
|
||||
如果只是想快速测试,可以先用假数据:
|
||||
|
||||
```python
|
||||
def _extract_trajectories(self, scenario_data):
|
||||
"""临时实现:使用随机数据"""
|
||||
print(f" ⚠️ 使用随机数据代替真实轨迹(临时)")
|
||||
|
||||
# 生成1000条随机轨迹用于测试
|
||||
for _ in range(min(1000, self.max_samples)):
|
||||
state = np.random.randn(108)
|
||||
next_state = np.random.randn(108)
|
||||
self.states.append(state)
|
||||
self.next_states.append(next_state)
|
||||
```
|
||||
|
||||
## 🎯 下一步行动计划
|
||||
|
||||
### 立即行动(按顺序)
|
||||
|
||||
#### 步骤1: 分析数据格式
|
||||
```bash
|
||||
conda activate metadrive
|
||||
python analyze_expert_data.py
|
||||
```
|
||||
|
||||
#### 步骤2: 根据输出实现提取逻辑
|
||||
|
||||
查看步骤1的输出,了解数据结构后,修改 `train_magail.py` 中的 `_extract_trajectories()`
|
||||
|
||||
#### 步骤3: 测试加载
|
||||
|
||||
```bash
|
||||
python train_magail.py --episodes 1 --horizon 50
|
||||
```
|
||||
|
||||
检查是否显示:
|
||||
```
|
||||
✅ 加载完成: XXXX 条专家轨迹 # 不再是0
|
||||
```
|
||||
|
||||
#### 步骤4: 验证训练
|
||||
|
||||
运行完整训练,观察判别器是否正常工作
|
||||
|
||||
### 备选方案
|
||||
|
||||
如果数据格式太复杂,暂时无法解析:
|
||||
|
||||
1. **使用模拟数据**: 从环境中收集轨迹作为"伪专家"
|
||||
2. **简化问题**: 先用PPO训练(不用GAIL)
|
||||
3. **寻求帮助**: 查看MetaDrive的示例代码
|
||||
|
||||
## 📚 参考资源
|
||||
|
||||
### MetaDrive相关
|
||||
- MetaDrive GitHub: https://github.com/metadriverse/metadrive
|
||||
- Waymo数据集: https://waymo.com/open/
|
||||
|
||||
### 项目内参考
|
||||
- `Env/scenario_env.py` - 第92-108行展示了如何读取轨迹数据
|
||||
- `analyze_expert_data.py` - 数据结构分析脚本
|
||||
- `train_magail.py` - ExpertBuffer类需要修改
|
||||
|
||||
## ❓ 常见问题
|
||||
|
||||
### Q1: 数据从哪里来的?
|
||||
A: 这是Waymo Open Motion Dataset,经过MetaDrive转换的格式
|
||||
|
||||
### Q2: 为什么环境能用但训练不能?
|
||||
A: 环境直接读取pkl文件用于场景初始化,但训练需要提取轨迹序列
|
||||
|
||||
### Q3: 必须使用专家数据吗?
|
||||
A: 不一定。您可以:
|
||||
- 只用PPO(去掉GAIL部分)
|
||||
- 用环境中收集的好轨迹作为"专家"
|
||||
- 用模拟数据测试框架
|
||||
|
||||
### Q4: 如何验证提取的数据是否正确?
|
||||
A: 检查:
|
||||
- shape是否正确(N, 108)
|
||||
- 数值是否合理(不全是0或NaN)
|
||||
- 可视化几条轨迹
|
||||
|
||||
## 🎉 总结
|
||||
|
||||
您已经**有完整的专家数据**了!只是需要:
|
||||
|
||||
1. ✅ 数据存在 - 75个Waymo场景文件
|
||||
2. ⏳ 解析代码 - 需要实现提取逻辑
|
||||
3. ⏳ 验证加载 - 确保数据正确
|
||||
|
||||
完成这些后,MAGAIL训练就可以真正利用专家数据进行模仿学习了!
|
||||
|
||||
---
|
||||
|
||||
**下一步**: 运行 `python analyze_expert_data.py` 查看数据结构
|
||||
|
||||
Reference in New Issue
Block a user