Files
MAGAIL4AutoDrive/技术说明文档.md
huangfu b626702cbb 上传文件至 /
Signed-off-by: huangfu <3045324663@qq.com>
2025-10-21 18:23:00 +08:00

53 KiB
Raw Blame History

MAGAIL4AutoDrive 技术说明文档

目录

  1. 项目概述
  2. 核心技术架构
  3. 算法模块详解
  4. 环境系统实现
  5. 数据流与训练流程
  6. 关键技术细节
  7. 使用指南

项目概述

背景与动机

MAGAIL4AutoDriveMulti-Agent Generative Adversarial Imitation Learning for Autonomous Driving是一个针对多智能体自动驾驶场景的模仿学习框架。项目的核心创新在于将单智能体GAIL算法扩展到多智能体场景解决了车辆数量动态变化时的学习问题。

核心挑战

  1. 动态输入维度:多智能体场景中车辆数量不固定,传统固定维度的神经网络无法直接应用
  2. 全局交互建模:需要同时考虑所有车辆的交互行为,而非独立建模
  3. 真实数据利用如何有效利用Waymo等真实驾驶数据进行训练

技术方案

  • BERT架构判别器使用Transformer处理变长序列输入
  • GAIL框架:通过对抗训练学习专家行为
  • PPO优化:稳定的策略梯度方法
  • MetaDrive仿真:高保真多智能体交通仿真环境

核心技术架构

整体架构图

┌─────────────────────────────────────────────────────────────┐
│                      MAGAIL训练系统                          │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  ┌──────────────┐      ┌──────────────┐                    │
│  │  专家数据库   │      │  策略缓冲区   │                    │
│  │(Waymo轨迹)   │      │(Agent经验)   │                    │
│  └──────┬───────┘      └──────┬───────┘                    │
│         │                     │                             │
│         │  状态-动作对        │  状态-动作对                 │
│         ▼                     ▼                             │
│  ┌──────────────────────────────────────┐                  │
│  │     BERT判别器 (Discriminator)        │                  │
│  │  ┌────────────────────────────────┐  │                  │
│  │  │ Input: (N, obs_dim)            │  │                  │
│  │  │ ↓                              │  │                  │
│  │  │ Linear Projection → embed_dim  │  │                  │
│  │  │ ↓                              │  │                  │
│  │  │ + Positional Encoding          │  │                  │
│  │  │ ↓                              │  │                  │
│  │  │ Transformer Layers (×4)        │  │                  │
│  │  │ ↓                              │  │                  │
│  │  │ Mean Pooling / CLS Token       │  │                  │
│  │  │ ↓                              │  │                  │
│  │  │ Output: Real/Fake Score        │  │                  │
│  │  └────────────────────────────────┘  │                  │
│  └──────────────┬───────────────────────┘                  │
│                 │                                           │
│                 │  Reward Signal                            │
│                 ▼                                           │
│  ┌──────────────────────────────────────┐                  │
│  │     PPO策略优化 (Policy)              │                  │
│  │  ┌────────────────────────────────┐  │                  │
│  │  │ Actor Network (MLP)            │  │                  │
│  │  │ Input: state → Action dist     │  │                  │
│  │  ├────────────────────────────────┤  │                  │
│  │  │ Critic Network (BERT)          │  │                  │
│  │  │ Input: state → Value estimate  │  │                  │
│  │  └────────────────────────────────┘  │                  │
│  └──────────────┬───────────────────────┘                  │
│                 │                                           │
│                 │  Actions                                  │
│                 ▼                                           │
│  ┌──────────────────────────────────────┐                  │
│  │   MetaDrive多智能体环境               │                  │
│  │  • 车辆动力学仿真                     │                  │
│  │  • 多维度传感器(激光雷达等)          │                  │
│  │  • 红绿灯、车道线等交通元素            │                  │
│  └──────────────────────────────────────┘                  │
│                                                              │
└─────────────────────────────────────────────────────────────┘

算法模块详解

3.1 BERT判别器实现

3.1.1 核心设计思想

BERT判别器是本项目的核心创新。传统GAIL的判别器使用固定维度的MLP无法处理多智能体场景下车辆数量变化的问题。本项目采用Transformer架构将多个车辆的观测视为序列通过自注意力机制捕捉车辆间的交互。

3.1.2 代码实现

文件:Algorithm/bert.py

