magail4autodrive: first commit

This commit is contained in:
ZHY
2025-09-28 18:57:04 +08:00
commit 947871a720
90 changed files with 1037 additions and 0 deletions

81
Algorithm/bert.py Normal file
View File

@@ -0,0 +1,81 @@
import torch
import torch.nn as nn
class Bert(nn.Module):
def __init__(self, input_dim, output_dim, embed_dim=128,
num_layers=4, ff_dim=512, num_heads=4, dropout=0.1, CLS=False, TANH=False):
super().__init__()
self.CLS = CLS
self.projection = nn.Linear(input_dim, embed_dim)
if self.CLS:
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.randn(1, input_dim + 1, embed_dim))
else:
self.pos_embed = nn.Parameter(torch.randn(1, input_dim, embed_dim))
self.layers = nn.ModuleList([
TransformerLayer(embed_dim, num_heads, ff_dim, dropout)
for _ in range(num_layers)
])
if TANH:
self.classifier = nn.Sequential(nn.Linear(embed_dim, output_dim), nn.Tanh())
else:
self.classifier = nn.Linear(embed_dim, output_dim)
self.layers.train()
self.classifier.train()
def forward(self, x, mask=None):
# x: (batch_size, seq_len, input_dim)
# 线性投影
x = self.projection(x) # (batch_size, input_dim, embed_dim)
batch_size = x.size(0)
if self.CLS:
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, 29, embed_dim)
# 添加位置编码
x = x + self.pos_embed
# 转置为(seq_len, batch_size, embed_dim)
x = x.permute(1, 0, 2)
for layer in self.layers:
x = layer(x, mask=mask)
if self.CLS:
return self.classifier(x[0, :, :])
else:
pooled = x.mean(dim=0) # (batch_size, embed_dim)
return self.classifier(pooled)
class TransformerLayer(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.linear1 = nn.Linear(embed_dim, ff_dim)
self.linear2 = nn.Linear(ff_dim, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
# 使用GELU激活函数
self.activation = nn.GELU()
def forward(self, x, mask=None):
# Post-LN 结构 (残差连接后归一化)
# 注意力部分
attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
x = x + self.dropout(attn_output)
x = self.norm1(x)
# FFN部分
ff_output = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = x + self.dropout(ff_output)
x = self.norm2(x)
return x

80
Algorithm/buffer.py Normal file
View File

@@ -0,0 +1,80 @@
import os
import numpy as np
import torch
class RolloutBuffer:
# TODO: state and action are list
def __init__(self, buffer_size, state_shape, action_shape, device):
self._n = 0
self._p = 0
self.buffer_size = buffer_size
self.states = torch.empty((self.buffer_size, *state_shape), dtype=torch.float, device=device)
# self.states_gail = torch.empty((self.buffer_size, *state_gail_shape), dtype=torch.float, device=device)
self.actions = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device)
self.rewards = torch.empty((self.buffer_size, 1), dtype=torch.float, device=device)
self.dones = torch.empty((self.buffer_size, 1), dtype=torch.int, device=device)
self.tm_dones = torch.empty((self.buffer_size, 1), dtype=torch.int, device=device)
self.log_pis = torch.empty((self.buffer_size, 1), dtype=torch.float, device=device)
self.next_states = torch.empty((self.buffer_size, *state_shape), dtype=torch.float, device=device)
# self.next_states_gail = torch.empty((self.buffer_size, *state_gail_shape), dtype=torch.float, device=device)
self.means = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device)
self.stds = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device)
def append(self, state, action, reward, done, tm_dones, log_pi, next_state, next_state_gail, means, stds):
self.states[self._p].copy_(state)
# self.states_gail[self._p].copy_(state_gail)
self.actions[self._p].copy_(torch.from_numpy(action))
self.rewards[self._p] = float(reward)
self.dones[self._p] = int(done)
self.tm_dones[self._p] = int(tm_dones)
self.log_pis[self._p] = float(log_pi)
self.next_states[self._p].copy_(torch.from_numpy(next_state))
# self.next_states_gail[self._p].copy_(torch.from_numpy(next_state_gail))
self.means[self._p].copy_(torch.from_numpy(means))
self.stds[self._p].copy_(torch.from_numpy(stds))
self._p = (self._p + 1) % self.buffer_size
self._n = min(self._n + 1, self.buffer_size)
def get(self):
assert self._p % self.buffer_size == 0
idxes = slice(0, self.buffer_size)
return (
self.states[idxes],
self.actions[idxes],
self.rewards[idxes],
self.dones[idxes],
self.tm_dones[idxes],
self.log_pis[idxes],
self.next_states[idxes],
self.means[idxes],
self.stds[idxes]
)
def sample(self, batch_size):
assert self._p % self.buffer_size == 0
idxes = np.random.randint(low=0, high=self._n, size=batch_size)
return (
self.states[idxes],
self.actions[idxes],
self.rewards[idxes],
self.dones[idxes],
self.tm_dones[idxes],
self.log_pis[idxes],
self.next_states[idxes],
self.means[idxes],
self.stds[idxes]
)
def clear(self):
self.states[:, :] = 0
self.actions[:, :] = 0
self.rewards[:, :] = 0
self.dones[:, :] = 0
self.tm_dones[:, :] = 0
self.log_pis[:, :] = 0
self.next_states[:, :] = 0
self.means[:, :] = 0
self.stds[:, :] = 0

