Files
MAGAIL4AutoDrive/专家数据说明.md

8.7 KiB
Raw Permalink Blame History

专家数据说明

📍 数据位置

您的专家数据已经存在于项目中了!

MAGAIL4AutoDrive/
└── Env/
    └── exp_converted/
        ├── dataset_mapping.pkl      # 数据集映射文件
        ├── dataset_summary.pkl      # 数据集摘要468KB
        └── exp_converted_0/         # 场景数据目录
            ├── dataset_mapping.pkl
            ├── dataset_summary.pkl
            └── sd_waymo_v1.2_*.pkl  # 75个Waymo场景文件共87MB

📊 数据概况

  • 数据来源: Waymo Open Motion Dataset
  • 场景数量: 75个场景
  • 数据格式: MetaDrive转换后的pkl格式
  • 总大小: 约87MB
  • 单个场景: 300KB - 3.4MB不等

🔍 数据内容

这些是Waymo自动驾驶数据集,包含:

场景信息

  • 道路网络(车道、路口)
  • 交通信号灯
  • 静态障碍物

车辆轨迹(专家数据)

  • 每辆车的完整行驶轨迹
  • 位置序列x, y, z
  • 速度、加速度
  • 朝向角
  • 时间戳

用途

这些轨迹数据可以作为GAIL算法的专家演示

  1. 判别器学习区分策略行为和专家行为
  2. 策略网络学习模仿专家的驾驶方式

⚠️ 当前问题

为什么显示"未找到专家数据"

train_magail.pyExpertBuffer 类中,_extract_trajectories() 方法是空的:

def _extract_trajectories(self, scenario_data):
    """
    从场景数据中提取轨迹
    
    TODO: 需要根据实际数据格式调整
    这里提供一个基本框架
    """
    # 由于实际数据格式未知,这里只是占位
    # 实际使用时需要根据Waymo数据格式提取state和next_state
    pass  # ← 这里什么都没做!

所以虽然找到了75个文件但没有提取出任何轨迹数据。

🔧 如何解析专家数据

方法1: 分析数据结构(推荐)

运行分析脚本查看数据格式:

# 在conda环境中运行
conda activate metadrive
python analyze_expert_data.py

这会告诉您:

  • pkl文件的内部结构
  • 哪些字段包含轨迹数据
  • 数据的shape和类型

方法2: 查看MetaDrive文档

MetaDrive的Waymo数据格式文档

通常包含的字段:

{
    'metadata': {...},  # 场景元数据
    'tracks': {         # 车辆轨迹
        'vehicle_id': {
            'state': {
                'position': np.array([...]),  # (T, 3) 位置序列
                'velocity': np.array([...]),  # (T, 2) 速度序列
                'heading': np.array([...]),   # (T,) 朝向序列
                'valid': np.array([...]),     # (T,) 有效标记
            },
            'type': 'VEHICLE',
            ...
        },
        ...
    },
    'map_features': {...},  # 地图特征
    ...
}

方法3: 参考环境代码

scenario_env.pyreset() 方法已经展示了如何读取这些数据:

# scenario_env.py 第92-108行
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
    if track["type"] == MetaDriveType.VEHICLE:
        valid = track['state']['valid']
        first_show = np.argmax(valid) if valid.any() else -1
        
        # 提取位置
        position = track['state']['position'][first_show]
        
        # 提取朝向
        heading = track['state']['heading'][first_show]

💡 实现专家数据加载

完整实现示例

修改 train_magail.py 中的 _extract_trajectories() 方法:

def _extract_trajectories(self, scenario_data):
    """
    从MetaDrive场景数据中提取车辆轨迹
    
    Args:
        scenario_data: MetaDrive格式的场景数据字典
    """
    try:
        # 方法1: 如果是字典且有'tracks'键
        if isinstance(scenario_data, dict) and 'tracks' in scenario_data:
            for vehicle_id, track_data in scenario_data['tracks'].items():
                if track_data.get('type') == 'VEHICLE':
                    states = track_data.get('state', {})
                    
                    # 获取有效帧
                    valid = states.get('valid', [])
                    if not hasattr(valid, 'any') or not valid.any():
                        continue
                    
                    # 提取位置、速度、朝向等
                    positions = states.get('position', [])
                    velocities = states.get('velocity', [])
                    headings = states.get('heading', [])
                    
                    # 构建state序列
                    for t in range(len(positions) - 1):
                        if valid[t] and valid[t+1]:
                            # 当前状态
                            state = np.concatenate([
                                positions[t][:2],      # x, y
                                velocities[t],         # vx, vy
                                [headings[t]],         # heading
                                # ... 其他观测维度激光雷达等暂时用0填充
                            ])
                            
                            # 下一状态
                            next_state = np.concatenate([
                                positions[t+1][:2],
                                velocities[t+1],
                                [headings[t+1]],
                            ])
                            
                            # 补齐到108维匹配实际观测维度
                            state = np.pad(state, (0, 108 - len(state)))
                            next_state = np.pad(next_state, (0, 108 - len(next_state)))
                            
                            self.states.append(state)
                            self.next_states.append(next_state)
                            
                            if len(self.states) >= self.max_samples:
                                return
        
        # 方法2: 其他可能的格式
        # ...
        
    except Exception as e:
        print(f"  ⚠️ 提取轨迹失败: {e}")
        import traceback
        traceback.print_exc()

简化版本(快速测试)

如果只是想快速测试,可以先用假数据:

def _extract_trajectories(self, scenario_data):
    """临时实现:使用随机数据"""
    print(f"  ⚠️ 使用随机数据代替真实轨迹(临时)")
    
    # 生成1000条随机轨迹用于测试
    for _ in range(min(1000, self.max_samples)):
        state = np.random.randn(108)
        next_state = np.random.randn(108)
        self.states.append(state)
        self.next_states.append(next_state)

🎯 下一步行动计划

立即行动(按顺序)

步骤1: 分析数据格式

conda activate metadrive
python analyze_expert_data.py

步骤2: 根据输出实现提取逻辑

查看步骤1的输出了解数据结构后修改 train_magail.py 中的 _extract_trajectories()

步骤3: 测试加载

python train_magail.py --episodes 1 --horizon 50

检查是否显示:

✅ 加载完成: XXXX 条专家轨迹  # 不再是0

步骤4: 验证训练

运行完整训练,观察判别器是否正常工作

备选方案

如果数据格式太复杂,暂时无法解析:

  1. 使用模拟数据: 从环境中收集轨迹作为"伪专家"
  2. 简化问题: 先用PPO训练不用GAIL
  3. 寻求帮助: 查看MetaDrive的示例代码

📚 参考资源

MetaDrive相关

项目内参考

  • Env/scenario_env.py - 第92-108行展示了如何读取轨迹数据
  • analyze_expert_data.py - 数据结构分析脚本
  • train_magail.py - ExpertBuffer类需要修改

常见问题

Q1: 数据从哪里来的?

A: 这是Waymo Open Motion Dataset经过MetaDrive转换的格式

Q2: 为什么环境能用但训练不能?

A: 环境直接读取pkl文件用于场景初始化但训练需要提取轨迹序列

Q3: 必须使用专家数据吗?

A: 不一定。您可以:

  • 只用PPO去掉GAIL部分
  • 用环境中收集的好轨迹作为"专家"
  • 用模拟数据测试框架

Q4: 如何验证提取的数据是否正确?

A: 检查:

  • shape是否正确N, 108
  • 数值是否合理不全是0或NaN
  • 可视化几条轨迹

🎉 总结

您已经有完整的专家数据了!只是需要:

  1. 数据存在 - 75个Waymo场景文件
  2. 解析代码 - 需要实现提取逻辑
  3. 验证加载 - 确保数据正确

完成这些后MAGAIL训练就可以真正利用专家数据进行模仿学习了


下一步: 运行 python analyze_expert_data.py 查看数据结构