class Bert(nn.Module):
    def __init__(self, input_dim, output_dim, embed_dim=128,
                 num_layers=4, ff_dim=512, num_heads=4, dropout=0.1, 
                 CLS=False, TANH=False):
        """
        BERT判别器/价值网络
        
        参数说明:
        - input_dim: 单个车辆的观测维度
        - output_dim: 输出维度判别器为1价值网络为1
        - embed_dim: Transformer嵌入维度默认128
        - num_layers: Transformer层数默认4层
        - ff_dim: 前馈网络维度默认512
        - num_heads: 多头注意力头数默认4
        - CLS: 是否使用CLS token进行特征聚合
        - TANH: 输出层是否使用Tanh激活
        """
        super().__init__()
        self.CLS = CLS
        
        # 线性投影层:将观测维度映射到嵌入维度
        self.projection = nn.Linear(input_dim, embed_dim)
        
        # 位置编码:为每个车辆位置添加可学习的编码
        if self.CLS:
            # CLS模式需要额外的CLS token位置
            self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
            self.pos_embed = nn.Parameter(torch.randn(1, input_dim + 1, embed_dim))
        else:
            # 均值池化模式
            self.pos_embed = nn.Parameter(torch.randn(1, input_dim, embed_dim))
        
        # Transformer编码器层
        self.layers = nn.ModuleList([
            TransformerLayer(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        
        # 分类头
        if TANH:
            self.classifier = nn.Sequential(
                nn.Linear(embed_dim, output_dim), 
                nn.Tanh()
            )
        else:
            self.classifier = nn.Linear(embed_dim, output_dim)

    def forward(self, x, mask=None):
        """
        前向传播
        
        输入:
        - x: (batch_size, seq_len, input_dim)
             seq_len = 车辆数量(动态变化)
             input_dim = 单车辆观测维度
        - mask: 可选的注意力掩码
        
        输出:
        - (batch_size, output_dim) 判别分数或价值估计
        """
        # 步骤1: 线性投影
        # 将每个车辆的观测映射到固定的嵌入空间
        x = self.projection(x)  # (batch_size, seq_len, embed_dim)
        
        batch_size = x.size(0)
        
        # 步骤2: 添加CLS token如果启用
        if self.CLS:
            # 复制CLS token到batch中的每个样本
            cls_tokens = self.cls_token.expand(batch_size, -1, -1)
            x = torch.cat([cls_tokens, x], dim=1)
        
        # 步骤3: 添加位置编码
        # 让模型知道每个车辆在序列中的位置
        x = x + self.pos_embed
        
        # 步骤4: 转置为Transformer期望的格式
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, embed_dim)
        
        # 步骤5: 通过Transformer层
        # 每层进行自注意力计算,捕捉车辆间的交互
        for layer in self.layers:
            x = layer(x, mask=mask)
        
        # 步骤6: 特征聚合
        if self.CLS:
            # CLS模式取CLS token的输出
            return self.classifier(x[0, :, :])
        else:
            # 均值池化:对所有车辆特征求平均
            pooled = x.mean(dim=0)  # (batch_size, embed_dim)
            return self.classifier(pooled)

Transformer层实现

class TransformerLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        """
        Transformer编码器层
        
        结构:
        1. 多头自注意力 + 残差连接 + LayerNorm
        2. 前馈网络 + 残差连接 + LayerNorm
        """
        super().__init__()
        # 多头自注意力
        self.self_attn = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=dropout
        )
        
        # 前馈网络
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embed_dim)
        
        # 归一化层
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, x, mask=None):
        """
        前向传播Post-LN结构
        """
        # 注意力模块
        attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = x + self.dropout(attn_output)  # 残差连接
        x = self.norm1(x)  # 归一化
        
        # 前馈网络模块
        ff_output = self.linear2(
            self.dropout(self.activation(self.linear1(x)))
        )
        x = x + self.dropout(ff_output)  # 残差连接
        x = self.norm2(x)  # 归一化
        
        return x

3.1.3 关键技术点

  1. 动态序列处理

    • 输入维度为(batch_size, N, obs_dim)其中N是车辆数量
    • N可以在不同batch中变化无需固定
  2. 位置编码

    • 使用可学习的位置编码而非正弦编码
    • 让模型能够区分不同位置的车辆
  3. 自注意力机制

    • 计算每个车辆与其他车辆的注意力权重
    • 捕捉车辆间的交互和影响关系
  4. 特征聚合

    • CLS模式专门的分类token类似BERT
    • 均值池化:简单但有效的全局特征提取

3.2 GAIL判别器

3.2.1 判别器设计

文件:Algorithm/disc.py

class GAILDiscrim(Bert):
    """
    GAIL判别器继承自BERT架构
    
    功能:
    1. 区分专家数据和策略生成数据
    2. 计算模仿学习的内在奖励
    """
    
    def __init__(self, input_dim, reward_i_coef=1.0, 
                 reward_t_coef=1.0, normalizer=None, device=None):
        """
        初始化判别器
        
        参数:
        - input_dim: 输入维度(状态+下一状态)
        - reward_i_coef: 内在奖励系数
        - reward_t_coef: 任务奖励系数
        """
        # 调用BERT构造函数输出维度为1真假分数
        super().__init__(input_dim=input_dim, output_dim=1, TANH=False)
        
        self.device = device
        self.reward_t_coef = reward_t_coef
        self.reward_i_coef = reward_i_coef
        self.normalizer = normalizer

    def calculate_reward(self, states_gail, next_states_gail, rewards_t):
        """
        计算GAIL奖励
        
        GAIL的核心思想
        - 判别器D(s,s')输出越小,说明越像专家,奖励越高
        - 使用 -log(1-D) 作为内在奖励
        
        参数:
        - states_gail: 当前状态
        - next_states_gail: 下一状态
        - rewards_t: 环境任务奖励
        
        返回:
        - rewards: 总奖励
        - rewards_t: 归一化后的任务奖励
        - rewards_i: 归一化后的内在奖励
        """
        states_gail = states_gail.clone()
        next_states_gail = next_states_gail.clone()
        
        # 拼接状态转移对
        states = torch.cat([states_gail, next_states_gail], dim=-1)
        
        with torch.no_grad():
            # 数据标准化
            if self.normalizer is not None:
                states = self.normalizer.normalize_torch(states, self.device)
            
            # 缩放任务奖励
            rewards_t = self.reward_t_coef * rewards_t
            
            # 获取判别器输出logit
            d = self.forward(states)
            
            # 转换为概率sigmoid(d) = 1/(1+exp(-d))
            prob = 1 / (1 + torch.exp(-d))
            
            # GAIL奖励公式-log(1-D(s,s'))
            # 当D(s,s')接近0像专家奖励高
            # 当D(s,s')接近1不像专家奖励低
            rewards_i = self.reward_i_coef * (
                -torch.log(torch.maximum(
                    1 - prob, 
                    torch.tensor(0.0001, device=self.device)
                ))
            )
            
            # 组合奖励
            rewards = rewards_t + rewards_i
        
        return (rewards, 
                rewards_t / (self.reward_t_coef + 1e-10), 
                rewards_i / (self.reward_i_coef + 1e-10))

    def get_disc_logit_weights(self):
        """获取输出层权重(用于正则化)"""
        return torch.flatten(self.classifier.weight)

    def get_disc_weights(self):
        """获取所有层权重(用于权重衰减)"""
        weights = []
        for m in self.layers.modules():
            if isinstance(m, nn.Linear):
                weights.append(torch.flatten(m.weight))
        weights.append(torch.flatten(self.classifier.weight))
        return weights

3.2.2 判别器训练

文件:Algorithm/magail.py