44
Algorithm/disc.py Normal file
View File

@@ -0,0 +1,44 @@
import torch
from torch import nn
from .bert import Bert
DISC_LOGIT_INIT_SCALE = 1.0
class GAILDiscrim(Bert):
def __init__(self, input_dim, reward_i_coef=1.0, reward_t_coef=1.0, normalizer=None, device=None):
super().__init__(input_dim=input_dim, output_dim=1, TANH=False)
self.device = device
self.reward_t_coef = reward_t_coef
self.reward_i_coef = reward_i_coef
self.normalizer = normalizer
def calculate_reward(self, states_gail, next_states_gail, rewards_t):
# PPO(GAIL) is to maximize E_{\pi} [-log(1 - D)].
states_gail = states_gail.clone()
next_states_gail = next_states_gail.clone()
states = torch.cat([states_gail, next_states_gail], dim=-1)
with torch.no_grad():
if self.normalizer is not None:
states = self.normalizer.normalize_torch(states, self.device)
rewards_t = self.reward_t_coef * rewards_t
d = self.forward(states)
prob = 1 / (1 + torch.exp(-d))
rewards_i = self.reward_i_coef * (
-torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.device))))
rewards = rewards_t + rewards_i
return rewards, rewards_t / (self.reward_t_coef + 1e-10), rewards_i / (self.reward_i_coef + 1e-10)
def get_disc_logit_weights(self):
return torch.flatten(self.classifier.weight)
def get_disc_weights(self):
weights = []
for m in self.layers.modules():
if isinstance(m, nn.Linear):
weights.append(torch.flatten(m.weight))
weights.append(torch.flatten(self.classifier.weight))
return weights

149
Algorithm/magail.py Normal file
View File

