1671 lines
53 KiB
Markdown
1671 lines
53 KiB
Markdown
|
|
# MAGAIL4AutoDrive 技术说明文档
|
|||
|
|
|
|||
|
|
## 目录
|
|||
|
|
1. [项目概述](#项目概述)
|
|||
|
|
2. [核心技术架构](#核心技术架构)
|
|||
|
|
3. [算法模块详解](#算法模块详解)
|
|||
|
|
4. [环境系统实现](#环境系统实现)
|
|||
|
|
5. [数据流与训练流程](#数据流与训练流程)
|
|||
|
|
6. [关键技术细节](#关键技术细节)
|
|||
|
|
7. [使用指南](#使用指南)
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 项目概述
|
|||
|
|
|
|||
|
|
### 背景与动机
|
|||
|
|
MAGAIL4AutoDrive(Multi-Agent Generative Adversarial Imitation Learning for Autonomous Driving)是一个针对多智能体自动驾驶场景的模仿学习框架。项目的核心创新在于将单智能体GAIL算法扩展到多智能体场景,解决了车辆数量动态变化时的学习问题。
|
|||
|
|
|
|||
|
|
### 核心挑战
|
|||
|
|
1. **动态输入维度**:多智能体场景中车辆数量不固定,传统固定维度的神经网络无法直接应用
|
|||
|
|
2. **全局交互建模**:需要同时考虑所有车辆的交互行为,而非独立建模
|
|||
|
|
3. **真实数据利用**:如何有效利用Waymo等真实驾驶数据进行训练
|
|||
|
|
|
|||
|
|
### 技术方案
|
|||
|
|
- **BERT架构判别器**:使用Transformer处理变长序列输入
|
|||
|
|
- **GAIL框架**:通过对抗训练学习专家行为
|
|||
|
|
- **PPO优化**:稳定的策略梯度方法
|
|||
|
|
- **MetaDrive仿真**:高保真多智能体交通仿真环境
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 核心技术架构
|
|||
|
|
|
|||
|
|
### 整体架构图
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
┌─────────────────────────────────────────────────────────────┐
|
|||
|
|
│ MAGAIL训练系统 │
|
|||
|
|
├─────────────────────────────────────────────────────────────┤
|
|||
|
|
│ │
|
|||
|
|
│ ┌──────────────┐ ┌──────────────┐ │
|
|||
|
|
│ │ 专家数据库 │ │ 策略缓冲区 │ │
|
|||
|
|
│ │(Waymo轨迹) │ │(Agent经验) │ │
|
|||
|
|
│ └──────┬───────┘ └──────┬───────┘ │
|
|||
|
|
│ │ │ │
|
|||
|
|
│ │ 状态-动作对 │ 状态-动作对 │
|
|||
|
|
│ ▼ ▼ │
|
|||
|
|
│ ┌──────────────────────────────────────┐ │
|
|||
|
|
│ │ BERT判别器 (Discriminator) │ │
|
|||
|
|
│ │ ┌────────────────────────────────┐ │ │
|
|||
|
|
│ │ │ Input: (N, obs_dim) │ │ │
|
|||
|
|
│ │ │ ↓ │ │ │
|
|||
|
|
│ │ │ Linear Projection → embed_dim │ │ │
|
|||
|
|
│ │ │ ↓ │ │ │
|
|||
|
|
│ │ │ + Positional Encoding │ │ │
|
|||
|
|
│ │ │ ↓ │ │ │
|
|||
|
|
│ │ │ Transformer Layers (×4) │ │ │
|
|||
|
|
│ │ │ ↓ │ │ │
|
|||
|
|
│ │ │ Mean Pooling / CLS Token │ │ │
|
|||
|
|
│ │ │ ↓ │ │ │
|
|||
|
|
│ │ │ Output: Real/Fake Score │ │ │
|
|||
|
|
│ │ └────────────────────────────────┘ │ │
|
|||
|
|
│ └──────────────┬───────────────────────┘ │
|
|||
|
|
│ │ │
|
|||
|
|
│ │ Reward Signal │
|
|||
|
|
│ ▼ │
|
|||
|
|
│ ┌──────────────────────────────────────┐ │
|
|||
|
|
│ │ PPO策略优化 (Policy) │ │
|
|||
|
|
│ │ ┌────────────────────────────────┐ │ │
|
|||
|
|
│ │ │ Actor Network (MLP) │ │ │
|
|||
|
|
│ │ │ Input: state → Action dist │ │ │
|
|||
|
|
│ │ ├────────────────────────────────┤ │ │
|
|||
|
|
│ │ │ Critic Network (BERT) │ │ │
|
|||
|
|
│ │ │ Input: state → Value estimate │ │ │
|
|||
|
|
│ │ └────────────────────────────────┘ │ │
|
|||
|
|
│ └──────────────┬───────────────────────┘ │
|
|||
|
|
│ │ │
|
|||
|
|
│ │ Actions │
|
|||
|
|
│ ▼ │
|
|||
|
|
│ ┌──────────────────────────────────────┐ │
|
|||
|
|
│ │ MetaDrive多智能体环境 │ │
|
|||
|
|
│ │ • 车辆动力学仿真 │ │
|
|||
|
|
│ │ • 多维度传感器(激光雷达等) │ │
|
|||
|
|
│ │ • 红绿灯、车道线等交通元素 │ │
|
|||
|
|
│ └──────────────────────────────────────┘ │
|
|||
|
|
│ │
|
|||
|
|
└─────────────────────────────────────────────────────────────┘
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 算法模块详解
|
|||
|
|
|
|||
|
|
### 3.1 BERT判别器实现
|
|||
|
|
|
|||
|
|
#### 3.1.1 核心设计思想
|
|||
|
|
|
|||
|
|
BERT判别器是本项目的核心创新。传统GAIL的判别器使用固定维度的MLP,无法处理多智能体场景下车辆数量变化的问题。本项目采用Transformer架构,将多个车辆的观测视为序列,通过自注意力机制捕捉车辆间的交互。
|
|||
|
|
|
|||
|
|
#### 3.1.2 代码实现
|
|||
|
|
|
|||
|
|
**文件:`Algorithm/bert.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
class Bert(nn.Module):
|
|||
|
|
def __init__(self, input_dim, output_dim, embed_dim=128,
|
|||
|
|
num_layers=4, ff_dim=512, num_heads=4, dropout=0.1,
|
|||
|
|
CLS=False, TANH=False):
|
|||
|
|
"""
|
|||
|
|
BERT判别器/价值网络
|
|||
|
|
|
|||
|
|
参数说明:
|
|||
|
|
- input_dim: 单个车辆的观测维度
|
|||
|
|
- output_dim: 输出维度(判别器为1,价值网络为1)
|
|||
|
|
- embed_dim: Transformer嵌入维度,默认128
|
|||
|
|
- num_layers: Transformer层数,默认4层
|
|||
|
|
- ff_dim: 前馈网络维度,默认512
|
|||
|
|
- num_heads: 多头注意力头数,默认4
|
|||
|
|
- CLS: 是否使用CLS token进行特征聚合
|
|||
|
|
- TANH: 输出层是否使用Tanh激活
|
|||
|
|
"""
|
|||
|
|
super().__init__()
|
|||
|
|
self.CLS = CLS
|
|||
|
|
|
|||
|
|
# 线性投影层:将观测维度映射到嵌入维度
|
|||
|
|
self.projection = nn.Linear(input_dim, embed_dim)
|
|||
|
|
|
|||
|
|
# 位置编码:为每个车辆位置添加可学习的编码
|
|||
|
|
if self.CLS:
|
|||
|
|
# CLS模式:需要额外的CLS token位置
|
|||
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
|
|||
|
|
self.pos_embed = nn.Parameter(torch.randn(1, input_dim + 1, embed_dim))
|
|||
|
|
else:
|
|||
|
|
# 均值池化模式
|
|||
|
|
self.pos_embed = nn.Parameter(torch.randn(1, input_dim, embed_dim))
|
|||
|
|
|
|||
|
|
# Transformer编码器层
|
|||
|
|
self.layers = nn.ModuleList([
|
|||
|
|
TransformerLayer(embed_dim, num_heads, ff_dim, dropout)
|
|||
|
|
for _ in range(num_layers)
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# 分类头
|
|||
|
|
if TANH:
|
|||
|
|
self.classifier = nn.Sequential(
|
|||
|
|
nn.Linear(embed_dim, output_dim),
|
|||
|
|
nn.Tanh()
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
self.classifier = nn.Linear(embed_dim, output_dim)
|
|||
|
|
|
|||
|
|
def forward(self, x, mask=None):
|
|||
|
|
"""
|
|||
|
|
前向传播
|
|||
|
|
|
|||
|
|
输入:
|
|||
|
|
- x: (batch_size, seq_len, input_dim)
|
|||
|
|
seq_len = 车辆数量(动态变化)
|
|||
|
|
input_dim = 单车辆观测维度
|
|||
|
|
- mask: 可选的注意力掩码
|
|||
|
|
|
|||
|
|
输出:
|
|||
|
|
- (batch_size, output_dim) 判别分数或价值估计
|
|||
|
|
"""
|
|||
|
|
# 步骤1: 线性投影
|
|||
|
|
# 将每个车辆的观测映射到固定的嵌入空间
|
|||
|
|
x = self.projection(x) # (batch_size, seq_len, embed_dim)
|
|||
|
|
|
|||
|
|
batch_size = x.size(0)
|
|||
|
|
|
|||
|
|
# 步骤2: 添加CLS token(如果启用)
|
|||
|
|
if self.CLS:
|
|||
|
|
# 复制CLS token到batch中的每个样本
|
|||
|
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
|||
|
|
x = torch.cat([cls_tokens, x], dim=1)
|
|||
|
|
|
|||
|
|
# 步骤3: 添加位置编码
|
|||
|
|
# 让模型知道每个车辆在序列中的位置
|
|||
|
|
x = x + self.pos_embed
|
|||
|
|
|
|||
|
|
# 步骤4: 转置为Transformer期望的格式
|
|||
|
|
x = x.permute(1, 0, 2) # (seq_len, batch_size, embed_dim)
|
|||
|
|
|
|||
|
|
# 步骤5: 通过Transformer层
|
|||
|
|
# 每层进行自注意力计算,捕捉车辆间的交互
|
|||
|
|
for layer in self.layers:
|
|||
|
|
x = layer(x, mask=mask)
|
|||
|
|
|
|||
|
|
# 步骤6: 特征聚合
|
|||
|
|
if self.CLS:
|
|||
|
|
# CLS模式:取CLS token的输出
|
|||
|
|
return self.classifier(x[0, :, :])
|
|||
|
|
else:
|
|||
|
|
# 均值池化:对所有车辆特征求平均
|
|||
|
|
pooled = x.mean(dim=0) # (batch_size, embed_dim)
|
|||
|
|
return self.classifier(pooled)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**Transformer层实现:**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
class TransformerLayer(nn.Module):
|
|||
|
|
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
|
|||
|
|
"""
|
|||
|
|
Transformer编码器层
|
|||
|
|
|
|||
|
|
结构:
|
|||
|
|
1. 多头自注意力 + 残差连接 + LayerNorm
|
|||
|
|
2. 前馈网络 + 残差连接 + LayerNorm
|
|||
|
|
"""
|
|||
|
|
super().__init__()
|
|||
|
|
# 多头自注意力
|
|||
|
|
self.self_attn = nn.MultiheadAttention(
|
|||
|
|
embed_dim, num_heads, dropout=dropout
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 前馈网络
|
|||
|
|
self.linear1 = nn.Linear(embed_dim, ff_dim)
|
|||
|
|
self.linear2 = nn.Linear(ff_dim, embed_dim)
|
|||
|
|
|
|||
|
|
# 归一化层
|
|||
|
|
self.norm1 = nn.LayerNorm(embed_dim)
|
|||
|
|
self.norm2 = nn.LayerNorm(embed_dim)
|
|||
|
|
|
|||
|
|
self.dropout = nn.Dropout(dropout)
|
|||
|
|
self.activation = nn.GELU()
|
|||
|
|
|
|||
|
|
def forward(self, x, mask=None):
|
|||
|
|
"""
|
|||
|
|
前向传播(Post-LN结构)
|
|||
|
|
"""
|
|||
|
|
# 注意力模块
|
|||
|
|
attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
|
|||
|
|
x = x + self.dropout(attn_output) # 残差连接
|
|||
|
|
x = self.norm1(x) # 归一化
|
|||
|
|
|
|||
|
|
# 前馈网络模块
|
|||
|
|
ff_output = self.linear2(
|
|||
|
|
self.dropout(self.activation(self.linear1(x)))
|
|||
|
|
)
|
|||
|
|
x = x + self.dropout(ff_output) # 残差连接
|
|||
|
|
x = self.norm2(x) # 归一化
|
|||
|
|
|
|||
|
|
return x
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
#### 3.1.3 关键技术点
|
|||
|
|
|
|||
|
|
1. **动态序列处理**:
|
|||
|
|
- 输入维度为`(batch_size, N, obs_dim)`,其中N是车辆数量
|
|||
|
|
- N可以在不同batch中变化,无需固定
|
|||
|
|
|
|||
|
|
2. **位置编码**:
|
|||
|
|
- 使用可学习的位置编码而非正弦编码
|
|||
|
|
- 让模型能够区分不同位置的车辆
|
|||
|
|
|
|||
|
|
3. **自注意力机制**:
|
|||
|
|
- 计算每个车辆与其他车辆的注意力权重
|
|||
|
|
- 捕捉车辆间的交互和影响关系
|
|||
|
|
|
|||
|
|
4. **特征聚合**:
|
|||
|
|
- CLS模式:专门的分类token,类似BERT
|
|||
|
|
- 均值池化:简单但有效的全局特征提取
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
### 3.2 GAIL判别器
|
|||
|
|
|
|||
|
|
#### 3.2.1 判别器设计
|
|||
|
|
|
|||
|
|
**文件:`Algorithm/disc.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
class GAILDiscrim(Bert):
|
|||
|
|
"""
|
|||
|
|
GAIL判别器:继承自BERT架构
|
|||
|
|
|
|||
|
|
功能:
|
|||
|
|
1. 区分专家数据和策略生成数据
|
|||
|
|
2. 计算模仿学习的内在奖励
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, input_dim, reward_i_coef=1.0,
|
|||
|
|
reward_t_coef=1.0, normalizer=None, device=None):
|
|||
|
|
"""
|
|||
|
|
初始化判别器
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
- input_dim: 输入维度(状态+下一状态)
|
|||
|
|
- reward_i_coef: 内在奖励系数
|
|||
|
|
- reward_t_coef: 任务奖励系数
|
|||
|
|
"""
|
|||
|
|
# 调用BERT构造函数,输出维度为1(真假分数)
|
|||
|
|
super().__init__(input_dim=input_dim, output_dim=1, TANH=False)
|
|||
|
|
|
|||
|
|
self.device = device
|
|||
|
|
self.reward_t_coef = reward_t_coef
|
|||
|
|
self.reward_i_coef = reward_i_coef
|
|||
|
|
self.normalizer = normalizer
|
|||
|
|
|
|||
|
|
def calculate_reward(self, states_gail, next_states_gail, rewards_t):
|
|||
|
|
"""
|
|||
|
|
计算GAIL奖励
|
|||
|
|
|
|||
|
|
GAIL的核心思想:
|
|||
|
|
- 判别器D(s,s')输出越小,说明越像专家,奖励越高
|
|||
|
|
- 使用 -log(1-D) 作为内在奖励
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
- states_gail: 当前状态
|
|||
|
|
- next_states_gail: 下一状态
|
|||
|
|
- rewards_t: 环境任务奖励
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
- rewards: 总奖励
|
|||
|
|
- rewards_t: 归一化后的任务奖励
|
|||
|
|
- rewards_i: 归一化后的内在奖励
|
|||
|
|
"""
|
|||
|
|
states_gail = states_gail.clone()
|
|||
|
|
next_states_gail = next_states_gail.clone()
|
|||
|
|
|
|||
|
|
# 拼接状态转移对
|
|||
|
|
states = torch.cat([states_gail, next_states_gail], dim=-1)
|
|||
|
|
|
|||
|
|
with torch.no_grad():
|
|||
|
|
# 数据标准化
|
|||
|
|
if self.normalizer is not None:
|
|||
|
|
states = self.normalizer.normalize_torch(states, self.device)
|
|||
|
|
|
|||
|
|
# 缩放任务奖励
|
|||
|
|
rewards_t = self.reward_t_coef * rewards_t
|
|||
|
|
|
|||
|
|
# 获取判别器输出(logit)
|
|||
|
|
d = self.forward(states)
|
|||
|
|
|
|||
|
|
# 转换为概率:sigmoid(d) = 1/(1+exp(-d))
|
|||
|
|
prob = 1 / (1 + torch.exp(-d))
|
|||
|
|
|
|||
|
|
# GAIL奖励公式:-log(1-D(s,s'))
|
|||
|
|
# 当D(s,s')接近0(像专家),奖励高
|
|||
|
|
# 当D(s,s')接近1(不像专家),奖励低
|
|||
|
|
rewards_i = self.reward_i_coef * (
|
|||
|
|
-torch.log(torch.maximum(
|
|||
|
|
1 - prob,
|
|||
|
|
torch.tensor(0.0001, device=self.device)
|
|||
|
|
))
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 组合奖励
|
|||
|
|
rewards = rewards_t + rewards_i
|
|||
|
|
|
|||
|
|
return (rewards,
|
|||
|
|
rewards_t / (self.reward_t_coef + 1e-10),
|
|||
|
|
rewards_i / (self.reward_i_coef + 1e-10))
|
|||
|
|
|
|||
|
|
def get_disc_logit_weights(self):
|
|||
|
|
"""获取输出层权重(用于正则化)"""
|
|||
|
|
return torch.flatten(self.classifier.weight)
|
|||
|
|
|
|||
|
|
def get_disc_weights(self):
|
|||
|
|
"""获取所有层权重(用于权重衰减)"""
|
|||
|
|
weights = []
|
|||
|
|
for m in self.layers.modules():
|
|||
|
|
if isinstance(m, nn.Linear):
|
|||
|
|
weights.append(torch.flatten(m.weight))
|
|||
|
|
weights.append(torch.flatten(self.classifier.weight))
|
|||
|
|
return weights
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
#### 3.2.2 判别器训练
|
|||
|
|
|
|||
|
|
**文件:`Algorithm/magail.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def update_disc(self, states, states_exp, writer):
|
|||
|
|
"""
|
|||
|
|
更新判别器
|
|||
|
|
|
|||
|
|
目标:最大化 E_expert[log D] + E_policy[log(1-D)]
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
- states: 策略生成的状态转移
|
|||
|
|
- states_exp: 专家演示的状态转移
|
|||
|
|
"""
|
|||
|
|
states_cp = states.clone()
|
|||
|
|
states_exp_cp = states_exp.clone()
|
|||
|
|
|
|||
|
|
# 步骤1: 获取判别器输出
|
|||
|
|
logits_pi = self.disc(states_cp) # 策略数据
|
|||
|
|
logits_exp = self.disc(states_exp_cp) # 专家数据
|
|||
|
|
|
|||
|
|
# 步骤2: 计算对抗损失
|
|||
|
|
# 希望:logits_pi < 0(策略被识别为假)
|
|||
|
|
# logits_exp > 0(专家被识别为真)
|
|||
|
|
loss_pi = -F.logsigmoid(-logits_pi).mean() # -log(1-sigmoid(logits_pi))
|
|||
|
|
loss_exp = -F.logsigmoid(logits_exp).mean() # -log(sigmoid(logits_exp))
|
|||
|
|
loss_disc = 0.5 * (loss_pi + loss_exp)
|
|||
|
|
|
|||
|
|
# 步骤3: Logit正则化
|
|||
|
|
# 防止判别器输出过大,导致梯度爆炸
|
|||
|
|
logit_weights = self.disc.get_disc_logit_weights()
|
|||
|
|
disc_logit_loss = torch.sum(torch.square(logit_weights))
|
|||
|
|
|
|||
|
|
# 步骤4: 梯度惩罚(Gradient Penalty)
|
|||
|
|
# 确保判别器满足Lipschitz约束,提高训练稳定性
|
|||
|
|
sample_expert = states_exp_cp
|
|||
|
|
sample_expert.requires_grad = True
|
|||
|
|
|
|||
|
|
# 对专家数据计算判别器输出
|
|||
|
|
disc = self.disc.linear(self.disc.trunk(sample_expert))
|
|||
|
|
ones = torch.ones(disc.size(), device=disc.device)
|
|||
|
|
|
|||
|
|
# 计算梯度
|
|||
|
|
disc_demo_grad = torch.autograd.grad(
|
|||
|
|
disc, sample_expert,
|
|||
|
|
grad_outputs=ones,
|
|||
|
|
create_graph=True,
|
|||
|
|
retain_graph=True,
|
|||
|
|
only_inputs=True
|
|||
|
|
)[0]
|
|||
|
|
|
|||
|
|
# 梯度的L2范数
|
|||
|
|
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
|
|||
|
|
grad_pen_loss = torch.mean(disc_demo_grad)
|
|||
|
|
|
|||
|
|
# 步骤5: 权重衰减(L2正则化)
|
|||
|
|
disc_weights = self.disc.get_disc_weights()
|
|||
|
|
disc_weights = torch.cat(disc_weights, dim=-1)
|
|||
|
|
disc_weight_decay = torch.sum(torch.square(disc_weights))
|
|||
|
|
|
|||
|
|
# 步骤6: 组合损失并更新
|
|||
|
|
loss = (self.disc_coef * loss_disc +
|
|||
|
|
self.disc_grad_penalty * grad_pen_loss +
|
|||
|
|
self.disc_logit_reg * disc_logit_loss +
|
|||
|
|
self.disc_weight_decay * disc_weight_decay)
|
|||
|
|
|
|||
|
|
self.optim_d.zero_grad()
|
|||
|
|
loss.backward()
|
|||
|
|
self.optim_d.step()
|
|||
|
|
|
|||
|
|
# 步骤7: 记录训练指标
|
|||
|
|
if self.learning_steps_disc % self.epoch_disc == 0:
|
|||
|
|
writer.add_scalar('Loss/disc', loss_disc.item(), self.learning_steps)
|
|||
|
|
|
|||
|
|
with torch.no_grad():
|
|||
|
|
# 判别器准确率
|
|||
|
|
acc_pi = (logits_pi < 0).float().mean().item() # 策略识别准确率
|
|||
|
|
acc_exp = (logits_exp > 0).float().mean().item() # 专家识别准确率
|
|||
|
|
|
|||
|
|
writer.add_scalar('Acc/acc_pi', acc_pi, self.learning_steps)
|
|||
|
|
writer.add_scalar('Acc/acc_exp', acc_exp, self.learning_steps)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
#### 3.2.3 关键技术细节
|
|||
|
|
|
|||
|
|
**1. 梯度惩罚的作用**
|
|||
|
|
```python
|
|||
|
|
# 梯度惩罚确保判别器是Lipschitz连续的
|
|||
|
|
# 即:|D(x1) - D(x2)| ≤ K|x1 - x2|
|
|||
|
|
# 这防止判别器变化过于剧烈,提高训练稳定性
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**2. 为什么使用logit而非概率**
|
|||
|
|
```python
|
|||
|
|
# 使用logit(未经sigmoid的输出)有几个优点:
|
|||
|
|
# 1. 数值稳定性:避免log(0)等问题
|
|||
|
|
# 2. 梯度更好:sigmoid饱和区梯度消失
|
|||
|
|
# 3. 理论保证:GAIL理论基于logit形式
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
### 3.3 PPO策略优化
|
|||
|
|
|
|||
|
|
#### 3.3.1 策略网络
|
|||
|
|
|
|||
|
|
**文件:`Algorithm/policy.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
class StateIndependentPolicy(nn.Module):
|
|||
|
|
"""
|
|||
|
|
状态独立策略(对角高斯策略)
|
|||
|
|
|
|||
|
|
输出:高斯分布的均值,标准差是可学习参数
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, state_shape, action_shape,
|
|||
|
|
hidden_units=(64, 64),
|
|||
|
|
hidden_activation=nn.Tanh()):
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
# 均值网络(MLP)
|
|||
|
|
self.net = build_mlp(
|
|||
|
|
input_dim=state_shape[0],
|
|||
|
|
output_dim=action_shape[0],
|
|||
|
|
hidden_units=hidden_units,
|
|||
|
|
hidden_activation=hidden_activation
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 可学习的对数标准差
|
|||
|
|
self.log_stds = nn.Parameter(torch.zeros(1, action_shape[0]))
|
|||
|
|
self.means = None
|
|||
|
|
|
|||
|
|
def forward(self, states):
|
|||
|
|
"""
|
|||
|
|
确定性前向传播(用于评估)
|
|||
|
|
"""
|
|||
|
|
return torch.tanh(self.net(states))
|
|||
|
|
|
|||
|
|
def sample(self, states):
|
|||
|
|
"""
|
|||
|
|
从策略分布中采样动作
|
|||
|
|
|
|||
|
|
使用重参数化技巧:
|
|||
|
|
a = tanh(μ + σ * ε), ε ~ N(0,1)
|
|||
|
|
"""
|
|||
|
|
self.means = self.net(states)
|
|||
|
|
actions, log_pis = reparameterize(self.means, self.log_stds)
|
|||
|
|
return actions, log_pis
|
|||
|
|
|
|||
|
|
def evaluate_log_pi(self, states, actions):
|
|||
|
|
"""
|
|||
|
|
计算给定状态-动作对的对数概率
|
|||
|
|
"""
|
|||
|
|
self.means = self.net(states)
|
|||
|
|
return evaluate_lop_pi(self.means, self.log_stds, actions)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**重参数化技巧:**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def reparameterize(means, log_stds):
|
|||
|
|
"""
|
|||
|
|
重参数化采样
|
|||
|
|
|
|||
|
|
原理:
|
|||
|
|
不直接从N(μ,σ²)采样,而是:
|
|||
|
|
1. 从N(0,1)采样噪声ε
|
|||
|
|
2. 计算 z = μ + σ * ε
|
|||
|
|
3. 应用tanh:a = tanh(z)
|
|||
|
|
|
|||
|
|
优点:
|
|||
|
|
- 梯度可以通过μ和σ反向传播
|
|||
|
|
- 支持梯度下降优化
|
|||
|
|
"""
|
|||
|
|
noises = torch.randn_like(means) # ε ~ N(0,1)
|
|||
|
|
us = means + noises * log_stds.exp() # z = μ + σε
|
|||
|
|
actions = torch.tanh(us) # a = tanh(z)
|
|||
|
|
|
|||
|
|
# 计算对数概率(需要考虑tanh的雅可比行列式)
|
|||
|
|
return actions, calculate_log_pi(log_stds, noises, actions)
|
|||
|
|
|
|||
|
|
def calculate_log_pi(log_stds, noises, actions):
|
|||
|
|
"""
|
|||
|
|
计算tanh高斯分布的对数概率
|
|||
|
|
|
|||
|
|
公式:
|
|||
|
|
log π(a|s) = log N(u|μ,σ²) - log|1 - tanh²(u)|
|
|||
|
|
= -0.5||ε||² - log σ - 0.5log(2π) - Σlog(1-a²)
|
|||
|
|
"""
|
|||
|
|
# 高斯分布的对数概率
|
|||
|
|
gaussian_log_probs = (
|
|||
|
|
-0.5 * noises.pow(2) - log_stds
|
|||
|
|
).sum(dim=-1, keepdim=True) - 0.5 * math.log(2 * math.pi) * log_stds.size(-1)
|
|||
|
|
|
|||
|
|
# tanh变换的雅可比修正
|
|||
|
|
# d/du tanh(u) = 1 - tanh²(u)
|
|||
|
|
return gaussian_log_probs - torch.log(
|
|||
|
|
1 - actions.pow(2) + 1e-6
|
|||
|
|
).sum(dim=-1, keepdim=True)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
#### 3.3.2 PPO更新
|
|||
|
|
|
|||
|
|
**文件:`Algorithm/ppo.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def update_ppo(self, states, actions, rewards, dones, tm_dones,
|
|||
|
|
log_pi_list, next_states, mus, sigmas, writer, total_steps):
|
|||
|
|
"""
|
|||
|
|
PPO策略和价值网络更新
|
|||
|
|
"""
|
|||
|
|
# 步骤1: 计算价值估计和优势函数
|
|||
|
|
with torch.no_grad():
|
|||
|
|
values = self.critic(states.detach())
|
|||
|
|
next_values = self.critic(next_states.detach())
|
|||
|
|
|
|||
|
|
# GAE(广义优势估计)
|
|||
|
|
targets, gaes = self.calculate_gae(
|
|||
|
|
values, rewards, dones, tm_dones, next_values,
|
|||
|
|
self.gamma, self.lambd
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 步骤2: 多轮更新
|
|||
|
|
state_list = states.permute(1, 0, 2)
|
|||
|
|
action_list = actions.permute(1, 0, 2)
|
|||
|
|
|
|||
|
|
for i in range(self.epoch_ppo):
|
|||
|
|
self.learning_steps_ppo += 1
|
|||
|
|
|
|||
|
|
# 更新价值网络
|
|||
|
|
self.update_critic(states, targets, writer)
|
|||
|
|
|
|||
|
|
# 更新策略网络
|
|||
|
|
for state, action, log_pi in zip(state_list, action_list, log_pi_list):
|
|||
|
|
self.update_actor(
|
|||
|
|
state, action, log_pi, gaes, mus, sigmas, writer
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def calculate_gae(self, values, rewards, dones, tm_dones,
|
|||
|
|
next_values, gamma, lambd):
|
|||
|
|
"""
|
|||
|
|
计算广义优势估计(GAE)
|
|||
|
|
|
|||
|
|
公式:
|
|||
|
|
δt = r + γV(s') - V(s)
|
|||
|
|
At = Σ(γλ)^k δt+k
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
- gamma: 折扣因子
|
|||
|
|
- lambd: GAE的λ参数(权衡偏差-方差)
|
|||
|
|
"""
|
|||
|
|
with torch.no_grad():
|
|||
|
|
# TD误差
|
|||
|
|
deltas = rewards + gamma * next_values * (1 - tm_dones) - values
|
|||
|
|
|
|||
|
|
# 初始化优势
|
|||
|
|
gaes = torch.empty_like(rewards)
|
|||
|
|
|
|||
|
|
# 从后往前计算GAE
|
|||
|
|
gaes[-1] = deltas[-1]
|
|||
|
|
for t in reversed(range(rewards.size(0) - 1)):
|
|||
|
|
gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1]
|
|||
|
|
|
|||
|
|
# 价值目标
|
|||
|
|
v_target = gaes + values
|
|||
|
|
|
|||
|
|
# 优势标准化
|
|||
|
|
if self.use_adv_norm:
|
|||
|
|
gaes = (gaes - gaes.mean()) / (gaes.std(dim=0) + 1e-8)
|
|||
|
|
|
|||
|
|
return v_target, gaes
|
|||
|
|
|
|||
|
|
def update_actor(self, states, actions, log_pis_old, gaes,
|
|||
|
|
mus_old, sigmas_old, writer):
|
|||
|
|
"""
|
|||
|
|
更新Actor网络
|
|||
|
|
|
|||
|
|
PPO裁剪目标:
|
|||
|
|
L = min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A)
|
|||
|
|
其中 r(θ) = π_new/π_old
|
|||
|
|
"""
|
|||
|
|
self.optim_actor.zero_grad()
|
|||
|
|
|
|||
|
|
# 新策略的对数概率
|
|||
|
|
log_pis = self.actor.evaluate_log_pi(states, actions)
|
|||
|
|
mus = self.actor.means
|
|||
|
|
sigmas = (self.actor.log_stds.exp()).repeat(mus.shape[0], 1)
|
|||
|
|
|
|||
|
|
# 熵(鼓励探索)
|
|||
|
|
entropy = -log_pis.mean()
|
|||
|
|
|
|||
|
|
# 重要性采样比率
|
|||
|
|
ratios = (log_pis - log_pis_old).exp_()
|
|||
|
|
|
|||
|
|
# PPO裁剪目标
|
|||
|
|
loss_actor1 = -ratios * gaes
|
|||
|
|
loss_actor2 = -torch.clamp(
|
|||
|
|
ratios,
|
|||
|
|
1.0 - self.clip_eps,
|
|||
|
|
1.0 + self.clip_eps
|
|||
|
|
) * gaes
|
|||
|
|
loss_actor = torch.max(loss_actor1, loss_actor2).mean()
|
|||
|
|
loss_actor = loss_actor * self.surrogate_loss_coef
|
|||
|
|
|
|||
|
|
# 自适应学习率(基于KL散度)
|
|||
|
|
if self.auto_lr:
|
|||
|
|
with torch.inference_mode():
|
|||
|
|
# 计算KL散度:KL(old||new)
|
|||
|
|
kl = torch.sum(
|
|||
|
|
torch.log(sigmas / sigmas_old + 1.e-5) +
|
|||
|
|
(torch.square(sigmas_old) + torch.square(mus_old - mus)) /
|
|||
|
|
(2.0 * torch.square(sigmas)) - 0.5,
|
|||
|
|
axis=-1
|
|||
|
|
)
|
|||
|
|
kl_mean = torch.mean(kl)
|
|||
|
|
|
|||
|
|
# 调整学习率
|
|||
|
|
if kl_mean > self.desired_kl * 2.0:
|
|||
|
|
# KL过大,降低学习率
|
|||
|
|
self.lr_actor = max(1e-5, self.lr_actor / 1.5)
|
|||
|
|
self.lr_critic = max(1e-5, self.lr_critic / 1.5)
|
|||
|
|
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
|
|||
|
|
# KL过小,提高学习率
|
|||
|
|
self.lr_actor = min(1e-2, self.lr_actor * 1.5)
|
|||
|
|
self.lr_critic = min(1e-2, self.lr_critic * 1.5)
|
|||
|
|
|
|||
|
|
# 更新优化器学习率
|
|||
|
|
for param_group in self.optim_actor.param_groups:
|
|||
|
|
param_group['lr'] = self.lr_actor
|
|||
|
|
for param_group in self.optim_critic.param_groups:
|
|||
|
|
param_group['lr'] = self.lr_critic
|
|||
|
|
|
|||
|
|
# 反向传播
|
|||
|
|
loss = loss_actor
|
|||
|
|
loss.backward()
|
|||
|
|
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
|||
|
|
self.optim_actor.step()
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 环境系统实现
|
|||
|
|
|
|||
|
|
### 4.1 多智能体场景环境
|
|||
|
|
|
|||
|
|
#### 4.1.1 环境设计
|
|||
|
|
|
|||
|
|
**文件:`Env/scenario_env.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
class MultiAgentScenarioEnv(ScenarioEnv):
|
|||
|
|
"""
|
|||
|
|
多智能体场景环境
|
|||
|
|
|
|||
|
|
继承自MetaDrive的ScenarioEnv,扩展为多智能体场景
|
|||
|
|
|
|||
|
|
核心功能:
|
|||
|
|
1. 从专家数据动态生成车辆
|
|||
|
|
2. 收集多维度观测
|
|||
|
|
3. 管理多智能体交互
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def default_config(cls):
|
|||
|
|
config = super().default_config()
|
|||
|
|
config.update(dict(
|
|||
|
|
data_directory=None, # 专家数据目录
|
|||
|
|
num_controlled_agents=3, # 可控车辆数量
|
|||
|
|
horizon=1000, # 场景时间步
|
|||
|
|
))
|
|||
|
|
return config
|
|||
|
|
|
|||
|
|
def __init__(self, config, agent2policy):
|
|||
|
|
"""
|
|||
|
|
初始化环境
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
- config: 环境配置
|
|||
|
|
- agent2policy: 为每个智能体分配的策略
|
|||
|
|
"""
|
|||
|
|
self.policy = agent2policy
|
|||
|
|
self.controlled_agents = {} # 可控车辆字典
|
|||
|
|
self.controlled_agent_ids = [] # 可控车辆ID列表
|
|||
|
|
self.obs_list = [] # 观测列表
|
|||
|
|
self.round = 0 # 当前时间步
|
|||
|
|
super().__init__(config)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
#### 4.1.2 环境重置与车辆生成
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def reset(self, seed: Union[None, int] = None):
|
|||
|
|
"""
|
|||
|
|
重置环境
|
|||
|
|
|
|||
|
|
流程:
|
|||
|
|
1. 解析专家数据中的车辆轨迹
|
|||
|
|
2. 提取车辆生成信息
|
|||
|
|
3. 清理原始数据
|
|||
|
|
4. 初始化场景
|
|||
|
|
5. 生成第一批车辆
|
|||
|
|
"""
|
|||
|
|
self.round = 0
|
|||
|
|
|
|||
|
|
# 日志初始化
|
|||
|
|
if self.logger is None:
|
|||
|
|
self.logger = get_logger()
|
|||
|
|
log_level = self.config.get("log_level", logging.INFO)
|
|||
|
|
set_log_level(log_level)
|
|||
|
|
|
|||
|
|
self.lazy_init()
|
|||
|
|
self._reset_global_seed(seed)
|
|||
|
|
|
|||
|
|
# 步骤1: 解析专家数据
|
|||
|
|
_obj_to_clean_this_frame = []
|
|||
|
|
self.car_birth_info_list = []
|
|||
|
|
|
|||
|
|
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
|||
|
|
# 跳过自车
|
|||
|
|
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 只处理车辆类型
|
|||
|
|
if track["type"] == MetaDriveType.VEHICLE:
|
|||
|
|
_obj_to_clean_this_frame.append(scenario_id)
|
|||
|
|
|
|||
|
|
# 获取有效帧
|
|||
|
|
valid = track['state']['valid']
|
|||
|
|
first_show = np.argmax(valid) if valid.any() else -1
|
|||
|
|
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
|||
|
|
|
|||
|
|
# 提取关键信息
|
|||
|
|
self.car_birth_info_list.append({
|
|||
|
|
'id': track['metadata']['object_id'], # 车辆ID
|
|||
|
|
'show_time': first_show, # 出现时间
|
|||
|
|
'begin': ( # 起点位置
|
|||
|
|
track['state']['position'][first_show, 0],
|
|||
|
|
track['state']['position'][first_show, 1]
|
|||
|
|
),
|
|||
|
|
'heading': track['state']['heading'][first_show], # 初始朝向
|
|||
|
|
'end': ( # 终点位置
|
|||
|
|
track['state']['position'][last_show, 0],
|
|||
|
|
track['state']['position'][last_show, 1]
|
|||
|
|
)
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 步骤2: 清理原始数据(避免重复生成)
|
|||
|
|
for scenario_id in _obj_to_clean_this_frame:
|
|||
|
|
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
|
|||
|
|
|
|||
|
|
# 步骤3: 重置引擎
|
|||
|
|
self.engine.reset()
|
|||
|
|
self.reset_sensors()
|
|||
|
|
self.engine.taskMgr.step()
|
|||
|
|
|
|||
|
|
# 步骤4: 获取车道网络(用于红绿灯检测)
|
|||
|
|
self.lanes = self.engine.map_manager.current_map.road_network.graph
|
|||
|
|
|
|||
|
|
# 步骤5: 清理旧状态
|
|||
|
|
if self.top_down_renderer is not None:
|
|||
|
|
self.top_down_renderer.clear()
|
|||
|
|
self.engine.top_down_renderer = None
|
|||
|
|
|
|||
|
|
self.dones = {}
|
|||
|
|
self.episode_rewards = defaultdict(float)
|
|||
|
|
self.episode_lengths = defaultdict(int)
|
|||
|
|
|
|||
|
|
self.controlled_agents.clear()
|
|||
|
|
self.controlled_agent_ids.clear()
|
|||
|
|
|
|||
|
|
# 步骤6: 初始化场景并生成第一批车辆
|
|||
|
|
super().reset(seed)
|
|||
|
|
self._spawn_controlled_agents()
|
|||
|
|
|
|||
|
|
return self._get_all_obs()
|
|||
|
|
|
|||
|
|
def _spawn_controlled_agents(self):
|
|||
|
|
"""
|
|||
|
|
动态生成可控车辆
|
|||
|
|
|
|||
|
|
根据专家数据中记录的时间戳,在正确的时间点生成车辆
|
|||
|
|
"""
|
|||
|
|
for car in self.car_birth_info_list:
|
|||
|
|
# 检查是否到了该车辆的出现时间
|
|||
|
|
if car['show_time'] == self.round:
|
|||
|
|
agent_id = f"controlled_{car['id']}"
|
|||
|
|
|
|||
|
|
# 生成车辆实例
|
|||
|
|
vehicle = self.engine.spawn_object(
|
|||
|
|
PolicyVehicle, # 自定义车辆类
|
|||
|
|
vehicle_config={},
|
|||
|
|
position=car['begin'], # 初始位置
|
|||
|
|
heading=car['heading'] # 初始朝向
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 重置车辆状态
|
|||
|
|
vehicle.reset(position=car['begin'], heading=car['heading'])
|
|||
|
|
|
|||
|
|
# 设置策略和目标
|
|||
|
|
vehicle.set_policy(self.policy)
|
|||
|
|
vehicle.set_destination(car['end'])
|
|||
|
|
|
|||
|
|
# 注册车辆
|
|||
|
|
self.controlled_agents[agent_id] = vehicle
|
|||
|
|
self.controlled_agent_ids.append(agent_id)
|
|||
|
|
|
|||
|
|
# 关键:注册到引擎的active_agents才能参与物理更新
|
|||
|
|
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
#### 4.1.3 观测系统
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def _get_all_obs(self):
|
|||
|
|
"""
|
|||
|
|
收集所有可控车辆的观测
|
|||
|
|
|
|||
|
|
观测维度构成:
|
|||
|
|
- 位置: 2D (x, y)
|
|||
|
|
- 速度: 2D (vx, vy)
|
|||
|
|
- 朝向: 1D (θ)
|
|||
|
|
- 前向激光雷达: 80D (距离)
|
|||
|
|
- 侧向检测器: 10D (距离)
|
|||
|
|
- 车道线检测: 10D (距离)
|
|||
|
|
- 红绿灯: 1D (0-3编码)
|
|||
|
|
- 导航: 2D (目标点坐标)
|
|||
|
|
|
|||
|
|
总维度: 2+2+1+80+10+10+1+2 = 108D
|
|||
|
|
"""
|
|||
|
|
self.obs_list = []
|
|||
|
|
|
|||
|
|
for agent_id, vehicle in self.controlled_agents.items():
|
|||
|
|
# 获取车辆基础状态
|
|||
|
|
state = vehicle.get_state()
|
|||
|
|
|
|||
|
|
# 红绿灯检测
|
|||
|
|
traffic_light = 0
|
|||
|
|
for lane in self.lanes.values():
|
|||
|
|
# 检查车辆是否在车道上
|
|||
|
|
if lane.lane.point_on_lane(state['position'][:2]):
|
|||
|
|
# 检查该车道是否有红绿灯
|
|||
|
|
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
|||
|
|
traffic_light = self.engine.light_manager._lane_index_to_obj[
|
|||
|
|
lane.lane.index
|
|||
|
|
].status
|
|||
|
|
|
|||
|
|
# 状态编码
|
|||
|
|
if traffic_light == 'TRAFFIC_LIGHT_GREEN':
|
|||
|
|
traffic_light = 1
|
|||
|
|
elif traffic_light == 'TRAFFIC_LIGHT_YELLOW':
|
|||
|
|
traffic_light = 2
|
|||
|
|
elif traffic_light == 'TRAFFIC_LIGHT_RED':
|
|||
|
|
traffic_light = 3
|
|||
|
|
else:
|
|||
|
|
traffic_light = 0
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 激光雷达感知
|
|||
|
|
# 前向激光雷达:80束,30米距离,检测动态物体
|
|||
|
|
lidar = self.engine.get_sensor("lidar").perceive(
|
|||
|
|
num_lasers=80,
|
|||
|
|
distance=30,
|
|||
|
|
base_vehicle=vehicle,
|
|||
|
|
physics_world=self.engine.physics_world.dynamic_world
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 侧向检测器:10束,8米距离,检测静态障碍物
|
|||
|
|
side_lidar = self.engine.get_sensor("side_detector").perceive(
|
|||
|
|
num_lasers=10,
|
|||
|
|
distance=8,
|
|||
|
|
base_vehicle=vehicle,
|
|||
|
|
physics_world=self.engine.physics_world.static_world
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 车道线检测器:10束,3米距离,检测车道边界
|
|||
|
|
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(
|
|||
|
|
num_lasers=10,
|
|||
|
|
distance=3,
|
|||
|
|
base_vehicle=vehicle,
|
|||
|
|
physics_world=self.engine.physics_world.static_world
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 组合观测向量
|
|||
|
|
obs = (
|
|||
|
|
state['position'][:2] + # 位置 (x, y)
|
|||
|
|
list(state['velocity']) + # 速度 (vx, vy)
|
|||
|
|
[state['heading_theta']] + # 朝向 θ
|
|||
|
|
lidar[0] + # 激光雷达 (80D)
|
|||
|
|
side_lidar[0] + # 侧向检测 (10D)
|
|||
|
|
lane_line_lidar[0] + # 车道线 (10D)
|
|||
|
|
[traffic_light] + # 红绿灯 (1D)
|
|||
|
|
list(vehicle.destination) # 目标点 (x, y)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.obs_list.append(obs)
|
|||
|
|
|
|||
|
|
return self.obs_list
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
#### 4.1.4 环境步进
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
|
|||
|
|
"""
|
|||
|
|
执行一步仿真
|
|||
|
|
|
|||
|
|
流程:
|
|||
|
|
1. 应用动作到车辆
|
|||
|
|
2. 运行物理引擎
|
|||
|
|
3. 车辆后处理
|
|||
|
|
4. 生成新车辆
|
|||
|
|
5. 收集观测和奖励
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
- action_dict: {agent_id: action} 字典
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
- obs: 观测列表
|
|||
|
|
- rewards: 奖励字典
|
|||
|
|
- dones: 完成标志字典
|
|||
|
|
- infos: 信息字典
|
|||
|
|
"""
|
|||
|
|
self.round += 1
|
|||
|
|
|
|||
|
|
# 步骤1: 应用动作
|
|||
|
|
for agent_id, action in action_dict.items():
|
|||
|
|
if agent_id in self.controlled_agents:
|
|||
|
|
self.controlled_agents[agent_id].before_step(action)
|
|||
|
|
|
|||
|
|
# 步骤2: 物理仿真
|
|||
|
|
self.engine.step()
|
|||
|
|
|
|||
|
|
# 步骤3: 车辆后处理
|
|||
|
|
for agent_id in action_dict:
|
|||
|
|
if agent_id in self.controlled_agents:
|
|||
|
|
self.controlled_agents[agent_id].after_step()
|
|||
|
|
|
|||
|
|
# 步骤4: 动态生成新车辆
|
|||
|
|
self._spawn_controlled_agents()
|
|||
|
|
|
|||
|
|
# 步骤5: 收集观测
|
|||
|
|
obs = self._get_all_obs()
|
|||
|
|
|
|||
|
|
# 步骤6: 计算奖励和完成标志
|
|||
|
|
rewards = {aid: 0.0 for aid in self.controlled_agents}
|
|||
|
|
dones = {aid: False for aid in self.controlled_agents}
|
|||
|
|
dones["__all__"] = self.episode_step >= self.config["horizon"]
|
|||
|
|
infos = {aid: {} for aid in self.controlled_agents}
|
|||
|
|
|
|||
|
|
return obs, rewards, dones, infos
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 4.2 自定义车辆类
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
class PolicyVehicle(DefaultVehicle):
|
|||
|
|
"""
|
|||
|
|
策略控制车辆
|
|||
|
|
|
|||
|
|
扩展MetaDrive的默认车辆,添加策略和目标点
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, *args, **kwargs):
|
|||
|
|
super().__init__(*args, **kwargs)
|
|||
|
|
self.policy = None # 控制策略
|
|||
|
|
self.destination = None # 目标点
|
|||
|
|
|
|||
|
|
def set_policy(self, policy):
|
|||
|
|
"""设置控制策略"""
|
|||
|
|
self.policy = policy
|
|||
|
|
|
|||
|
|
def set_destination(self, des):
|
|||
|
|
"""设置目标点"""
|
|||
|
|
self.destination = des
|
|||
|
|
|
|||
|
|
def act(self, observation, policy=None):
|
|||
|
|
"""
|
|||
|
|
执行动作
|
|||
|
|
|
|||
|
|
如果有策略,使用策略;否则随机动作
|
|||
|
|
"""
|
|||
|
|
if self.policy is not None:
|
|||
|
|
return self.policy.act(observation)
|
|||
|
|
else:
|
|||
|
|
return self.action_space.sample()
|
|||
|
|
|
|||
|
|
def before_step(self, action):
|
|||
|
|
"""
|
|||
|
|
步进前处理
|
|||
|
|
|
|||
|
|
记录历史状态并应用动作
|
|||
|
|
"""
|
|||
|
|
self.last_position = self.position
|
|||
|
|
self.last_velocity = self.velocity
|
|||
|
|
self.last_speed = self.speed
|
|||
|
|
self.last_heading_dir = self.heading
|
|||
|
|
|
|||
|
|
if action is not None:
|
|||
|
|
self.last_current_action.append(action)
|
|||
|
|
|
|||
|
|
# 将动作转换为车辆控制指令
|
|||
|
|
self._set_action(action)
|
|||
|
|
|
|||
|
|
# 注册车辆类型
|
|||
|
|
vehicle_class_to_type[PolicyVehicle] = "default"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 数据流与训练流程
|
|||
|
|
|
|||
|
|
### 5.1 完整训练流程
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""
|
|||
|
|
训练流程伪代码
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 1. 初始化
|
|||
|
|
env = MultiAgentScenarioEnv(config, policy)
|
|||
|
|
magail = MAGAIL(buffer_exp, input_dim, device)
|
|||
|
|
|
|||
|
|
# 2. 加载专家数据
|
|||
|
|
buffer_exp = load_expert_data("waymo_dataset")
|
|||
|
|
|
|||
|
|
# 3. 训练循环
|
|||
|
|
for episode in range(max_episodes):
|
|||
|
|
# 3.1 重置环境
|
|||
|
|
obs_list = env.reset()
|
|||
|
|
|
|||
|
|
for step in range(max_steps):
|
|||
|
|
# 3.2 策略采样动作
|
|||
|
|
actions, log_pis = magail.explore(obs_list)
|
|||
|
|
|
|||
|
|
# 3.3 环境交互
|
|||
|
|
next_obs_list, rewards, dones, infos = env.step(actions)
|
|||
|
|
|
|||
|
|
# 3.4 存储经验
|
|||
|
|
magail.buffer.append(
|
|||
|
|
obs_list, actions, rewards, dones,
|
|||
|
|
log_pis, next_obs_list
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
obs_list = next_obs_list
|
|||
|
|
|
|||
|
|
# 3.5 判断是否更新
|
|||
|
|
if magail.is_update(step):
|
|||
|
|
# 3.5.1 更新判别器
|
|||
|
|
for _ in range(epoch_disc):
|
|||
|
|
# 采样策略数据
|
|||
|
|
states_policy = magail.buffer.sample(batch_size)
|
|||
|
|
# 采样专家数据
|
|||
|
|
states_expert = buffer_exp.sample(batch_size)
|
|||
|
|
# 更新判别器
|
|||
|
|
magail.update_disc(states_policy, states_expert)
|
|||
|
|
|
|||
|
|
# 3.5.2 计算GAIL奖励
|
|||
|
|
rewards_gail = magail.disc.calculate_reward(states, next_states)
|
|||
|
|
|
|||
|
|
# 3.5.3 更新PPO
|
|||
|
|
magail.update_ppo(states, actions, rewards_gail, ...)
|
|||
|
|
|
|||
|
|
# 3.5.4 清空缓冲区
|
|||
|
|
magail.buffer.clear()
|
|||
|
|
|
|||
|
|
if dones["__all__"]:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 3.6 保存模型
|
|||
|
|
if episode % save_interval == 0:
|
|||
|
|
magail.save_models(save_path)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 5.2 数据流图
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
专家数据 策略数据
|
|||
|
|
↓ ↓
|
|||
|
|
┌──────────────┐ ┌──────────────┐
|
|||
|
|
│ Expert Buffer│ │Policy Buffer │
|
|||
|
|
│ (s,a,s') │ │ (s,a,s') │
|
|||
|
|
└──────┬───────┘ └──────┬───────┘
|
|||
|
|
│ │
|
|||
|
|
│ 采样 │ 采样
|
|||
|
|
▼ ▼
|
|||
|
|
┌────────────────────────────┐
|
|||
|
|
│ BERT Discriminator │
|
|||
|
|
│ Input: (N, obs_dim*2) │
|
|||
|
|
│ Output: Real/Fake Score │
|
|||
|
|
└────────────┬───────────────┘
|
|||
|
|
│
|
|||
|
|
│ 梯度反向传播
|
|||
|
|
▼
|
|||
|
|
┌──────────────┐
|
|||
|
|
│ Disc Loss │
|
|||
|
|
│ + Grad Pen │
|
|||
|
|
│ + Logit Reg │
|
|||
|
|
│ + Weight Dec │
|
|||
|
|
└──────────────┘
|
|||
|
|
|
|||
|
|
策略数据
|
|||
|
|
↓
|
|||
|
|
┌──────────────┐
|
|||
|
|
│Policy Buffer │
|
|||
|
|
│ (s,a,r,s') │
|
|||
|
|
└──────┬───────┘
|
|||
|
|
│
|
|||
|
|
│ r_GAIL = -log(1-D(s,s'))
|
|||
|
|
▼
|
|||
|
|
┌────────────────────────────┐
|
|||
|
|
│ PPO Optimization │
|
|||
|
|
│ Actor: MLP │
|
|||
|
|
│ Critic: BERT │
|
|||
|
|
└────────────┬───────────────┘
|
|||
|
|
│
|
|||
|
|
│ 策略改进
|
|||
|
|
▼
|
|||
|
|
┌──────────────┐
|
|||
|
|
│ Environment │
|
|||
|
|
│ Interaction │
|
|||
|
|
└──────────────┘
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 5.3 缓冲区管理
|
|||
|
|
|
|||
|
|
**文件:`Algorithm/buffer.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
class RolloutBuffer:
|
|||
|
|
"""
|
|||
|
|
滚动缓冲区
|
|||
|
|
|
|||
|
|
存储策略与环境交互的经验
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, buffer_size, state_shape, action_shape, device):
|
|||
|
|
"""
|
|||
|
|
初始化缓冲区
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
- buffer_size: 缓冲区大小(通常等于rollout_length)
|
|||
|
|
- state_shape: 状态维度
|
|||
|
|
- action_shape: 动作维度
|
|||
|
|
- device: 存储设备(CPU/GPU)
|
|||
|
|
"""
|
|||
|
|
self._n = 0 # 当前存储数量
|
|||
|
|
self._p = 0 # 当前写入位置
|
|||
|
|
self.buffer_size = buffer_size
|
|||
|
|
|
|||
|
|
# 预分配张量(提高效率)
|
|||
|
|
self.states = torch.empty(
|
|||
|
|
(self.buffer_size, *state_shape),
|
|||
|
|
dtype=torch.float, device=device
|
|||
|
|
)
|
|||
|
|
self.actions = torch.empty(
|
|||
|
|
(self.buffer_size, *action_shape),
|
|||
|
|
dtype=torch.float, device=device
|
|||
|
|
)
|
|||
|
|
self.rewards = torch.empty(
|
|||
|
|
(self.buffer_size, 1),
|
|||
|
|
dtype=torch.float, device=device
|
|||
|
|
)
|
|||
|
|
self.dones = torch.empty(
|
|||
|
|
(self.buffer_size, 1),
|
|||
|
|
dtype=torch.int, device=device
|
|||
|
|
)
|
|||
|
|
self.tm_dones = torch.empty(
|
|||
|
|
(self.buffer_size, 1),
|
|||
|
|
dtype=torch.int, device=device
|
|||
|
|
)
|
|||
|
|
self.log_pis = torch.empty(
|
|||
|
|
(self.buffer_size, 1),
|
|||
|
|
dtype=torch.float, device=device
|
|||
|
|
)
|
|||
|
|
self.next_states = torch.empty(
|
|||
|
|
(self.buffer_size, *state_shape),
|
|||
|
|
dtype=torch.float, device=device
|
|||
|
|
)
|
|||
|
|
self.means = torch.empty(
|
|||
|
|
(self.buffer_size, *action_shape),
|
|||
|
|
dtype=torch.float, device=device
|
|||
|
|
)
|
|||
|
|
self.stds = torch.empty(
|
|||
|
|
(self.buffer_size, *action_shape),
|
|||
|
|
dtype=torch.float, device=device
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def append(self, state, action, reward, done, tm_dones,
|
|||
|
|
log_pi, next_state, next_state_gail, means, stds):
|
|||
|
|
"""
|
|||
|
|
添加经验
|
|||
|
|
|
|||
|
|
使用循环缓冲区,自动覆盖旧数据
|
|||
|
|
"""
|
|||
|
|
self.states[self._p].copy_(state)
|
|||
|
|
self.actions[self._p].copy_(torch.from_numpy(action))
|
|||
|
|
self.rewards[self._p] = float(reward)
|
|||
|
|
self.dones[self._p] = int(done)
|
|||
|
|
self.tm_dones[self._p] = int(tm_dones)
|
|||
|
|
self.log_pis[self._p] = float(log_pi)
|
|||
|
|
self.next_states[self._p].copy_(torch.from_numpy(next_state))
|
|||
|
|
self.means[self._p].copy_(torch.from_numpy(means))
|
|||
|
|
self.stds[self._p].copy_(torch.from_numpy(stds))
|
|||
|
|
|
|||
|
|
# 更新指针
|
|||
|
|
self._p = (self._p + 1) % self.buffer_size
|
|||
|
|
self._n = min(self._n + 1, self.buffer_size)
|
|||
|
|
|
|||
|
|
def get(self):
|
|||
|
|
"""
|
|||
|
|
获取所有数据(用于PPO更新)
|
|||
|
|
"""
|
|||
|
|
assert self._p % self.buffer_size == 0
|
|||
|
|
idxes = slice(0, self.buffer_size)
|
|||
|
|
return (
|
|||
|
|
self.states[idxes],
|
|||
|
|
self.actions[idxes],
|
|||
|
|
self.rewards[idxes],
|
|||
|
|
self.dones[idxes],
|
|||
|
|
self.tm_dones[idxes],
|
|||
|
|
self.log_pis[idxes],
|
|||
|
|
self.next_states[idxes],
|
|||
|
|
self.means[idxes],
|
|||
|
|
self.stds[idxes]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def sample(self, batch_size):
|
|||
|
|
"""
|
|||
|
|
随机采样批次(用于判别器更新)
|
|||
|
|
"""
|
|||
|
|
assert self._p % self.buffer_size == 0
|
|||
|
|
idxes = np.random.randint(low=0, high=self._n, size=batch_size)
|
|||
|
|
return (
|
|||
|
|
self.states[idxes],
|
|||
|
|
self.actions[idxes],
|
|||
|
|
self.rewards[idxes],
|
|||
|
|
self.dones[idxes],
|
|||
|
|
self.tm_dones[idxes],
|
|||
|
|
self.log_pis[idxes],
|
|||
|
|
self.next_states[idxes],
|
|||
|
|
self.means[idxes],
|
|||
|
|
self.stds[idxes]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def clear(self):
|
|||
|
|
"""清空缓冲区"""
|
|||
|
|
self.states[:, :] = 0
|
|||
|
|
self.actions[:, :] = 0
|
|||
|
|
self.rewards[:, :] = 0
|
|||
|
|
self.dones[:, :] = 0
|
|||
|
|
self.tm_dones[:, :] = 0
|
|||
|
|
self.log_pis[:, :] = 0
|
|||
|
|
self.next_states[:, :] = 0
|
|||
|
|
self.means[:, :] = 0
|
|||
|
|
self.stds[:, :] = 0
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 关键技术细节
|
|||
|
|
|
|||
|
|
### 6.1 数据标准化
|
|||
|
|
|
|||
|
|
**文件:`Algorithm/utils.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
class RunningMeanStd(object):
|
|||
|
|
"""
|
|||
|
|
运行时均值和标准差计算
|
|||
|
|
|
|||
|
|
使用Welford在线算法,高效计算流式数据的统计量
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
|
|||
|
|
self.mean = np.zeros(shape, np.float64)
|
|||
|
|
self.var = np.ones(shape, np.float64)
|
|||
|
|
self.count = epsilon
|
|||
|
|
|
|||
|
|
def update(self, arr: np.ndarray) -> None:
|
|||
|
|
"""
|
|||
|
|
更新统计量
|
|||
|
|
|
|||
|
|
使用并行算法合并批次统计量和当前统计量
|
|||
|
|
"""
|
|||
|
|
batch_mean = np.mean(arr, axis=0)
|
|||
|
|
batch_var = np.var(arr, axis=0)
|
|||
|
|
batch_count = arr.shape[0]
|
|||
|
|
self.update_from_moments(batch_mean, batch_var, batch_count)
|
|||
|
|
|
|||
|
|
def update_from_moments(self, batch_mean, batch_var, batch_count):
|
|||
|
|
"""
|
|||
|
|
从矩更新统计量
|
|||
|
|
|
|||
|
|
参考:https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
|
|||
|
|
"""
|
|||
|
|
delta = batch_mean - self.mean
|
|||
|
|
tot_count = self.count + batch_count
|
|||
|
|
|
|||
|
|
# 新均值
|
|||
|
|
new_mean = self.mean + delta * batch_count / tot_count
|
|||
|
|
|
|||
|
|
# 新方差
|
|||
|
|
m_a = self.var * self.count
|
|||
|
|
m_b = batch_var * batch_count
|
|||
|
|
m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
|
|||
|
|
new_var = m_2 / tot_count
|
|||
|
|
|
|||
|
|
# 更新
|
|||
|
|
self.mean = new_mean
|
|||
|
|
self.var = new_var
|
|||
|
|
self.count = tot_count
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Normalizer(RunningMeanStd):
|
|||
|
|
"""
|
|||
|
|
数据标准化器
|
|||
|
|
|
|||
|
|
提供标准化和逆标准化功能
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, input_dim, epsilon=1e-4, clip_obs=10.0):
|
|||
|
|
super().__init__(shape=input_dim)
|
|||
|
|
self.epsilon = epsilon
|
|||
|
|
self.clip_obs = clip_obs
|
|||
|
|
|
|||
|
|
def normalize(self, input):
|
|||
|
|
"""
|
|||
|
|
标准化(NumPy版本)
|
|||
|
|
|
|||
|
|
公式:(x - μ) / √(σ² + ε)
|
|||
|
|
"""
|
|||
|
|
return np.clip(
|
|||
|
|
(input - self.mean) / np.sqrt(self.var + self.epsilon),
|
|||
|
|
-self.clip_obs, self.clip_obs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def normalize_torch(self, input, device):
|
|||
|
|
"""
|
|||
|
|
标准化(PyTorch版本)
|
|||
|
|
|
|||
|
|
用于在GPU上高效计算
|
|||
|
|
"""
|
|||
|
|
mean_torch = torch.tensor(
|
|||
|
|
self.mean, device=device, dtype=torch.float32
|
|||
|
|
)
|
|||
|
|
std_torch = torch.sqrt(torch.tensor(
|
|||
|
|
self.var + self.epsilon, device=device, dtype=torch.float32
|
|||
|
|
))
|
|||
|
|
return torch.clamp(
|
|||
|
|
(input - mean_torch) / std_torch,
|
|||
|
|
-self.clip_obs, self.clip_obs
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 6.2 为什么使用BERT架构
|
|||
|
|
|
|||
|
|
**传统MLP的问题:**
|
|||
|
|
```python
|
|||
|
|
# 假设场景中有N辆车,每辆车观测维度为D
|
|||
|
|
# 传统方法:拼接所有车辆观测
|
|||
|
|
input = concat([obs_1, obs_2, ..., obs_N]) # 维度: N*D
|
|||
|
|
output = MLP(input)
|
|||
|
|
|
|||
|
|
# 问题:
|
|||
|
|
# 1. 输入维度N*D随N变化,需要重新训练网络
|
|||
|
|
# 2. 不同位置的车辆语义相同,但MLP无法共享权重
|
|||
|
|
# 3. 无法处理车辆间的交互关系
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
**BERT架构的优势:**
|
|||
|
|
```python
|
|||
|
|
# BERT方法:将车辆观测视为序列
|
|||
|
|
input = [obs_1, obs_2, ..., obs_N] # 序列长度N可变
|
|||
|
|
embeddings = [Linear(obs_i) for obs_i in input] # 共享权重
|
|||
|
|
output = Transformer(embeddings) # 自注意力捕捉交互
|
|||
|
|
|
|||
|
|
# 优势:
|
|||
|
|
# 1. 序列长度可变,无需固定N
|
|||
|
|
# 2. Linear投影层参数共享,泛化性好
|
|||
|
|
# 3. 自注意力机制建模车辆间交互
|
|||
|
|
# 4. 位置编码区分不同车辆
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 6.3 梯度惩罚详解
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""
|
|||
|
|
梯度惩罚(Gradient Penalty)详解
|
|||
|
|
|
|||
|
|
目标:确保判别器是Lipschitz连续的
|
|||
|
|
即:|D(x1) - D(x2)| ≤ K|x1 - x2|
|
|||
|
|
|
|||
|
|
为什么需要:
|
|||
|
|
1. WGAN理论要求判别器是Lipschitz连续
|
|||
|
|
2. 防止判别器梯度过大,提高训练稳定性
|
|||
|
|
3. 避免模式崩溃(mode collapse)
|
|||
|
|
|
|||
|
|
实现:
|
|||
|
|
对于专家数据x,惩罚 ||∇_x D(x)||²
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 步骤1: 使专家数据需要梯度
|
|||
|
|
sample_expert = states_exp_cp
|
|||
|
|
sample_expert.requires_grad = True
|
|||
|
|
|
|||
|
|
# 步骤2: 前向传播
|
|||
|
|
disc = self.disc.linear(self.disc.trunk(sample_expert))
|
|||
|
|
|
|||
|
|
# 步骤3: 计算梯度
|
|||
|
|
ones = torch.ones(disc.size(), device=disc.device)
|
|||
|
|
disc_demo_grad = torch.autograd.grad(
|
|||
|
|
disc, sample_expert,
|
|||
|
|
grad_outputs=ones, # ∂disc/∂x
|
|||
|
|
create_graph=True, # 保留计算图(二阶导数)
|
|||
|
|
retain_graph=True,
|
|||
|
|
only_inputs=True
|
|||
|
|
)[0]
|
|||
|
|
|
|||
|
|
# 步骤4: 梯度的L2范数
|
|||
|
|
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
|
|||
|
|
grad_pen_loss = torch.mean(disc_demo_grad)
|
|||
|
|
|
|||
|
|
# 步骤5: 添加到总损失
|
|||
|
|
loss += self.disc_grad_penalty * grad_pen_loss
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 6.4 自适应学习率
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""
|
|||
|
|
基于KL散度的自适应学习率
|
|||
|
|
|
|||
|
|
原理:
|
|||
|
|
PPO希望策略更新不要太大,通过监控KL散度来调整学习率
|
|||
|
|
|
|||
|
|
KL(π_old || π_new) = 期望策略变化程度
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 计算KL散度(对角高斯分布)
|
|||
|
|
kl = torch.sum(
|
|||
|
|
torch.log(sigmas_new / sigmas_old + 1e-5) + # 标准差比值的对数
|
|||
|
|
(sigmas_old² + (mus_old - mus_new)²) / (2 * sigmas_new²) - 0.5,
|
|||
|
|
axis=-1
|
|||
|
|
)
|
|||
|
|
kl_mean = torch.mean(kl)
|
|||
|
|
|
|||
|
|
# 调整学习率
|
|||
|
|
if kl_mean > desired_kl * 2.0:
|
|||
|
|
# KL过大:策略变化太剧烈,降低学习率
|
|||
|
|
lr_actor *= 0.67 # 除以1.5
|
|||
|
|
lr_critic *= 0.67
|
|||
|
|
elif kl_mean < desired_kl / 2.0:
|
|||
|
|
# KL过小:策略变化太保守,提高学习率
|
|||
|
|
lr_actor *= 1.5
|
|||
|
|
lr_critic *= 1.5
|
|||
|
|
|
|||
|
|
# 应用新学习率
|
|||
|
|
for param_group in optimizer.param_groups:
|
|||
|
|
param_group['lr'] = lr_actor
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 使用指南
|
|||
|
|
|
|||
|
|
### 7.1 环境依赖安装
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
# 创建虚拟环境
|
|||
|
|
conda create -n magail python=3.8
|
|||
|
|
conda activate magail
|
|||
|
|
|
|||
|
|
# 安装PyTorch(根据CUDA版本)
|
|||
|
|
pip install torch==1.12.0 torchvision torchaudio
|
|||
|
|
|
|||
|
|
# 安装MetaDrive仿真环境
|
|||
|
|
pip install metadrive-simulator
|
|||
|
|
|
|||
|
|
# 安装其他依赖
|
|||
|
|
pip install numpy pandas matplotlib tensorboard tqdm gym
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 7.2 运行环境测试
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
# 进入项目目录
|
|||
|
|
cd /path/to/MAGAIL4AutoDrive
|
|||
|
|
|
|||
|
|
# 运行环境测试(需先修复导入路径)
|
|||
|
|
python Env/run_multiagent_env.py
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 7.3 训练模型
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""
|
|||
|
|
训练脚本示例(需要创建)
|
|||
|
|
"""
|
|||
|
|
import torch
|
|||
|
|
from Algorithm.magail import MAGAIL
|
|||
|
|
from Env.scenario_env import MultiAgentScenarioEnv
|
|||
|
|
from metadrive.engine.asset_loader import AssetLoader
|
|||
|
|
|
|||
|
|
# 配置
|
|||
|
|
config = {
|
|||
|
|
"data_directory": "/path/to/exp_converted",
|
|||
|
|
"is_multi_agent": True,
|
|||
|
|
"num_controlled_agents": 5,
|
|||
|
|
"horizon": 300,
|
|||
|
|
"use_render": False,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 创建环境
|
|||
|
|
env = MultiAgentScenarioEnv(config, policy)
|
|||
|
|
|
|||
|
|
# 创建算法
|
|||
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||
|
|
magail = MAGAIL(
|
|||
|
|
buffer_exp=expert_buffer,
|
|||
|
|
input_dim=obs_dim,
|
|||
|
|
device=device,
|
|||
|
|
lr_disc=1e-3,
|
|||
|
|
epoch_disc=50
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 训练循环
|
|||
|
|
for episode in range(10000):
|
|||
|
|
obs = env.reset()
|
|||
|
|
episode_reward = 0
|
|||
|
|
|
|||
|
|
for step in range(config["horizon"]):
|
|||
|
|
actions, log_pis = magail.explore(obs)
|
|||
|
|
next_obs, rewards, dones, infos = env.step(actions)
|
|||
|
|
|
|||
|
|
magail.buffer.append(...)
|
|||
|
|
|
|||
|
|
if magail.is_update(step):
|
|||
|
|
reward = magail.update(writer, total_steps)
|
|||
|
|
episode_reward += reward
|
|||
|
|
|
|||
|
|
obs = next_obs
|
|||
|
|
|
|||
|
|
if dones["__all__"]:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
print(f"Episode {episode}, Reward: {episode_reward}")
|
|||
|
|
|
|||
|
|
if episode % 100 == 0:
|
|||
|
|
magail.save_models(f"models/episode_{episode}")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### 7.4 关键参数说明
|
|||
|
|
|
|||
|
|
| 参数 | 说明 | 推荐值 |
|
|||
|
|
|------|------|--------|
|
|||
|
|
| `embed_dim` | BERT嵌入维度 | 128 |
|
|||
|
|
| `num_layers` | Transformer层数 | 4 |
|
|||
|
|
| `num_heads` | 注意力头数 | 4 |
|
|||
|
|
| `lr_disc` | 判别器学习率 | 1e-3 |
|
|||
|
|
| `lr_actor` | Actor学习率 | 1e-3 |
|
|||
|
|
| `lr_critic` | Critic学习率 | 1e-3 |
|
|||
|
|
| `epoch_disc` | 判别器更新轮数 | 50 |
|
|||
|
|
| `epoch_ppo` | PPO更新轮数 | 10 |
|
|||
|
|
| `disc_grad_penalty` | 梯度惩罚系数 | 0.1 |
|
|||
|
|
| `disc_logit_reg` | Logit正则化系数 | 0.25 |
|
|||
|
|
| `gamma` | 折扣因子 | 0.995 |
|
|||
|
|
| `lambd` | GAE λ参数 | 0.97 |
|
|||
|
|
| `clip_eps` | PPO裁剪参数 | 0.2 |
|
|||
|
|
|
|||
|
|
### 7.5 常见问题
|
|||
|
|
|
|||
|
|
**Q1: 为什么判别器准确率总是50%?**
|
|||
|
|
- 这是正常现象,说明判别器无法区分策略和专家
|
|||
|
|
- 表示策略已经学习到接近专家的行为
|
|||
|
|
|
|||
|
|
**Q2: 训练不稳定怎么办?**
|
|||
|
|
- 增大梯度惩罚系数
|
|||
|
|
- 降低学习率
|
|||
|
|
- 增加数据标准化
|
|||
|
|
|
|||
|
|
**Q3: 如何调整奖励权重?**
|
|||
|
|
- `reward_t_coef`: 任务奖励权重
|
|||
|
|
- `reward_i_coef`: 模仿奖励权重
|
|||
|
|
- 通常设置为1:1或调整以平衡两者
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 总结
|
|||
|
|
|
|||
|
|
MAGAIL4AutoDrive项目通过以下技术创新实现了多智能体自动驾驶的模仿学习:
|
|||
|
|
|
|||
|
|
1. **BERT判别器**:使用Transformer架构处理动态数量的车辆
|
|||
|
|
2. **GAIL框架**:通过对抗训练学习专家策略
|
|||
|
|
3. **PPO优化**:稳定的策略梯度方法
|
|||
|
|
4. **多维观测**:融合多种传感器信息
|
|||
|
|
5. **真实数据**:利用Waymo等真实驾驶数据
|
|||
|
|
|
|||
|
|
该项目为多智能体自动驾驶提供了一个完整的解决方案,具有良好的可扩展性和实用价值。
|
|||
|
|
|