def update_disc(self, states, states_exp, writer):
    """
    更新判别器
    
    目标:最大化 E_expert[log D] + E_policy[log(1-D)]
    
    参数:
    - states: 策略生成的状态转移
    - states_exp: 专家演示的状态转移
    """
    states_cp = states.clone()
    states_exp_cp = states_exp.clone()
    
    # 步骤1: 获取判别器输出
    logits_pi = self.disc(states_cp)     # 策略数据
    logits_exp = self.disc(states_exp_cp) # 专家数据
    
    # 步骤2: 计算对抗损失
    # 希望logits_pi < 0策略被识别为假
    #       logits_exp > 0专家被识别为真
    loss_pi = -F.logsigmoid(-logits_pi).mean()   # -log(1-sigmoid(logits_pi))
    loss_exp = -F.logsigmoid(logits_exp).mean()  # -log(sigmoid(logits_exp))
    loss_disc = 0.5 * (loss_pi + loss_exp)
    
    # 步骤3: Logit正则化
    # 防止判别器输出过大,导致梯度爆炸
    logit_weights = self.disc.get_disc_logit_weights()
    disc_logit_loss = torch.sum(torch.square(logit_weights))
    
    # 步骤4: 梯度惩罚Gradient Penalty
    # 确保判别器满足Lipschitz约束提高训练稳定性
    sample_expert = states_exp_cp
    sample_expert.requires_grad = True
    
    # 对专家数据计算判别器输出
    disc = self.disc.linear(self.disc.trunk(sample_expert))
    ones = torch.ones(disc.size(), device=disc.device)
    
    # 计算梯度
    disc_demo_grad = torch.autograd.grad(
        disc, sample_expert,
        grad_outputs=ones,
        create_graph=True, 
        retain_graph=True, 
        only_inputs=True
    )[0]
    
    # 梯度的L2范数
    disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
    grad_pen_loss = torch.mean(disc_demo_grad)
    
    # 步骤5: 权重衰减L2正则化
    disc_weights = self.disc.get_disc_weights()
    disc_weights = torch.cat(disc_weights, dim=-1)
    disc_weight_decay = torch.sum(torch.square(disc_weights))
    
    # 步骤6: 组合损失并更新
    loss = (self.disc_coef * loss_disc + 
            self.disc_grad_penalty * grad_pen_loss + 
            self.disc_logit_reg * disc_logit_loss + 
            self.disc_weight_decay * disc_weight_decay)
    
    self.optim_d.zero_grad()
    loss.backward()
    self.optim_d.step()
    
    # 步骤7: 记录训练指标
    if self.learning_steps_disc % self.epoch_disc == 0:
        writer.add_scalar('Loss/disc', loss_disc.item(), self.learning_steps)
        
        with torch.no_grad():
            # 判别器准确率
            acc_pi = (logits_pi < 0).float().mean().item()   # 策略识别准确率
            acc_exp = (logits_exp > 0).float().mean().item() # 专家识别准确率
        
        writer.add_scalar('Acc/acc_pi', acc_pi, self.learning_steps)
        writer.add_scalar('Acc/acc_exp', acc_exp, self.learning_steps)

3.2.3 关键技术细节

1. 梯度惩罚的作用

# 梯度惩罚确保判别器是Lipschitz连续的
# 即:|D(x1) - D(x2)| ≤ K|x1 - x2|
# 这防止判别器变化过于剧烈,提高训练稳定性

2. 为什么使用logit而非概率

# 使用logit未经sigmoid的输出有几个优点
# 1. 数值稳定性避免log(0)等问题
# 2. 梯度更好sigmoid饱和区梯度消失
# 3. 理论保证GAIL理论基于logit形式

3.3 PPO策略优化

3.3.1 策略网络

文件:Algorithm/policy.py

class StateIndependentPolicy(nn.Module):
    """
    状态独立策略(对角高斯策略)
    
    输出:高斯分布的均值,标准差是可学习参数
    """
    
    def __init__(self, state_shape, action_shape, 
                 hidden_units=(64, 64),
                 hidden_activation=nn.Tanh()):
        super().__init__()
        
        # 均值网络MLP
        self.net = build_mlp(
            input_dim=state_shape[0],
            output_dim=action_shape[0],
            hidden_units=hidden_units,
            hidden_activation=hidden_activation
        )
        
        # 可学习的对数标准差
        self.log_stds = nn.Parameter(torch.zeros(1, action_shape[0]))
        self.means = None

    def forward(self, states):
        """
        确定性前向传播(用于评估)
        """
        return torch.tanh(self.net(states))

    def sample(self, states):
        """
        从策略分布中采样动作
        
        使用重参数化技巧:
        a = tanh(μ + σ * ε), ε ~ N(0,1)
        """
        self.means = self.net(states)
        actions, log_pis = reparameterize(self.means, self.log_stds)
        return actions, log_pis

    def evaluate_log_pi(self, states, actions):
        """
        计算给定状态-动作对的对数概率
        """
        self.means = self.net(states)
        return evaluate_lop_pi(self.means, self.log_stds, actions)

重参数化技巧:

def reparameterize(means, log_stds):
    """
    重参数化采样
    
    原理:
    不直接从N(μ,σ²)采样,而是:
    1. 从N(0,1)采样噪声ε
    2. 计算 z = μ + σ * ε
    3. 应用tanha = tanh(z)
    
    优点:
    - 梯度可以通过μ和σ反向传播
    - 支持梯度下降优化
    """
    noises = torch.randn_like(means)  # ε ~ N(0,1)
    us = means + noises * log_stds.exp()  # z = μ + σε
    actions = torch.tanh(us)  # a = tanh(z)
    
    # 计算对数概率需要考虑tanh的雅可比行列式
    return actions, calculate_log_pi(log_stds, noises, actions)

def calculate_log_pi(log_stds, noises, actions):
    """
    计算tanh高斯分布的对数概率
    
    公式:
    log π(a|s) = log N(u|μ,σ²) - log|1 - tanh²(u)|
               = -0.5||ε||² - log σ - 0.5log(2π) - Σlog(1-a²)
    """
    # 高斯分布的对数概率
    gaussian_log_probs = (
        -0.5 * noises.pow(2) - log_stds
    ).sum(dim=-1, keepdim=True) - 0.5 * math.log(2 * math.pi) * log_stds.size(-1)
    
    # tanh变换的雅可比修正
    # d/du tanh(u) = 1 - tanh²(u)
    return gaussian_log_probs - torch.log(
        1 - actions.pow(2) + 1e-6
    ).sum(dim=-1, keepdim=True)