@@ -0,0 +1,149 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from .disc import GAILDiscrim
from .ppo import PPO
from .utils import Normalizer
class MAGAIL(PPO):
def __init__(self, buffer_exp, input_dim, device,
disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005,
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True
):
super().__init__(state_shape=input_dim, device=device)
self.learning_steps = 0
self.learning_steps_disc = 0
self.disc = GAILDiscrim(input_dim=input_dim)
self.disc_grad_penalty = disc_grad_penalty
self.disc_coef = disc_coef
self.disc_logit_reg = disc_logit_reg
self.disc_weight_decay = disc_weight_decay
self.lr_disc = lr_disc
self.epoch_disc = epoch_disc
self.optim_d = torch.optim.Adam(self.disc.parameters(), lr=self.lr_disc)
self.normalizer = None
if use_gail_norm:
self.normalizer = Normalizer(self.state_shape[0]*2)
self.batch_size = batch_size
self.buffer_exp = buffer_exp
def update_disc(self, states, states_exp, writer):
states_cp = states.clone()
states_exp_cp = states_exp.clone()
# Output of discriminator is (-inf, inf), not [0, 1].
logits_pi = self.disc(states_cp)
logits_exp = self.disc(states_exp_cp)
# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].
loss_pi = -F.logsigmoid(-logits_pi).mean()
loss_exp = -F.logsigmoid(logits_exp).mean()
loss_disc = 0.5 * (loss_pi + loss_exp)
# logit reg
logit_weights = self.disc.get_disc_logit_weights()
disc_logit_loss = torch.sum(torch.square(logit_weights))
# grad penalty
sample_expert = states_exp_cp
sample_expert.requires_grad = True
disc = self.disc.linear(self.disc.trunk(sample_expert))
ones = torch.ones(disc.size(), device=disc.device)
disc_demo_grad = torch.autograd.grad(disc, sample_expert,
grad_outputs=ones,
create_graph=True, retain_graph=True, only_inputs=True)
disc_demo_grad = disc_demo_grad[0]
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
grad_pen_loss = torch.mean(disc_demo_grad)
# weight decay
disc_weights = self.disc.get_disc_weights()
disc_weights = torch.cat(disc_weights, dim=-1)
disc_weight_decay = torch.sum(torch.square(disc_weights))
loss = self.disc_coef * loss_disc + self.disc_grad_penalty * grad_pen_loss + \
self.disc_logit_reg * disc_logit_loss + self.disc_weight_decay * disc_weight_decay
self.optim_d.zero_grad()
loss.backward()
self.optim_d.step()
if self.learning_steps_disc % self.epoch_disc == 0:
writer.add_scalar('Loss/disc', loss_disc.item(), self.learning_steps)
# Discriminator's accuracies.
with torch.no_grad():
acc_pi = (logits_pi < 0).float().mean().item()
acc_exp = (logits_exp > 0).float().mean().item()
writer.add_scalar('Acc/acc_pi', acc_pi, self.learning_steps)
writer.add_scalar('Acc/acc_exp', acc_exp, self.learning_steps)
def update(self, writer, total_steps):
self.learning_steps += 1
for _ in range(self.epoch_disc):
self.learning_steps_disc += 1
# Samples from current policy trajectories.
samples_policy = self.buffer.sample(self.batch_size)
states, next_states = samples_policy[1], samples_policy[-3]
states = torch.cat([states, next_states], dim=-1)
# Samples from expert demonstrations.
samples_expert = self.buffer_exp.sample(self.batch_size)
states_exp, next_states_exp = samples_expert[0], samples_expert[1]
states_exp = torch.cat([states_exp, next_states_exp], dim=-1)
if self.normalizer is not None:
with torch.no_grad():
states = self.normalizer.normalize_torch(states, self.device)
states_exp = self.normalizer.normalize_torch(states_exp, self.device)
# Update discriminator and us encoder.
self.update_disc(states, states_exp, writer)
# Calulates the running mean and std of a data stream
if self.normalizer is not None:
self.normalizer.update(states.cpu().numpy())
self.normalizer.update(states_exp.cpu().numpy())
states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas = self.buffer.get()
# Calculate rewards.
rewards, rewards_t, rewards_i = self.disc.calculate_reward(states, next_states, rewards)
writer.add_scalar('Reward/rewards', rewards_t.mean().item() + rewards_i.mean().item(),
self.learning_steps)
writer.add_scalar('Reward/rewards_t', rewards_t.mean().item(), self.learning_steps)
writer.add_scalar('Reward/rewards_i', rewards_i.mean().item(), self.learning_steps)
# Update PPO using estimated rewards.
self.update_ppo(states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas, writer,
total_steps)
self.buffer.clear()
return rewards_t.mean().item() + rewards_i.mean().item()
def save_models(self, path):
torch.save({
'actor': self.actor.state_dict(),
'critic': self.critic.state_dict(),
'disc': self.disc.state_dict(),
'optim_actor': self.optim_actor.state_dict(),
'optim_critic': self.optim_critic.state_dict(),
'optim_d': self.optim_d.state_dict()
}, os.path.join(path, 'model.pth'))
def load_models(self, path, load_optimizer=True):
loaded_dict = torch.load(path, map_location='cuda:0')
self.actor.load_state_dict(loaded_dict['actor'])
self.critic.load_state_dict(loaded_dict['critic'])
self.disc.load_state_dict(loaded_dict['disc'])
if load_optimizer:
self.optim_actor.load_state_dict(loaded_dict['optim_actor'])
self.optim_critic.load_state_dict(loaded_dict['optim_critic'])
self.optim_d.load_state_dict(loaded_dict['optim_d'])

