magail4autodrive: first commit
This commit is contained in:
81
Algorithm/bert.py
Normal file
81
Algorithm/bert.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Bert(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, embed_dim=128,
|
||||||
|
num_layers=4, ff_dim=512, num_heads=4, dropout=0.1, CLS=False, TANH=False):
|
||||||
|
super().__init__()
|
||||||
|
self.CLS = CLS
|
||||||
|
self.projection = nn.Linear(input_dim, embed_dim)
|
||||||
|
if self.CLS:
|
||||||
|
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
|
||||||
|
self.pos_embed = nn.Parameter(torch.randn(1, input_dim + 1, embed_dim))
|
||||||
|
else:
|
||||||
|
self.pos_embed = nn.Parameter(torch.randn(1, input_dim, embed_dim))
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
TransformerLayer(embed_dim, num_heads, ff_dim, dropout)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
if TANH:
|
||||||
|
self.classifier = nn.Sequential(nn.Linear(embed_dim, output_dim), nn.Tanh())
|
||||||
|
else:
|
||||||
|
self.classifier = nn.Linear(embed_dim, output_dim)
|
||||||
|
|
||||||
|
self.layers.train()
|
||||||
|
self.classifier.train()
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
# x: (batch_size, seq_len, input_dim)
|
||||||
|
# 线性投影
|
||||||
|
x = self.projection(x) # (batch_size, input_dim, embed_dim)
|
||||||
|
|
||||||
|
batch_size = x.size(0)
|
||||||
|
if self.CLS:
|
||||||
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||||
|
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, 29, embed_dim)
|
||||||
|
|
||||||
|
# 添加位置编码
|
||||||
|
x = x + self.pos_embed
|
||||||
|
|
||||||
|
# 转置为(seq_len, batch_size, embed_dim)
|
||||||
|
x = x.permute(1, 0, 2)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, mask=mask)
|
||||||
|
|
||||||
|
if self.CLS:
|
||||||
|
return self.classifier(x[0, :, :])
|
||||||
|
else:
|
||||||
|
pooled = x.mean(dim=0) # (batch_size, embed_dim)
|
||||||
|
return self.classifier(pooled)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
|
||||||
|
self.linear1 = nn.Linear(embed_dim, ff_dim)
|
||||||
|
self.linear2 = nn.Linear(ff_dim, embed_dim)
|
||||||
|
self.norm1 = nn.LayerNorm(embed_dim)
|
||||||
|
self.norm2 = nn.LayerNorm(embed_dim)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
# 使用GELU激活函数
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
# Post-LN 结构 (残差连接后归一化)
|
||||||
|
# 注意力部分
|
||||||
|
attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
|
||||||
|
x = x + self.dropout(attn_output)
|
||||||
|
x = self.norm1(x)
|
||||||
|
|
||||||
|
# FFN部分
|
||||||
|
ff_output = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||||
|
x = x + self.dropout(ff_output)
|
||||||
|
x = self.norm2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
80
Algorithm/buffer.py
Normal file
80
Algorithm/buffer.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class RolloutBuffer:
|
||||||
|
# TODO: state and action are list
|
||||||
|
def __init__(self, buffer_size, state_shape, action_shape, device):
|
||||||
|
self._n = 0
|
||||||
|
self._p = 0
|
||||||
|
self.buffer_size = buffer_size
|
||||||
|
|
||||||
|
self.states = torch.empty((self.buffer_size, *state_shape), dtype=torch.float, device=device)
|
||||||
|
# self.states_gail = torch.empty((self.buffer_size, *state_gail_shape), dtype=torch.float, device=device)
|
||||||
|
self.actions = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device)
|
||||||
|
self.rewards = torch.empty((self.buffer_size, 1), dtype=torch.float, device=device)
|
||||||
|
self.dones = torch.empty((self.buffer_size, 1), dtype=torch.int, device=device)
|
||||||
|
self.tm_dones = torch.empty((self.buffer_size, 1), dtype=torch.int, device=device)
|
||||||
|
self.log_pis = torch.empty((self.buffer_size, 1), dtype=torch.float, device=device)
|
||||||
|
self.next_states = torch.empty((self.buffer_size, *state_shape), dtype=torch.float, device=device)
|
||||||
|
# self.next_states_gail = torch.empty((self.buffer_size, *state_gail_shape), dtype=torch.float, device=device)
|
||||||
|
self.means = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device)
|
||||||
|
self.stds = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device)
|
||||||
|
|
||||||
|
def append(self, state, action, reward, done, tm_dones, log_pi, next_state, next_state_gail, means, stds):
|
||||||
|
self.states[self._p].copy_(state)
|
||||||
|
# self.states_gail[self._p].copy_(state_gail)
|
||||||
|
self.actions[self._p].copy_(torch.from_numpy(action))
|
||||||
|
self.rewards[self._p] = float(reward)
|
||||||
|
self.dones[self._p] = int(done)
|
||||||
|
self.tm_dones[self._p] = int(tm_dones)
|
||||||
|
self.log_pis[self._p] = float(log_pi)
|
||||||
|
self.next_states[self._p].copy_(torch.from_numpy(next_state))
|
||||||
|
# self.next_states_gail[self._p].copy_(torch.from_numpy(next_state_gail))
|
||||||
|
self.means[self._p].copy_(torch.from_numpy(means))
|
||||||
|
self.stds[self._p].copy_(torch.from_numpy(stds))
|
||||||
|
|
||||||
|
self._p = (self._p + 1) % self.buffer_size
|
||||||
|
self._n = min(self._n + 1, self.buffer_size)
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
assert self._p % self.buffer_size == 0
|
||||||
|
idxes = slice(0, self.buffer_size)
|
||||||
|
return (
|
||||||
|
self.states[idxes],
|
||||||
|
self.actions[idxes],
|
||||||
|
self.rewards[idxes],
|
||||||
|
self.dones[idxes],
|
||||||
|
self.tm_dones[idxes],
|
||||||
|
self.log_pis[idxes],
|
||||||
|
self.next_states[idxes],
|
||||||
|
self.means[idxes],
|
||||||
|
self.stds[idxes]
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample(self, batch_size):
|
||||||
|
assert self._p % self.buffer_size == 0
|
||||||
|
idxes = np.random.randint(low=0, high=self._n, size=batch_size)
|
||||||
|
return (
|
||||||
|
self.states[idxes],
|
||||||
|
self.actions[idxes],
|
||||||
|
self.rewards[idxes],
|
||||||
|
self.dones[idxes],
|
||||||
|
self.tm_dones[idxes],
|
||||||
|
self.log_pis[idxes],
|
||||||
|
self.next_states[idxes],
|
||||||
|
self.means[idxes],
|
||||||
|
self.stds[idxes]
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.states[:, :] = 0
|
||||||
|
self.actions[:, :] = 0
|
||||||
|
self.rewards[:, :] = 0
|
||||||
|
self.dones[:, :] = 0
|
||||||
|
self.tm_dones[:, :] = 0
|
||||||
|
self.log_pis[:, :] = 0
|
||||||
|
self.next_states[:, :] = 0
|
||||||
|
self.means[:, :] = 0
|
||||||
|
self.stds[:, :] = 0
|
||||||
44
Algorithm/disc.py
Normal file
44
Algorithm/disc.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from .bert import Bert
|
||||||
|
|
||||||
|
|
||||||
|
DISC_LOGIT_INIT_SCALE = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class GAILDiscrim(Bert):
|
||||||
|
|
||||||
|
def __init__(self, input_dim, reward_i_coef=1.0, reward_t_coef=1.0, normalizer=None, device=None):
|
||||||
|
super().__init__(input_dim=input_dim, output_dim=1, TANH=False)
|
||||||
|
self.device = device
|
||||||
|
self.reward_t_coef = reward_t_coef
|
||||||
|
self.reward_i_coef = reward_i_coef
|
||||||
|
self.normalizer = normalizer
|
||||||
|
|
||||||
|
def calculate_reward(self, states_gail, next_states_gail, rewards_t):
|
||||||
|
# PPO(GAIL) is to maximize E_{\pi} [-log(1 - D)].
|
||||||
|
states_gail = states_gail.clone()
|
||||||
|
next_states_gail = next_states_gail.clone()
|
||||||
|
states = torch.cat([states_gail, next_states_gail], dim=-1)
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.normalizer is not None:
|
||||||
|
states = self.normalizer.normalize_torch(states, self.device)
|
||||||
|
rewards_t = self.reward_t_coef * rewards_t
|
||||||
|
d = self.forward(states)
|
||||||
|
prob = 1 / (1 + torch.exp(-d))
|
||||||
|
rewards_i = self.reward_i_coef * (
|
||||||
|
-torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.device))))
|
||||||
|
rewards = rewards_t + rewards_i
|
||||||
|
return rewards, rewards_t / (self.reward_t_coef + 1e-10), rewards_i / (self.reward_i_coef + 1e-10)
|
||||||
|
|
||||||
|
def get_disc_logit_weights(self):
|
||||||
|
return torch.flatten(self.classifier.weight)
|
||||||
|
|
||||||
|
def get_disc_weights(self):
|
||||||
|
weights = []
|
||||||
|
for m in self.layers.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
weights.append(torch.flatten(m.weight))
|
||||||
|
|
||||||
|
weights.append(torch.flatten(self.classifier.weight))
|
||||||
|
return weights
|
||||||
149
Algorithm/magail.py
Normal file
149
Algorithm/magail.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .disc import GAILDiscrim
|
||||||
|
from .ppo import PPO
|
||||||
|
from .utils import Normalizer
|
||||||
|
|
||||||
|
|
||||||
|
class MAGAIL(PPO):
|
||||||
|
def __init__(self, buffer_exp, input_dim, device,
|
||||||
|
disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005,
|
||||||
|
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True
|
||||||
|
):
|
||||||
|
super().__init__(state_shape=input_dim, device=device)
|
||||||
|
self.learning_steps = 0
|
||||||
|
self.learning_steps_disc = 0
|
||||||
|
|
||||||
|
self.disc = GAILDiscrim(input_dim=input_dim)
|
||||||
|
self.disc_grad_penalty = disc_grad_penalty
|
||||||
|
self.disc_coef = disc_coef
|
||||||
|
self.disc_logit_reg = disc_logit_reg
|
||||||
|
self.disc_weight_decay = disc_weight_decay
|
||||||
|
self.lr_disc = lr_disc
|
||||||
|
self.epoch_disc = epoch_disc
|
||||||
|
self.optim_d = torch.optim.Adam(self.disc.parameters(), lr=self.lr_disc)
|
||||||
|
|
||||||
|
self.normalizer = None
|
||||||
|
if use_gail_norm:
|
||||||
|
self.normalizer = Normalizer(self.state_shape[0]*2)
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.buffer_exp = buffer_exp
|
||||||
|
|
||||||
|
def update_disc(self, states, states_exp, writer):
|
||||||
|
states_cp = states.clone()
|
||||||
|
states_exp_cp = states_exp.clone()
|
||||||
|
|
||||||
|
# Output of discriminator is (-inf, inf), not [0, 1].
|
||||||
|
logits_pi = self.disc(states_cp)
|
||||||
|
logits_exp = self.disc(states_exp_cp)
|
||||||
|
|
||||||
|
# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].
|
||||||
|
loss_pi = -F.logsigmoid(-logits_pi).mean()
|
||||||
|
loss_exp = -F.logsigmoid(logits_exp).mean()
|
||||||
|
loss_disc = 0.5 * (loss_pi + loss_exp)
|
||||||
|
|
||||||
|
# logit reg
|
||||||
|
logit_weights = self.disc.get_disc_logit_weights()
|
||||||
|
disc_logit_loss = torch.sum(torch.square(logit_weights))
|
||||||
|
|
||||||
|
# grad penalty
|
||||||
|
sample_expert = states_exp_cp
|
||||||
|
sample_expert.requires_grad = True
|
||||||
|
disc = self.disc.linear(self.disc.trunk(sample_expert))
|
||||||
|
ones = torch.ones(disc.size(), device=disc.device)
|
||||||
|
disc_demo_grad = torch.autograd.grad(disc, sample_expert,
|
||||||
|
grad_outputs=ones,
|
||||||
|
create_graph=True, retain_graph=True, only_inputs=True)
|
||||||
|
disc_demo_grad = disc_demo_grad[0]
|
||||||
|
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
|
||||||
|
grad_pen_loss = torch.mean(disc_demo_grad)
|
||||||
|
|
||||||
|
# weight decay
|
||||||
|
disc_weights = self.disc.get_disc_weights()
|
||||||
|
disc_weights = torch.cat(disc_weights, dim=-1)
|
||||||
|
disc_weight_decay = torch.sum(torch.square(disc_weights))
|
||||||
|
|
||||||
|
loss = self.disc_coef * loss_disc + self.disc_grad_penalty * grad_pen_loss + \
|
||||||
|
self.disc_logit_reg * disc_logit_loss + self.disc_weight_decay * disc_weight_decay
|
||||||
|
|
||||||
|
self.optim_d.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optim_d.step()
|
||||||
|
|
||||||
|
if self.learning_steps_disc % self.epoch_disc == 0:
|
||||||
|
writer.add_scalar('Loss/disc', loss_disc.item(), self.learning_steps)
|
||||||
|
|
||||||
|
# Discriminator's accuracies.
|
||||||
|
with torch.no_grad():
|
||||||
|
acc_pi = (logits_pi < 0).float().mean().item()
|
||||||
|
acc_exp = (logits_exp > 0).float().mean().item()
|
||||||
|
|
||||||
|
writer.add_scalar('Acc/acc_pi', acc_pi, self.learning_steps)
|
||||||
|
writer.add_scalar('Acc/acc_exp', acc_exp, self.learning_steps)
|
||||||
|
|
||||||
|
def update(self, writer, total_steps):
|
||||||
|
self.learning_steps += 1
|
||||||
|
for _ in range(self.epoch_disc):
|
||||||
|
self.learning_steps_disc += 1
|
||||||
|
|
||||||
|
# Samples from current policy trajectories.
|
||||||
|
samples_policy = self.buffer.sample(self.batch_size)
|
||||||
|
states, next_states = samples_policy[1], samples_policy[-3]
|
||||||
|
states = torch.cat([states, next_states], dim=-1)
|
||||||
|
|
||||||
|
# Samples from expert demonstrations.
|
||||||
|
samples_expert = self.buffer_exp.sample(self.batch_size)
|
||||||
|
states_exp, next_states_exp = samples_expert[0], samples_expert[1]
|
||||||
|
states_exp = torch.cat([states_exp, next_states_exp], dim=-1)
|
||||||
|
|
||||||
|
if self.normalizer is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
states = self.normalizer.normalize_torch(states, self.device)
|
||||||
|
states_exp = self.normalizer.normalize_torch(states_exp, self.device)
|
||||||
|
|
||||||
|
# Update discriminator and us encoder.
|
||||||
|
self.update_disc(states, states_exp, writer)
|
||||||
|
|
||||||
|
# Calulates the running mean and std of a data stream
|
||||||
|
if self.normalizer is not None:
|
||||||
|
self.normalizer.update(states.cpu().numpy())
|
||||||
|
self.normalizer.update(states_exp.cpu().numpy())
|
||||||
|
|
||||||
|
states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas = self.buffer.get()
|
||||||
|
|
||||||
|
# Calculate rewards.
|
||||||
|
rewards, rewards_t, rewards_i = self.disc.calculate_reward(states, next_states, rewards)
|
||||||
|
|
||||||
|
writer.add_scalar('Reward/rewards', rewards_t.mean().item() + rewards_i.mean().item(),
|
||||||
|
self.learning_steps)
|
||||||
|
writer.add_scalar('Reward/rewards_t', rewards_t.mean().item(), self.learning_steps)
|
||||||
|
writer.add_scalar('Reward/rewards_i', rewards_i.mean().item(), self.learning_steps)
|
||||||
|
|
||||||
|
# Update PPO using estimated rewards.
|
||||||
|
self.update_ppo(states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas, writer,
|
||||||
|
total_steps)
|
||||||
|
self.buffer.clear()
|
||||||
|
return rewards_t.mean().item() + rewards_i.mean().item()
|
||||||
|
|
||||||
|
def save_models(self, path):
|
||||||
|
torch.save({
|
||||||
|
'actor': self.actor.state_dict(),
|
||||||
|
'critic': self.critic.state_dict(),
|
||||||
|
'disc': self.disc.state_dict(),
|
||||||
|
'optim_actor': self.optim_actor.state_dict(),
|
||||||
|
'optim_critic': self.optim_critic.state_dict(),
|
||||||
|
'optim_d': self.optim_d.state_dict()
|
||||||
|
}, os.path.join(path, 'model.pth'))
|
||||||
|
|
||||||
|
def load_models(self, path, load_optimizer=True):
|
||||||
|
loaded_dict = torch.load(path, map_location='cuda:0')
|
||||||
|
self.actor.load_state_dict(loaded_dict['actor'])
|
||||||
|
self.critic.load_state_dict(loaded_dict['critic'])
|
||||||
|
self.disc.load_state_dict(loaded_dict['disc'])
|
||||||
|
if load_optimizer:
|
||||||
|
self.optim_actor.load_state_dict(loaded_dict['optim_actor'])
|
||||||
|
self.optim_critic.load_state_dict(loaded_dict['optim_critic'])
|
||||||
|
self.optim_d.load_state_dict(loaded_dict['optim_d'])
|
||||||
31
Algorithm/policy.py
Normal file
31
Algorithm/policy.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
from .utils import build_mlp, reparameterize, evaluate_lop_pi
|
||||||
|
|
||||||
|
class StateIndependentPolicy(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, state_shape, action_shape, hidden_units=(64, 64),
|
||||||
|
hidden_activation=nn.Tanh()):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.net = build_mlp(
|
||||||
|
input_dim=state_shape[0],
|
||||||
|
output_dim=action_shape[0],
|
||||||
|
hidden_units=hidden_units,
|
||||||
|
hidden_activation=hidden_activation
|
||||||
|
)
|
||||||
|
self.log_stds = nn.Parameter(torch.zeros(1, action_shape[0]))
|
||||||
|
self.means = None
|
||||||
|
|
||||||
|
def forward(self, states):
|
||||||
|
return torch.tanh(self.net(states))
|
||||||
|
|
||||||
|
def sample(self, states):
|
||||||
|
self.means = self.net(states)
|
||||||
|
actions, log_pis = reparameterize(self.means, self.log_stds)
|
||||||
|
return actions, log_pis
|
||||||
|
|
||||||
|
def evaluate_log_pi(self, states, actions):
|
||||||
|
self.means = self.net(states)
|
||||||
|
return evaluate_lop_pi(self.means, self.log_stds, actions)
|
||||||
267
Algorithm/ppo.py
Normal file
267
Algorithm/ppo.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import Adam
|
||||||
|
from buffer import RolloutBuffer
|
||||||
|
from bert import Bert
|
||||||
|
from policy import StateIndependentPolicy
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class Algorithm(ABC):
|
||||||
|
|
||||||
|
def __init__(self, state_shape, device, gamma):
|
||||||
|
|
||||||
|
self.learning_steps = 0
|
||||||
|
self.state_shape = state_shape
|
||||||
|
self.device = device
|
||||||
|
self.gamma = gamma
|
||||||
|
|
||||||
|
def explore(self, state_list):
|
||||||
|
action_list = []
|
||||||
|
log_pi_list = []
|
||||||
|
if type(state_list).__module__ != "torch":
|
||||||
|
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
for state in state_list:
|
||||||
|
action, log_pi = self.actor.sample(state.unsqueeze(0))
|
||||||
|
action_list.append(action.cpu().numpy()[0])
|
||||||
|
log_pi_list.append(log_pi.item())
|
||||||
|
return action_list, log_pi_list
|
||||||
|
|
||||||
|
def exploit(self, state_list):
|
||||||
|
action_list = []
|
||||||
|
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
for state in state_list:
|
||||||
|
action = self.actor(state.unsqueeze(0))
|
||||||
|
action_list.append(action.cpu().numpy()[0])
|
||||||
|
return action_list
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_update(self, step):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, writer, total_steps):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_models(self, save_dir):
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class PPO(Algorithm):
|
||||||
|
|
||||||
|
def __init__(self, state_shape, device, gamma=0.995, rollout_length=2048,
|
||||||
|
units_actor=(64, 64), epoch_ppo=10, clip_eps=0.2,
|
||||||
|
lambd=0.97, max_grad_norm=1.0, desired_kl=0.01, surrogate_loss_coef=2.,
|
||||||
|
value_loss_coef=5., entropy_coef=0., bounds_loss_coef=10., lr_actor=1e-3, lr_critic=1e-3,
|
||||||
|
lr_disc=1e-3, auto_lr=True, use_adv_norm=True, max_steps=10000000):
|
||||||
|
super().__init__(state_shape, device, gamma)
|
||||||
|
|
||||||
|
self.lr_actor = lr_actor
|
||||||
|
self.lr_critic = lr_critic
|
||||||
|
self.lr_disc = lr_disc
|
||||||
|
self.auto_lr = auto_lr
|
||||||
|
|
||||||
|
self.use_adv_norm = use_adv_norm
|
||||||
|
|
||||||
|
# Rollout buffer.
|
||||||
|
self.buffer = RolloutBuffer(
|
||||||
|
buffer_size=rollout_length,
|
||||||
|
state_shape=state_shape,
|
||||||
|
action_shape=action_shape,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Actor.
|
||||||
|
self.actor = StateIndependentPolicy(
|
||||||
|
state_shape=state_shape,
|
||||||
|
action_shape=action_shape,
|
||||||
|
hidden_units=units_actor,
|
||||||
|
hidden_activation=nn.Tanh()
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# Critic.
|
||||||
|
self.critic = Bert(
|
||||||
|
input_dim=state_shape,
|
||||||
|
output_dim=1
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.learning_steps_ppo = 0
|
||||||
|
self.rollout_length = rollout_length
|
||||||
|
self.epoch_ppo = epoch_ppo
|
||||||
|
self.clip_eps = clip_eps
|
||||||
|
self.lambd = lambd
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
self.desired_kl = desired_kl
|
||||||
|
self.surrogate_loss_coef = surrogate_loss_coef
|
||||||
|
self.value_loss_coef = value_loss_coef
|
||||||
|
self.entropy_coef = entropy_coef
|
||||||
|
self.bounds_loss_coef = bounds_loss_coef
|
||||||
|
self.max_steps = max_steps
|
||||||
|
|
||||||
|
self.optim_actor = Adam([{'params': self.actor.parameters()}], lr=lr_actor)
|
||||||
|
# self.optim_actor = Adam([
|
||||||
|
# {'params': self.actor.net.f_net.parameters(), 'lr': lr_actor},
|
||||||
|
# {'params': self.actor.net.k_net.parameters(), 'lr': lr_actor/3}])
|
||||||
|
self.optim_critic = Adam([{'params': self.critic.parameters()}], lr=lr_critic)
|
||||||
|
|
||||||
|
def is_update(self, step):
|
||||||
|
return step % self.rollout_length == 0
|
||||||
|
|
||||||
|
def step(self, env, state_list, state_gail):
|
||||||
|
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
|
||||||
|
state_gail = torch.tensor(state_gail, dtype=torch.float, device=self.device)
|
||||||
|
action_list, log_pi_list = self.explore(state_list)
|
||||||
|
next_state, reward, terminated, truncated, info = env.step(np.array(action_list))
|
||||||
|
next_state_gail = env.state_gail
|
||||||
|
done = terminated or truncated
|
||||||
|
|
||||||
|
means = self.actor.means.detach().cpu().numpy()[0]
|
||||||
|
stds = (self.actor.log_stds.exp()).detach().cpu().numpy()[0]
|
||||||
|
|
||||||
|
self.buffer.append(state_list, state_gail, action_list, reward, done, terminated, log_pi_list,
|
||||||
|
next_state, next_state_gail, means, stds)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
next_state = env.reset()
|
||||||
|
next_state_gail = env.state_gail
|
||||||
|
|
||||||
|
return next_state, next_state_gail, info
|
||||||
|
|
||||||
|
def update(self, writer, total_steps):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update_ppo(self, states, actions, rewards, dones, tm_dones, log_pi_list, next_states, mus, sigmas, writer,
|
||||||
|
total_steps):
|
||||||
|
with torch.no_grad():
|
||||||
|
values = self.critic(states.detach())
|
||||||
|
next_values = self.critic(next_states.detach())
|
||||||
|
|
||||||
|
targets, gaes = self.calculate_gae(
|
||||||
|
values, rewards, dones, tm_dones, next_values, self.gamma, self.lambd)
|
||||||
|
|
||||||
|
state_list = states.permute(1, 0, 2)
|
||||||
|
action_list = actions.permute(1, 0, 2)
|
||||||
|
|
||||||
|
for i in range(self.epoch_ppo):
|
||||||
|
self.learning_steps_ppo += 1
|
||||||
|
self.update_critic(states, targets, writer)
|
||||||
|
for state, action, log_pi in state_list, action_list, log_pi_list:
|
||||||
|
self.update_actor(state, action, log_pi, gaes, mus, sigmas, writer)
|
||||||
|
|
||||||
|
# self.lr_decay(total_steps, writer)
|
||||||
|
|
||||||
|
def update_critic(self, states, targets, writer):
|
||||||
|
loss_critic = (self.critic(states) - targets).pow_(2).mean()
|
||||||
|
loss_critic = loss_critic * self.value_loss_coef
|
||||||
|
|
||||||
|
self.optim_critic.zero_grad()
|
||||||
|
loss_critic.backward(retain_graph=False)
|
||||||
|
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
|
||||||
|
self.optim_critic.step()
|
||||||
|
|
||||||
|
if self.learning_steps_ppo % self.epoch_ppo == 0:
|
||||||
|
writer.add_scalar(
|
||||||
|
'Loss/critic', loss_critic.item(), self.learning_steps)
|
||||||
|
|
||||||
|
def update_actor(self, states, actions, log_pis_old, gaes, mus_old, sigmas_old, writer):
|
||||||
|
self.optim_actor.zero_grad()
|
||||||
|
log_pis = self.actor.evaluate_log_pi(states, actions)
|
||||||
|
mus = self.actor.means
|
||||||
|
sigmas = (self.actor.log_stds.exp()).repeat(mus.shape[0], 1)
|
||||||
|
entropy = -log_pis.mean()
|
||||||
|
|
||||||
|
ratios = (log_pis - log_pis_old).exp_()
|
||||||
|
loss_actor1 = -ratios * gaes
|
||||||
|
loss_actor2 = -torch.clamp(
|
||||||
|
ratios,
|
||||||
|
1.0 - self.clip_eps,
|
||||||
|
1.0 + self.clip_eps
|
||||||
|
) * gaes
|
||||||
|
loss_actor = torch.max(loss_actor1, loss_actor2).mean()
|
||||||
|
loss_actor = loss_actor * self.surrogate_loss_coef
|
||||||
|
if self.auto_lr:
|
||||||
|
# desired_kl: 0.01
|
||||||
|
with torch.inference_mode():
|
||||||
|
kl = torch.sum(torch.log(sigmas / sigmas_old + 1.e-5) +
|
||||||
|
(torch.square(sigmas_old) + torch.square(mus_old - mus)) /
|
||||||
|
(2.0 * torch.square(sigmas)) - 0.5, axis=-1)
|
||||||
|
|
||||||
|
kl_mean = torch.mean(kl)
|
||||||
|
|
||||||
|
if kl_mean > self.desired_kl * 2.0:
|
||||||
|
self.lr_actor = max(1e-5, self.lr_actor / 1.5)
|
||||||
|
self.lr_critic = max(1e-5, self.lr_critic / 1.5)
|
||||||
|
self.lr_disc = max(1e-5, self.lr_disc / 1.5)
|
||||||
|
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
|
||||||
|
self.lr_actor = min(1e-2, self.lr_actor * 1.5)
|
||||||
|
self.lr_critic = min(1e-2, self.lr_critic * 1.5)
|
||||||
|
self.lr_disc = min(1e-2, self.lr_disc * 1.5)
|
||||||
|
|
||||||
|
for param_group in self.optim_actor.param_groups:
|
||||||
|
param_group['lr'] = self.lr_actor
|
||||||
|
for param_group in self.optim_critic.param_groups:
|
||||||
|
param_group['lr'] = self.lr_critic
|
||||||
|
for param_group in self.optim_d.param_groups:
|
||||||
|
param_group['lr'] = self.lr_disc
|
||||||
|
|
||||||
|
loss = loss_actor # + b_loss * 0 - self.entropy_coef * entropy * 0
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
||||||
|
self.optim_actor.step()
|
||||||
|
|
||||||
|
if self.learning_steps_ppo % self.epoch_ppo == 0:
|
||||||
|
writer.add_scalar(
|
||||||
|
'Loss/actor', loss_actor.item(), self.learning_steps)
|
||||||
|
writer.add_scalar(
|
||||||
|
'Loss/entropy', entropy.item(), self.learning_steps)
|
||||||
|
writer.add_scalar(
|
||||||
|
'Loss/learning_rate', self.lr_actor, self.learning_steps)
|
||||||
|
|
||||||
|
def lr_decay(self, total_steps, writer):
|
||||||
|
lr_a_now = max(1e-5, self.lr_actor * (1 - total_steps / self.max_steps))
|
||||||
|
lr_c_now = max(1e-5, self.lr_critic * (1 - total_steps / self.max_steps))
|
||||||
|
lr_d_now = max(1e-5, self.lr_disc * (1 - total_steps / self.max_steps))
|
||||||
|
for p in self.optim_actor.param_groups:
|
||||||
|
p['lr'] = lr_a_now
|
||||||
|
for p in self.optim_critic.param_groups:
|
||||||
|
p['lr'] = lr_c_now
|
||||||
|
for p in self.optim_d.param_groups:
|
||||||
|
p['lr'] = lr_d_now
|
||||||
|
|
||||||
|
writer.add_scalar(
|
||||||
|
'Loss/learning_rate', lr_a_now, self.learning_steps)
|
||||||
|
|
||||||
|
def calculate_gae(self, values, rewards, dones, tm_dones, next_values, gamma, lambd):
|
||||||
|
"""
|
||||||
|
Calculate the advantage using GAE
|
||||||
|
'tm_dones=True' means dead or win, there is no next state s'
|
||||||
|
'dones=True' represents the terminal of an episode(dead or win or reaching the max_episode_steps).
|
||||||
|
When calculating the adv, if dones=True, gae=0
|
||||||
|
Reference: https://github.com/Lizhi-sjtu/DRL-code-pytorch/blob/main/5.PPO-continuous/ppo_continuous.py
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
# Calculate TD errors.
|
||||||
|
deltas = rewards + gamma * next_values * (1 - tm_dones) - values
|
||||||
|
# Initialize gae.
|
||||||
|
gaes = torch.empty_like(rewards)
|
||||||
|
|
||||||
|
# Calculate gae recursively from behind.
|
||||||
|
gaes[-1] = deltas[-1]
|
||||||
|
for t in reversed(range(rewards.size(0) - 1)):
|
||||||
|
gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1]
|
||||||
|
|
||||||
|
v_target = gaes + values
|
||||||
|
if self.use_adv_norm:
|
||||||
|
gaes = (gaes - gaes.mean()) / (gaes.std(dim=0) + 1e-8)
|
||||||
|
|
||||||
|
return v_target, gaes
|
||||||
|
|
||||||
|
def save_models(self, save_dir):
|
||||||
|
pass
|
||||||
108
Algorithm/utils.py
Normal file
108
Algorithm/utils.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class RunningMeanStd(object):
|
||||||
|
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
|
||||||
|
"""
|
||||||
|
Calulates the running mean and std of a data stream
|
||||||
|
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||||
|
:param epsilon: helps with arithmetic issues
|
||||||
|
:param shape: the shape of the data stream's output
|
||||||
|
"""
|
||||||
|
self.mean = np.zeros(shape, np.float64)
|
||||||
|
self.var = np.ones(shape, np.float64)
|
||||||
|
self.count = epsilon
|
||||||
|
|
||||||
|
def update(self, arr: np.ndarray) -> None:
|
||||||
|
batch_mean = np.mean(arr, axis=0)
|
||||||
|
batch_var = np.var(arr, axis=0)
|
||||||
|
batch_count = arr.shape[0]
|
||||||
|
self.update_from_moments(batch_mean, batch_var, batch_count)
|
||||||
|
|
||||||
|
def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None:
|
||||||
|
delta = batch_mean - self.mean
|
||||||
|
tot_count = self.count + batch_count
|
||||||
|
|
||||||
|
new_mean = self.mean + delta * batch_count / tot_count
|
||||||
|
m_a = self.var * self.count
|
||||||
|
m_b = batch_var * batch_count
|
||||||
|
m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
|
||||||
|
new_var = m_2 / (self.count + batch_count)
|
||||||
|
|
||||||
|
new_count = batch_count + self.count
|
||||||
|
|
||||||
|
self.mean = new_mean
|
||||||
|
self.var = new_var
|
||||||
|
self.count = new_count
|
||||||
|
|
||||||
|
|
||||||
|
class Normalizer(RunningMeanStd):
|
||||||
|
def __init__(self, input_dim, epsilon=1e-4, clip_obs=10.0):
|
||||||
|
super().__init__(shape=input_dim)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.clip_obs = clip_obs
|
||||||
|
|
||||||
|
def normalize(self, input):
|
||||||
|
return np.clip(
|
||||||
|
(input - self.mean) / np.sqrt(self.var + self.epsilon),
|
||||||
|
-self.clip_obs, self.clip_obs)
|
||||||
|
|
||||||
|
def normalize_torch(self, input, device):
|
||||||
|
mean_torch = torch.tensor(
|
||||||
|
self.mean, device=device, dtype=torch.float32)
|
||||||
|
std_torch = torch.sqrt(torch.tensor(
|
||||||
|
self.var + self.epsilon, device=device, dtype=torch.float32))
|
||||||
|
return torch.clamp(
|
||||||
|
(input - mean_torch) / std_torch, -self.clip_obs, self.clip_obs)
|
||||||
|
|
||||||
|
def update_normalizer(self, rollouts, expert_loader):
|
||||||
|
policy_data_generator = rollouts.feed_forward_generator_amp(
|
||||||
|
None, mini_batch_size=expert_loader.batch_size)
|
||||||
|
expert_data_generator = expert_loader.dataset.feed_forward_generator_amp(
|
||||||
|
expert_loader.batch_size)
|
||||||
|
|
||||||
|
for expert_batch, policy_batch in zip(expert_data_generator, policy_data_generator):
|
||||||
|
self.update(
|
||||||
|
torch.vstack(tuple(policy_batch) + tuple(expert_batch)).cpu().numpy())
|
||||||
|
|
||||||
|
|
||||||
|
def build_mlp(input_dim, output_dim, hidden_units=[64, 64],
|
||||||
|
hidden_activation=nn.Tanh(), output_activation=None):
|
||||||
|
layers = []
|
||||||
|
units = input_dim
|
||||||
|
for next_units in hidden_units:
|
||||||
|
layers.append(nn.Linear(units, next_units))
|
||||||
|
layers.append(hidden_activation)
|
||||||
|
units = next_units
|
||||||
|
layers.append(nn.Linear(units, output_dim))
|
||||||
|
if output_activation is not None:
|
||||||
|
layers.append(output_activation)
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_log_pi(log_stds, noises, actions):
|
||||||
|
gaussian_log_probs = (-0.5 * noises.pow(2) - log_stds).sum(
|
||||||
|
dim=-1, keepdim=True) - 0.5 * math.log(2 * math.pi) * log_stds.size(-1)
|
||||||
|
|
||||||
|
return gaussian_log_probs - torch.log(
|
||||||
|
1 - actions.pow(2) + 1e-6).sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
def reparameterize(means, log_stds):
|
||||||
|
noises = torch.randn_like(means)
|
||||||
|
us = means + noises * log_stds.exp()
|
||||||
|
actions = torch.tanh(us)
|
||||||
|
return actions, calculate_log_pi(log_stds, noises, actions)
|
||||||
|
|
||||||
|
|
||||||
|
def atanh(x):
|
||||||
|
return 0.5 * (torch.log(1 + x + 1e-6) - torch.log(1 - x + 1e-6))
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_lop_pi(means, log_stds, actions):
|
||||||
|
noises = (atanh(actions) - means) / (log_stds.exp() + 1e-8)
|
||||||
|
return calculate_log_pi(log_stds, noises, actions)
|
||||||
BIN
Env/exp_converted/dataset_mapping.pkl
Normal file
BIN
Env/exp_converted/dataset_mapping.pkl
Normal file
Binary file not shown.
BIN
Env/exp_converted/dataset_summary.pkl
Normal file
BIN
Env/exp_converted/dataset_summary.pkl
Normal file
Binary file not shown.
BIN
Env/exp_converted/exp_converted_0/dataset_mapping.pkl
Normal file
BIN
Env/exp_converted/exp_converted_0/dataset_mapping.pkl
Normal file
Binary file not shown.
BIN
Env/exp_converted/exp_converted_0/dataset_summary.pkl
Normal file
BIN
Env/exp_converted/exp_converted_0/dataset_summary.pkl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
41
Env/run_multiagent_env.py
Normal file
41
Env/run_multiagent_env.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from scenario_env import MultiAgentScenarioEnv
|
||||||
|
from Env.simple_idm_policy import ConstantVelocityPolicy
|
||||||
|
from metadrive.engine.asset_loader import AssetLoader
|
||||||
|
|
||||||
|
WAYMO_DATA_DIR = r"/home/zhy/桌面/MAGAIL_TR/Env"
|
||||||
|
|
||||||
|
def main():
|
||||||
|
env = MultiAgentScenarioEnv(
|
||||||
|
config={
|
||||||
|
# "data_directory": AssetLoader.file_path(AssetLoader.asset_path, "waymo", unix_style=False),
|
||||||
|
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||||
|
"is_multi_agent": True,
|
||||||
|
"num_controlled_agents": 3,
|
||||||
|
"horizon": 300,
|
||||||
|
"use_render": True,
|
||||||
|
"sequential_seed": True,
|
||||||
|
"reactive_traffic": True,
|
||||||
|
"manual_control": True,
|
||||||
|
},
|
||||||
|
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||||
|
)
|
||||||
|
|
||||||
|
obs = env.reset(0
|
||||||
|
)
|
||||||
|
for step in range(10000):
|
||||||
|
actions = {
|
||||||
|
aid: env.controlled_agents[aid].policy.act()
|
||||||
|
for aid in env.controlled_agents
|
||||||
|
}
|
||||||
|
|
||||||
|
obs, rewards, dones, infos = env.step(actions)
|
||||||
|
env.render(mode="topdown")
|
||||||
|
|
||||||
|
if dones["__all__"]:
|
||||||
|
break
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
204
Env/scenario_env.py
Normal file
204
Env/scenario_env.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
import numpy as np
|
||||||
|
from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation
|
||||||
|
from metadrive.envs.scenario_env import ScenarioEnv
|
||||||
|
from metadrive.component.vehicle.vehicle_type import DefaultVehicle, vehicle_class_to_type
|
||||||
|
import math
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Union, Dict, AnyStr
|
||||||
|
from metadrive.engine.logger import get_logger, set_log_level
|
||||||
|
from metadrive.type import MetaDriveType
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyVehicle(DefaultVehicle):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.policy = None
|
||||||
|
self.destination = None
|
||||||
|
|
||||||
|
def set_policy(self, policy):
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
|
def set_destination(self, des):
|
||||||
|
self.destination = des
|
||||||
|
|
||||||
|
def act(self, observation, policy=None):
|
||||||
|
if self.policy is not None:
|
||||||
|
return self.policy.act(observation)
|
||||||
|
else:
|
||||||
|
return self.action_space.sample()
|
||||||
|
|
||||||
|
def before_step(self, action):
|
||||||
|
self.last_position = self.position # 2D vector
|
||||||
|
self.last_velocity = self.velocity # 2D vector
|
||||||
|
self.last_speed = self.speed # Scalar
|
||||||
|
self.last_heading_dir = self.heading
|
||||||
|
if action is not None:
|
||||||
|
self.last_current_action.append(action)
|
||||||
|
self._set_action(action)
|
||||||
|
|
||||||
|
def is_done(self):
|
||||||
|
# arrive or crash
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
vehicle_class_to_type[PolicyVehicle] = "default"
|
||||||
|
|
||||||
|
|
||||||
|
class MultiAgentScenarioEnv(ScenarioEnv):
|
||||||
|
@classmethod
|
||||||
|
def default_config(cls):
|
||||||
|
config = super().default_config()
|
||||||
|
config.update(dict(
|
||||||
|
data_directory=None,
|
||||||
|
num_controlled_agents=3,
|
||||||
|
horizon=1000,
|
||||||
|
))
|
||||||
|
return config
|
||||||
|
|
||||||
|
def __init__(self, config, agent2policy):
|
||||||
|
self.policy = agent2policy
|
||||||
|
self.controlled_agents = {}
|
||||||
|
self.controlled_agent_ids = []
|
||||||
|
self.obs_list = []
|
||||||
|
self.round = 0
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def reset(self, seed: Union[None, int] = None):
|
||||||
|
self.round = 0
|
||||||
|
if self.logger is None:
|
||||||
|
self.logger = get_logger()
|
||||||
|
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
|
||||||
|
set_log_level(log_level)
|
||||||
|
|
||||||
|
self.lazy_init()
|
||||||
|
self._reset_global_seed(seed)
|
||||||
|
if self.engine is None:
|
||||||
|
raise ValueError("Broken MetaDrive instance.")
|
||||||
|
|
||||||
|
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
|
||||||
|
_obj_to_clean_this_frame = []
|
||||||
|
self.car_birth_info_list = []
|
||||||
|
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
||||||
|
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if track["type"] == MetaDriveType.VEHICLE:
|
||||||
|
_obj_to_clean_this_frame.append(scenario_id)
|
||||||
|
valid = track['state']['valid']
|
||||||
|
first_show = np.argmax(valid) if valid.any() else -1
|
||||||
|
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
||||||
|
# id,出现时间,出生点坐标,出生朝向,目的地
|
||||||
|
self.car_birth_info_list.append({
|
||||||
|
'id': track['metadata']['object_id'],
|
||||||
|
'show_time': first_show,
|
||||||
|
'begin': (track['state']['position'][first_show, 0], track['state']['position'][first_show, 1]),
|
||||||
|
'heading': track['state']['heading'][first_show],
|
||||||
|
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
|
||||||
|
})
|
||||||
|
|
||||||
|
for scenario_id in _obj_to_clean_this_frame:
|
||||||
|
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
|
||||||
|
|
||||||
|
self.engine.reset()
|
||||||
|
self.reset_sensors()
|
||||||
|
self.engine.taskMgr.step()
|
||||||
|
|
||||||
|
self.lanes = self.engine.map_manager.current_map.road_network.graph
|
||||||
|
|
||||||
|
if self.top_down_renderer is not None:
|
||||||
|
self.top_down_renderer.clear()
|
||||||
|
self.engine.top_down_renderer = None
|
||||||
|
|
||||||
|
self.dones = {}
|
||||||
|
self.episode_rewards = defaultdict(float)
|
||||||
|
self.episode_lengths = defaultdict(int)
|
||||||
|
|
||||||
|
self.controlled_agents.clear()
|
||||||
|
self.controlled_agent_ids.clear()
|
||||||
|
|
||||||
|
super().reset(seed) # 初始化场景
|
||||||
|
self._spawn_controlled_agents()
|
||||||
|
|
||||||
|
return self._get_all_obs()
|
||||||
|
|
||||||
|
def _spawn_controlled_agents(self):
|
||||||
|
# ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent")
|
||||||
|
# ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0])
|
||||||
|
for car in self.car_birth_info_list:
|
||||||
|
if car['show_time'] == self.round:
|
||||||
|
agent_id = f"controlled_{car['id']}"
|
||||||
|
|
||||||
|
vehicle = self.engine.spawn_object(
|
||||||
|
PolicyVehicle,
|
||||||
|
vehicle_config={},
|
||||||
|
position=car['begin'],
|
||||||
|
heading=car['heading']
|
||||||
|
)
|
||||||
|
vehicle.reset(position=car['begin'], heading=car['heading'])
|
||||||
|
|
||||||
|
vehicle.set_policy(self.policy)
|
||||||
|
vehicle.set_destination(car['end'])
|
||||||
|
|
||||||
|
self.controlled_agents[agent_id] = vehicle
|
||||||
|
self.controlled_agent_ids.append(agent_id)
|
||||||
|
|
||||||
|
# ✅ 关键:注册到引擎的 active_agents,才能参与物理更新
|
||||||
|
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
||||||
|
|
||||||
|
def _get_all_obs(self):
|
||||||
|
# position, velocity, heading, lidar, navigation, TODO: trafficlight -> list
|
||||||
|
self.obs_list = []
|
||||||
|
for agent_id, vehicle in self.controlled_agents.items():
|
||||||
|
state = vehicle.get_state()
|
||||||
|
|
||||||
|
traffic_light = 0
|
||||||
|
for lane in self.lanes.values():
|
||||||
|
if lane.lane.point_on_lane(state['position'][:2]):
|
||||||
|
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
||||||
|
traffic_light = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
|
||||||
|
if traffic_light == 'TRAFFIC_LIGHT_GREEN':
|
||||||
|
traffic_light = 1
|
||||||
|
elif traffic_light == 'TRAFFIC_LIGHT_YELLOW':
|
||||||
|
traffic_light = 2
|
||||||
|
elif traffic_light == 'TRAFFIC_LIGHT_RED':
|
||||||
|
traffic_light = 3
|
||||||
|
else:
|
||||||
|
traffic_light = 0
|
||||||
|
break
|
||||||
|
|
||||||
|
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
|
||||||
|
physics_world=self.engine.physics_world.dynamic_world)
|
||||||
|
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
|
||||||
|
base_vehicle=vehicle,
|
||||||
|
physics_world=self.engine.physics_world.static_world)
|
||||||
|
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3,
|
||||||
|
base_vehicle=vehicle,
|
||||||
|
physics_world=self.engine.physics_world.static_world)
|
||||||
|
|
||||||
|
obs = (state['position'][:2] + list(state['velocity']) + [state['heading_theta']]
|
||||||
|
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
|
||||||
|
+ list(vehicle.destination))
|
||||||
|
self.obs_list.append(obs)
|
||||||
|
return self.obs_list
|
||||||
|
|
||||||
|
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
|
||||||
|
self.round += 1
|
||||||
|
|
||||||
|
for agent_id, action in action_dict.items():
|
||||||
|
if agent_id in self.controlled_agents:
|
||||||
|
self.controlled_agents[agent_id].before_step(action)
|
||||||
|
|
||||||
|
self.engine.step()
|
||||||
|
|
||||||
|
for agent_id in action_dict:
|
||||||
|
if agent_id in self.controlled_agents:
|
||||||
|
self.controlled_agents[agent_id].after_step()
|
||||||
|
|
||||||
|
self._spawn_controlled_agents()
|
||||||
|
obs = self._get_all_obs()
|
||||||
|
rewards = {aid: 0.0 for aid in self.controlled_agents}
|
||||||
|
dones = {aid: False for aid in self.controlled_agents}
|
||||||
|
dones["__all__"] = self.episode_step >= self.config["horizon"]
|
||||||
|
infos = {aid: {} for aid in self.controlled_agents}
|
||||||
|
return obs, rewards, dones, infos
|
||||||
18
Env/simple_idm_policy.py
Normal file
18
Env/simple_idm_policy.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class ConstantVelocityPolicy:
|
||||||
|
def __init__(self, target_speed=50):
|
||||||
|
self.step_num = 0
|
||||||
|
|
||||||
|
def act(self):
|
||||||
|
self.step_num += 1
|
||||||
|
if self.step_num % 30 < 15:
|
||||||
|
throttle = 1.0
|
||||||
|
else:
|
||||||
|
throttle = 1.0
|
||||||
|
|
||||||
|
steering = 0.1
|
||||||
|
|
||||||
|
# return [steering, throttle]
|
||||||
|
|
||||||
|
return [0.0,0.05]
|
||||||
14
Env/utils.py
Normal file
14
Env/utils.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
if seed == -1:
|
||||||
|
seed = np.random.randint(0, 10000)
|
||||||
|
print('Random seed: {}'.format(seed))
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
Reference in New Issue
Block a user