3.3.2 PPO更新

文件:Algorithm/ppo.py

def update_ppo(self, states, actions, rewards, dones, tm_dones, 
               log_pi_list, next_states, mus, sigmas, writer, total_steps):
    """
    PPO策略和价值网络更新
    """
    # 步骤1: 计算价值估计和优势函数
    with torch.no_grad():
        values = self.critic(states.detach())
        next_values = self.critic(next_states.detach())
    
    # GAE广义优势估计
    targets, gaes = self.calculate_gae(
        values, rewards, dones, tm_dones, next_values, 
        self.gamma, self.lambd
    )
    
    # 步骤2: 多轮更新
    state_list = states.permute(1, 0, 2)
    action_list = actions.permute(1, 0, 2)
    
    for i in range(self.epoch_ppo):
        self.learning_steps_ppo += 1
        
        # 更新价值网络
        self.update_critic(states, targets, writer)
        
        # 更新策略网络
        for state, action, log_pi in zip(state_list, action_list, log_pi_list):
            self.update_actor(
                state, action, log_pi, gaes, mus, sigmas, writer
            )

def calculate_gae(self, values, rewards, dones, tm_dones, 
                  next_values, gamma, lambd):
    """
    计算广义优势估计GAE
    
    公式:
    δt = r + γV(s') - V(s)
    At = Σ(γλ)^k δt+k
    
    参数:
    - gamma: 折扣因子
    - lambd: GAE的λ参数权衡偏差-方差)
    """
    with torch.no_grad():
        # TD误差
        deltas = rewards + gamma * next_values * (1 - tm_dones) - values
        
        # 初始化优势
        gaes = torch.empty_like(rewards)
        
        # 从后往前计算GAE
        gaes[-1] = deltas[-1]
        for t in reversed(range(rewards.size(0) - 1)):
            gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1]
        
        # 价值目标
        v_target = gaes + values
        
        # 优势标准化
        if self.use_adv_norm:
            gaes = (gaes - gaes.mean()) / (gaes.std(dim=0) + 1e-8)
    
    return v_target, gaes

def update_actor(self, states, actions, log_pis_old, gaes, 
                 mus_old, sigmas_old, writer):
    """
    更新Actor网络
    
    PPO裁剪目标
    L = min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A)
    其中 r(θ) = π_new/π_old
    """
    self.optim_actor.zero_grad()
    
    # 新策略的对数概率
    log_pis = self.actor.evaluate_log_pi(states, actions)
    mus = self.actor.means
    sigmas = (self.actor.log_stds.exp()).repeat(mus.shape[0], 1)
    
    # 熵(鼓励探索)
    entropy = -log_pis.mean()
    
    # 重要性采样比率
    ratios = (log_pis - log_pis_old).exp_()
    
    # PPO裁剪目标
    loss_actor1 = -ratios * gaes
    loss_actor2 = -torch.clamp(
        ratios,
        1.0 - self.clip_eps,
        1.0 + self.clip_eps
    ) * gaes
    loss_actor = torch.max(loss_actor1, loss_actor2).mean()
    loss_actor = loss_actor * self.surrogate_loss_coef
    
    # 自适应学习率基于KL散度
    if self.auto_lr:
        with torch.inference_mode():
            # 计算KL散度KL(old||new)
            kl = torch.sum(
                torch.log(sigmas / sigmas_old + 1.e-5) +
                (torch.square(sigmas_old) + torch.square(mus_old - mus)) /
                (2.0 * torch.square(sigmas)) - 0.5, 
                axis=-1
            )
            kl_mean = torch.mean(kl)
            
            # 调整学习率
            if kl_mean > self.desired_kl * 2.0:
                # KL过大降低学习率
                self.lr_actor = max(1e-5, self.lr_actor / 1.5)
                self.lr_critic = max(1e-5, self.lr_critic / 1.5)
            elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
                # KL过小提高学习率
                self.lr_actor = min(1e-2, self.lr_actor * 1.5)
                self.lr_critic = min(1e-2, self.lr_critic * 1.5)
            
            # 更新优化器学习率
            for param_group in self.optim_actor.param_groups:
                param_group['lr'] = self.lr_actor
            for param_group in self.optim_critic.param_groups:
                param_group['lr'] = self.lr_critic
    
    # 反向传播
    loss = loss_actor
    loss.backward()
    nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
    self.optim_actor.step()

环境系统实现

4.1 多智能体场景环境

4.1.1 环境设计

文件:Env/scenario_env.py

class MultiAgentScenarioEnv(ScenarioEnv):
    """
    多智能体场景环境
    
    继承自MetaDrive的ScenarioEnv扩展为多智能体场景
    
    核心功能:
    1. 从专家数据动态生成车辆
    2. 收集多维度观测
    3. 管理多智能体交互
    """
    
    @classmethod
    def default_config(cls):
        config = super().default_config()
        config.update(dict(
            data_directory=None,          # 专家数据目录
            num_controlled_agents=3,      # 可控车辆数量
            horizon=1000,                 # 场景时间步
        ))
        return config

    def __init__(self, config, agent2policy):
        """
        初始化环境
        
        参数:
        - config: 环境配置
        - agent2policy: 为每个智能体分配的策略
        """
        self.policy = agent2policy
        self.controlled_agents = {}      # 可控车辆字典
        self.controlled_agent_ids = []   # 可控车辆ID列表
        self.obs_list = []               # 观测列表
        self.round = 0                   # 当前时间步
        super().__init__(config)

4.1.2 环境重置与车辆生成