31
Algorithm/policy.py Normal file
View File

@@ -0,0 +1,31 @@
import torch
import numpy as np
from torch import nn
from .utils import build_mlp, reparameterize, evaluate_lop_pi
class StateIndependentPolicy(nn.Module):
def __init__(self, state_shape, action_shape, hidden_units=(64, 64),
hidden_activation=nn.Tanh()):
super().__init__()
self.net = build_mlp(
input_dim=state_shape[0],
output_dim=action_shape[0],
hidden_units=hidden_units,
hidden_activation=hidden_activation
)
self.log_stds = nn.Parameter(torch.zeros(1, action_shape[0]))
self.means = None
def forward(self, states):
return torch.tanh(self.net(states))
def sample(self, states):
self.means = self.net(states)
actions, log_pis = reparameterize(self.means, self.log_stds)
return actions, log_pis
def evaluate_log_pi(self, states, actions):
self.means = self.net(states)
return evaluate_lop_pi(self.means, self.log_stds, actions)

267
Algorithm/ppo.py Normal file
View File

@@ -0,0 +1,267 @@
import os
import torch
import numpy as np
from torch import nn
from torch.optim import Adam
from buffer import RolloutBuffer
from bert import Bert
from policy import StateIndependentPolicy
from abc import ABC, abstractmethod
class Algorithm(ABC):
def __init__(self, state_shape, device, gamma):
self.learning_steps = 0
self.state_shape = state_shape
self.device = device
self.gamma = gamma
def explore(self, state_list):
action_list = []
log_pi_list = []
if type(state_list).__module__ != "torch":
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
with torch.no_grad():
for state in state_list:
action, log_pi = self.actor.sample(state.unsqueeze(0))
action_list.append(action.cpu().numpy()[0])
log_pi_list.append(log_pi.item())
return action_list, log_pi_list
def exploit(self, state_list):
action_list = []
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
with torch.no_grad():
for state in state_list:
action = self.actor(state.unsqueeze(0))
action_list.append(action.cpu().numpy()[0])
return action_list
@abstractmethod
def is_update(self, step):
pass
@abstractmethod
def update(self, writer, total_steps):
pass
@abstractmethod
def save_models(self, save_dir):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
class PPO(Algorithm):
def __init__(self, state_shape, device, gamma=0.995, rollout_length=2048,
units_actor=(64, 64), epoch_ppo=10, clip_eps=0.2,
lambd=0.97, max_grad_norm=1.0, desired_kl=0.01, surrogate_loss_coef=2.,
value_loss_coef=5., entropy_coef=0., bounds_loss_coef=10., lr_actor=1e-3, lr_critic=1e-3,
lr_disc=1e-3, auto_lr=True, use_adv_norm=True, max_steps=10000000):
super().__init__(state_shape, device, gamma)
self.lr_actor = lr_actor
self.lr_critic = lr_critic
self.lr_disc = lr_disc
self.auto_lr = auto_lr
self.use_adv_norm = use_adv_norm
# Rollout buffer.
self.buffer = RolloutBuffer(
buffer_size=rollout_length,
state_shape=state_shape,
action_shape=action_shape,
device=device
)
# Actor.
self.actor = StateIndependentPolicy(
state_shape=state_shape,
action_shape=action_shape,
hidden_units=units_actor,
hidden_activation=nn.Tanh()
).to(device)
# Critic.
self.critic = Bert(
input_dim=state_shape,
output_dim=1
).to(device)
self.learning_steps_ppo = 0
self.rollout_length = rollout_length
self.epoch_ppo = epoch_ppo
self.clip_eps = clip_eps
self.lambd = lambd
self.max_grad_norm = max_grad_norm
self.desired_kl = desired_kl
self.surrogate_loss_coef = surrogate_loss_coef
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.bounds_loss_coef = bounds_loss_coef
self.max_steps = max_steps
self.optim_actor = Adam([{'params': self.actor.parameters()}], lr=lr_actor)
# self.optim_actor = Adam([
# {'params': self.actor.net.f_net.parameters(), 'lr': lr_actor},
# {'params': self.actor.net.k_net.parameters(), 'lr': lr_actor/3}])
self.optim_critic = Adam([{'params': self.critic.parameters()}], lr=lr_critic)
def is_update(self, step):
return step % self.rollout_length == 0
def step(self, env, state_list, state_gail):
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
state_gail = torch.tensor(state_gail, dtype=torch.float, device=self.device)
action_list, log_pi_list = self.explore(state_list)
next_state, reward, terminated, truncated, info = env.step(np.array(action_list))
next_state_gail = env.state_gail
done = terminated or truncated
means = self.actor.means.detach().cpu().numpy()[0]
stds = (self.actor.log_stds.exp()).detach().cpu().numpy()[0]
self.buffer.append(state_list, state_gail, action_list, reward, done, terminated, log_pi_list,
next_state, next_state_gail, means, stds)
if done:
next_state = env.reset()
next_state_gail = env.state_gail
return next_state, next_state_gail, info
def update(self, writer, total_steps):
pass
def update_ppo(self, states, actions, rewards, dones, tm_dones, log_pi_list, next_states, mus, sigmas, writer,
total_steps):
with torch.no_grad():
values = self.critic(states.detach())
next_values = self.critic(next_states.detach())
targets, gaes = self.calculate_gae(
values, rewards, dones, tm_dones, next_values, self.gamma, self.lambd)
state_list = states.permute(1, 0, 2)
action_list = actions.permute(1, 0, 2)
for i in range(self.epoch_ppo):
self.learning_steps_ppo += 1
self.update_critic(states, targets, writer)
for state, action, log_pi in state_list, action_list, log_pi_list:
self.update_actor(state, action, log_pi, gaes, mus, sigmas, writer)
# self.lr_decay(total_steps, writer)
def update_critic(self, states, targets, writer):
loss_critic = (self.critic(states) - targets).pow_(2).mean()
loss_critic = loss_critic * self.value_loss_coef
self.optim_critic.zero_grad()
loss_critic.backward(retain_graph=False)
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
self.optim_critic.step()
if self.learning_steps_ppo % self.epoch_ppo == 0:
writer.add_scalar(
'Loss/critic', loss_critic.item(), self.learning_steps)
def update_actor(self, states, actions, log_pis_old, gaes, mus_old, sigmas_old, writer):
self.optim_actor.zero_grad()
log_pis = self.actor.evaluate_log_pi(states, actions)
mus = self.actor.means
sigmas = (self.actor.log_stds.exp()).repeat(mus.shape[0], 1)
entropy = -log_pis.mean()
ratios = (log_pis - log_pis_old).exp_()
loss_actor1 = -ratios * gaes
loss_actor2 = -torch.clamp(
ratios,
1.0 - self.clip_eps,
1.0 + self.clip_eps
) * gaes
loss_actor = torch.max(loss_actor1, loss_actor2).mean()
loss_actor = loss_actor * self.surrogate_loss_coef
if self.auto_lr:
# desired_kl: 0.01
with torch.inference_mode():
kl = torch.sum(torch.log(sigmas / sigmas_old + 1.e-5) +
(torch.square(sigmas_old) + torch.square(mus_old - mus)) /
(2.0 * torch.square(sigmas)) - 0.5, axis=-1)
kl_mean = torch.mean(kl)
if kl_mean > self.desired_kl * 2.0:
self.lr_actor = max(1e-5, self.lr_actor / 1.5)
self.lr_critic = max(1e-5, self.lr_critic / 1.5)
self.lr_disc = max(1e-5, self.lr_disc / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
self.lr_actor = min(1e-2, self.lr_actor * 1.5)
self.lr_critic = min(1e-2, self.lr_critic * 1.5)
self.lr_disc = min(1e-2, self.lr_disc * 1.5)
for param_group in self.optim_actor.param_groups:
param_group['lr'] = self.lr_actor
for param_group in self.optim_critic.param_groups:
param_group['lr'] = self.lr_critic
for param_group in self.optim_d.param_groups:
param_group['lr'] = self.lr_disc
loss = loss_actor # + b_loss * 0 - self.entropy_coef * entropy * 0
loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
self.optim_actor.step()
if self.learning_steps_ppo % self.epoch_ppo == 0:
writer.add_scalar(
'Loss/actor', loss_actor.item(), self.learning_steps)
writer.add_scalar(
'Loss/entropy', entropy.item(), self.learning_steps)
writer.add_scalar(
'Loss/learning_rate', self.lr_actor, self.learning_steps)
def lr_decay(self, total_steps, writer):
lr_a_now = max(1e-5, self.lr_actor * (1 - total_steps / self.max_steps))
lr_c_now = max(1e-5, self.lr_critic * (1 - total_steps / self.max_steps))
lr_d_now = max(1e-5, self.lr_disc * (1 - total_steps / self.max_steps))
for p in self.optim_actor.param_groups:
p['lr'] = lr_a_now
for p in self.optim_critic.param_groups:
p['lr'] = lr_c_now
for p in self.optim_d.param_groups:
p['lr'] = lr_d_now
writer.add_scalar(
'Loss/learning_rate', lr_a_now, self.learning_steps)
def calculate_gae(self, values, rewards, dones, tm_dones, next_values, gamma, lambd):
"""
Calculate the advantage using GAE
'tm_dones=True' means dead or win, there is no next state s'
'dones=True' represents the terminal of an episode(dead or win or reaching the max_episode_steps).
When calculating the adv, if dones=True, gae=0
Reference: https://github.com/Lizhi-sjtu/DRL-code-pytorch/blob/main/5.PPO-continuous/ppo_continuous.py
"""
with torch.no_grad():
# Calculate TD errors.
deltas = rewards + gamma * next_values * (1 - tm_dones) - values
# Initialize gae.
gaes = torch.empty_like(rewards)
# Calculate gae recursively from behind.
gaes[-1] = deltas[-1]
for t in reversed(range(rewards.size(0) - 1)):
gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1]
v_target = gaes + values
if self.use_adv_norm:
gaes = (gaes - gaes.mean()) / (gaes.std(dim=0) + 1e-8)
return v_target, gaes
def save_models(self, save_dir):
pass

