544 lines
15 KiB
Markdown
544 lines
15 KiB
Markdown
# MAGAIL算法应用指南
|
||
|
||
## 目录
|
||
1. [Algorithm模块概览](#algorithm模块概览)
|
||
2. [如何应用到环境](#如何应用到环境)
|
||
3. [完整训练流程](#完整训练流程)
|
||
4. [当前实现状态](#当前实现状态)
|
||
5. [需要完善的部分](#需要完善的部分)
|
||
|
||
---
|
||
|
||
## Algorithm模块概览
|
||
|
||
### 📁 模块文件说明
|
||
|
||
```
|
||
Algorithm/
|
||
├── bert.py # BERT判别器/价值网络
|
||
├── disc.py # GAIL判别器(继承BERT)
|
||
├── policy.py # 策略网络(Actor)
|
||
├── ppo.py # PPO算法基类
|
||
├── magail.py # MAGAIL主算法(继承PPO)
|
||
├── buffer.py # 经验回放缓冲区
|
||
└── utils.py # 工具函数(标准化等)
|
||
```
|
||
|
||
### 🔗 模块依赖关系
|
||
|
||
```
|
||
MAGAIL (magail.py)
|
||
├─ 继承 PPO (ppo.py)
|
||
│ ├─ 使用 RolloutBuffer (buffer.py)
|
||
│ ├─ 使用 StateIndependentPolicy (policy.py)
|
||
│ └─ 使用 Bert作为Critic (bert.py)
|
||
│
|
||
├─ 使用 GAILDiscrim (disc.py)
|
||
│ └─ 继承 Bert (bert.py)
|
||
│
|
||
└─ 使用 Normalizer (utils.py)
|
||
```
|
||
|
||
---
|
||
|
||
## 如何应用到环境
|
||
|
||
### ✅ 已完成的准备工作
|
||
|
||
我已经为您:
|
||
|
||
1. **修复了PPO代码bug**:添加了缺失的`action_shape`参数
|
||
2. **创建了训练脚本**:`train_magail.py`
|
||
3. **提供了完整框架**:包含环境初始化、训练循环、模型保存等
|
||
|
||
### 🚀 快速开始
|
||
|
||
#### 方法1:使用训练脚本(推荐)
|
||
|
||
```bash
|
||
# 基本训练(使用默认参数)
|
||
python train_magail.py
|
||
|
||
# 自定义参数
|
||
python train_magail.py \
|
||
--data-dir /path/to/waymo/data \
|
||
--episodes 1000 \
|
||
--horizon 300 \
|
||
--batch-size 256 \
|
||
--lr-actor 3e-4 \
|
||
--render # 可视化
|
||
|
||
# 查看所有参数
|
||
python train_magail.py --help
|
||
```
|
||
|
||
#### 方法2:在Jupyter Notebook中使用
|
||
|
||
```python
|
||
import sys
|
||
sys.path.append('Algorithm')
|
||
sys.path.append('Env')
|
||
|
||
from Algorithm.magail import MAGAIL
|
||
from Env.scenario_env import MultiAgentScenarioEnv
|
||
|
||
# 初始化环境
|
||
env = MultiAgentScenarioEnv(config={...})
|
||
|
||
# 初始化MAGAIL
|
||
magail = MAGAIL(
|
||
buffer_exp=expert_buffer,
|
||
input_dim=(108,), # 观测维度
|
||
device="cuda"
|
||
)
|
||
|
||
# 训练循环
|
||
for episode in range(1000):
|
||
obs = env.reset()
|
||
for step in range(300):
|
||
actions, log_pis = magail.explore(obs)
|
||
next_obs, rewards, dones, infos = env.step(actions)
|
||
# ... 更新逻辑
|
||
```
|
||
|
||
---
|
||
|
||
## 完整训练流程
|
||
|
||
### 📊 数据流程图
|
||
|
||
```
|
||
┌─────────────────────────────────────────────────────────────┐
|
||
│ MAGAIL训练流程 │
|
||
└─────────────────────────────────────────────────────────────┘
|
||
|
||
第1步: 初始化
|
||
├─ 加载Waymo专家数据 → ExpertBuffer
|
||
├─ 创建MAGAIL算法实例
|
||
│ ├─ Actor (policy.py)
|
||
│ ├─ Critic (bert.py)
|
||
│ ├─ Discriminator (disc.py)
|
||
│ └─ Buffers (buffer.py)
|
||
└─ 创建多智能体环境
|
||
|
||
第2步: 训练循环
|
||
for episode in range(episodes):
|
||
├─ env.reset() → 重置环境,生成车辆
|
||
│
|
||
for step in range(horizon):
|
||
├─ obs = env._get_all_obs() # 收集观测
|
||
│
|
||
├─ actions = magail.explore(obs) # 策略采样
|
||
│
|
||
├─ next_obs, rewards, dones = env.step(actions)
|
||
│
|
||
├─ buffer.append(obs, actions, rewards, ...) # 存储经验
|
||
│
|
||
└─ if step % rollout_length == 0:
|
||
├─ 更新判别器
|
||
│ ├─ 采样策略数据: buffer.sample()
|
||
│ ├─ 采样专家数据: expert_buffer.sample()
|
||
│ └─ update_disc(policy_data, expert_data)
|
||
│
|
||
├─ 计算GAIL奖励
|
||
│ └─ reward = -log(1 - D(s, s'))
|
||
│
|
||
└─ 更新PPO
|
||
├─ 计算GAE优势
|
||
├─ update_actor()
|
||
└─ update_critic()
|
||
|
||
第3步: 评估与保存
|
||
└─ 保存模型、记录指标
|
||
```
|
||
|
||
### 🔑 关键代码段
|
||
|
||
#### 1. 初始化MAGAIL
|
||
|
||
```python
|
||
from Algorithm.magail import MAGAIL
|
||
|
||
magail = MAGAIL(
|
||
buffer_exp=expert_buffer, # 专家数据缓冲区
|
||
input_dim=(obs_dim,), # 观测维度 (108,)
|
||
device=device, # cuda/cpu
|
||
action_shape=(2,), # 动作维度 [转向, 油门]
|
||
|
||
# 判别器参数
|
||
disc_coef=20.0, # 判别器损失系数
|
||
disc_grad_penalty=0.1, # 梯度惩罚系数
|
||
disc_logit_reg=0.25, # Logit正则化
|
||
disc_weight_decay=0.0005, # 权重衰减
|
||
lr_disc=3e-4, # 判别器学习率
|
||
epoch_disc=5, # 判别器更新轮数
|
||
|
||
# PPO参数
|
||
rollout_length=2048, # 更新间隔
|
||
lr_actor=3e-4, # Actor学习率
|
||
lr_critic=3e-4, # Critic学习率
|
||
epoch_ppo=10, # PPO更新轮数
|
||
batch_size=256, # 批次大小
|
||
gamma=0.995, # 折扣因子
|
||
lambd=0.97, # GAE lambda
|
||
|
||
# 其他
|
||
use_gail_norm=True, # 使用数据标准化
|
||
)
|
||
```
|
||
|
||
#### 2. 环境交互
|
||
|
||
```python
|
||
# 重置环境
|
||
obs_list = env.reset(episode)
|
||
|
||
# 收集观测(所有车辆)
|
||
obs_array = np.array(env.obs_list) # shape: (n_agents, 108)
|
||
|
||
# 策略采样
|
||
actions, log_pis = magail.explore(obs_array)
|
||
# actions: list of [转向, 油门] for each agent
|
||
# log_pis: list of log probabilities
|
||
|
||
# 构建动作字典
|
||
action_dict = {
|
||
agent_id: actions[i]
|
||
for i, agent_id in enumerate(env.controlled_agents.keys())
|
||
}
|
||
|
||
# 环境步进
|
||
next_obs, rewards, dones, infos = env.step(action_dict)
|
||
```
|
||
|
||
#### 3. 模型更新
|
||
|
||
```python
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
|
||
writer = SummaryWriter('logs')
|
||
|
||
# 更新判别器和策略
|
||
if total_steps % rollout_length == 0:
|
||
# MAGAIL会自动:
|
||
# 1. 从buffer采样策略数据
|
||
# 2. 从expert_buffer采样专家数据
|
||
# 3. 更新判别器
|
||
# 4. 计算GAIL奖励
|
||
# 5. 更新PPO(Actor + Critic)
|
||
|
||
reward = magail.update(writer, total_steps)
|
||
|
||
print(f"Step {total_steps}, Reward: {reward:.4f}")
|
||
```
|
||
|
||
#### 4. 保存和加载模型
|
||
|
||
```python
|
||
# 保存
|
||
magail.save_models('outputs/models/checkpoint_100')
|
||
|
||
# 加载
|
||
magail.load_models('outputs/models/checkpoint_100/model.pth')
|
||
```
|
||
|
||
---
|
||
|
||
## 当前实现状态
|
||
|
||
### ✅ 已实现功能
|
||
|
||
| 模块 | 状态 | 说明 |
|
||
|------|------|------|
|
||
| BERT判别器 | ✅ 完整 | 支持动态车辆数量 |
|
||
| GAIL判别器 | ✅ 完整 | 包含梯度惩罚、正则化 |
|
||
| 策略网络 | ✅ 完整 | 高斯策略,重参数化 |
|
||
| PPO算法 | ✅ 完整 | GAE、裁剪目标、自适应LR |
|
||
| MAGAIL | ✅ 完整 | 判别器+PPO整合 |
|
||
| 缓冲区 | ✅ 完整 | 经验存储和采样 |
|
||
| 数据标准化 | ✅ 完整 | 运行时统计量 |
|
||
| 环境接口 | ✅ 完整 | 多智能体场景环境 |
|
||
|
||
### ⚠️ 需要注意的问题
|
||
|
||
#### 1. 多智能体适配问题
|
||
|
||
**当前状态:** Algorithm模块设计为单智能体,但环境是多智能体
|
||
|
||
**影响:**
|
||
- `buffer.append()` 接受单个状态-动作对
|
||
- 但环境返回多个智能体的数据
|
||
|
||
**解决方案A:** 将所有智能体视为一个整体
|
||
```python
|
||
# 拼接所有智能体的观测
|
||
all_obs = np.concatenate([obs for obs in obs_list])
|
||
all_actions = np.concatenate([actions for actions in action_list])
|
||
```
|
||
|
||
**解决方案B:** 为每个智能体独立存储
|
||
```python
|
||
for i, agent_id in enumerate(env.controlled_agents):
|
||
buffer.append(obs_list[i], actions[i], rewards[i], ...)
|
||
```
|
||
|
||
**推荐:** 解决方案B,因为MAGAIL的设计就是处理多智能体的
|
||
|
||
#### 2. 专家数据加载
|
||
|
||
**当前状态:** `ExpertBuffer` 类只有框架,未实现实际加载
|
||
|
||
**需要完善:**
|
||
```python
|
||
def _extract_trajectories(self, scenario_data):
|
||
"""
|
||
需要根据Waymo数据格式实现
|
||
|
||
示例结构:
|
||
scenario_data = {
|
||
'tracks': {
|
||
'vehicle_id': {
|
||
'states': [...], # 状态序列
|
||
'actions': [...], # 动作序列(如果有)
|
||
}
|
||
}
|
||
}
|
||
"""
|
||
# TODO: 提取state和next_state对
|
||
for track_id, track_data in scenario_data['tracks'].items():
|
||
states = track_data['states']
|
||
for i in range(len(states) - 1):
|
||
self.states.append(states[i])
|
||
self.next_states.append(states[i+1])
|
||
```
|
||
|
||
#### 3. 观测维度对齐
|
||
|
||
**当前假设:** 观测维度为108
|
||
- 位置(2) + 速度(2) + 朝向(1) + 激光雷达(80) + 侧向(10) + 车道线(10) + 红绿灯(1) + 目标点(2) = 108
|
||
|
||
**需要验证:** 实际运行时打印观测shape
|
||
```python
|
||
obs = env.reset()
|
||
print(f"观测维度: {len(obs[0]) if obs else 0}")
|
||
```
|
||
|
||
---
|
||
|
||
## 需要完善的部分
|
||
|
||
### 🔨 短期TODO
|
||
|
||
#### 1. 修复多智能体buffer问题
|
||
|
||
**创建文件:** `Algorithm/multi_agent_buffer.py`
|
||
|
||
```python
|
||
class MultiAgentRolloutBuffer:
|
||
"""
|
||
多智能体经验缓冲区
|
||
|
||
支持动态数量的智能体
|
||
"""
|
||
|
||
def __init__(self, buffer_size, state_shape, action_shape, device):
|
||
self.buffer_size = buffer_size
|
||
self.state_shape = state_shape
|
||
self.action_shape = action_shape
|
||
self.device = device
|
||
|
||
# 使用列表存储,支持动态智能体数量
|
||
self.episodes = []
|
||
self.current_episode = {
|
||
'states': [],
|
||
'actions': [],
|
||
'rewards': [],
|
||
'dones': [],
|
||
'log_pis': [],
|
||
'next_states': [],
|
||
}
|
||
|
||
def append(self, state, action, reward, done, log_pi, next_state):
|
||
"""添加单步经验"""
|
||
self.current_episode['states'].append(state)
|
||
self.current_episode['actions'].append(action)
|
||
self.current_episode['rewards'].append(reward)
|
||
self.current_episode['dones'].append(done)
|
||
self.current_episode['log_pis'].append(log_pi)
|
||
self.current_episode['next_states'].append(next_state)
|
||
|
||
def finish_episode(self):
|
||
"""完成一个episode"""
|
||
self.episodes.append(self.current_episode)
|
||
self.current_episode = {
|
||
'states': [],
|
||
'actions': [],
|
||
'rewards': [],
|
||
'dones': [],
|
||
'log_pis': [],
|
||
'next_states': [],
|
||
}
|
||
|
||
def sample(self, batch_size):
|
||
"""采样批次"""
|
||
# 从所有episode中随机采样
|
||
all_states = []
|
||
all_next_states = []
|
||
|
||
for episode in self.episodes:
|
||
all_states.extend(episode['states'])
|
||
all_next_states.extend(episode['next_states'])
|
||
|
||
indices = np.random.choice(len(all_states), batch_size, replace=False)
|
||
|
||
states = torch.tensor([all_states[i] for i in indices], device=self.device)
|
||
next_states = torch.tensor([all_next_states[i] for i in indices], device=self.device)
|
||
|
||
return states, next_states
|
||
```
|
||
|
||
#### 2. 实现专家数据加载
|
||
|
||
**需要了解:** Waymo数据的实际格式
|
||
|
||
```python
|
||
# 示例:读取一个pkl文件并打印结构
|
||
import pickle
|
||
|
||
with open('Env/exp_converted/exp_converted_0/sd_waymo_*.pkl', 'rb') as f:
|
||
data = pickle.load(f)
|
||
print(type(data))
|
||
print(data.keys() if isinstance(data, dict) else len(data))
|
||
# 根据实际结构调整加载代码
|
||
```
|
||
|
||
#### 3. 完善训练循环
|
||
|
||
**在 `train_magail.py` 中添加:**
|
||
|
||
```python
|
||
# 完整的buffer存储逻辑
|
||
for i, agent_id in enumerate(env.controlled_agents.keys()):
|
||
if i < len(obs_array) and i < len(actions):
|
||
magail.buffer.append(
|
||
state=obs_array[i],
|
||
action=actions[i],
|
||
reward=rewards.get(agent_id, 0.0),
|
||
done=dones.get(agent_id, False),
|
||
tm_done=dones.get(agent_id, False),
|
||
log_pi=log_pis[i],
|
||
next_state=next_obs_array[i] if i < len(next_obs_array) else obs_array[i],
|
||
next_state_gail=next_obs_array[i] if i < len(next_obs_array) else obs_array[i],
|
||
means=magail.actor.means[i].detach().cpu().numpy(),
|
||
stds=magail.actor.log_stds.exp()[0].detach().cpu().numpy()
|
||
)
|
||
```
|
||
|
||
### 🎯 中期TODO
|
||
|
||
1. **实现多智能体BERT**:当前BERT接受(batch, N, obs_dim),需要确保正确处理
|
||
2. **奖励设计**:当前环境奖励为0,需要设计合理的任务奖励
|
||
3. **评估脚本**:创建评估脚本,可视化训练好的策略
|
||
4. **超参数调优**:使用wandb或tensorboard进行超参数搜索
|
||
|
||
---
|
||
|
||
## 使用示例
|
||
|
||
### 示例1:简单训练
|
||
|
||
```bash
|
||
# 1. 确保环境正常
|
||
python Env/run_multiagent_env.py
|
||
|
||
# 2. 开始训练(不渲染,快速训练)
|
||
python train_magail.py \
|
||
--episodes 100 \
|
||
--horizon 200 \
|
||
--rollout-length 1024 \
|
||
--batch-size 128
|
||
|
||
# 3. 查看训练日志
|
||
tensorboard --logdir outputs/magail_*/logs
|
||
```
|
||
|
||
### 示例2:调试模式
|
||
|
||
```bash
|
||
# 少量episode,启用渲染
|
||
python train_magail.py \
|
||
--episodes 5 \
|
||
--horizon 100 \
|
||
--render
|
||
```
|
||
|
||
### 示例3:在代码中使用
|
||
|
||
```python
|
||
# test_algorithm.py
|
||
import sys
|
||
sys.path.append('Algorithm')
|
||
|
||
from Algorithm.magail import MAGAIL
|
||
import torch
|
||
|
||
# 创建虚拟数据测试
|
||
class DummyExpertBuffer:
|
||
def __init__(self, device):
|
||
self.device = device
|
||
|
||
def sample(self, batch_size):
|
||
states = torch.randn(batch_size, 108, device=self.device)
|
||
next_states = torch.randn(batch_size, 108, device=self.device)
|
||
return states, next_states
|
||
|
||
# 初始化
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
expert_buffer = DummyExpertBuffer(device)
|
||
|
||
magail = MAGAIL(
|
||
buffer_exp=expert_buffer,
|
||
input_dim=(108,),
|
||
device=device,
|
||
action_shape=(2,),
|
||
)
|
||
|
||
# 测试前向传播
|
||
test_obs = torch.randn(5, 108, device=device) # 5个智能体
|
||
actions, log_pis = magail.explore(test_obs)
|
||
|
||
print(f"观测形状: {test_obs.shape}")
|
||
print(f"动作数量: {len(actions)}")
|
||
print(f"单个动作形状: {actions[0].shape}")
|
||
print(f"测试成功!✅")
|
||
```
|
||
|
||
---
|
||
|
||
## 总结
|
||
|
||
### ✅ 现在可以做什么
|
||
|
||
1. **运行环境测试**:`run_multiagent_env.py` 已经可以正常运行
|
||
2. **测试算法模块**:Algorithm中的所有模块都已实现
|
||
3. **开始初步训练**:使用 `train_magail.py`(但需要完善buffer逻辑)
|
||
|
||
### ⚠️ 需要您完成的
|
||
|
||
1. **调试多智能体buffer**:确保经验正确存储
|
||
2. **实现专家数据加载**:根据实际数据格式调整
|
||
3. **验证观测维度**:确认实际观测是否为108维
|
||
4. **调整训练参数**:根据训练效果调优
|
||
|
||
### 🎯 最终目标
|
||
|
||
```
|
||
环境 (Env/) + 算法 (Algorithm/) = 完整的MAGAIL训练系统
|
||
↓
|
||
训练出能够模仿专家行为的
|
||
多智能体自动驾驶策略
|
||
```
|
||
|
||
祝训练顺利!🚀
|
||
|