def reset(self, seed: Union[None, int] = None):
    """
    重置环境
    
    流程:
    1. 解析专家数据中的车辆轨迹
    2. 提取车辆生成信息
    3. 清理原始数据
    4. 初始化场景
    5. 生成第一批车辆
    """
    self.round = 0
    
    # 日志初始化
    if self.logger is None:
        self.logger = get_logger()
        log_level = self.config.get("log_level", logging.INFO)
        set_log_level(log_level)
    
    self.lazy_init()
    self._reset_global_seed(seed)
    
    # 步骤1: 解析专家数据
    _obj_to_clean_this_frame = []
    self.car_birth_info_list = []
    
    for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
        # 跳过自车
        if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
            continue
        
        # 只处理车辆类型
        if track["type"] == MetaDriveType.VEHICLE:
            _obj_to_clean_this_frame.append(scenario_id)
            
            # 获取有效帧
            valid = track['state']['valid']
            first_show = np.argmax(valid) if valid.any() else -1
            last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
            
            # 提取关键信息
            self.car_birth_info_list.append({
                'id': track['metadata']['object_id'],           # 车辆ID
                'show_time': first_show,                        # 出现时间
                'begin': (                                      # 起点位置
                    track['state']['position'][first_show, 0], 
                    track['state']['position'][first_show, 1]
                ),
                'heading': track['state']['heading'][first_show], # 初始朝向
                'end': (                                         # 终点位置
                    track['state']['position'][last_show, 0], 
                    track['state']['position'][last_show, 1]
                )
            })
    
    # 步骤2: 清理原始数据(避免重复生成)
    for scenario_id in _obj_to_clean_this_frame:
        self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
    
    # 步骤3: 重置引擎
    self.engine.reset()
    self.reset_sensors()
    self.engine.taskMgr.step()
    
    # 步骤4: 获取车道网络(用于红绿灯检测)
    self.lanes = self.engine.map_manager.current_map.road_network.graph
    
    # 步骤5: 清理旧状态
    if self.top_down_renderer is not None:
        self.top_down_renderer.clear()
        self.engine.top_down_renderer = None
    
    self.dones = {}
    self.episode_rewards = defaultdict(float)
    self.episode_lengths = defaultdict(int)
    
    self.controlled_agents.clear()
    self.controlled_agent_ids.clear()
    
    # 步骤6: 初始化场景并生成第一批车辆
    super().reset(seed)
    self._spawn_controlled_agents()
    
    return self._get_all_obs()

def _spawn_controlled_agents(self):
    """
    动态生成可控车辆
    
    根据专家数据中记录的时间戳,在正确的时间点生成车辆
    """
    for car in self.car_birth_info_list:
        # 检查是否到了该车辆的出现时间
        if car['show_time'] == self.round:
            agent_id = f"controlled_{car['id']}"
            
            # 生成车辆实例
            vehicle = self.engine.spawn_object(
                PolicyVehicle,           # 自定义车辆类
                vehicle_config={},
                position=car['begin'],   # 初始位置
                heading=car['heading']   # 初始朝向
            )
            
            # 重置车辆状态
            vehicle.reset(position=car['begin'], heading=car['heading'])
            
            # 设置策略和目标
            vehicle.set_policy(self.policy)
            vehicle.set_destination(car['end'])
            
            # 注册车辆
            self.controlled_agents[agent_id] = vehicle
            self.controlled_agent_ids.append(agent_id)
            
            # 关键注册到引擎的active_agents才能参与物理更新
            self.engine.agent_manager.active_agents[agent_id] = vehicle

4.1.3 观测系统

def _get_all_obs(self):
    """
    收集所有可控车辆的观测
    
    观测维度构成:
    - 位置: 2D (x, y)
    - 速度: 2D (vx, vy)
    - 朝向: 1D (θ)
    - 前向激光雷达: 80D (距离)
    - 侧向检测器: 10D (距离)
    - 车道线检测: 10D (距离)
    - 红绿灯: 1D (0-3编码)
    - 导航: 2D (目标点坐标)
    
    总维度: 2+2+1+80+10+10+1+2 = 108D
    """
    self.obs_list = []
    
    for agent_id, vehicle in self.controlled_agents.items():
        # 获取车辆基础状态
        state = vehicle.get_state()
        
        # 红绿灯检测
        traffic_light = 0
        for lane in self.lanes.values():
            # 检查车辆是否在车道上
            if lane.lane.point_on_lane(state['position'][:2]):
                # 检查该车道是否有红绿灯
                if self.engine.light_manager.has_traffic_light(lane.lane.index):
                    traffic_light = self.engine.light_manager._lane_index_to_obj[
                        lane.lane.index
                    ].status
                    
                    # 状态编码
                    if traffic_light == 'TRAFFIC_LIGHT_GREEN':
                        traffic_light = 1
                    elif traffic_light == 'TRAFFIC_LIGHT_YELLOW':
                        traffic_light = 2
                    elif traffic_light == 'TRAFFIC_LIGHT_RED':
                        traffic_light = 3
                    else:
                        traffic_light = 0
                break
        
        # 激光雷达感知
        # 前向激光雷达80束30米距离检测动态物体
        lidar = self.engine.get_sensor("lidar").perceive(
            num_lasers=80, 
            distance=30, 
            base_vehicle=vehicle,
            physics_world=self.engine.physics_world.dynamic_world
        )
        
        # 侧向检测器10束8米距离检测静态障碍物
        side_lidar = self.engine.get_sensor("side_detector").perceive(
            num_lasers=10, 
            distance=8,
            base_vehicle=vehicle,
            physics_world=self.engine.physics_world.static_world
        )
        
        # 车道线检测器10束3米距离检测车道边界
        lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(
            num_lasers=10, 
            distance=3,
            base_vehicle=vehicle,
            physics_world=self.engine.physics_world.static_world
        )
        
        # 组合观测向量
        obs = (
            state['position'][:2] +           # 位置 (x, y)
            list(state['velocity']) +         # 速度 (vx, vy)
            [state['heading_theta']] +        # 朝向 θ
            lidar[0] +                        # 激光雷达 (80D)
            side_lidar[0] +                   # 侧向检测 (10D)
            lane_line_lidar[0] +              # 车道线 (10D)
            [traffic_light] +                 # 红绿灯 (1D)
            list(vehicle.destination)         # 目标点 (x, y)
        )
        
        self.obs_list.append(obs)
    
    return self.obs_list