108
Algorithm/utils.py Normal file
View File

@@ -0,0 +1,108 @@
import math
import torch
import numpy as np
from torch import nn
from typing import Tuple
class RunningMeanStd(object):
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
"""
Calulates the running mean and std of a data stream
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
:param epsilon: helps with arithmetic issues
:param shape: the shape of the data stream's output
"""
self.mean = np.zeros(shape, np.float64)
self.var = np.ones(shape, np.float64)
self.count = epsilon
def update(self, arr: np.ndarray) -> None:
batch_mean = np.mean(arr, axis=0)
batch_var = np.var(arr, axis=0)
batch_count = arr.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None:
delta = batch_mean - self.mean
tot_count = self.count + batch_count
new_mean = self.mean + delta * batch_count / tot_count
m_a = self.var * self.count
m_b = batch_var * batch_count
m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
new_var = m_2 / (self.count + batch_count)
new_count = batch_count + self.count
self.mean = new_mean
self.var = new_var
self.count = new_count
class Normalizer(RunningMeanStd):
def __init__(self, input_dim, epsilon=1e-4, clip_obs=10.0):
super().__init__(shape=input_dim)
self.epsilon = epsilon
self.clip_obs = clip_obs
def normalize(self, input):
return np.clip(
(input - self.mean) / np.sqrt(self.var + self.epsilon),
-self.clip_obs, self.clip_obs)
def normalize_torch(self, input, device):
mean_torch = torch.tensor(
self.mean, device=device, dtype=torch.float32)
std_torch = torch.sqrt(torch.tensor(
self.var + self.epsilon, device=device, dtype=torch.float32))
return torch.clamp(
(input - mean_torch) / std_torch, -self.clip_obs, self.clip_obs)
def update_normalizer(self, rollouts, expert_loader):
policy_data_generator = rollouts.feed_forward_generator_amp(
None, mini_batch_size=expert_loader.batch_size)
expert_data_generator = expert_loader.dataset.feed_forward_generator_amp(
expert_loader.batch_size)
for expert_batch, policy_batch in zip(expert_data_generator, policy_data_generator):
self.update(
torch.vstack(tuple(policy_batch) + tuple(expert_batch)).cpu().numpy())
def build_mlp(input_dim, output_dim, hidden_units=[64, 64],
hidden_activation=nn.Tanh(), output_activation=None):
layers = []
units = input_dim
for next_units in hidden_units:
layers.append(nn.Linear(units, next_units))
layers.append(hidden_activation)
units = next_units
layers.append(nn.Linear(units, output_dim))
if output_activation is not None:
layers.append(output_activation)
return nn.Sequential(*layers)
def calculate_log_pi(log_stds, noises, actions):
gaussian_log_probs = (-0.5 * noises.pow(2) - log_stds).sum(
dim=-1, keepdim=True) - 0.5 * math.log(2 * math.pi) * log_stds.size(-1)
return gaussian_log_probs - torch.log(
1 - actions.pow(2) + 1e-6).sum(dim=-1, keepdim=True)
def reparameterize(means, log_stds):
noises = torch.randn_like(means)
us = means + noises * log_stds.exp()
actions = torch.tanh(us)
return actions, calculate_log_pi(log_stds, noises, actions)
def atanh(x):
return 0.5 * (torch.log(1 + x + 1e-6) - torch.log(1 - x + 1e-6))
def evaluate_lop_pi(means, log_stds, actions):
noises = (atanh(actions) - means) / (log_stds.exp() + 1e-8)
return calculate_log_pi(log_stds, noises, actions)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

