1671 lines
53 KiB
Markdown
1671 lines
53 KiB
Markdown
# MAGAIL4AutoDrive 技术说明文档
|
||
|
||
## 目录
|
||
1. [项目概述](#项目概述)
|
||
2. [核心技术架构](#核心技术架构)
|
||
3. [算法模块详解](#算法模块详解)
|
||
4. [环境系统实现](#环境系统实现)
|
||
5. [数据流与训练流程](#数据流与训练流程)
|
||
6. [关键技术细节](#关键技术细节)
|
||
7. [使用指南](#使用指南)
|
||
|
||
---
|
||
|
||
## 项目概述
|
||
|
||
### 背景与动机
|
||
MAGAIL4AutoDrive(Multi-Agent Generative Adversarial Imitation Learning for Autonomous Driving)是一个针对多智能体自动驾驶场景的模仿学习框架。项目的核心创新在于将单智能体GAIL算法扩展到多智能体场景,解决了车辆数量动态变化时的学习问题。
|
||
|
||
### 核心挑战
|
||
1. **动态输入维度**:多智能体场景中车辆数量不固定,传统固定维度的神经网络无法直接应用
|
||
2. **全局交互建模**:需要同时考虑所有车辆的交互行为,而非独立建模
|
||
3. **真实数据利用**:如何有效利用Waymo等真实驾驶数据进行训练
|
||
|
||
### 技术方案
|
||
- **BERT架构判别器**:使用Transformer处理变长序列输入
|
||
- **GAIL框架**:通过对抗训练学习专家行为
|
||
- **PPO优化**:稳定的策略梯度方法
|
||
- **MetaDrive仿真**:高保真多智能体交通仿真环境
|
||
|
||
---
|
||
|
||
## 核心技术架构
|
||
|
||
### 整体架构图
|
||
|
||
```
|
||
┌─────────────────────────────────────────────────────────────┐
|
||
│ MAGAIL训练系统 │
|
||
├─────────────────────────────────────────────────────────────┤
|
||
│ │
|
||
│ ┌──────────────┐ ┌──────────────┐ │
|
||
│ │ 专家数据库 │ │ 策略缓冲区 │ │
|
||
│ │(Waymo轨迹) │ │(Agent经验) │ │
|
||
│ └──────┬───────┘ └──────┬───────┘ │
|
||
│ │ │ │
|
||
│ │ 状态-动作对 │ 状态-动作对 │
|
||
│ ▼ ▼ │
|
||
│ ┌──────────────────────────────────────┐ │
|
||
│ │ BERT判别器 (Discriminator) │ │
|
||
│ │ ┌────────────────────────────────┐ │ │
|
||
│ │ │ Input: (N, obs_dim) │ │ │
|
||
│ │ │ ↓ │ │ │
|
||
│ │ │ Linear Projection → embed_dim │ │ │
|
||
│ │ │ ↓ │ │ │
|
||
│ │ │ + Positional Encoding │ │ │
|
||
│ │ │ ↓ │ │ │
|
||
│ │ │ Transformer Layers (×4) │ │ │
|
||
│ │ │ ↓ │ │ │
|
||
│ │ │ Mean Pooling / CLS Token │ │ │
|
||
│ │ │ ↓ │ │ │
|
||
│ │ │ Output: Real/Fake Score │ │ │
|
||
│ │ └────────────────────────────────┘ │ │
|
||
│ └──────────────┬───────────────────────┘ │
|
||
│ │ │
|
||
│ │ Reward Signal │
|
||
│ ▼ │
|
||
│ ┌──────────────────────────────────────┐ │
|
||
│ │ PPO策略优化 (Policy) │ │
|
||
│ │ ┌────────────────────────────────┐ │ │
|
||
│ │ │ Actor Network (MLP) │ │ │
|
||
│ │ │ Input: state → Action dist │ │ │
|
||
│ │ ├────────────────────────────────┤ │ │
|
||
│ │ │ Critic Network (BERT) │ │ │
|
||
│ │ │ Input: state → Value estimate │ │ │
|
||
│ │ └────────────────────────────────┘ │ │
|
||
│ └──────────────┬───────────────────────┘ │
|
||
│ │ │
|
||
│ │ Actions │
|
||
│ ▼ │
|
||
│ ┌──────────────────────────────────────┐ │
|
||
│ │ MetaDrive多智能体环境 │ │
|
||
│ │ • 车辆动力学仿真 │ │
|
||
│ │ • 多维度传感器(激光雷达等) │ │
|
||
│ │ • 红绿灯、车道线等交通元素 │ │
|
||
│ └──────────────────────────────────────┘ │
|
||
│ │
|
||
└─────────────────────────────────────────────────────────────┘
|
||
```
|
||
|
||
---
|
||
|
||
## 算法模块详解
|
||
|
||
### 3.1 BERT判别器实现
|
||
|
||
#### 3.1.1 核心设计思想
|
||
|
||
BERT判别器是本项目的核心创新。传统GAIL的判别器使用固定维度的MLP,无法处理多智能体场景下车辆数量变化的问题。本项目采用Transformer架构,将多个车辆的观测视为序列,通过自注意力机制捕捉车辆间的交互。
|
||
|
||
#### 3.1.2 代码实现
|
||
|
||
**文件:`Algorithm/bert.py`**
|
||
|
||
```python
|
||
class Bert(nn.Module):
|
||
def __init__(self, input_dim, output_dim, embed_dim=128,
|
||
num_layers=4, ff_dim=512, num_heads=4, dropout=0.1,
|
||
CLS=False, TANH=False):
|
||
"""
|
||
BERT判别器/价值网络
|
||
|
||
参数说明:
|
||
- input_dim: 单个车辆的观测维度
|
||
- output_dim: 输出维度(判别器为1,价值网络为1)
|
||
- embed_dim: Transformer嵌入维度,默认128
|
||
- num_layers: Transformer层数,默认4层
|
||
- ff_dim: 前馈网络维度,默认512
|
||
- num_heads: 多头注意力头数,默认4
|
||
- CLS: 是否使用CLS token进行特征聚合
|
||
- TANH: 输出层是否使用Tanh激活
|
||
"""
|
||
super().__init__()
|
||
self.CLS = CLS
|
||
|
||
# 线性投影层:将观测维度映射到嵌入维度
|
||
self.projection = nn.Linear(input_dim, embed_dim)
|
||
|
||
# 位置编码:为每个车辆位置添加可学习的编码
|
||
if self.CLS:
|
||
# CLS模式:需要额外的CLS token位置
|
||
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
|
||
self.pos_embed = nn.Parameter(torch.randn(1, input_dim + 1, embed_dim))
|
||
else:
|
||
# 均值池化模式
|
||
self.pos_embed = nn.Parameter(torch.randn(1, input_dim, embed_dim))
|
||
|
||
# Transformer编码器层
|
||
self.layers = nn.ModuleList([
|
||
TransformerLayer(embed_dim, num_heads, ff_dim, dropout)
|
||
for _ in range(num_layers)
|
||
])
|
||
|
||
# 分类头
|
||
if TANH:
|
||
self.classifier = nn.Sequential(
|
||
nn.Linear(embed_dim, output_dim),
|
||
nn.Tanh()
|
||
)
|
||
else:
|
||
self.classifier = nn.Linear(embed_dim, output_dim)
|
||
|
||
def forward(self, x, mask=None):
|
||
"""
|
||
前向传播
|
||
|
||
输入:
|
||
- x: (batch_size, seq_len, input_dim)
|
||
seq_len = 车辆数量(动态变化)
|
||
input_dim = 单车辆观测维度
|
||
- mask: 可选的注意力掩码
|
||
|
||
输出:
|
||
- (batch_size, output_dim) 判别分数或价值估计
|
||
"""
|
||
# 步骤1: 线性投影
|
||
# 将每个车辆的观测映射到固定的嵌入空间
|
||
x = self.projection(x) # (batch_size, seq_len, embed_dim)
|
||
|
||
batch_size = x.size(0)
|
||
|
||
# 步骤2: 添加CLS token(如果启用)
|
||
if self.CLS:
|
||
# 复制CLS token到batch中的每个样本
|
||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||
x = torch.cat([cls_tokens, x], dim=1)
|
||
|
||
# 步骤3: 添加位置编码
|
||
# 让模型知道每个车辆在序列中的位置
|
||
x = x + self.pos_embed
|
||
|
||
# 步骤4: 转置为Transformer期望的格式
|
||
x = x.permute(1, 0, 2) # (seq_len, batch_size, embed_dim)
|
||
|
||
# 步骤5: 通过Transformer层
|
||
# 每层进行自注意力计算,捕捉车辆间的交互
|
||
for layer in self.layers:
|
||
x = layer(x, mask=mask)
|
||
|
||
# 步骤6: 特征聚合
|
||
if self.CLS:
|
||
# CLS模式:取CLS token的输出
|
||
return self.classifier(x[0, :, :])
|
||
else:
|
||
# 均值池化:对所有车辆特征求平均
|
||
pooled = x.mean(dim=0) # (batch_size, embed_dim)
|
||
return self.classifier(pooled)
|
||
```
|
||
|
||
**Transformer层实现:**
|
||
|
||
```python
|
||
class TransformerLayer(nn.Module):
|
||
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
|
||
"""
|
||
Transformer编码器层
|
||
|
||
结构:
|
||
1. 多头自注意力 + 残差连接 + LayerNorm
|
||
2. 前馈网络 + 残差连接 + LayerNorm
|
||
"""
|
||
super().__init__()
|
||
# 多头自注意力
|
||
self.self_attn = nn.MultiheadAttention(
|
||
embed_dim, num_heads, dropout=dropout
|
||
)
|
||
|
||
# 前馈网络
|
||
self.linear1 = nn.Linear(embed_dim, ff_dim)
|
||
self.linear2 = nn.Linear(ff_dim, embed_dim)
|
||
|
||
# 归一化层
|
||
self.norm1 = nn.LayerNorm(embed_dim)
|
||
self.norm2 = nn.LayerNorm(embed_dim)
|
||
|
||
self.dropout = nn.Dropout(dropout)
|
||
self.activation = nn.GELU()
|
||
|
||
def forward(self, x, mask=None):
|
||
"""
|
||
前向传播(Post-LN结构)
|
||
"""
|
||
# 注意力模块
|
||
attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
|
||
x = x + self.dropout(attn_output) # 残差连接
|
||
x = self.norm1(x) # 归一化
|
||
|
||
# 前馈网络模块
|
||
ff_output = self.linear2(
|
||
self.dropout(self.activation(self.linear1(x)))
|
||
)
|
||
x = x + self.dropout(ff_output) # 残差连接
|
||
x = self.norm2(x) # 归一化
|
||
|
||
return x
|
||
```
|
||
|
||
#### 3.1.3 关键技术点
|
||
|
||
1. **动态序列处理**:
|
||
- 输入维度为`(batch_size, N, obs_dim)`,其中N是车辆数量
|
||
- N可以在不同batch中变化,无需固定
|
||
|
||
2. **位置编码**:
|
||
- 使用可学习的位置编码而非正弦编码
|
||
- 让模型能够区分不同位置的车辆
|
||
|
||
3. **自注意力机制**:
|
||
- 计算每个车辆与其他车辆的注意力权重
|
||
- 捕捉车辆间的交互和影响关系
|
||
|
||
4. **特征聚合**:
|
||
- CLS模式:专门的分类token,类似BERT
|
||
- 均值池化:简单但有效的全局特征提取
|
||
|
||
---
|
||
|
||
### 3.2 GAIL判别器
|
||
|
||
#### 3.2.1 判别器设计
|
||
|
||
**文件:`Algorithm/disc.py`**
|
||
|
||
```python
|
||
class GAILDiscrim(Bert):
|
||
"""
|
||
GAIL判别器:继承自BERT架构
|
||
|
||
功能:
|
||
1. 区分专家数据和策略生成数据
|
||
2. 计算模仿学习的内在奖励
|
||
"""
|
||
|
||
def __init__(self, input_dim, reward_i_coef=1.0,
|
||
reward_t_coef=1.0, normalizer=None, device=None):
|
||
"""
|
||
初始化判别器
|
||
|
||
参数:
|
||
- input_dim: 输入维度(状态+下一状态)
|
||
- reward_i_coef: 内在奖励系数
|
||
- reward_t_coef: 任务奖励系数
|
||
"""
|
||
# 调用BERT构造函数,输出维度为1(真假分数)
|
||
super().__init__(input_dim=input_dim, output_dim=1, TANH=False)
|
||
|
||
self.device = device
|
||
self.reward_t_coef = reward_t_coef
|
||
self.reward_i_coef = reward_i_coef
|
||
self.normalizer = normalizer
|
||
|
||
def calculate_reward(self, states_gail, next_states_gail, rewards_t):
|
||
"""
|
||
计算GAIL奖励
|
||
|
||
GAIL的核心思想:
|
||
- 判别器D(s,s')输出越小,说明越像专家,奖励越高
|
||
- 使用 -log(1-D) 作为内在奖励
|
||
|
||
参数:
|
||
- states_gail: 当前状态
|
||
- next_states_gail: 下一状态
|
||
- rewards_t: 环境任务奖励
|
||
|
||
返回:
|
||
- rewards: 总奖励
|
||
- rewards_t: 归一化后的任务奖励
|
||
- rewards_i: 归一化后的内在奖励
|
||
"""
|
||
states_gail = states_gail.clone()
|
||
next_states_gail = next_states_gail.clone()
|
||
|
||
# 拼接状态转移对
|
||
states = torch.cat([states_gail, next_states_gail], dim=-1)
|
||
|
||
with torch.no_grad():
|
||
# 数据标准化
|
||
if self.normalizer is not None:
|
||
states = self.normalizer.normalize_torch(states, self.device)
|
||
|
||
# 缩放任务奖励
|
||
rewards_t = self.reward_t_coef * rewards_t
|
||
|
||
# 获取判别器输出(logit)
|
||
d = self.forward(states)
|
||
|
||
# 转换为概率:sigmoid(d) = 1/(1+exp(-d))
|
||
prob = 1 / (1 + torch.exp(-d))
|
||
|
||
# GAIL奖励公式:-log(1-D(s,s'))
|
||
# 当D(s,s')接近0(像专家),奖励高
|
||
# 当D(s,s')接近1(不像专家),奖励低
|
||
rewards_i = self.reward_i_coef * (
|
||
-torch.log(torch.maximum(
|
||
1 - prob,
|
||
torch.tensor(0.0001, device=self.device)
|
||
))
|
||
)
|
||
|
||
# 组合奖励
|
||
rewards = rewards_t + rewards_i
|
||
|
||
return (rewards,
|
||
rewards_t / (self.reward_t_coef + 1e-10),
|
||
rewards_i / (self.reward_i_coef + 1e-10))
|
||
|
||
def get_disc_logit_weights(self):
|
||
"""获取输出层权重(用于正则化)"""
|
||
return torch.flatten(self.classifier.weight)
|
||
|
||
def get_disc_weights(self):
|
||
"""获取所有层权重(用于权重衰减)"""
|
||
weights = []
|
||
for m in self.layers.modules():
|
||
if isinstance(m, nn.Linear):
|
||
weights.append(torch.flatten(m.weight))
|
||
weights.append(torch.flatten(self.classifier.weight))
|
||
return weights
|
||
```
|
||
|
||
#### 3.2.2 判别器训练
|
||
|
||
**文件:`Algorithm/magail.py`**
|
||
|
||
```python
|
||
def update_disc(self, states, states_exp, writer):
|
||
"""
|
||
更新判别器
|
||
|
||
目标:最大化 E_expert[log D] + E_policy[log(1-D)]
|
||
|
||
参数:
|
||
- states: 策略生成的状态转移
|
||
- states_exp: 专家演示的状态转移
|
||
"""
|
||
states_cp = states.clone()
|
||
states_exp_cp = states_exp.clone()
|
||
|
||
# 步骤1: 获取判别器输出
|
||
logits_pi = self.disc(states_cp) # 策略数据
|
||
logits_exp = self.disc(states_exp_cp) # 专家数据
|
||
|
||
# 步骤2: 计算对抗损失
|
||
# 希望:logits_pi < 0(策略被识别为假)
|
||
# logits_exp > 0(专家被识别为真)
|
||
loss_pi = -F.logsigmoid(-logits_pi).mean() # -log(1-sigmoid(logits_pi))
|
||
loss_exp = -F.logsigmoid(logits_exp).mean() # -log(sigmoid(logits_exp))
|
||
loss_disc = 0.5 * (loss_pi + loss_exp)
|
||
|
||
# 步骤3: Logit正则化
|
||
# 防止判别器输出过大,导致梯度爆炸
|
||
logit_weights = self.disc.get_disc_logit_weights()
|
||
disc_logit_loss = torch.sum(torch.square(logit_weights))
|
||
|
||
# 步骤4: 梯度惩罚(Gradient Penalty)
|
||
# 确保判别器满足Lipschitz约束,提高训练稳定性
|
||
sample_expert = states_exp_cp
|
||
sample_expert.requires_grad = True
|
||
|
||
# 对专家数据计算判别器输出
|
||
disc = self.disc.linear(self.disc.trunk(sample_expert))
|
||
ones = torch.ones(disc.size(), device=disc.device)
|
||
|
||
# 计算梯度
|
||
disc_demo_grad = torch.autograd.grad(
|
||
disc, sample_expert,
|
||
grad_outputs=ones,
|
||
create_graph=True,
|
||
retain_graph=True,
|
||
only_inputs=True
|
||
)[0]
|
||
|
||
# 梯度的L2范数
|
||
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
|
||
grad_pen_loss = torch.mean(disc_demo_grad)
|
||
|
||
# 步骤5: 权重衰减(L2正则化)
|
||
disc_weights = self.disc.get_disc_weights()
|
||
disc_weights = torch.cat(disc_weights, dim=-1)
|
||
disc_weight_decay = torch.sum(torch.square(disc_weights))
|
||
|
||
# 步骤6: 组合损失并更新
|
||
loss = (self.disc_coef * loss_disc +
|
||
self.disc_grad_penalty * grad_pen_loss +
|
||
self.disc_logit_reg * disc_logit_loss +
|
||
self.disc_weight_decay * disc_weight_decay)
|
||
|
||
self.optim_d.zero_grad()
|
||
loss.backward()
|
||
self.optim_d.step()
|
||
|
||
# 步骤7: 记录训练指标
|
||
if self.learning_steps_disc % self.epoch_disc == 0:
|
||
writer.add_scalar('Loss/disc', loss_disc.item(), self.learning_steps)
|
||
|
||
with torch.no_grad():
|
||
# 判别器准确率
|
||
acc_pi = (logits_pi < 0).float().mean().item() # 策略识别准确率
|
||
acc_exp = (logits_exp > 0).float().mean().item() # 专家识别准确率
|
||
|
||
writer.add_scalar('Acc/acc_pi', acc_pi, self.learning_steps)
|
||
writer.add_scalar('Acc/acc_exp', acc_exp, self.learning_steps)
|
||
```
|
||
|
||
#### 3.2.3 关键技术细节
|
||
|
||
**1. 梯度惩罚的作用**
|
||
```python
|
||
# 梯度惩罚确保判别器是Lipschitz连续的
|
||
# 即:|D(x1) - D(x2)| ≤ K|x1 - x2|
|
||
# 这防止判别器变化过于剧烈,提高训练稳定性
|
||
```
|
||
|
||
**2. 为什么使用logit而非概率**
|
||
```python
|
||
# 使用logit(未经sigmoid的输出)有几个优点:
|
||
# 1. 数值稳定性:避免log(0)等问题
|
||
# 2. 梯度更好:sigmoid饱和区梯度消失
|
||
# 3. 理论保证:GAIL理论基于logit形式
|
||
```
|
||
|
||
---
|
||
|
||
### 3.3 PPO策略优化
|
||
|
||
#### 3.3.1 策略网络
|
||
|
||
**文件:`Algorithm/policy.py`**
|
||
|
||
```python
|
||
class StateIndependentPolicy(nn.Module):
|
||
"""
|
||
状态独立策略(对角高斯策略)
|
||
|
||
输出:高斯分布的均值,标准差是可学习参数
|
||
"""
|
||
|
||
def __init__(self, state_shape, action_shape,
|
||
hidden_units=(64, 64),
|
||
hidden_activation=nn.Tanh()):
|
||
super().__init__()
|
||
|
||
# 均值网络(MLP)
|
||
self.net = build_mlp(
|
||
input_dim=state_shape[0],
|
||
output_dim=action_shape[0],
|
||
hidden_units=hidden_units,
|
||
hidden_activation=hidden_activation
|
||
)
|
||
|
||
# 可学习的对数标准差
|
||
self.log_stds = nn.Parameter(torch.zeros(1, action_shape[0]))
|
||
self.means = None
|
||
|
||
def forward(self, states):
|
||
"""
|
||
确定性前向传播(用于评估)
|
||
"""
|
||
return torch.tanh(self.net(states))
|
||
|
||
def sample(self, states):
|
||
"""
|
||
从策略分布中采样动作
|
||
|
||
使用重参数化技巧:
|
||
a = tanh(μ + σ * ε), ε ~ N(0,1)
|
||
"""
|
||
self.means = self.net(states)
|
||
actions, log_pis = reparameterize(self.means, self.log_stds)
|
||
return actions, log_pis
|
||
|
||
def evaluate_log_pi(self, states, actions):
|
||
"""
|
||
计算给定状态-动作对的对数概率
|
||
"""
|
||
self.means = self.net(states)
|
||
return evaluate_lop_pi(self.means, self.log_stds, actions)
|
||
```
|
||
|
||
**重参数化技巧:**
|
||
|
||
```python
|
||
def reparameterize(means, log_stds):
|
||
"""
|
||
重参数化采样
|
||
|
||
原理:
|
||
不直接从N(μ,σ²)采样,而是:
|
||
1. 从N(0,1)采样噪声ε
|
||
2. 计算 z = μ + σ * ε
|
||
3. 应用tanh:a = tanh(z)
|
||
|
||
优点:
|
||
- 梯度可以通过μ和σ反向传播
|
||
- 支持梯度下降优化
|
||
"""
|
||
noises = torch.randn_like(means) # ε ~ N(0,1)
|
||
us = means + noises * log_stds.exp() # z = μ + σε
|
||
actions = torch.tanh(us) # a = tanh(z)
|
||
|
||
# 计算对数概率(需要考虑tanh的雅可比行列式)
|
||
return actions, calculate_log_pi(log_stds, noises, actions)
|
||
|
||
def calculate_log_pi(log_stds, noises, actions):
|
||
"""
|
||
计算tanh高斯分布的对数概率
|
||
|
||
公式:
|
||
log π(a|s) = log N(u|μ,σ²) - log|1 - tanh²(u)|
|
||
= -0.5||ε||² - log σ - 0.5log(2π) - Σlog(1-a²)
|
||
"""
|
||
# 高斯分布的对数概率
|
||
gaussian_log_probs = (
|
||
-0.5 * noises.pow(2) - log_stds
|
||
).sum(dim=-1, keepdim=True) - 0.5 * math.log(2 * math.pi) * log_stds.size(-1)
|
||
|
||
# tanh变换的雅可比修正
|
||
# d/du tanh(u) = 1 - tanh²(u)
|
||
return gaussian_log_probs - torch.log(
|
||
1 - actions.pow(2) + 1e-6
|
||
).sum(dim=-1, keepdim=True)
|
||
```
|
||
|
||
#### 3.3.2 PPO更新
|
||
|
||
**文件:`Algorithm/ppo.py`**
|
||
|
||
```python
|
||
def update_ppo(self, states, actions, rewards, dones, tm_dones,
|
||
log_pi_list, next_states, mus, sigmas, writer, total_steps):
|
||
"""
|
||
PPO策略和价值网络更新
|
||
"""
|
||
# 步骤1: 计算价值估计和优势函数
|
||
with torch.no_grad():
|
||
values = self.critic(states.detach())
|
||
next_values = self.critic(next_states.detach())
|
||
|
||
# GAE(广义优势估计)
|
||
targets, gaes = self.calculate_gae(
|
||
values, rewards, dones, tm_dones, next_values,
|
||
self.gamma, self.lambd
|
||
)
|
||
|
||
# 步骤2: 多轮更新
|
||
state_list = states.permute(1, 0, 2)
|
||
action_list = actions.permute(1, 0, 2)
|
||
|
||
for i in range(self.epoch_ppo):
|
||
self.learning_steps_ppo += 1
|
||
|
||
# 更新价值网络
|
||
self.update_critic(states, targets, writer)
|
||
|
||
# 更新策略网络
|
||
for state, action, log_pi in zip(state_list, action_list, log_pi_list):
|
||
self.update_actor(
|
||
state, action, log_pi, gaes, mus, sigmas, writer
|
||
)
|
||
|
||
def calculate_gae(self, values, rewards, dones, tm_dones,
|
||
next_values, gamma, lambd):
|
||
"""
|
||
计算广义优势估计(GAE)
|
||
|
||
公式:
|
||
δt = r + γV(s') - V(s)
|
||
At = Σ(γλ)^k δt+k
|
||
|
||
参数:
|
||
- gamma: 折扣因子
|
||
- lambd: GAE的λ参数(权衡偏差-方差)
|
||
"""
|
||
with torch.no_grad():
|
||
# TD误差
|
||
deltas = rewards + gamma * next_values * (1 - tm_dones) - values
|
||
|
||
# 初始化优势
|
||
gaes = torch.empty_like(rewards)
|
||
|
||
# 从后往前计算GAE
|
||
gaes[-1] = deltas[-1]
|
||
for t in reversed(range(rewards.size(0) - 1)):
|
||
gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1]
|
||
|
||
# 价值目标
|
||
v_target = gaes + values
|
||
|
||
# 优势标准化
|
||
if self.use_adv_norm:
|
||
gaes = (gaes - gaes.mean()) / (gaes.std(dim=0) + 1e-8)
|
||
|
||
return v_target, gaes
|
||
|
||
def update_actor(self, states, actions, log_pis_old, gaes,
|
||
mus_old, sigmas_old, writer):
|
||
"""
|
||
更新Actor网络
|
||
|
||
PPO裁剪目标:
|
||
L = min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A)
|
||
其中 r(θ) = π_new/π_old
|
||
"""
|
||
self.optim_actor.zero_grad()
|
||
|
||
# 新策略的对数概率
|
||
log_pis = self.actor.evaluate_log_pi(states, actions)
|
||
mus = self.actor.means
|
||
sigmas = (self.actor.log_stds.exp()).repeat(mus.shape[0], 1)
|
||
|
||
# 熵(鼓励探索)
|
||
entropy = -log_pis.mean()
|
||
|
||
# 重要性采样比率
|
||
ratios = (log_pis - log_pis_old).exp_()
|
||
|
||
# PPO裁剪目标
|
||
loss_actor1 = -ratios * gaes
|
||
loss_actor2 = -torch.clamp(
|
||
ratios,
|
||
1.0 - self.clip_eps,
|
||
1.0 + self.clip_eps
|
||
) * gaes
|
||
loss_actor = torch.max(loss_actor1, loss_actor2).mean()
|
||
loss_actor = loss_actor * self.surrogate_loss_coef
|
||
|
||
# 自适应学习率(基于KL散度)
|
||
if self.auto_lr:
|
||
with torch.inference_mode():
|
||
# 计算KL散度:KL(old||new)
|
||
kl = torch.sum(
|
||
torch.log(sigmas / sigmas_old + 1.e-5) +
|
||
(torch.square(sigmas_old) + torch.square(mus_old - mus)) /
|
||
(2.0 * torch.square(sigmas)) - 0.5,
|
||
axis=-1
|
||
)
|
||
kl_mean = torch.mean(kl)
|
||
|
||
# 调整学习率
|
||
if kl_mean > self.desired_kl * 2.0:
|
||
# KL过大,降低学习率
|
||
self.lr_actor = max(1e-5, self.lr_actor / 1.5)
|
||
self.lr_critic = max(1e-5, self.lr_critic / 1.5)
|
||
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
|
||
# KL过小,提高学习率
|
||
self.lr_actor = min(1e-2, self.lr_actor * 1.5)
|
||
self.lr_critic = min(1e-2, self.lr_critic * 1.5)
|
||
|
||
# 更新优化器学习率
|
||
for param_group in self.optim_actor.param_groups:
|
||
param_group['lr'] = self.lr_actor
|
||
for param_group in self.optim_critic.param_groups:
|
||
param_group['lr'] = self.lr_critic
|
||
|
||
# 反向传播
|
||
loss = loss_actor
|
||
loss.backward()
|
||
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
||
self.optim_actor.step()
|
||
```
|
||
|
||
---
|
||
|
||
## 环境系统实现
|
||
|
||
### 4.1 多智能体场景环境
|
||
|
||
#### 4.1.1 环境设计
|
||
|
||
**文件:`Env/scenario_env.py`**
|
||
|
||
```python
|
||
class MultiAgentScenarioEnv(ScenarioEnv):
|
||
"""
|
||
多智能体场景环境
|
||
|
||
继承自MetaDrive的ScenarioEnv,扩展为多智能体场景
|
||
|
||
核心功能:
|
||
1. 从专家数据动态生成车辆
|
||
2. 收集多维度观测
|
||
3. 管理多智能体交互
|
||
"""
|
||
|
||
@classmethod
|
||
def default_config(cls):
|
||
config = super().default_config()
|
||
config.update(dict(
|
||
data_directory=None, # 专家数据目录
|
||
num_controlled_agents=3, # 可控车辆数量
|
||
horizon=1000, # 场景时间步
|
||
))
|
||
return config
|
||
|
||
def __init__(self, config, agent2policy):
|
||
"""
|
||
初始化环境
|
||
|
||
参数:
|
||
- config: 环境配置
|
||
- agent2policy: 为每个智能体分配的策略
|
||
"""
|
||
self.policy = agent2policy
|
||
self.controlled_agents = {} # 可控车辆字典
|
||
self.controlled_agent_ids = [] # 可控车辆ID列表
|
||
self.obs_list = [] # 观测列表
|
||
self.round = 0 # 当前时间步
|
||
super().__init__(config)
|
||
```
|
||
|
||
#### 4.1.2 环境重置与车辆生成
|
||
|
||
```python
|
||
def reset(self, seed: Union[None, int] = None):
|
||
"""
|
||
重置环境
|
||
|
||
流程:
|
||
1. 解析专家数据中的车辆轨迹
|
||
2. 提取车辆生成信息
|
||
3. 清理原始数据
|
||
4. 初始化场景
|
||
5. 生成第一批车辆
|
||
"""
|
||
self.round = 0
|
||
|
||
# 日志初始化
|
||
if self.logger is None:
|
||
self.logger = get_logger()
|
||
log_level = self.config.get("log_level", logging.INFO)
|
||
set_log_level(log_level)
|
||
|
||
self.lazy_init()
|
||
self._reset_global_seed(seed)
|
||
|
||
# 步骤1: 解析专家数据
|
||
_obj_to_clean_this_frame = []
|
||
self.car_birth_info_list = []
|
||
|
||
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
||
# 跳过自车
|
||
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
|
||
continue
|
||
|
||
# 只处理车辆类型
|
||
if track["type"] == MetaDriveType.VEHICLE:
|
||
_obj_to_clean_this_frame.append(scenario_id)
|
||
|
||
# 获取有效帧
|
||
valid = track['state']['valid']
|
||
first_show = np.argmax(valid) if valid.any() else -1
|
||
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
||
|
||
# 提取关键信息
|
||
self.car_birth_info_list.append({
|
||
'id': track['metadata']['object_id'], # 车辆ID
|
||
'show_time': first_show, # 出现时间
|
||
'begin': ( # 起点位置
|
||
track['state']['position'][first_show, 0],
|
||
track['state']['position'][first_show, 1]
|
||
),
|
||
'heading': track['state']['heading'][first_show], # 初始朝向
|
||
'end': ( # 终点位置
|
||
track['state']['position'][last_show, 0],
|
||
track['state']['position'][last_show, 1]
|
||
)
|
||
})
|
||
|
||
# 步骤2: 清理原始数据(避免重复生成)
|
||
for scenario_id in _obj_to_clean_this_frame:
|
||
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
|
||
|
||
# 步骤3: 重置引擎
|
||
self.engine.reset()
|
||
self.reset_sensors()
|
||
self.engine.taskMgr.step()
|
||
|
||
# 步骤4: 获取车道网络(用于红绿灯检测)
|
||
self.lanes = self.engine.map_manager.current_map.road_network.graph
|
||
|
||
# 步骤5: 清理旧状态
|
||
if self.top_down_renderer is not None:
|
||
self.top_down_renderer.clear()
|
||
self.engine.top_down_renderer = None
|
||
|
||
self.dones = {}
|
||
self.episode_rewards = defaultdict(float)
|
||
self.episode_lengths = defaultdict(int)
|
||
|
||
self.controlled_agents.clear()
|
||
self.controlled_agent_ids.clear()
|
||
|
||
# 步骤6: 初始化场景并生成第一批车辆
|
||
super().reset(seed)
|
||
self._spawn_controlled_agents()
|
||
|
||
return self._get_all_obs()
|
||
|
||
def _spawn_controlled_agents(self):
|
||
"""
|
||
动态生成可控车辆
|
||
|
||
根据专家数据中记录的时间戳,在正确的时间点生成车辆
|
||
"""
|
||
for car in self.car_birth_info_list:
|
||
# 检查是否到了该车辆的出现时间
|
||
if car['show_time'] == self.round:
|
||
agent_id = f"controlled_{car['id']}"
|
||
|
||
# 生成车辆实例
|
||
vehicle = self.engine.spawn_object(
|
||
PolicyVehicle, # 自定义车辆类
|
||
vehicle_config={},
|
||
position=car['begin'], # 初始位置
|
||
heading=car['heading'] # 初始朝向
|
||
)
|
||
|
||
# 重置车辆状态
|
||
vehicle.reset(position=car['begin'], heading=car['heading'])
|
||
|
||
# 设置策略和目标
|
||
vehicle.set_policy(self.policy)
|
||
vehicle.set_destination(car['end'])
|
||
|
||
# 注册车辆
|
||
self.controlled_agents[agent_id] = vehicle
|
||
self.controlled_agent_ids.append(agent_id)
|
||
|
||
# 关键:注册到引擎的active_agents才能参与物理更新
|
||
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
||
```
|
||
|
||
#### 4.1.3 观测系统
|
||
|
||
```python
|
||
def _get_all_obs(self):
|
||
"""
|
||
收集所有可控车辆的观测
|
||
|
||
观测维度构成:
|
||
- 位置: 2D (x, y)
|
||
- 速度: 2D (vx, vy)
|
||
- 朝向: 1D (θ)
|
||
- 前向激光雷达: 80D (距离)
|
||
- 侧向检测器: 10D (距离)
|
||
- 车道线检测: 10D (距离)
|
||
- 红绿灯: 1D (0-3编码)
|
||
- 导航: 2D (目标点坐标)
|
||
|
||
总维度: 2+2+1+80+10+10+1+2 = 108D
|
||
"""
|
||
self.obs_list = []
|
||
|
||
for agent_id, vehicle in self.controlled_agents.items():
|
||
# 获取车辆基础状态
|
||
state = vehicle.get_state()
|
||
|
||
# 红绿灯检测
|
||
traffic_light = 0
|
||
for lane in self.lanes.values():
|
||
# 检查车辆是否在车道上
|
||
if lane.lane.point_on_lane(state['position'][:2]):
|
||
# 检查该车道是否有红绿灯
|
||
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
||
traffic_light = self.engine.light_manager._lane_index_to_obj[
|
||
lane.lane.index
|
||
].status
|
||
|
||
# 状态编码
|
||
if traffic_light == 'TRAFFIC_LIGHT_GREEN':
|
||
traffic_light = 1
|
||
elif traffic_light == 'TRAFFIC_LIGHT_YELLOW':
|
||
traffic_light = 2
|
||
elif traffic_light == 'TRAFFIC_LIGHT_RED':
|
||
traffic_light = 3
|
||
else:
|
||
traffic_light = 0
|
||
break
|
||
|
||
# 激光雷达感知
|
||
# 前向激光雷达:80束,30米距离,检测动态物体
|
||
lidar = self.engine.get_sensor("lidar").perceive(
|
||
num_lasers=80,
|
||
distance=30,
|
||
base_vehicle=vehicle,
|
||
physics_world=self.engine.physics_world.dynamic_world
|
||
)
|
||
|
||
# 侧向检测器:10束,8米距离,检测静态障碍物
|
||
side_lidar = self.engine.get_sensor("side_detector").perceive(
|
||
num_lasers=10,
|
||
distance=8,
|
||
base_vehicle=vehicle,
|
||
physics_world=self.engine.physics_world.static_world
|
||
)
|
||
|
||
# 车道线检测器:10束,3米距离,检测车道边界
|
||
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(
|
||
num_lasers=10,
|
||
distance=3,
|
||
base_vehicle=vehicle,
|
||
physics_world=self.engine.physics_world.static_world
|
||
)
|
||
|
||
# 组合观测向量
|
||
obs = (
|
||
state['position'][:2] + # 位置 (x, y)
|
||
list(state['velocity']) + # 速度 (vx, vy)
|
||
[state['heading_theta']] + # 朝向 θ
|
||
lidar[0] + # 激光雷达 (80D)
|
||
side_lidar[0] + # 侧向检测 (10D)
|
||
lane_line_lidar[0] + # 车道线 (10D)
|
||
[traffic_light] + # 红绿灯 (1D)
|
||
list(vehicle.destination) # 目标点 (x, y)
|
||
)
|
||
|
||
self.obs_list.append(obs)
|
||
|
||
return self.obs_list
|
||
```
|
||
|
||
#### 4.1.4 环境步进
|
||
|
||
```python
|
||
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
|
||
"""
|
||
执行一步仿真
|
||
|
||
流程:
|
||
1. 应用动作到车辆
|
||
2. 运行物理引擎
|
||
3. 车辆后处理
|
||
4. 生成新车辆
|
||
5. 收集观测和奖励
|
||
|
||
参数:
|
||
- action_dict: {agent_id: action} 字典
|
||
|
||
返回:
|
||
- obs: 观测列表
|
||
- rewards: 奖励字典
|
||
- dones: 完成标志字典
|
||
- infos: 信息字典
|
||
"""
|
||
self.round += 1
|
||
|
||
# 步骤1: 应用动作
|
||
for agent_id, action in action_dict.items():
|
||
if agent_id in self.controlled_agents:
|
||
self.controlled_agents[agent_id].before_step(action)
|
||
|
||
# 步骤2: 物理仿真
|
||
self.engine.step()
|
||
|
||
# 步骤3: 车辆后处理
|
||
for agent_id in action_dict:
|
||
if agent_id in self.controlled_agents:
|
||
self.controlled_agents[agent_id].after_step()
|
||
|
||
# 步骤4: 动态生成新车辆
|
||
self._spawn_controlled_agents()
|
||
|
||
# 步骤5: 收集观测
|
||
obs = self._get_all_obs()
|
||
|
||
# 步骤6: 计算奖励和完成标志
|
||
rewards = {aid: 0.0 for aid in self.controlled_agents}
|
||
dones = {aid: False for aid in self.controlled_agents}
|
||
dones["__all__"] = self.episode_step >= self.config["horizon"]
|
||
infos = {aid: {} for aid in self.controlled_agents}
|
||
|
||
return obs, rewards, dones, infos
|
||
```
|
||
|
||
### 4.2 自定义车辆类
|
||
|
||
```python
|
||
class PolicyVehicle(DefaultVehicle):
|
||
"""
|
||
策略控制车辆
|
||
|
||
扩展MetaDrive的默认车辆,添加策略和目标点
|
||
"""
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.policy = None # 控制策略
|
||
self.destination = None # 目标点
|
||
|
||
def set_policy(self, policy):
|
||
"""设置控制策略"""
|
||
self.policy = policy
|
||
|
||
def set_destination(self, des):
|
||
"""设置目标点"""
|
||
self.destination = des
|
||
|
||
def act(self, observation, policy=None):
|
||
"""
|
||
执行动作
|
||
|
||
如果有策略,使用策略;否则随机动作
|
||
"""
|
||
if self.policy is not None:
|
||
return self.policy.act(observation)
|
||
else:
|
||
return self.action_space.sample()
|
||
|
||
def before_step(self, action):
|
||
"""
|
||
步进前处理
|
||
|
||
记录历史状态并应用动作
|
||
"""
|
||
self.last_position = self.position
|
||
self.last_velocity = self.velocity
|
||
self.last_speed = self.speed
|
||
self.last_heading_dir = self.heading
|
||
|
||
if action is not None:
|
||
self.last_current_action.append(action)
|
||
|
||
# 将动作转换为车辆控制指令
|
||
self._set_action(action)
|
||
|
||
# 注册车辆类型
|
||
vehicle_class_to_type[PolicyVehicle] = "default"
|
||
```
|
||
|
||
---
|
||
|
||
## 数据流与训练流程
|
||
|
||
### 5.1 完整训练流程
|
||
|
||
```python
|
||
"""
|
||
训练流程伪代码
|
||
"""
|
||
|
||
# 1. 初始化
|
||
env = MultiAgentScenarioEnv(config, policy)
|
||
magail = MAGAIL(buffer_exp, input_dim, device)
|
||
|
||
# 2. 加载专家数据
|
||
buffer_exp = load_expert_data("waymo_dataset")
|
||
|
||
# 3. 训练循环
|
||
for episode in range(max_episodes):
|
||
# 3.1 重置环境
|
||
obs_list = env.reset()
|
||
|
||
for step in range(max_steps):
|
||
# 3.2 策略采样动作
|
||
actions, log_pis = magail.explore(obs_list)
|
||
|
||
# 3.3 环境交互
|
||
next_obs_list, rewards, dones, infos = env.step(actions)
|
||
|
||
# 3.4 存储经验
|
||
magail.buffer.append(
|
||
obs_list, actions, rewards, dones,
|
||
log_pis, next_obs_list
|
||
)
|
||
|
||
obs_list = next_obs_list
|
||
|
||
# 3.5 判断是否更新
|
||
if magail.is_update(step):
|
||
# 3.5.1 更新判别器
|
||
for _ in range(epoch_disc):
|
||
# 采样策略数据
|
||
states_policy = magail.buffer.sample(batch_size)
|
||
# 采样专家数据
|
||
states_expert = buffer_exp.sample(batch_size)
|
||
# 更新判别器
|
||
magail.update_disc(states_policy, states_expert)
|
||
|
||
# 3.5.2 计算GAIL奖励
|
||
rewards_gail = magail.disc.calculate_reward(states, next_states)
|
||
|
||
# 3.5.3 更新PPO
|
||
magail.update_ppo(states, actions, rewards_gail, ...)
|
||
|
||
# 3.5.4 清空缓冲区
|
||
magail.buffer.clear()
|
||
|
||
if dones["__all__"]:
|
||
break
|
||
|
||
# 3.6 保存模型
|
||
if episode % save_interval == 0:
|
||
magail.save_models(save_path)
|
||
```
|
||
|
||
### 5.2 数据流图
|
||
|
||
```
|
||
专家数据 策略数据
|
||
↓ ↓
|
||
┌──────────────┐ ┌──────────────┐
|
||
│ Expert Buffer│ │Policy Buffer │
|
||
│ (s,a,s') │ │ (s,a,s') │
|
||
└──────┬───────┘ └──────┬───────┘
|
||
│ │
|
||
│ 采样 │ 采样
|
||
▼ ▼
|
||
┌────────────────────────────┐
|
||
│ BERT Discriminator │
|
||
│ Input: (N, obs_dim*2) │
|
||
│ Output: Real/Fake Score │
|
||
└────────────┬───────────────┘
|
||
│
|
||
│ 梯度反向传播
|
||
▼
|
||
┌──────────────┐
|
||
│ Disc Loss │
|
||
│ + Grad Pen │
|
||
│ + Logit Reg │
|
||
│ + Weight Dec │
|
||
└──────────────┘
|
||
|
||
策略数据
|
||
↓
|
||
┌──────────────┐
|
||
│Policy Buffer │
|
||
│ (s,a,r,s') │
|
||
└──────┬───────┘
|
||
│
|
||
│ r_GAIL = -log(1-D(s,s'))
|
||
▼
|
||
┌────────────────────────────┐
|
||
│ PPO Optimization │
|
||
│ Actor: MLP │
|
||
│ Critic: BERT │
|
||
└────────────┬───────────────┘
|
||
│
|
||
│ 策略改进
|
||
▼
|
||
┌──────────────┐
|
||
│ Environment │
|
||
│ Interaction │
|
||
└──────────────┘
|
||
```
|
||
|
||
### 5.3 缓冲区管理
|
||
|
||
**文件:`Algorithm/buffer.py`**
|
||
|
||
```python
|
||
class RolloutBuffer:
|
||
"""
|
||
滚动缓冲区
|
||
|
||
存储策略与环境交互的经验
|
||
"""
|
||
|
||
def __init__(self, buffer_size, state_shape, action_shape, device):
|
||
"""
|
||
初始化缓冲区
|
||
|
||
参数:
|
||
- buffer_size: 缓冲区大小(通常等于rollout_length)
|
||
- state_shape: 状态维度
|
||
- action_shape: 动作维度
|
||
- device: 存储设备(CPU/GPU)
|
||
"""
|
||
self._n = 0 # 当前存储数量
|
||
self._p = 0 # 当前写入位置
|
||
self.buffer_size = buffer_size
|
||
|
||
# 预分配张量(提高效率)
|
||
self.states = torch.empty(
|
||
(self.buffer_size, *state_shape),
|
||
dtype=torch.float, device=device
|
||
)
|
||
self.actions = torch.empty(
|
||
(self.buffer_size, *action_shape),
|
||
dtype=torch.float, device=device
|
||
)
|
||
self.rewards = torch.empty(
|
||
(self.buffer_size, 1),
|
||
dtype=torch.float, device=device
|
||
)
|
||
self.dones = torch.empty(
|
||
(self.buffer_size, 1),
|
||
dtype=torch.int, device=device
|
||
)
|
||
self.tm_dones = torch.empty(
|
||
(self.buffer_size, 1),
|
||
dtype=torch.int, device=device
|
||
)
|
||
self.log_pis = torch.empty(
|
||
(self.buffer_size, 1),
|
||
dtype=torch.float, device=device
|
||
)
|
||
self.next_states = torch.empty(
|
||
(self.buffer_size, *state_shape),
|
||
dtype=torch.float, device=device
|
||
)
|
||
self.means = torch.empty(
|
||
(self.buffer_size, *action_shape),
|
||
dtype=torch.float, device=device
|
||
)
|
||
self.stds = torch.empty(
|
||
(self.buffer_size, *action_shape),
|
||
dtype=torch.float, device=device
|
||
)
|
||
|
||
def append(self, state, action, reward, done, tm_dones,
|
||
log_pi, next_state, next_state_gail, means, stds):
|
||
"""
|
||
添加经验
|
||
|
||
使用循环缓冲区,自动覆盖旧数据
|
||
"""
|
||
self.states[self._p].copy_(state)
|
||
self.actions[self._p].copy_(torch.from_numpy(action))
|
||
self.rewards[self._p] = float(reward)
|
||
self.dones[self._p] = int(done)
|
||
self.tm_dones[self._p] = int(tm_dones)
|
||
self.log_pis[self._p] = float(log_pi)
|
||
self.next_states[self._p].copy_(torch.from_numpy(next_state))
|
||
self.means[self._p].copy_(torch.from_numpy(means))
|
||
self.stds[self._p].copy_(torch.from_numpy(stds))
|
||
|
||
# 更新指针
|
||
self._p = (self._p + 1) % self.buffer_size
|
||
self._n = min(self._n + 1, self.buffer_size)
|
||
|
||
def get(self):
|
||
"""
|
||
获取所有数据(用于PPO更新)
|
||
"""
|
||
assert self._p % self.buffer_size == 0
|
||
idxes = slice(0, self.buffer_size)
|
||
return (
|
||
self.states[idxes],
|
||
self.actions[idxes],
|
||
self.rewards[idxes],
|
||
self.dones[idxes],
|
||
self.tm_dones[idxes],
|
||
self.log_pis[idxes],
|
||
self.next_states[idxes],
|
||
self.means[idxes],
|
||
self.stds[idxes]
|
||
)
|
||
|
||
def sample(self, batch_size):
|
||
"""
|
||
随机采样批次(用于判别器更新)
|
||
"""
|
||
assert self._p % self.buffer_size == 0
|
||
idxes = np.random.randint(low=0, high=self._n, size=batch_size)
|
||
return (
|
||
self.states[idxes],
|
||
self.actions[idxes],
|
||
self.rewards[idxes],
|
||
self.dones[idxes],
|
||
self.tm_dones[idxes],
|
||
self.log_pis[idxes],
|
||
self.next_states[idxes],
|
||
self.means[idxes],
|
||
self.stds[idxes]
|
||
)
|
||
|
||
def clear(self):
|
||
"""清空缓冲区"""
|
||
self.states[:, :] = 0
|
||
self.actions[:, :] = 0
|
||
self.rewards[:, :] = 0
|
||
self.dones[:, :] = 0
|
||
self.tm_dones[:, :] = 0
|
||
self.log_pis[:, :] = 0
|
||
self.next_states[:, :] = 0
|
||
self.means[:, :] = 0
|
||
self.stds[:, :] = 0
|
||
```
|
||
|
||
---
|
||
|
||
## 关键技术细节
|
||
|
||
### 6.1 数据标准化
|
||
|
||
**文件:`Algorithm/utils.py`**
|
||
|
||
```python
|
||
class RunningMeanStd(object):
|
||
"""
|
||
运行时均值和标准差计算
|
||
|
||
使用Welford在线算法,高效计算流式数据的统计量
|
||
"""
|
||
|
||
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
|
||
self.mean = np.zeros(shape, np.float64)
|
||
self.var = np.ones(shape, np.float64)
|
||
self.count = epsilon
|
||
|
||
def update(self, arr: np.ndarray) -> None:
|
||
"""
|
||
更新统计量
|
||
|
||
使用并行算法合并批次统计量和当前统计量
|
||
"""
|
||
batch_mean = np.mean(arr, axis=0)
|
||
batch_var = np.var(arr, axis=0)
|
||
batch_count = arr.shape[0]
|
||
self.update_from_moments(batch_mean, batch_var, batch_count)
|
||
|
||
def update_from_moments(self, batch_mean, batch_var, batch_count):
|
||
"""
|
||
从矩更新统计量
|
||
|
||
参考:https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
|
||
"""
|
||
delta = batch_mean - self.mean
|
||
tot_count = self.count + batch_count
|
||
|
||
# 新均值
|
||
new_mean = self.mean + delta * batch_count / tot_count
|
||
|
||
# 新方差
|
||
m_a = self.var * self.count
|
||
m_b = batch_var * batch_count
|
||
m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
|
||
new_var = m_2 / tot_count
|
||
|
||
# 更新
|
||
self.mean = new_mean
|
||
self.var = new_var
|
||
self.count = tot_count
|
||
|
||
|
||
class Normalizer(RunningMeanStd):
|
||
"""
|
||
数据标准化器
|
||
|
||
提供标准化和逆标准化功能
|
||
"""
|
||
|
||
def __init__(self, input_dim, epsilon=1e-4, clip_obs=10.0):
|
||
super().__init__(shape=input_dim)
|
||
self.epsilon = epsilon
|
||
self.clip_obs = clip_obs
|
||
|
||
def normalize(self, input):
|
||
"""
|
||
标准化(NumPy版本)
|
||
|
||
公式:(x - μ) / √(σ² + ε)
|
||
"""
|
||
return np.clip(
|
||
(input - self.mean) / np.sqrt(self.var + self.epsilon),
|
||
-self.clip_obs, self.clip_obs
|
||
)
|
||
|
||
def normalize_torch(self, input, device):
|
||
"""
|
||
标准化(PyTorch版本)
|
||
|
||
用于在GPU上高效计算
|
||
"""
|
||
mean_torch = torch.tensor(
|
||
self.mean, device=device, dtype=torch.float32
|
||
)
|
||
std_torch = torch.sqrt(torch.tensor(
|
||
self.var + self.epsilon, device=device, dtype=torch.float32
|
||
))
|
||
return torch.clamp(
|
||
(input - mean_torch) / std_torch,
|
||
-self.clip_obs, self.clip_obs
|
||
)
|
||
```
|
||
|
||
### 6.2 为什么使用BERT架构
|
||
|
||
**传统MLP的问题:**
|
||
```python
|
||
# 假设场景中有N辆车,每辆车观测维度为D
|
||
# 传统方法:拼接所有车辆观测
|
||
input = concat([obs_1, obs_2, ..., obs_N]) # 维度: N*D
|
||
output = MLP(input)
|
||
|
||
# 问题:
|
||
# 1. 输入维度N*D随N变化,需要重新训练网络
|
||
# 2. 不同位置的车辆语义相同,但MLP无法共享权重
|
||
# 3. 无法处理车辆间的交互关系
|
||
```
|
||
|
||
**BERT架构的优势:**
|
||
```python
|
||
# BERT方法:将车辆观测视为序列
|
||
input = [obs_1, obs_2, ..., obs_N] # 序列长度N可变
|
||
embeddings = [Linear(obs_i) for obs_i in input] # 共享权重
|
||
output = Transformer(embeddings) # 自注意力捕捉交互
|
||
|
||
# 优势:
|
||
# 1. 序列长度可变,无需固定N
|
||
# 2. Linear投影层参数共享,泛化性好
|
||
# 3. 自注意力机制建模车辆间交互
|
||
# 4. 位置编码区分不同车辆
|
||
```
|
||
|
||
### 6.3 梯度惩罚详解
|
||
|
||
```python
|
||
"""
|
||
梯度惩罚(Gradient Penalty)详解
|
||
|
||
目标:确保判别器是Lipschitz连续的
|
||
即:|D(x1) - D(x2)| ≤ K|x1 - x2|
|
||
|
||
为什么需要:
|
||
1. WGAN理论要求判别器是Lipschitz连续
|
||
2. 防止判别器梯度过大,提高训练稳定性
|
||
3. 避免模式崩溃(mode collapse)
|
||
|
||
实现:
|
||
对于专家数据x,惩罚 ||∇_x D(x)||²
|
||
"""
|
||
|
||
# 步骤1: 使专家数据需要梯度
|
||
sample_expert = states_exp_cp
|
||
sample_expert.requires_grad = True
|
||
|
||
# 步骤2: 前向传播
|
||
disc = self.disc.linear(self.disc.trunk(sample_expert))
|
||
|
||
# 步骤3: 计算梯度
|
||
ones = torch.ones(disc.size(), device=disc.device)
|
||
disc_demo_grad = torch.autograd.grad(
|
||
disc, sample_expert,
|
||
grad_outputs=ones, # ∂disc/∂x
|
||
create_graph=True, # 保留计算图(二阶导数)
|
||
retain_graph=True,
|
||
only_inputs=True
|
||
)[0]
|
||
|
||
# 步骤4: 梯度的L2范数
|
||
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
|
||
grad_pen_loss = torch.mean(disc_demo_grad)
|
||
|
||
# 步骤5: 添加到总损失
|
||
loss += self.disc_grad_penalty * grad_pen_loss
|
||
```
|
||
|
||
### 6.4 自适应学习率
|
||
|
||
```python
|
||
"""
|
||
基于KL散度的自适应学习率
|
||
|
||
原理:
|
||
PPO希望策略更新不要太大,通过监控KL散度来调整学习率
|
||
|
||
KL(π_old || π_new) = 期望策略变化程度
|
||
"""
|
||
|
||
# 计算KL散度(对角高斯分布)
|
||
kl = torch.sum(
|
||
torch.log(sigmas_new / sigmas_old + 1e-5) + # 标准差比值的对数
|
||
(sigmas_old² + (mus_old - mus_new)²) / (2 * sigmas_new²) - 0.5,
|
||
axis=-1
|
||
)
|
||
kl_mean = torch.mean(kl)
|
||
|
||
# 调整学习率
|
||
if kl_mean > desired_kl * 2.0:
|
||
# KL过大:策略变化太剧烈,降低学习率
|
||
lr_actor *= 0.67 # 除以1.5
|
||
lr_critic *= 0.67
|
||
elif kl_mean < desired_kl / 2.0:
|
||
# KL过小:策略变化太保守,提高学习率
|
||
lr_actor *= 1.5
|
||
lr_critic *= 1.5
|
||
|
||
# 应用新学习率
|
||
for param_group in optimizer.param_groups:
|
||
param_group['lr'] = lr_actor
|
||
```
|
||
|
||
---
|
||
|
||
## 使用指南
|
||
|
||
### 7.1 环境依赖安装
|
||
|
||
```bash
|
||
# 创建虚拟环境
|
||
conda create -n magail python=3.8
|
||
conda activate magail
|
||
|
||
# 安装PyTorch(根据CUDA版本)
|
||
pip install torch==1.12.0 torchvision torchaudio
|
||
|
||
# 安装MetaDrive仿真环境
|
||
pip install metadrive-simulator
|
||
|
||
# 安装其他依赖
|
||
pip install numpy pandas matplotlib tensorboard tqdm gym
|
||
```
|
||
|
||
### 7.2 运行环境测试
|
||
|
||
```bash
|
||
# 进入项目目录
|
||
cd /path/to/MAGAIL4AutoDrive
|
||
|
||
# 运行环境测试(需先修复导入路径)
|
||
python Env/run_multiagent_env.py
|
||
```
|
||
|
||
### 7.3 训练模型
|
||
|
||
```python
|
||
"""
|
||
训练脚本示例(需要创建)
|
||
"""
|
||
import torch
|
||
from Algorithm.magail import MAGAIL
|
||
from Env.scenario_env import MultiAgentScenarioEnv
|
||
from metadrive.engine.asset_loader import AssetLoader
|
||
|
||
# 配置
|
||
config = {
|
||
"data_directory": "/path/to/exp_converted",
|
||
"is_multi_agent": True,
|
||
"num_controlled_agents": 5,
|
||
"horizon": 300,
|
||
"use_render": False,
|
||
}
|
||
|
||
# 创建环境
|
||
env = MultiAgentScenarioEnv(config, policy)
|
||
|
||
# 创建算法
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
magail = MAGAIL(
|
||
buffer_exp=expert_buffer,
|
||
input_dim=obs_dim,
|
||
device=device,
|
||
lr_disc=1e-3,
|
||
epoch_disc=50
|
||
)
|
||
|
||
# 训练循环
|
||
for episode in range(10000):
|
||
obs = env.reset()
|
||
episode_reward = 0
|
||
|
||
for step in range(config["horizon"]):
|
||
actions, log_pis = magail.explore(obs)
|
||
next_obs, rewards, dones, infos = env.step(actions)
|
||
|
||
magail.buffer.append(...)
|
||
|
||
if magail.is_update(step):
|
||
reward = magail.update(writer, total_steps)
|
||
episode_reward += reward
|
||
|
||
obs = next_obs
|
||
|
||
if dones["__all__"]:
|
||
break
|
||
|
||
print(f"Episode {episode}, Reward: {episode_reward}")
|
||
|
||
if episode % 100 == 0:
|
||
magail.save_models(f"models/episode_{episode}")
|
||
```
|
||
|
||
### 7.4 关键参数说明
|
||
|
||
| 参数 | 说明 | 推荐值 |
|
||
|------|------|--------|
|
||
| `embed_dim` | BERT嵌入维度 | 128 |
|
||
| `num_layers` | Transformer层数 | 4 |
|
||
| `num_heads` | 注意力头数 | 4 |
|
||
| `lr_disc` | 判别器学习率 | 1e-3 |
|
||
| `lr_actor` | Actor学习率 | 1e-3 |
|
||
| `lr_critic` | Critic学习率 | 1e-3 |
|
||
| `epoch_disc` | 判别器更新轮数 | 50 |
|
||
| `epoch_ppo` | PPO更新轮数 | 10 |
|
||
| `disc_grad_penalty` | 梯度惩罚系数 | 0.1 |
|
||
| `disc_logit_reg` | Logit正则化系数 | 0.25 |
|
||
| `gamma` | 折扣因子 | 0.995 |
|
||
| `lambd` | GAE λ参数 | 0.97 |
|
||
| `clip_eps` | PPO裁剪参数 | 0.2 |
|
||
|
||
### 7.5 常见问题
|
||
|
||
**Q1: 为什么判别器准确率总是50%?**
|
||
- 这是正常现象,说明判别器无法区分策略和专家
|
||
- 表示策略已经学习到接近专家的行为
|
||
|
||
**Q2: 训练不稳定怎么办?**
|
||
- 增大梯度惩罚系数
|
||
- 降低学习率
|
||
- 增加数据标准化
|
||
|
||
**Q3: 如何调整奖励权重?**
|
||
- `reward_t_coef`: 任务奖励权重
|
||
- `reward_i_coef`: 模仿奖励权重
|
||
- 通常设置为1:1或调整以平衡两者
|
||
|
||
---
|
||
|
||
## 总结
|
||
|
||
MAGAIL4AutoDrive项目通过以下技术创新实现了多智能体自动驾驶的模仿学习:
|
||
|
||
1. **BERT判别器**:使用Transformer架构处理动态数量的车辆
|
||
2. **GAIL框架**:通过对抗训练学习专家策略
|
||
3. **PPO优化**:稳定的策略梯度方法
|
||
4. **多维观测**:融合多种传感器信息
|
||
5. **真实数据**:利用Waymo等真实驾驶数据
|
||
|
||
该项目为多智能体自动驾驶提供了一个完整的解决方案,具有良好的可扩展性和实用价值。
|
||
|