4.1.4 环境步进

def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
    """
    执行一步仿真
    
    流程:
    1. 应用动作到车辆
    2. 运行物理引擎
    3. 车辆后处理
    4. 生成新车辆
    5. 收集观测和奖励
    
    参数:
    - action_dict: {agent_id: action} 字典
    
    返回:
    - obs: 观测列表
    - rewards: 奖励字典
    - dones: 完成标志字典
    - infos: 信息字典
    """
    self.round += 1
    
    # 步骤1: 应用动作
    for agent_id, action in action_dict.items():
        if agent_id in self.controlled_agents:
            self.controlled_agents[agent_id].before_step(action)
    
    # 步骤2: 物理仿真
    self.engine.step()
    
    # 步骤3: 车辆后处理
    for agent_id in action_dict:
        if agent_id in self.controlled_agents:
            self.controlled_agents[agent_id].after_step()
    
    # 步骤4: 动态生成新车辆
    self._spawn_controlled_agents()
    
    # 步骤5: 收集观测
    obs = self._get_all_obs()
    
    # 步骤6: 计算奖励和完成标志
    rewards = {aid: 0.0 for aid in self.controlled_agents}
    dones = {aid: False for aid in self.controlled_agents}
    dones["__all__"] = self.episode_step >= self.config["horizon"]
    infos = {aid: {} for aid in self.controlled_agents}
    
    return obs, rewards, dones, infos

4.2 自定义车辆类

class PolicyVehicle(DefaultVehicle):
    """
    策略控制车辆
    
    扩展MetaDrive的默认车辆添加策略和目标点
    """
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.policy = None       # 控制策略
        self.destination = None  # 目标点

    def set_policy(self, policy):
        """设置控制策略"""
        self.policy = policy

    def set_destination(self, des):
        """设置目标点"""
        self.destination = des

    def act(self, observation, policy=None):
        """
        执行动作
        
        如果有策略,使用策略;否则随机动作
        """
        if self.policy is not None:
            return self.policy.act(observation)
        else:
            return self.action_space.sample()

    def before_step(self, action):
        """
        步进前处理
        
        记录历史状态并应用动作
        """
        self.last_position = self.position
        self.last_velocity = self.velocity
        self.last_speed = self.speed
        self.last_heading_dir = self.heading
        
        if action is not None:
            self.last_current_action.append(action)
        
        # 将动作转换为车辆控制指令
        self._set_action(action)

# 注册车辆类型
vehicle_class_to_type[PolicyVehicle] = "default"

数据流与训练流程

5.1 完整训练流程

"""
训练流程伪代码
"""

# 1. 初始化
env = MultiAgentScenarioEnv(config, policy)
magail = MAGAIL(buffer_exp, input_dim, device)

# 2. 加载专家数据
buffer_exp = load_expert_data("waymo_dataset")

# 3. 训练循环
for episode in range(max_episodes):
    # 3.1 重置环境
    obs_list = env.reset()
    
    for step in range(max_steps):
        # 3.2 策略采样动作
        actions, log_pis = magail.explore(obs_list)
        
        # 3.3 环境交互
        next_obs_list, rewards, dones, infos = env.step(actions)
        
        # 3.4 存储经验
        magail.buffer.append(
            obs_list, actions, rewards, dones, 
            log_pis, next_obs_list
        )
        
        obs_list = next_obs_list
        
        # 3.5 判断是否更新
        if magail.is_update(step):
            # 3.5.1 更新判别器
            for _ in range(epoch_disc):
                # 采样策略数据
                states_policy = magail.buffer.sample(batch_size)
                # 采样专家数据
                states_expert = buffer_exp.sample(batch_size)
                # 更新判别器
                magail.update_disc(states_policy, states_expert)
            
            # 3.5.2 计算GAIL奖励
            rewards_gail = magail.disc.calculate_reward(states, next_states)
            
            # 3.5.3 更新PPO
            magail.update_ppo(states, actions, rewards_gail, ...)
            
            # 3.5.4 清空缓冲区
            magail.buffer.clear()
        
        if dones["__all__"]:
            break
    
    # 3.6 保存模型
    if episode % save_interval == 0:
        magail.save_models(save_path)

5.2 数据流图

专家数据                策略数据
   ↓                      ↓
┌──────────────┐    ┌──────────────┐
│ Expert Buffer│    │Policy Buffer │
│  (s,a,s')   │    │  (s,a,s')   │
└──────┬───────┘    └──────┬───────┘
       │                   │
       │  采样             │  采样
       ▼                   ▼
   ┌────────────────────────────┐
   │     BERT Discriminator     │
   │  Input: (N, obs_dim*2)    │
   │  Output: Real/Fake Score  │
   └────────────┬───────────────┘
                │
                │  梯度反向传播
                ▼
         ┌──────────────┐
         │ Disc Loss    │
         │ + Grad Pen   │
         │ + Logit Reg  │
         │ + Weight Dec │
         └──────────────┘

策略数据
   ↓