41
Env/run_multiagent_env.py Normal file
View File

@@ -0,0 +1,41 @@
from scenario_env import MultiAgentScenarioEnv
from Env.simple_idm_policy import ConstantVelocityPolicy
from metadrive.engine.asset_loader import AssetLoader
WAYMO_DATA_DIR = r"/home/zhy/桌面/MAGAIL_TR/Env"
def main():
env = MultiAgentScenarioEnv(
config={
# "data_directory": AssetLoader.file_path(AssetLoader.asset_path, "waymo", unix_style=False),
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
"is_multi_agent": True,
"num_controlled_agents": 3,
"horizon": 300,
"use_render": True,
"sequential_seed": True,
"reactive_traffic": True,
"manual_control": True,
},
agent2policy=ConstantVelocityPolicy(target_speed=50)
)
obs = env.reset(0
)
for step in range(10000):
actions = {
aid: env.controlled_agents[aid].policy.act()
for aid in env.controlled_agents
}
obs, rewards, dones, infos = env.step(actions)
env.render(mode="topdown")
if dones["__all__"]:
break
env.close()
if __name__ == "__main__":
main()

204
Env/scenario_env.py Normal file
View File

@@ -0,0 +1,204 @@
import numpy as np
from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.component.vehicle.vehicle_type import DefaultVehicle, vehicle_class_to_type
import math
import logging
from collections import defaultdict
from typing import Union, Dict, AnyStr
from metadrive.engine.logger import get_logger, set_log_level
from metadrive.type import MetaDriveType
class PolicyVehicle(DefaultVehicle):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.policy = None
self.destination = None
def set_policy(self, policy):
self.policy = policy
def set_destination(self, des):
self.destination = des
def act(self, observation, policy=None):
if self.policy is not None:
return self.policy.act(observation)
else:
return self.action_space.sample()
def before_step(self, action):
self.last_position = self.position # 2D vector
self.last_velocity = self.velocity # 2D vector
self.last_speed = self.speed # Scalar
self.last_heading_dir = self.heading
if action is not None:
self.last_current_action.append(action)
self._set_action(action)
def is_done(self):
# arrive or crash
pass
vehicle_class_to_type[PolicyVehicle] = "default"
class MultiAgentScenarioEnv(ScenarioEnv):
@classmethod
def default_config(cls):
config = super().default_config()
config.update(dict(
data_directory=None,
num_controlled_agents=3,
horizon=1000,
))
return config
def __init__(self, config, agent2policy):
self.policy = agent2policy
self.controlled_agents = {}
self.controlled_agent_ids = []
self.obs_list = []
self.round = 0
super().__init__(config)
def reset(self, seed: Union[None, int] = None):
self.round = 0
if self.logger is None:
self.logger = get_logger()
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
set_log_level(log_level)
self.lazy_init()
self._reset_global_seed(seed)
if self.engine is None:
raise ValueError("Broken MetaDrive instance.")
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
_obj_to_clean_this_frame = []
self.car_birth_info_list = []
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
continue
else:
if track["type"] == MetaDriveType.VEHICLE:
_obj_to_clean_this_frame.append(scenario_id)
valid = track['state']['valid']
first_show = np.argmax(valid) if valid.any() else -1
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
# id出现时间出生点坐标出生朝向目的地
self.car_birth_info_list.append({
'id': track['metadata']['object_id'],
'show_time': first_show,
'begin': (track['state']['position'][first_show, 0], track['state']['position'][first_show, 1]),
'heading': track['state']['heading'][first_show],
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
})
for scenario_id in _obj_to_clean_this_frame:
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
self.engine.reset()
self.reset_sensors()
self.engine.taskMgr.step()
self.lanes = self.engine.map_manager.current_map.road_network.graph
if self.top_down_renderer is not None:
self.top_down_renderer.clear()
self.engine.top_down_renderer = None
self.dones = {}
self.episode_rewards = defaultdict(float)
self.episode_lengths = defaultdict(int)
self.controlled_agents.clear()
self.controlled_agent_ids.clear()
super().reset(seed) # 初始化场景
self._spawn_controlled_agents()
return self._get_all_obs()
def _spawn_controlled_agents(self):
# ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent")
# ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0])
for car in self.car_birth_info_list:
if car['show_time'] == self.round:
agent_id = f"controlled_{car['id']}"
vehicle = self.engine.spawn_object(
PolicyVehicle,
vehicle_config={},
position=car['begin'],
heading=car['heading']
)
vehicle.reset(position=car['begin'], heading=car['heading'])
vehicle.set_policy(self.policy)
vehicle.set_destination(car['end'])
self.controlled_agents[agent_id] = vehicle
self.controlled_agent_ids.append(agent_id)
# ✅ 关键:注册到引擎的 active_agents才能参与物理更新
self.engine.agent_manager.active_agents[agent_id] = vehicle
def _get_all_obs(self):
# position, velocity, heading, lidar, navigation, TODO: trafficlight -> list
self.obs_list = []
for agent_id, vehicle in self.controlled_agents.items():
state = vehicle.get_state()
traffic_light = 0
for lane in self.lanes.values():
if lane.lane.point_on_lane(state['position'][:2]):
if self.engine.light_manager.has_traffic_light(lane.lane.index):
traffic_light = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
if traffic_light == 'TRAFFIC_LIGHT_GREEN':
traffic_light = 1
elif traffic_light == 'TRAFFIC_LIGHT_YELLOW':
traffic_light = 2
elif traffic_light == 'TRAFFIC_LIGHT_RED':
traffic_light = 3
else:
traffic_light = 0
break
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
physics_world=self.engine.physics_world.dynamic_world)
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3,
base_vehicle=vehicle,
physics_world=self.engine.physics_world.static_world)
obs = (state['position'][:2] + list(state['velocity']) + [state['heading_theta']]
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
+ list(vehicle.destination))
self.obs_list.append(obs)
return self.obs_list
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
self.round += 1
for agent_id, action in action_dict.items():
if agent_id in self.controlled_agents:
self.controlled_agents[agent_id].before_step(action)
self.engine.step()
for agent_id in action_dict:
if agent_id in self.controlled_agents:
self.controlled_agents[agent_id].after_step()
self._spawn_controlled_agents()
obs = self._get_all_obs()
rewards = {aid: 0.0 for aid in self.controlled_agents}
dones = {aid: False for aid in self.controlled_agents}
dones["__all__"] = self.episode_step >= self.config["horizon"]
infos = {aid: {} for aid in self.controlled_agents}
return obs, rewards, dones, infos

18
Env/simple_idm_policy.py Normal file
View File

@@ -0,0 +1,18 @@
import numpy as np
class ConstantVelocityPolicy:
def __init__(self, target_speed=50):
self.step_num = 0
def act(self):
self.step_num += 1
if self.step_num % 30 < 15:
throttle = 1.0
else:
throttle = 1.0
steering = 0.1
# return [steering, throttle]
return [0.0,0.05]

14
Env/utils.py Normal file
View File

@@ -0,0 +1,14 @@
import numpy as np
import torch
import random
def set_seed(seed):
if seed == -1:
seed = np.random.randint(0, 10000)
print('Random seed: {}'.format(seed))
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)