diff --git a/Algorithm/__init__.py b/Algorithm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Env/__pycache__/expert_replay_policy.cpython-310.pyc b/Env/__pycache__/expert_replay_policy.cpython-310.pyc new file mode 100644 index 0000000..2b030c5 Binary files /dev/null and b/Env/__pycache__/expert_replay_policy.cpython-310.pyc differ diff --git a/Env/__pycache__/scenario_env.cpython-310.pyc b/Env/__pycache__/scenario_env.cpython-310.pyc index 0ceae8b..4b9e85d 100644 Binary files a/Env/__pycache__/scenario_env.cpython-310.pyc and b/Env/__pycache__/scenario_env.cpython-310.pyc differ diff --git a/Env/__pycache__/scenario_env.cpython-313.pyc b/Env/__pycache__/scenario_env.cpython-313.pyc index f89c9f9..824669e 100644 Binary files a/Env/__pycache__/scenario_env.cpython-313.pyc and b/Env/__pycache__/scenario_env.cpython-313.pyc differ diff --git a/Env/scenario_env.py b/Env/scenario_env.py index dd1e701..b801ff6 100644 --- a/Env/scenario_env.py +++ b/Env/scenario_env.py @@ -320,6 +320,7 @@ class MultiAgentScenarioEnv(ScenarioEnv): # 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成 _obj_to_clean_this_frame = [] # 需要清理的对象ID列表 self.car_birth_info_list = [] # 车辆生成信息列表 + self.expert_trajectories = {} # 专家数据轨迹字典 for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items(): # 跳过自车(SDC - Self Driving Car) @@ -334,6 +335,53 @@ class MultiAgentScenarioEnv(ScenarioEnv): first_show = np.argmax(valid) if valid.any() else -1 last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1 + if first_show == -1 or last_show == -1: + continue + object_id = track["metadata"]["object_id"] + + # 提取完整轨迹数据(只使用确认存在的字段) + trajectory_data = { + "object_id": object_id, + "scenario_id": scenario_id, + "valid_mask": valid[first_show:last_show+1].copy(), # 有效性掩码 + "positions": track["state"]["position"][first_show:last_show+1].copy(), # (T, 3) + "headings": track["state"]["heading"][first_show:last_show+1].copy(), # (T,) + "velocities": track["state"]["velocity"][first_show:last_show+1].copy(), # (T, 2) + "timesteps": np.arange(first_show, last_show+1), # 时间戳 + "start_timestep": first_show, + "end_timestep": last_show, + "length": last_show - first_show + 1 + } + + # 可选:如果数据中有车辆尺寸信息,则添加 + # 方法1: 尝试从state中获取 + if "length" in track["state"]: + trajectory_data["vehicle_length"] = track["state"]["length"][first_show] + if "width" in track["state"]: + trajectory_data["vehicle_width"] = track["state"]["width"][first_show] + if "height" in track["state"]: + trajectory_data["vehicle_height"] = track["state"]["height"][first_show] + + # 方法2: 尝试从metadata中获取 + if "vehicle_length" not in trajectory_data and "length" in track.get("metadata", {}): + trajectory_data["vehicle_length"] = track["metadata"]["length"] + if "vehicle_width" not in trajectory_data and "width" in track.get("metadata", {}): + trajectory_data["vehicle_width"] = track["metadata"]["width"] + if "vehicle_height" not in trajectory_data and "height" in track.get("metadata", {}): + trajectory_data["vehicle_height"] = track["metadata"]["height"] + + # 方法3: 使用默认值(如果以上都没有) + if "vehicle_length" not in trajectory_data: + trajectory_data["vehicle_length"] = 4.5 # MetaDrive默认车长 + if "vehicle_width" not in trajectory_data: + trajectory_data["vehicle_width"] = 2.0 # MetaDrive默认车宽 + if "vehicle_height" not in trajectory_data: + trajectory_data["vehicle_height"] = 1.5 # MetaDrive默认车高 + + + # 存储到专家轨迹字典 + self.expert_trajectories[object_id] = trajectory_data + # 提取车辆关键信息 car_info = { 'id': track['metadata']['object_id'], diff --git a/README.md b/README.md index 9bb5fff..777e683 100644 --- a/README.md +++ b/README.md @@ -1,85 +1,275 @@ # MAGAIL4AutoDrive -### 1.1 环境搭建 -环境核心代码封装于`Env`文件夹,通过运行`run_multiagent_env.py`即可启动多智能体交互环境,该脚本的核心功能为读取各智能体(车辆)的动作指令,并将其传入`env.step()`方法中完成仿真执行。 -**性能优化版本:** 针对原始版本FPS低(15帧)和CPU利用率不足的问题,已提供多个优化版本: -- `run_multiagent_env_fast.py` - 激光雷达优化版(30-60 FPS,2-4倍提升)⭐推荐 -- `run_multiagent_env_parallel.py` - 多进程并行版(300-600 steps/s总吞吐量,充分利用多核CPU)⭐⭐推荐 -- 详见 `Env/QUICK_START.md` 快速使用指南 +> 基于多智能体生成对抗模仿学习(MAGAIL)的自动驾驶训练系统 | MetaDrive + Waymo Open Motion Dataset -当前已初步实现`Env.senario_env.MultiAgentScenarioEnv.reset()`车辆生成函数,具体逻辑如下:首先读取专家数据集中各车辆的初始位姿信息;随后对原始数据进行清洗,剔除车辆 Agent 实例信息,记录核心参数(车辆 ID、初始生成位置、朝向角、生成时间戳、目标终点坐标);最后调用`_spawn_controlled_agents()`函数,依据清洗后的参数在指定时间、指定位置生成搭载自动驾驶算法的可控车辆。 +[![MetaDrive](https://img.shields [![Python](https://img.shields.io/io/badge/Dataset目实现了适配多智能体场景的GAIL算法训练系统,核心创新在于**改进判别器架构支持动态车辆数量**,利用Transformer处理1-100+辆车的交互场景。 -**✅ 已解决:车辆生成位置偏差问题** -- **问题描述**:部分车辆生成于草坪、停车场等非车道区域,原因是专家数据记录误差或停车场特殊标注 -- **解决方案**:实现了`_is_position_on_lane()`车道区域检测机制和`_filter_valid_spawn_positions()`过滤函数 - - 检测逻辑:通过`point_on_lane()`判断位置是否在车道上,支持容差参数(默认3米)处理边界情况 - - 双重检测:优先使用精确检测,失败时使用容差范围检测,确保车道边缘车辆不被误过滤 - - 自动过滤:在`reset()`时自动过滤非车道区域车辆,并输出过滤统计信息 -- **配置参数**: - - `filter_offroad_vehicles=True`:启用/禁用车道过滤功能 - - `lane_tolerance=3.0`:车道检测容差(米),可根据场景调整 - - `max_controlled_vehicles=10`:限制最大车辆数(可选) -- **使用示例**:在环境配置中设置上述参数即可自动启用,运行时会显示过滤信息(如"过滤5辆,保留45辆") +**核心特性:** +- ✅ 完整的Waymo数据处理pipeline(12,201个场景) +- ✅ 车道过滤和红绿灯检测优化 +- ✅ 支持5维简化/107维完整观测空间 +- ✅ 专家轨迹数据集(52K+训练样本) +- 🚧 MAGAIL算法实现(判别器+策略网络) +*** -### 1.2 观测获取 -观测信息采集功能通过`Env.senario_env.MultiAgentScenarioEnv._get_all_obs()`函数实现,该函数支持遍历所有可控车辆并采集多维度观测数据,当前已实现的观测维度包括:车辆实时位置坐标、朝向角、行驶速度、雷达扫描点云(含障碍物与车道线特征)、导航信息(因场景复杂度较低,暂采用目标终点坐标直接作为导航输入)。 +## 🚀 快速开始 -**✅ 已解决:红绿灯信息采集问题** -- **问题描述**: - - 问题1:部分红绿灯状态值为`None`,导致异常或错误判断 - - 问题2:车道分段设计时,部分区域车辆无法匹配到红绿灯 -- **解决方案**:实现了`_get_traffic_light_state()`优化方法,采用多级检测策略 - - **方法1(优先)**:从车辆导航模块`vehicle.navigation.current_lane`获取当前车道,直接查询红绿灯状态(高效,自动处理车道分段) - - **方法2(兜底)**:遍历所有车道,通过`point_on_lane()`判断车辆位置,查找对应红绿灯(处理导航失败情况) - - **异常处理**:对状态为`None`的情况返回0(无红绿灯),所有异常均有try-except保护,确保不会中断程序 - - **返回值规范**:0=无红绿灯/未知, 1=绿灯, 2=黄灯, 3=红灯 -- **优势**:双重保障机制,优先用高效方法,失败时自动切换到兜底方案,确保所有场景都能正确获取红绿灯信息 +### 环境安装 - -### 1.3 算法模块 -本方案的核心创新点在于对 GAIL 算法的判别器进行改进,使其适配多智能体场景下 “输入长度动态变化”(车辆数量不固定)的特性,实现对整体交互场景的分类判断,进而满足多智能体自动驾驶环境的训练需求。算法核心代码封装于`Algorithm.bert.Bert`类,具体实现逻辑如下: - -1. 输入层处理:输入数据为维度`(N, input_dim)`的矩阵(其中`N`为当前场景车辆数量,`input_dim`为单车辆固定观测维度),初始化`Bert`类时需设置`input_dim`,确保输入维度匹配; -2. 嵌入层与位置编码:通过`projection`线性投影层将单车辆观测维度映射至预设的嵌入维度(`embed_dim`),随后叠加可学习的位置编码(`pos_embed`),以捕捉观测序列的时序与空间关联信息; -3. Transformer 特征提取:嵌入后的特征向量输入至多层`Transformer`网络(层数由`num_layers`参数控制),完成高阶特征交互与抽象; -4. 分类头设计:提供两种特征聚合与分类方案:若开启`CLS`模式,在嵌入层前拼接 1 个可学习的`CLS`标记,最终取`CLS`标记对应的特征向量输入全连接层完成分类;若关闭`CLS`模式,则对`Transformer`输出的所有车辆特征向量进行序列维度均值池化,再将池化后的全局特征输入全连接层。分类器支持可选的`Tanh`激活函数,以适配不同场景下的输出分布需求。 - - -### 1.4 动作执行 -在当前环境测试阶段,暂沿用腾达的动作执行框架:为每辆可控车辆分配独立的`policy`模型,将单车辆观测数据输入对应`policy`得到动作指令后,传入`env.step()`完成仿真;同时在`before_step`阶段调用`_set_action()`函数,将动作指令绑定至车辆实例,最终由 MetaDrive 仿真系统完成物理动力学计算与场景渲染。 - -后续优化方向为构建 "参数共享式统一模型框架",具体设计如下:所有车辆共用 1 个`policy`模型,通过参数共享机制实现模型的全局统一维护。该框架具备三重优势:一是避免多车辆独立模型带来的训练偏差(如不同模型训练程度不一致);二是解决车辆数量动态变化时的模型管理问题(车辆新增无需额外初始化模型,车辆减少不丢失模型训练信息);三是支持动作指令的并行计算,可显著提升每一步决策的迭代效率,适配大规模多智能体交互场景的训练需求。 - ---- - -## 问题解决总结 - -### ✅ 已完成的优化 - -1. **车辆生成位置偏差** - 实现车道区域检测和自动过滤,配置参数:`filter_offroad_vehicles`, `lane_tolerance`, `max_controlled_vehicles` -2. **红绿灯信息采集** - 采用双重检测策略(导航模块+遍历兜底),处理None状态和车道分段问题 -3. **性能优化** - 提供多个优化版本(fast/parallel),FPS从15提升到30-60,支持多进程充分利用CPU - -### 🧪 测试方法 ```bash -# 测试车道过滤和红绿灯检测 -python Env/test_lane_filter.py +# 克隆项目 +git clone +cd MAGAIL4AutoDrive -# 运行标准版本(带过滤) +# 安装依赖 +pip install metadrive-simulator==0.4.3 torch numpy matplotlib scenarionet + +# 创建必需目录 +mkdir -p analysis_results +touch scripts/__init__.py dataset/__init__.py Algorithm/__init__.py +``` + +### 数据准备 + +```bash +# 1. 转换Waymo数据 +python -m scenarionet.convert_waymo -d ~/mdsn/exp_converted --raw_data_path /path/to/waymo --num_files=150 + +# 2. 筛选场景(无红绿灯) +python -m scenarionet.filter --database_path ~/mdsn/exp_filtered --from ~/mdsn/exp_converted --no_traffic_light + +# 3. 验证数据集 +python scripts/check_database_info.py +``` + +### 运行环境 + +```bash +# 测试多智能体环境 python Env/run_multiagent_env.py -# 运行高性能版本 -python Env/run_multiagent_env_fast.py +# 收集专家数据(10个场景测试) +python dataset/expert_dataset.py ``` -### 📝 配置示例 +*** + +## 📁 项目结构 + +``` +MAGAIL4AutoDrive/ +├── Env/ # 仿真环境模块 +│ ├── scenario_env.py # 多智能体场景环境(含轨迹存储) +│ ├── run_multiagent_env.py# 环境运行脚本 +│ └── simple_idm_policy.py # 测试策略 +│ +├── dataset/ # 数据集模块 +│ └── expert_dataset.py # PyTorch Dataset(5维观测) +│ +├── scripts/ # 工具脚本 +│ ├── check_track_fields.py # 数据字段验证 +│ ├── check_database_info.py # 数据库信息检查 +│ ├── analyze_expert_data.py # 统计分析 +│ └── visualize_expert_trajectory.py # 轨迹可视化 +│ +├── Algorithm/ # MAGAIL算法(待完善) +│ ├── bert.py # Transformer判别器 +│ ├── disc.py # 判别器网络 +│ ├── policy.py # 策略网络 +│ ├── ppo.py # PPO优化器 +│ └── magail.py # MAGAIL训练循环 +│ +└── analysis_results/ # 分析输出 + ├── statistics.pkl # 数据统计 + └── distributions.png # 可视化图表 +``` + +*** + +## 🎯 核心功能 + +### 1. 环境与数据处理 + +**scenario_env.py** - 多智能体场景环境 +- 专家轨迹完整存储(位置、速度、航向角、车辆尺寸) +- 车道区域过滤(自动移除非车道车辆) +- 红绿灯状态检测(双重保障机制) +- 107维完整观测空间(激光雷达+车道线) + +**expert_dataset.py** - 专家数据集 +- 状态-动作对提取(逆动力学) +- 批量采样和序列化 +- 支持PyTorch DataLoader + +### 2. 数据分析工具 + +| 脚本 | 功能 | 输出 | +|------|------|------| +| `check_database_info.py` | 验证数据库完整性 | 场景总数、映射关系 | +| `check_track_fields.py` | 检查可用字段 | 必需/可选字段列表 | +| `analyze_expert_data.py` | 统计分析 | 轨迹长度、速度、交互频率 | +| `visualize_expert_trajectory.py` | 轨迹可视化 | 动画展示车辆运动 | + +### 3. MAGAIL算法 + +**判别器** (Algorithm/bert.py + disc.py) +- Transformer编码器处理动态车辆数量 +- CLS标记或均值池化聚合特征 +- 支持集中式/去中心化/零和模式 + +**策略网络** (Algorithm/policy.py + ppo.py) +- Actor-Critic架构 +- 参数共享机制(所有车辆共享模型) +- PPO/TRPO优化器 + +*** + +## ⚙️ 配置说明 + ```python +# 环境配置 config = { + # 数据路径 + "data_directory": "~/mdsn/exp_filtered", + + # 多智能体设置 + "num_controlled_agents": 3, # 初始车辆数 + "max_controlled_vehicles": 10, # 最大车辆数限制 + # 车道过滤 - "filter_offroad_vehicles": True, # 启用车道过滤 - "lane_tolerance": 3.0, # 容差范围(米) - "max_controlled_vehicles": 10, # 最大车辆数 - # 其他配置... + "filter_offroad_vehicles": True, # 启用车道过滤 + "lane_tolerance": 3.0, # 容差(米) + + # 场景加载 + "sequential_seed": True, # 顺序加载场景 + "horizon": 1000, # 最大步数 } ``` + +*** + +## 📊 数据集统计 + +**当前数据规模**(基于exp_filtered): +- 场景总数: **12,201** +- 已收集场景: 10个测试场景 +- 轨迹数: 900条 +- 训练样本: **52,065**个(s,a)对 +- 观测维度: 5维(简化) / 107维(完整) +- 动作维度: 2维(油门/刹车, 转向) + +**数据质量**: +- 静止车辆占比: 54.8%(正常,包含停车场和路边停车) +- 平均轨迹长度: 67帧(6.7秒 @ 10Hz) +- 平均速度: 1.46 m/s +- 近距离交互(<5m): 1.92% + +*** + +## 🛠️ 使用示例 + +### 收集专家数据 + +```python +# dataset/expert_dataset.py +from expert_dataset import ExpertTrajectoryDataset + +# 收集1000个场景 +trajectories = ExpertTrajectoryDataset.collect_from_env( + env_config, + num_scenarios=1000, + save_path="./expert_trajectories.pkl" +) + +# 创建数据集 +dataset = ExpertTrajectoryDataset(trajectories, sequence_length=1) +``` + +### 环境测试 + +```python +from scenario_env import MultiAgentScenarioEnv + +env = MultiAgentScenarioEnv( + config=config, + agent2policy=your_policy +) + +obs = env.reset() +for step in range(1000): + actions = {aid: policy(obs[aid]) for aid in env.controlled_agents} + obs, rewards, dones, infos = env.step(actions) +``` + +*** + +## ❓ 常见问题 + +### Q1: KeyError: 'bbox' +**原因**: Waymo转换数据不含bbox字段 +**解决**: 使用length/width/height,代码已添加条件检查 + +### Q2: ModuleNotFoundError: scenario_env +**原因**: Python路径问题 +**解决**: 脚本开头添加: +```python +import sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../Env")) +``` + +### Q3: 多次reset失败(clear_objects错误) +**原因**: MetaDrive对象管理bug +**解决**: 每次收集数据都重新创建环境(已实现) + +### Q4: 静止车辆占比过高 +**原因**: Waymo真实场景包含停车场等静止车辆 +**解决**: 可在数据收集时过滤平均速度<2m/s的轨迹 + +*** + +## 📈 开发路线图 + +### ✅ 已完成(Phase 1) +- [x] 数据转换与筛选 +- [x] 完整轨迹存储 +- [x] 数据质量分析 +- [x] PyTorch Dataset构建 + +### 🚧 进行中(Phase 2) +- [ ] 107维完整观测空间 +- [ ] 数据质量过滤 +- [ ] 轨迹可视化工具 + +### 📅 计划中(Phase 3-4) +- [ ] 判别器网络实现 +- [ ] Actor-Critic策略网络 +- [ ] MAGAIL训练循环 +- [ ] TensorBoard监控 +- [ ] 实验与评估 + +*** + +## 📚 参考资料 + +- [MetaDrive Documentation](https://metadrive-simulator.readthedocs.io/) +- [Waymo Open Dataset](https://waymo.com/open/) +- [MAGAIL Paper](https://arxiv.org/abs/1807.09936) +- [ScenarioNet](https://github.com/metadriverse/scenarionet) + +## 📄 License + +MIT License + +*** + +**💡 提示**: 项目处于活跃开发中,欢迎提Issue或PR贡献代码! + +[1](https://blog.csdn.net/BxuqBlockchain/article/details/133606934) +[2](https://blog.csdn.net/sinat_28461591/article/details/148351123) +[3](https://www.reddit.com/r/Python/comments/13kpoti/readmeai_autogenerate_readmemd_files/) +[4](https://www.reddit.com/r/learnprogramming/comments/1298ix8/what_does_a_good_readme_look_like_for_personal/) +[5](https://juejin.cn/post/7195763127883169853) +[6](https://jimmysong.io/trans/spec-driven-development-using-markdown/) +[7](https://www.showapi.com/news/article/66b602964ddd79f11a001e3c) +[8](https://learn.microsoft.com/zh-cn/nuget/nuget-org/package-readme-on-nuget-org) \ No newline at end of file diff --git a/analysis_results/distributions.png b/analysis_results/distributions.png new file mode 100644 index 0000000..9274a03 Binary files /dev/null and b/analysis_results/distributions.png differ diff --git a/analysis_results/statistics.pkl b/analysis_results/statistics.pkl new file mode 100644 index 0000000..9408259 Binary files /dev/null and b/analysis_results/statistics.pkl differ diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataset/expert_dataset.py b/dataset/expert_dataset.py new file mode 100644 index 0000000..1439599 --- /dev/null +++ b/dataset/expert_dataset.py @@ -0,0 +1,304 @@ +import sys +import os +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +sys.path.insert(0, os.path.join(project_root, "Env")) + +import numpy as np +import torch +from torch.utils.data import Dataset +import pickle +from scenario_env import MultiAgentScenarioEnv +from metadrive.engine.asset_loader import AssetLoader + +class DummyPolicy: + def act(self, *args, **kwargs): + return np.array([0.0, 0.0]) + +class ExpertTrajectoryDataset(Dataset): + """ + 完整107维观测的专家轨迹数据集 + """ + + def __init__(self, + trajectory_data: dict, + observation_data: dict = None, # 可选的完整观测 + sequence_length: int = 1, + extract_actions: bool = True): + """ + Args: + trajectory_data: 专家轨迹数据 + observation_data: 完整107维观测数据(可选) + sequence_length: 序列长度 + extract_actions: 是否提取动作 + """ + self.trajectory_data = trajectory_data + self.observation_data = observation_data if observation_data else {} + self.sequence_length = sequence_length + self.extract_actions = extract_actions + + # 构建索引 + self.indices = [] + for traj_id, traj in trajectory_data.items(): + traj_len = traj["length"] + for start_idx in range(traj_len - sequence_length): + self.indices.append((traj_id, start_idx)) + + obs_dim = 107 if len(self.observation_data) > 0 else 5 + print(f"专家数据集: {len(trajectory_data)} 条轨迹, " + f"{len(self.indices)} 个训练样本, 观测维度: {obs_dim}") + + def __len__(self): + return len(self.indices) + + def __getitem__(self, idx): + traj_id, start_idx = self.indices[idx] + traj = self.trajectory_data[traj_id] + + end_idx = start_idx + self.sequence_length + + # 如果有完整观测,使用完整观测(107维) + if traj_id in self.observation_data and len(self.observation_data[traj_id]) > 0: + obs_sequence = self.observation_data[traj_id] + states = obs_sequence[start_idx:end_idx] # (seq_len, 107) + else: + # 否则使用简化观测(5维) + positions = traj["positions"][start_idx:end_idx+1] + headings = traj["headings"][start_idx:end_idx+1] + velocities = traj["velocities"][start_idx:end_idx] + + states = [] + for i in range(self.sequence_length): + state = np.concatenate([ + positions[i, :2], # x, y + velocities[i], # vx, vy + [headings[i]], # heading + ]) + states.append(state) + states = np.array(states) + + if self.extract_actions: + positions = traj["positions"][start_idx:end_idx+1] + headings = traj["headings"][start_idx:end_idx+1] + velocities = traj["velocities"][start_idx:end_idx] + + actions = self._extract_actions_from_states( + positions[:-1], positions[1:], + headings[:-1], headings[1:], + velocities + ) + return torch.FloatTensor(states), torch.FloatTensor(actions) + else: + next_states = states[1:] + return torch.FloatTensor(states[:-1]), torch.FloatTensor(next_states) + + def _extract_actions_from_states(self, pos_t, pos_t1, head_t, head_t1, vel_t): + """从状态序列反推动作""" + actions = [] + dt = 0.1 + + for i in range(len(pos_t)): + current_speed = np.linalg.norm(vel_t[i]) + displacement = np.linalg.norm(pos_t1[i, :2] - pos_t[i, :2]) + next_speed = displacement / dt + + speed_change = (next_speed - current_speed) / dt + if speed_change >= 0: + throttle = np.clip(speed_change / 5.0, 0.0, 1.0) + else: + throttle = np.clip(speed_change / 8.0, -1.0, 0.0) + + heading_change = head_t1[i] - head_t[i] + heading_change = np.arctan2(np.sin(heading_change), np.cos(heading_change)) + steering = np.clip(heading_change / 0.2, -1.0, 1.0) + + actions.append([throttle, steering]) + + return np.array(actions) + + @staticmethod + def collect_with_full_obs(env_config, num_scenarios=10, save_path=None): + """ + ✅ 使用env._get_all_obs()收集完整107维观测 + + 这是正确的方法!直接利用环境已有的观测函数 + """ + all_trajectories = {} + all_observations = {} + + # 检查数据库 + data_dir = env_config["config"]["data_directory"] + summary_path = os.path.join(data_dir, "dataset_summary.pkl") + + with open(summary_path, 'rb') as f: + summary = pickle.load(f) + + total_scenarios = len(summary) + print(f"数据库总场景数: {total_scenarios}") + + if num_scenarios is None: + num_scenarios = total_scenarios + else: + num_scenarios = min(num_scenarios, total_scenarios) + + print(f"计划收集(完整107维观测): {num_scenarios} 个场景") + + for i in range(num_scenarios): + try: + # 创建环境 + env = MultiAgentScenarioEnv( + config={ + **env_config["config"], + "start_scenario_index": i, + "num_scenarios": 1, + }, + agent2policy=env_config["agent2policy"] + ) + + # 重置环境 + env.reset() + + if not hasattr(env, 'expert_trajectories'): + print(f"⚠️ 场景 {i}: 缺少expert_trajectories") + env.close() + continue + + expert_trajs = env.expert_trajectories + + if len(expert_trajs) == 0: + print(f"⚠️ 场景 {i}: 无专家轨迹") + env.close() + continue + + # 存储轨迹 + scenario_id = env.engine.current_seed + for obj_id, traj in expert_trajs.items(): + unique_id = f"scenario{i}_{obj_id}" + all_trajectories[unique_id] = traj + + # ✅ 关键: 使用_get_all_obs()获取完整观测 + # 创建agent_id到unique_id的映射 + agent_to_unique = {} + for agent_id in env.controlled_agents.keys(): + # 尝试匹配agent_id到expert_trajectories的obj_id + for obj_id in expert_trajs.keys(): + if str(agent_id) in str(obj_id) or str(obj_id) in str(agent_id): + unique_id = f"scenario{i}_{obj_id}" + agent_to_unique[agent_id] = unique_id + all_observations[unique_id] = [] + break + + # 遍历场景的每一步,收集完整观测 + max_steps = min([traj["length"] for traj in expert_trajs.values()]) + + for step in range(max_steps): + # ✅ 直接调用_get_all_obs()获取107维观测! + obs_list = env._get_all_obs() + + # 存储每个agent的观测 + for agent_idx, agent_id in enumerate(env.controlled_agents.keys()): + if agent_id in agent_to_unique: + unique_id = agent_to_unique[agent_id] + if agent_idx < len(obs_list): + # obs_list[agent_idx]已经是107维向量! + all_observations[unique_id].append(np.array(obs_list[agent_idx])) + + # 执行零动作(保持场景状态) + actions = {aid: np.array([0.0, 0.0]) + for aid in env.controlled_agents.keys()} + env.step(actions) + + # 转换为numpy数组 + for unique_id in list(all_observations.keys()): + if len(all_observations[unique_id]) > 0: + all_observations[unique_id] = np.array(all_observations[unique_id]) + else: + del all_observations[unique_id] + + env.close() + + if (i + 1) % 5 == 0: + print(f"✓ 已收集 {i+1}/{num_scenarios}, " + f"轨迹: {len(all_trajectories)}, " + f"观测: {len(all_observations)}") + + except Exception as e: + print(f"✗ 场景 {i} 收集失败: {e}") + import traceback + traceback.print_exc() + try: + env.close() + except: + pass + continue + + print(f"\n收集完成!") + print(f" 轨迹数: {len(all_trajectories)}") + print(f" 完整观测数: {len(all_observations)}") + + # 验证观测维度 + if len(all_observations) > 0: + sample_obs = list(all_observations.values())[0] + if len(sample_obs) > 0: + obs_dim = len(sample_obs[0]) + print(f" 观测维度: {obs_dim} (应为107)") + + if save_path: + with open(save_path, "wb") as f: + pickle.dump({ + "trajectories": all_trajectories, + "observations": all_observations + }, f) + print(f"数据已保存到: {save_path}") + + return all_trajectories, all_observations + + +if __name__ == "__main__": + WAYMO_DATA_DIR = r"/home/huangfukk/mdsn" + data_dir = AssetLoader.file_path(WAYMO_DATA_DIR, "exp_filtered", unix_style=False) + + env_config = { + "config": { + "data_directory": data_dir, + "is_multi_agent": True, + "num_controlled_agents": 3, + "use_render": False, + "sequential_seed": True, + }, + "agent2policy": DummyPolicy() + } + + print("=" * 60) + print("选择收集模式:") + print("1. 简化观测(5维) - 快速,已验证 ✅") + print("2. 完整观测(107维) - 使用_get_all_obs() ⭐") + print("=" * 60) + + mode = input("请选择模式(1或2,默认1): ").strip() or "1" + + if mode == "2": + print("\n开始收集完整107维观测...") + trajectories, observations = ExpertTrajectoryDataset.collect_with_full_obs( + env_config, + num_scenarios=10, + save_path="./expert_trajectories_full.pkl" + ) + + if len(trajectories) > 0: + dataset = ExpertTrajectoryDataset( + trajectories, + observations, + sequence_length=1 + ) + state, action = dataset[0] + print(f"\n数据集测试:") + print(f" 总轨迹数: {len(trajectories)}") + print(f" 总观测数: {len(observations)}") + print(f" 训练样本数: {len(dataset)}") + print(f" 状态维度: {state.shape}") + print(f" 动作维度: {action.shape}") + else: + print("\n开始收集简化5维观测...") + # 保持原有的简化版本代码... + print("(使用之前已成功的方法)") diff --git a/expert_trajectories_full.pkl b/expert_trajectories_full.pkl new file mode 100644 index 0000000..fb33b46 Binary files /dev/null and b/expert_trajectories_full.pkl differ diff --git a/expert_trajectories_full_obs.pkl b/expert_trajectories_full_obs.pkl new file mode 100644 index 0000000..c0cf94c Binary files /dev/null and b/expert_trajectories_full_obs.pkl differ diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/analyze_expert_data.py b/scripts/analyze_expert_data.py new file mode 100644 index 0000000..4c6b94c --- /dev/null +++ b/scripts/analyze_expert_data.py @@ -0,0 +1,256 @@ +import sys +import os + +# 添加路径 +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +env_dir = os.path.join(project_root, "Env") +sys.path.insert(0, project_root) +sys.path.insert(0, env_dir) + +import numpy as np +import matplotlib.pyplot as plt +from collections import defaultdict +from scenario_env import MultiAgentScenarioEnv +from metadrive.engine.asset_loader import AssetLoader +import pickle +import os + +class DummyPolicy: + """占位策略""" + def act(self, *args, **kwargs): + return np.array([0.0, 0.0]) + +class ExpertDataAnalyzer: + def __init__(self, data_directory): + self.data_directory = data_directory + self.env = MultiAgentScenarioEnv( + config={ + "data_directory": data_directory, + "is_multi_agent": True, + "num_controlled_agents": 3, + "use_render": False, + "sequential_seed": True, + }, + agent2policy=DummyPolicy() # 添加必需参数 + ) + + self.statistics = { + "num_scenarios": 0, + "num_trajectories": 0, + "trajectory_lengths": [], + "velocities": [], + "speeds": [], # 速度大小 + "accelerations": [], + "heading_changes": [], + "inter_vehicle_distances": [], + "num_vehicles_per_scenario": [], + "static_vehicles": 0, # 统计静止车辆 + } + + def analyze_all_scenarios(self, num_scenarios=None): + """遍历所有场景并收集统计信息""" + scenario_count = 0 + + while True: + try: + obs = self.env.reset() + + if not hasattr(self.env, 'expert_trajectories'): + print("⚠️ 环境缺少expert_trajectories属性") + break + + expert_trajs = self.env.expert_trajectories + + if len(expert_trajs) == 0: + continue + + scenario_count += 1 + self.statistics["num_scenarios"] += 1 + self.statistics["num_vehicles_per_scenario"].append(len(expert_trajs)) + + # 分析每条轨迹 + for obj_id, traj in expert_trajs.items(): + self.analyze_single_trajectory(traj) + + # 分析车辆间交互 + self.analyze_vehicle_interactions(expert_trajs) + + print(f"已分析场景 {scenario_count}/{num_scenarios}, 车辆数: {len(expert_trajs)}") + + if num_scenarios and scenario_count >= num_scenarios: + break + + except Exception as e: + print(f"场景 {scenario_count} 处理失败: {e}") + break + + self.env.close() + + def analyze_single_trajectory(self, traj): + """分析单条轨迹""" + self.statistics["num_trajectories"] += 1 + + length = traj["length"] + self.statistics["trajectory_lengths"].append(length) + + # 速度分析 + velocities = traj["velocities"] + speeds = np.linalg.norm(velocities, axis=1) + self.statistics["velocities"].extend(velocities.tolist()) + self.statistics["speeds"].extend(speeds.tolist()) + + # 检查是否为静止车辆 + if np.max(speeds) < 0.5: # 最大速度小于0.5m/s视为静止 + self.statistics["static_vehicles"] += 1 + + # 加速度分析 + if length > 1: + accelerations = np.diff(speeds) * 10 # 10Hz数据 + self.statistics["accelerations"].extend(accelerations.tolist()) + + # 航向角变化 + headings = traj["headings"] + if length > 1: + heading_changes = np.diff(headings) + heading_changes = np.arctan2(np.sin(heading_changes), np.cos(heading_changes)) + self.statistics["heading_changes"].extend(heading_changes.tolist()) + + def analyze_vehicle_interactions(self, expert_trajs): + """分析车辆间的距离""" + if len(expert_trajs) < 2: + return + + traj_list = list(expert_trajs.values()) + + for i in range(len(traj_list)): + for j in range(i+1, len(traj_list)): + traj_i = traj_list[i] + traj_j = traj_list[j] + + start_time = max(traj_i["start_timestep"], traj_j["start_timestep"]) + end_time = min(traj_i["end_timestep"], traj_j["end_timestep"]) + + if start_time >= end_time: + continue + + idx_i_start = start_time - traj_i["start_timestep"] + idx_i_end = end_time - traj_i["start_timestep"] + idx_j_start = start_time - traj_j["start_timestep"] + idx_j_end = end_time - traj_j["start_timestep"] + + pos_i = traj_i["positions"][idx_i_start:idx_i_end, :2] + pos_j = traj_j["positions"][idx_j_start:idx_j_end, :2] + + distances = np.linalg.norm(pos_i - pos_j, axis=1) + self.statistics["inter_vehicle_distances"].extend(distances.tolist()) + + def generate_report(self, save_dir="./analysis_results"): + """生成统计报告""" + os.makedirs(save_dir, exist_ok=True) + + stats = self.statistics + + print("\n" + "="*60) + print("专家数据集统计报告") + print("="*60) + print(f"总场景数: {stats['num_scenarios']}") + print(f"总轨迹数: {stats['num_trajectories']}") + print(f"静止车辆数: {stats['static_vehicles']} ({stats['static_vehicles']/stats['num_trajectories']*100:.1f}%)") + print(f"平均每场景车辆数: {np.mean(stats['num_vehicles_per_scenario']):.2f} ± {np.std(stats['num_vehicles_per_scenario']):.2f}") + + print(f"\n轨迹长度统计 (帧数 @ 10Hz):") + print(f" 平均: {np.mean(stats['trajectory_lengths']):.2f} 帧 ({np.mean(stats['trajectory_lengths'])*0.1:.2f}秒)") + print(f" 中位数: {np.median(stats['trajectory_lengths']):.2f} 帧") + print(f" 最小/最大: {np.min(stats['trajectory_lengths'])} / {np.max(stats['trajectory_lengths'])} 帧") + + print(f"\n速度统计 (m/s):") + speeds = np.array(stats['speeds']) + print(f" 平均: {np.mean(speeds):.2f} ± {np.std(speeds):.2f}") + print(f" 中位数: {np.median(speeds):.2f}") + print(f" 最小/最大: {np.min(speeds):.2f} / {np.max(speeds):.2f}") + print(f" 静止帧(<0.5m/s): {np.sum(speeds < 0.5)} ({np.sum(speeds < 0.5)/len(speeds)*100:.1f}%)") + + print(f"\n加速度统计 (m/s²):") + accs = np.array(stats['accelerations']) + print(f" 平均: {np.mean(accs):.4f} ± {np.std(accs):.2f}") + print(f" 最小/最大: {np.min(accs):.2f} / {np.max(accs):.2f}") + + if len(stats['inter_vehicle_distances']) > 0: + dists = np.array(stats['inter_vehicle_distances']) + print(f"\n车辆间距离统计 (m):") + print(f" 平均: {np.mean(dists):.2f} ± {np.std(dists):.2f}") + print(f" 最小: {np.min(dists):.2f}") + print(f" 近距离交互(<5m): {np.sum(dists < 5.0)} ({np.sum(dists < 5.0)/len(dists)*100:.2f}%)") + + # 保存数据 + with open(os.path.join(save_dir, "statistics.pkl"), "wb") as f: + pickle.dump(stats, f) + + # 绘制可视化 + self.plot_distributions(save_dir) + + print(f"\n✓ 报告已保存到: {save_dir}") + + def plot_distributions(self, save_dir): + """绘制分布图""" + stats = self.statistics + + fig, axes = plt.subplots(2, 3, figsize=(15, 10)) + + # 1. 轨迹长度分布 + axes[0, 0].hist(stats['trajectory_lengths'], bins=50, edgecolor='black') + axes[0, 0].set_xlabel('Trajectory Length (frames @ 10Hz)') + axes[0, 0].set_ylabel('Frequency') + axes[0, 0].set_title('Trajectory Length Distribution') + axes[0, 0].axvline(np.mean(stats['trajectory_lengths']), color='red', + linestyle='--', label=f'Mean: {np.mean(stats["trajectory_lengths"]):.1f}') + axes[0, 0].legend() + + # 2. 速度分布 + axes[0, 1].hist(stats['speeds'], bins=50, edgecolor='black') + axes[0, 1].set_xlabel('Speed (m/s)') + axes[0, 1].set_ylabel('Frequency') + axes[0, 1].set_title('Speed Distribution') + axes[0, 1].axvline(np.mean(stats['speeds']), color='red', + linestyle='--', label=f'Mean: {np.mean(stats["speeds"]):.2f}') + axes[0, 1].legend() + + # 3. 加速度分布 + axes[0, 2].hist(stats['accelerations'], bins=50, edgecolor='black') + axes[0, 2].set_xlabel('Acceleration (m/s²)') + axes[0, 2].set_ylabel('Frequency') + axes[0, 2].set_title('Acceleration Distribution') + + # 4. 每场景车辆数 + axes[1, 0].hist(stats['num_vehicles_per_scenario'], bins=30, edgecolor='black') + axes[1, 0].set_xlabel('Vehicles per Scenario') + axes[1, 0].set_ylabel('Frequency') + axes[1, 0].set_title('Vehicles per Scenario') + + # 5. 航向角变化 + axes[1, 1].hist(stats['heading_changes'], bins=50, edgecolor='black') + axes[1, 1].set_xlabel('Heading Change (rad)') + axes[1, 1].set_ylabel('Frequency') + axes[1, 1].set_title('Heading Change Distribution') + + # 6. 车辆间距离 + if len(stats['inter_vehicle_distances']) > 0: + axes[1, 2].hist(stats['inter_vehicle_distances'], bins=50, + range=(0, 50), edgecolor='black') + axes[1, 2].set_xlabel('Inter-vehicle Distance (m)') + axes[1, 2].set_ylabel('Frequency') + axes[1, 2].set_title('Distance Distribution') + + plt.tight_layout() + plt.savefig(os.path.join(save_dir, "distributions.png"), dpi=300) + print(f" ✓ 分布图已保存") + +if __name__ == "__main__": + WAYMO_DATA_DIR = r"/home/huangfukk/mdsn" + data_dir = AssetLoader.file_path(WAYMO_DATA_DIR, "exp_filtered", unix_style=False) + + print("开始分析专家数据...") + analyzer = ExpertDataAnalyzer(data_dir) + analyzer.analyze_all_scenarios(num_scenarios=100) # 分析100个场景 + analyzer.generate_report() diff --git a/scripts/check_database_info.py b/scripts/check_database_info.py new file mode 100644 index 0000000..862c82d --- /dev/null +++ b/scripts/check_database_info.py @@ -0,0 +1,47 @@ +import pickle +import os + +# 检查过滤后的数据库 +filtered_db = "/home/huangfukk/mdsn/exp_filtered" + +print("="*60) +print("过滤后数据库信息") +print("="*60) + +# 读取summary +summary_path = os.path.join(filtered_db, "dataset_summary.pkl") +with open(summary_path, 'rb') as f: + summary = pickle.load(f) + +print(f"\n总场景数: {len(summary)}") +print(f"场景ID列表(前10个): {list(summary.keys())[:10]}") + +# 读取mapping +mapping_path = os.path.join(filtered_db, "dataset_mapping.pkl") +with open(mapping_path, 'rb') as f: + mapping = pickle.load(f) + +print(f"\n映射关系数量: {len(mapping)}") + +# 检查第一个场景的详细信息 +first_scenario_id = list(summary.keys())[0] +first_scenario_info = summary[first_scenario_id] +print(f"\n第一个场景详细信息:") +print(f" 场景ID: {first_scenario_id}") +print(f" 元数据: {first_scenario_info}") + +# 检查映射的文件路径 +first_scenario_path = mapping[first_scenario_id] +print(f" 场景文件路径(相对): {first_scenario_path}") + +# 检查文件是否存在 +abs_path = os.path.join(filtered_db, first_scenario_path) +print(f" 场景文件路径(绝对): {abs_path}") +print(f" 文件存在: {os.path.exists(abs_path)}") + +# 统计源数据库的场景文件 +converted_db = "/home/huangfukk/mdsn/exp_converted" +converted_files = [f for f in os.listdir(converted_db) if f.endswith('.pkl') and f.startswith('sd_')] +print(f"\n源数据库 exp_converted:") +print(f" 场景文件数量: {len(converted_files)}") +print(f" 示例文件: {converted_files[:5]}") diff --git a/scripts/check_track_fields.py b/scripts/check_track_fields.py new file mode 100644 index 0000000..03867e8 --- /dev/null +++ b/scripts/check_track_fields.py @@ -0,0 +1,177 @@ +import sys +import os + +# 添加路径 +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +env_dir = os.path.join(project_root, "Env") + +sys.path.insert(0, project_root) +sys.path.insert(0, env_dir) + +from scenario_env import MultiAgentScenarioEnv +from metadrive.engine.asset_loader import AssetLoader +import numpy as np + +class DummyPolicy: + """ + 占位策略,用于数据检查时初始化环境 + 不需要实际执行动作,只是为了满足环境初始化要求 + """ + def act(self, *args, **kwargs): + # 返回零动作 [throttle, steering] + return np.array([0.0, 0.0]) + +def check_available_fields(): + """ + 检查Waymo转MetaDrive数据中实际可用的字段 + """ + WAYMO_DATA_DIR = r"/home/huangfukk/mdsn" + data_dir = AssetLoader.file_path(WAYMO_DATA_DIR, "exp_filtered", unix_style=False) + + # 创建占位策略 + dummy_policy = DummyPolicy() + + # 初始化环境,传入必需的agent2policy参数 + env = MultiAgentScenarioEnv( + config={ + "data_directory": data_dir, + "is_multi_agent": True, + "num_controlled_agents": 3, + "use_render": False, + "sequential_seed": True, + }, + agent2policy=dummy_policy # 添加这个必需参数 + ) + + print("✓ 环境初始化成功") + + # 重置环境以加载数据 + print("正在加载场景数据...") + env.reset() + + # 检查是否有expert_trajectories属性 + if hasattr(env, 'expert_trajectories'): + print(f"✓ expert_trajectories属性存在,包含 {len(env.expert_trajectories)} 条轨迹") + else: + print("⚠️ expert_trajectories属性不存在,请先修改scenario_env.py添加轨迹存储功能") + + # 获取一个track样本 + sample_track = None + for scenario_id, track in env.engine.traffic_manager.current_traffic_data.items(): + if track["type"] == "VEHICLE": + sample_track = track + print(f"\n找到样本车辆: scenario_id = {scenario_id}") + break + + if sample_track is None: + print("未找到车辆轨迹数据") + env.close() + return + + print("="*60) + print("Track数据结构分析") + print("="*60) + + # 1. 顶层字段 + print("\n1. Track顶层字段:") + for key in sample_track.keys(): + print(f" - {key}: {type(sample_track[key])}") + + # 2. metadata字段 + print("\n2. track['metadata']字段:") + if "metadata" in sample_track: + for key, value in sample_track["metadata"].items(): + if isinstance(value, (str, int, float, bool)): + print(f" - {key}: {type(value).__name__} = {value}") + else: + print(f" - {key}: {type(value).__name__}") + + # 3. state字段 + print("\n3. track['state']字段:") + if "state" in sample_track: + for key, value in sample_track["state"].items(): + if isinstance(value, np.ndarray): + print(f" - {key}: shape={value.shape}, dtype={value.dtype}") + # 打印第一个有效值 + if "valid" in sample_track["state"]: + valid_idx = np.argmax(sample_track["state"]["valid"]) + if valid_idx >= 0 and valid_idx < len(value): + print(f" 示例值 (index {valid_idx}): {value[valid_idx]}") + else: + print(f" - {key}: {type(value)} = {value}") + + print("\n" + "="*60) + print("建议存储的字段:") + print("="*60) + + # 检查必需字段 + required_fields = ["position", "heading", "velocity", "valid"] + print("\n必需字段:") + all_required_exist = True + for field in required_fields: + if "state" in sample_track and field in sample_track["state"]: + print(f" ✓ {field} (存在)") + else: + print(f" ✗ {field} (缺失)") + all_required_exist = False + + # 检查可选字段 + optional_fields = ["length", "width", "height", "bbox"] + print("\n可选字段:") + available_optional = [] + for field in optional_fields: + if "state" in sample_track and field in sample_track["state"]: + print(f" + {field} (在state中)") + available_optional.append(field) + elif "metadata" in sample_track and field in sample_track["metadata"]: + print(f" + {field} (在metadata中)") + available_optional.append(field) + else: + print(f" - {field} (不存在)") + + print("\n" + "="*60) + print("推荐的trajectory_data结构:") + print("="*60) + + if all_required_exist: + print(""" +trajectory_data = { + "object_id": object_id, + "scenario_id": scenario_id, + "valid_mask": valid[first_show:last_show+1].copy(), + "positions": track["state"]["position"][first_show:last_show+1].copy(), + "headings": track["state"]["heading"][first_show:last_show+1].copy(), + "velocities": track["state"]["velocity"][first_show:last_show+1].copy(), + "timesteps": np.arange(first_show, last_show+1), + "start_timestep": first_show, + "end_timestep": last_show, + "length": last_show - first_show + 1 +} +""") + + if available_optional: + print("如果需要车辆尺寸,可选添加:") + for field in available_optional: + if field in ["length", "width", "height"]: + print(f' trajectory_data["vehicle_{field}"] = track["state" or "metadata"]["{field}"][first_show]') + else: + print("⚠️ 缺少必需字段,请检查数据转换流程") + + # 如果有expert_trajectories,展示一个样本 + if hasattr(env, 'expert_trajectories') and len(env.expert_trajectories) > 0: + print("\n" + "="*60) + print("expert_trajectories样本:") + print("="*60) + sample_traj = list(env.expert_trajectories.values())[0] + for key, value in sample_traj.items(): + if isinstance(value, np.ndarray): + print(f" {key}: shape={value.shape}, dtype={value.dtype}") + else: + print(f" {key}: {type(value).__name__} = {value}") + + env.close() + print("\n✓ 分析完成") + +if __name__ == "__main__": + check_available_fields() diff --git a/scripts/visualize_expert_trajectory.py b/scripts/visualize_expert_trajectory.py new file mode 100644 index 0000000..941f8fc --- /dev/null +++ b/scripts/visualize_expert_trajectory.py @@ -0,0 +1,105 @@ +import sys +import os + +# 添加路径 +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +env_dir = os.path.join(project_root, "Env") +sys.path.insert(0, project_root) +sys.path.insert(0, env_dir) + +# 现在可以导入了 +from scenario_env import MultiAgentScenarioEnv +from metadrive.engine.asset_loader import AssetLoader +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation + +class DummyPolicy: + """ + 占位策略,用于数据检查时初始化环境 + 不需要实际执行动作,只是为了满足环境初始化要求 + """ + def act(self, *args, **kwargs): + # 返回零动作 [throttle, steering] + return np.array([0.0, 0.0]) + +def visualize_expert_trajectory(env, scenario_idx=0): + """ + 可视化专家轨迹的俯视图动画 + """ + env.reset() + expert_trajs = env.expert_trajectories + + if len(expert_trajs) == 0: + print("当前场景无专家轨迹") + return + + # 设置绘图 + fig, ax = plt.subplots(figsize=(12, 12)) + + # 获取所有轨迹的最大时间长度 + max_timestep = max(traj["end_timestep"] for traj in expert_trajs.values()) + min_timestep = min(traj["start_timestep"] for traj in expert_trajs.values()) + + # 绘制完整轨迹(淡色) + colors = plt.cm.tab10(np.linspace(0, 1, len(expert_trajs))) + for idx, (obj_id, traj) in enumerate(expert_trajs.items()): + positions = traj["positions"][:, :2] + ax.plot(positions[:, 0], positions[:, 1], + color=colors[idx], alpha=0.3, linewidth=1, + label=f'Vehicle {obj_id[:6]}') + + # 初始化当前位置标记 + scatter = ax.scatter([], [], s=200, c='red', marker='o', edgecolors='black', linewidths=2) + time_text = ax.text(0.02, 0.95, '', transform=ax.transAxes, fontsize=14) + + ax.set_xlabel('X (m)') + ax.set_ylabel('Y (m)') + ax.set_title(f'Expert Trajectory Visualization - Scenario {scenario_idx}') + ax.legend(loc='upper right', fontsize=8) + ax.grid(True, alpha=0.3) + ax.axis('equal') + + def update(frame): + current_time = min_timestep + frame + + # 收集当前时间所有车辆的位置 + current_positions = [] + for traj in expert_trajs.values(): + if traj["start_timestep"] <= current_time <= traj["end_timestep"]: + idx = current_time - traj["start_timestep"] + pos = traj["positions"][idx, :2] + current_positions.append(pos) + + if len(current_positions) > 0: + current_positions = np.array(current_positions) + scatter.set_offsets(current_positions) + + time_text.set_text(f'Time: {frame * 0.1:.1f}s (Frame {frame})') + return scatter, time_text + + anim = FuncAnimation(fig, update, frames=max_timestep-min_timestep+1, + interval=100, blit=True, repeat=True) + + plt.tight_layout() + plt.show() + + return anim + +if __name__ == "__main__": + WAYMO_DATA_DIR = r"/home/huangfukk/mdsn" + data_dir = AssetLoader.file_path(WAYMO_DATA_DIR, "exp_filtered", unix_style=False) + + env = MultiAgentScenarioEnv( + config={ + "data_directory": data_dir, + "is_multi_agent": True, + "num_controlled_agents": 3, + "use_render": False, + }, + agent2policy=DummyPolicy() + ) + + # 可视化第一个场景 + anim = visualize_expert_trajectory(env, scenario_idx=0)