┌──────────────┐
│Policy Buffer │
│  (s,a,r,s') │
└──────┬───────┘
       │
       │  r_GAIL = -log(1-D(s,s'))
       ▼
   ┌────────────────────────────┐
   │     PPO Optimization       │
   │  Actor: MLP                │
   │  Critic: BERT              │
   └────────────┬───────────────┘
                │
                │  策略改进
                ▼
         ┌──────────────┐
         │ Environment  │
         │  Interaction │
         └──────────────┘

5.3 缓冲区管理

文件:Algorithm/buffer.py

class RolloutBuffer:
    """
    滚动缓冲区
    
    存储策略与环境交互的经验
    """
    
    def __init__(self, buffer_size, state_shape, action_shape, device):
        """
        初始化缓冲区
        
        参数:
        - buffer_size: 缓冲区大小通常等于rollout_length
        - state_shape: 状态维度
        - action_shape: 动作维度
        - device: 存储设备CPU/GPU
        """
        self._n = 0  # 当前存储数量
        self._p = 0  # 当前写入位置
        self.buffer_size = buffer_size
        
        # 预分配张量(提高效率)
        self.states = torch.empty(
            (self.buffer_size, *state_shape), 
            dtype=torch.float, device=device
        )
        self.actions = torch.empty(
            (self.buffer_size, *action_shape), 
            dtype=torch.float, device=device
        )
        self.rewards = torch.empty(
            (self.buffer_size, 1), 
            dtype=torch.float, device=device
        )
        self.dones = torch.empty(
            (self.buffer_size, 1), 
            dtype=torch.int, device=device
        )
        self.tm_dones = torch.empty(
            (self.buffer_size, 1), 
            dtype=torch.int, device=device
        )
        self.log_pis = torch.empty(
            (self.buffer_size, 1), 
            dtype=torch.float, device=device
        )
        self.next_states = torch.empty(
            (self.buffer_size, *state_shape), 
            dtype=torch.float, device=device
        )
        self.means = torch.empty(
            (self.buffer_size, *action_shape), 
            dtype=torch.float, device=device
        )
        self.stds = torch.empty(
            (self.buffer_size, *action_shape), 
            dtype=torch.float, device=device
        )

    def append(self, state, action, reward, done, tm_dones, 
               log_pi, next_state, next_state_gail, means, stds):
        """
        添加经验
        
        使用循环缓冲区,自动覆盖旧数据
        """
        self.states[self._p].copy_(state)
        self.actions[self._p].copy_(torch.from_numpy(action))
        self.rewards[self._p] = float(reward)
        self.dones[self._p] = int(done)
        self.tm_dones[self._p] = int(tm_dones)
        self.log_pis[self._p] = float(log_pi)
        self.next_states[self._p].copy_(torch.from_numpy(next_state))
        self.means[self._p].copy_(torch.from_numpy(means))
        self.stds[self._p].copy_(torch.from_numpy(stds))
        
        # 更新指针
        self._p = (self._p + 1) % self.buffer_size
        self._n = min(self._n + 1, self.buffer_size)

    def get(self):
        """
        获取所有数据用于PPO更新
        """
        assert self._p % self.buffer_size == 0
        idxes = slice(0, self.buffer_size)
        return (
            self.states[idxes],
            self.actions[idxes],
            self.rewards[idxes],
            self.dones[idxes],
            self.tm_dones[idxes],
            self.log_pis[idxes],
            self.next_states[idxes],
            self.means[idxes],
            self.stds[idxes]
        )

    def sample(self, batch_size):
        """
        随机采样批次(用于判别器更新)
        """
        assert self._p % self.buffer_size == 0
        idxes = np.random.randint(low=0, high=self._n, size=batch_size)
        return (
            self.states[idxes],
            self.actions[idxes],
            self.rewards[idxes],
            self.dones[idxes],
            self.tm_dones[idxes],
            self.log_pis[idxes],
            self.next_states[idxes],
            self.means[idxes],
            self.stds[idxes]
        )

    def clear(self):
        """清空缓冲区"""
        self.states[:, :] = 0
        self.actions[:, :] = 0
        self.rewards[:, :] = 0
        self.dones[:, :] = 0
        self.tm_dones[:, :] = 0
        self.log_pis[:, :] = 0
        self.next_states[:, :] = 0
        self.means[:, :] = 0
        self.stds[:, :] = 0

关键技术细节

6.1 数据标准化

文件:Algorithm/utils.py

class RunningMeanStd(object):
    """
    运行时均值和标准差计算
    
    使用Welford在线算法高效计算流式数据的统计量
    """
    
    def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
        self.mean = np.zeros(shape, np.float64)
        self.var = np.ones(shape, np.float64)
        self.count = epsilon

    def update(self, arr: np.ndarray) -> None:
        """
        更新统计量
        
        使用并行算法合并批次统计量和当前统计量
        """
        batch_mean = np.mean(arr, axis=0)
        batch_var = np.var(arr, axis=0)
        batch_count = arr.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        """
        从矩更新统计量
        
        参考https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
        """
        delta = batch_mean - self.mean
        tot_count = self.count + batch_count
        
        # 新均值
        new_mean = self.mean + delta * batch_count / tot_count
        
        # 新方差
        m_a = self.var * self.count
        m_b = batch_var * batch_count
        m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / tot_count
        new_var = m_2 / tot_count
        
        # 更新
        self.mean = new_mean
        self.var = new_var
        self.count = tot_count


class Normalizer(RunningMeanStd):
    """
    数据标准化器
    
    提供标准化和逆标准化功能
    """
    
    def __init__(self, input_dim, epsilon=1e-4, clip_obs=10.0):
        super().__init__(shape=input_dim)
        self.epsilon = epsilon
        self.clip_obs = clip_obs

    def normalize(self, input):
        """
        标准化NumPy版本
        
        公式:(x - μ) / √(σ² + ε)
        """
        return np.clip(
            (input - self.mean) / np.sqrt(self.var + self.epsilon),
            -self.clip_obs, self.clip_obs
        )

    def normalize_torch(self, input, device):
        """
        标准化PyTorch版本
        
        用于在GPU上高效计算
        """
        mean_torch = torch.tensor(
            self.mean, device=device, dtype=torch.float32
        )
        std_torch = torch.sqrt(torch.tensor(
            self.var + self.epsilon, device=device, dtype=torch.float32
        ))
        return torch.clamp(
            (input - mean_torch) / std_torch, 
            -self.clip_obs, self.clip_obs
        )

6.2 为什么使用BERT架构

传统MLP的问题

# 假设场景中有N辆车每辆车观测维度为D
# 传统方法:拼接所有车辆观测
input = concat([obs_1, obs_2, ..., obs_N])  # 维度: N*D
output = MLP(input)

# 问题:
# 1. 输入维度N*D随N变化需要重新训练网络
# 2. 不同位置的车辆语义相同但MLP无法共享权重
# 3. 无法处理车辆间的交互关系

BERT架构的优势

# BERT方法将车辆观测视为序列
input = [obs_1, obs_2, ..., obs_N]  # 序列长度N可变
embeddings = [Linear(obs_i) for obs_i in input]  # 共享权重
output = Transformer(embeddings)  # 自注意力捕捉交互

# 优势:
# 1. 序列长度可变无需固定N
# 2. Linear投影层参数共享泛化性好
# 3. 自注意力机制建模车辆间交互
# 4. 位置编码区分不同车辆

6.3 梯度惩罚详解

"""
梯度惩罚Gradient Penalty详解

目标确保判别器是Lipschitz连续的
即:|D(x1) - D(x2)| ≤ K|x1 - x2|

为什么需要:
1. WGAN理论要求判别器是Lipschitz连续
2. 防止判别器梯度过大,提高训练稳定性
3. 避免模式崩溃mode collapse

实现:
对于专家数据x惩罚 ||∇_x D(x)||²
"""

# 步骤1: 使专家数据需要梯度
sample_expert = states_exp_cp
sample_expert.requires_grad = True

# 步骤2: 前向传播
disc = self.disc.linear(self.disc.trunk(sample_expert))

# 步骤3: 计算梯度
ones = torch.ones(disc.size(), device=disc.device)
disc_demo_grad = torch.autograd.grad(
    disc, sample_expert,
    grad_outputs=ones,  # ∂disc/∂x
    create_graph=True,  # 保留计算图(二阶导数)
    retain_graph=True,
    only_inputs=True
)[0]

# 步骤4: 梯度的L2范数
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
grad_pen_loss = torch.mean(disc_demo_grad)

# 步骤5: 添加到总损失
loss += self.disc_grad_penalty * grad_pen_loss

6.4 自适应学习率

"""
基于KL散度的自适应学习率

原理:
PPO希望策略更新不要太大通过监控KL散度来调整学习率

KL(π_old || π_new) = 期望策略变化程度
"""

# 计算KL散度对角高斯分布
kl = torch.sum(
    torch.log(sigmas_new / sigmas_old + 1e-5) +  # 标准差比值的对数
    (sigmas_old² + (mus_old - mus_new)²) / (2 * sigmas_new²) - 0.5,
    axis=-1
)
kl_mean = torch.mean(kl)

# 调整学习率
if kl_mean > desired_kl * 2.0:
    # KL过大策略变化太剧烈降低学习率
    lr_actor *= 0.67  # 除以1.5
    lr_critic *= 0.67
elif kl_mean < desired_kl / 2.0:
    # KL过小策略变化太保守提高学习率
    lr_actor *= 1.5
    lr_critic *= 1.5

# 应用新学习率
for param_group in optimizer.param_groups:
    param_group['lr'] = lr_actor

使用指南

7.1 环境依赖安装

# 创建虚拟环境
conda create -n magail python=3.8
conda activate magail

# 安装PyTorch根据CUDA版本
pip install torch==1.12.0 torchvision torchaudio

# 安装MetaDrive仿真环境
pip install metadrive-simulator

# 安装其他依赖
pip install numpy pandas matplotlib tensorboard tqdm gym

7.2 运行环境测试

# 进入项目目录
cd /path/to/MAGAIL4AutoDrive

# 运行环境测试(需先修复导入路径)
python Env/run_multiagent_env.py

7.3 训练模型

"""
训练脚本示例(需要创建)
"""
import torch
from Algorithm.magail import MAGAIL
from Env.scenario_env import MultiAgentScenarioEnv
from metadrive.engine.asset_loader import AssetLoader

# 配置
config = {
    "data_directory": "/path/to/exp_converted",
    "is_multi_agent": True,
    "num_controlled_agents": 5,
    "horizon": 300,
    "use_render": False,
}

# 创建环境
env = MultiAgentScenarioEnv(config, policy)

# 创建算法
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
magail = MAGAIL(
    buffer_exp=expert_buffer,
    input_dim=obs_dim,
    device=device,
    lr_disc=1e-3,
    epoch_disc=50
)

# 训练循环
for episode in range(10000):
    obs = env.reset()
    episode_reward = 0
    
    for step in range(config["horizon"]):
        actions, log_pis = magail.explore(obs)
        next_obs, rewards, dones, infos = env.step(actions)
        
        magail.buffer.append(...)
        
        if magail.is_update(step):
            reward = magail.update(writer, total_steps)
            episode_reward += reward
        
        obs = next_obs
        
        if dones["__all__"]:
            break
    
    print(f"Episode {episode}, Reward: {episode_reward}")
    
    if episode % 100 == 0:
        magail.save_models(f"models/episode_{episode}")

7.4 关键参数说明

参数 说明 推荐值
embed_dim BERT嵌入维度 128
num_layers Transformer层数 4
num_heads 注意力头数 4
lr_disc 判别器学习率 1e-3
lr_actor Actor学习率 1e-3
lr_critic Critic学习率 1e-3
epoch_disc 判别器更新轮数 50
epoch_ppo PPO更新轮数 10
disc_grad_penalty 梯度惩罚系数 0.1
disc_logit_reg Logit正则化系数 0.25
gamma 折扣因子 0.995
lambd GAE λ参数 0.97
clip_eps PPO裁剪参数 0.2

7.5 常见问题

Q1: 为什么判别器准确率总是50%

  • 这是正常现象,说明判别器无法区分策略和专家
  • 表示策略已经学习到接近专家的行为

Q2: 训练不稳定怎么办?

  • 增大梯度惩罚系数
  • 降低学习率
  • 增加数据标准化

Q3: 如何调整奖励权重?

  • reward_t_coef: 任务奖励权重
  • reward_i_coef: 模仿奖励权重
  • 通常设置为1:1或调整以平衡两者

总结

MAGAIL4AutoDrive项目通过以下技术创新实现了多智能体自动驾驶的模仿学习

  1. BERT判别器使用Transformer架构处理动态数量的车辆
  2. GAIL框架:通过对抗训练学习专家策略
  3. PPO优化:稳定的策略梯度方法
  4. 多维观测:融合多种传感器信息
  5. 真实数据利用Waymo等真实驾驶数据

该项目为多智能体自动驾驶提供了一个完整的解决方案,具有良好的可扩展性和实用价值。