commit 947871a7200bb10050a49fbad834551b7cbc4175 Author: ZHY Date: Sun Sep 28 18:57:04 2025 +0800 magail4autodrive: first commit diff --git a/Algorithm/bert.py b/Algorithm/bert.py new file mode 100644 index 0000000..ff854a2 --- /dev/null +++ b/Algorithm/bert.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn + + +class Bert(nn.Module): + def __init__(self, input_dim, output_dim, embed_dim=128, + num_layers=4, ff_dim=512, num_heads=4, dropout=0.1, CLS=False, TANH=False): + super().__init__() + self.CLS = CLS + self.projection = nn.Linear(input_dim, embed_dim) + if self.CLS: + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.randn(1, input_dim + 1, embed_dim)) + else: + self.pos_embed = nn.Parameter(torch.randn(1, input_dim, embed_dim)) + + self.layers = nn.ModuleList([ + TransformerLayer(embed_dim, num_heads, ff_dim, dropout) + for _ in range(num_layers) + ]) + + if TANH: + self.classifier = nn.Sequential(nn.Linear(embed_dim, output_dim), nn.Tanh()) + else: + self.classifier = nn.Linear(embed_dim, output_dim) + + self.layers.train() + self.classifier.train() + + def forward(self, x, mask=None): + # x: (batch_size, seq_len, input_dim) + # 线性投影 + x = self.projection(x) # (batch_size, input_dim, embed_dim) + + batch_size = x.size(0) + if self.CLS: + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + x = torch.cat([cls_tokens, x], dim=1) # (batch_size, 29, embed_dim) + + # 添加位置编码 + x = x + self.pos_embed + + # 转置为(seq_len, batch_size, embed_dim) + x = x.permute(1, 0, 2) + + for layer in self.layers: + x = layer(x, mask=mask) + + if self.CLS: + return self.classifier(x[0, :, :]) + else: + pooled = x.mean(dim=0) # (batch_size, embed_dim) + return self.classifier(pooled) + + + +class TransformerLayer(nn.Module): + def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1): + super().__init__() + self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) + self.linear1 = nn.Linear(embed_dim, ff_dim) + self.linear2 = nn.Linear(ff_dim, embed_dim) + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout = nn.Dropout(dropout) + # 使用GELU激活函数 + self.activation = nn.GELU() + + def forward(self, x, mask=None): + # Post-LN 结构 (残差连接后归一化) + # 注意力部分 + attn_output, _ = self.self_attn(x, x, x, attn_mask=mask) + x = x + self.dropout(attn_output) + x = self.norm1(x) + + # FFN部分 + ff_output = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = x + self.dropout(ff_output) + x = self.norm2(x) + + return x \ No newline at end of file diff --git a/Algorithm/buffer.py b/Algorithm/buffer.py new file mode 100644 index 0000000..33d5f46 --- /dev/null +++ b/Algorithm/buffer.py @@ -0,0 +1,80 @@ +import os +import numpy as np +import torch + + +class RolloutBuffer: + # TODO: state and action are list + def __init__(self, buffer_size, state_shape, action_shape, device): + self._n = 0 + self._p = 0 + self.buffer_size = buffer_size + + self.states = torch.empty((self.buffer_size, *state_shape), dtype=torch.float, device=device) + # self.states_gail = torch.empty((self.buffer_size, *state_gail_shape), dtype=torch.float, device=device) + self.actions = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device) + self.rewards = torch.empty((self.buffer_size, 1), dtype=torch.float, device=device) + self.dones = torch.empty((self.buffer_size, 1), dtype=torch.int, device=device) + self.tm_dones = torch.empty((self.buffer_size, 1), dtype=torch.int, device=device) + self.log_pis = torch.empty((self.buffer_size, 1), dtype=torch.float, device=device) + self.next_states = torch.empty((self.buffer_size, *state_shape), dtype=torch.float, device=device) + # self.next_states_gail = torch.empty((self.buffer_size, *state_gail_shape), dtype=torch.float, device=device) + self.means = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device) + self.stds = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device) + + def append(self, state, action, reward, done, tm_dones, log_pi, next_state, next_state_gail, means, stds): + self.states[self._p].copy_(state) + # self.states_gail[self._p].copy_(state_gail) + self.actions[self._p].copy_(torch.from_numpy(action)) + self.rewards[self._p] = float(reward) + self.dones[self._p] = int(done) + self.tm_dones[self._p] = int(tm_dones) + self.log_pis[self._p] = float(log_pi) + self.next_states[self._p].copy_(torch.from_numpy(next_state)) + # self.next_states_gail[self._p].copy_(torch.from_numpy(next_state_gail)) + self.means[self._p].copy_(torch.from_numpy(means)) + self.stds[self._p].copy_(torch.from_numpy(stds)) + + self._p = (self._p + 1) % self.buffer_size + self._n = min(self._n + 1, self.buffer_size) + + def get(self): + assert self._p % self.buffer_size == 0 + idxes = slice(0, self.buffer_size) + return ( + self.states[idxes], + self.actions[idxes], + self.rewards[idxes], + self.dones[idxes], + self.tm_dones[idxes], + self.log_pis[idxes], + self.next_states[idxes], + self.means[idxes], + self.stds[idxes] + ) + + def sample(self, batch_size): + assert self._p % self.buffer_size == 0 + idxes = np.random.randint(low=0, high=self._n, size=batch_size) + return ( + self.states[idxes], + self.actions[idxes], + self.rewards[idxes], + self.dones[idxes], + self.tm_dones[idxes], + self.log_pis[idxes], + self.next_states[idxes], + self.means[idxes], + self.stds[idxes] + ) + + def clear(self): + self.states[:, :] = 0 + self.actions[:, :] = 0 + self.rewards[:, :] = 0 + self.dones[:, :] = 0 + self.tm_dones[:, :] = 0 + self.log_pis[:, :] = 0 + self.next_states[:, :] = 0 + self.means[:, :] = 0 + self.stds[:, :] = 0 diff --git a/Algorithm/disc.py b/Algorithm/disc.py new file mode 100644 index 0000000..3363d08 --- /dev/null +++ b/Algorithm/disc.py @@ -0,0 +1,44 @@ +import torch +from torch import nn +from .bert import Bert + + +DISC_LOGIT_INIT_SCALE = 1.0 + + +class GAILDiscrim(Bert): + + def __init__(self, input_dim, reward_i_coef=1.0, reward_t_coef=1.0, normalizer=None, device=None): + super().__init__(input_dim=input_dim, output_dim=1, TANH=False) + self.device = device + self.reward_t_coef = reward_t_coef + self.reward_i_coef = reward_i_coef + self.normalizer = normalizer + + def calculate_reward(self, states_gail, next_states_gail, rewards_t): + # PPO(GAIL) is to maximize E_{\pi} [-log(1 - D)]. + states_gail = states_gail.clone() + next_states_gail = next_states_gail.clone() + states = torch.cat([states_gail, next_states_gail], dim=-1) + with torch.no_grad(): + if self.normalizer is not None: + states = self.normalizer.normalize_torch(states, self.device) + rewards_t = self.reward_t_coef * rewards_t + d = self.forward(states) + prob = 1 / (1 + torch.exp(-d)) + rewards_i = self.reward_i_coef * ( + -torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.device)))) + rewards = rewards_t + rewards_i + return rewards, rewards_t / (self.reward_t_coef + 1e-10), rewards_i / (self.reward_i_coef + 1e-10) + + def get_disc_logit_weights(self): + return torch.flatten(self.classifier.weight) + + def get_disc_weights(self): + weights = [] + for m in self.layers.modules(): + if isinstance(m, nn.Linear): + weights.append(torch.flatten(m.weight)) + + weights.append(torch.flatten(self.classifier.weight)) + return weights diff --git a/Algorithm/magail.py b/Algorithm/magail.py new file mode 100644 index 0000000..227bfce --- /dev/null +++ b/Algorithm/magail.py @@ -0,0 +1,149 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from .disc import GAILDiscrim +from .ppo import PPO +from .utils import Normalizer + + +class MAGAIL(PPO): + def __init__(self, buffer_exp, input_dim, device, + disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005, + lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True + ): + super().__init__(state_shape=input_dim, device=device) + self.learning_steps = 0 + self.learning_steps_disc = 0 + + self.disc = GAILDiscrim(input_dim=input_dim) + self.disc_grad_penalty = disc_grad_penalty + self.disc_coef = disc_coef + self.disc_logit_reg = disc_logit_reg + self.disc_weight_decay = disc_weight_decay + self.lr_disc = lr_disc + self.epoch_disc = epoch_disc + self.optim_d = torch.optim.Adam(self.disc.parameters(), lr=self.lr_disc) + + self.normalizer = None + if use_gail_norm: + self.normalizer = Normalizer(self.state_shape[0]*2) + + self.batch_size = batch_size + self.buffer_exp = buffer_exp + + def update_disc(self, states, states_exp, writer): + states_cp = states.clone() + states_exp_cp = states_exp.clone() + + # Output of discriminator is (-inf, inf), not [0, 1]. + logits_pi = self.disc(states_cp) + logits_exp = self.disc(states_exp_cp) + + # Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)]. + loss_pi = -F.logsigmoid(-logits_pi).mean() + loss_exp = -F.logsigmoid(logits_exp).mean() + loss_disc = 0.5 * (loss_pi + loss_exp) + + # logit reg + logit_weights = self.disc.get_disc_logit_weights() + disc_logit_loss = torch.sum(torch.square(logit_weights)) + + # grad penalty + sample_expert = states_exp_cp + sample_expert.requires_grad = True + disc = self.disc.linear(self.disc.trunk(sample_expert)) + ones = torch.ones(disc.size(), device=disc.device) + disc_demo_grad = torch.autograd.grad(disc, sample_expert, + grad_outputs=ones, + create_graph=True, retain_graph=True, only_inputs=True) + disc_demo_grad = disc_demo_grad[0] + disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1) + grad_pen_loss = torch.mean(disc_demo_grad) + + # weight decay + disc_weights = self.disc.get_disc_weights() + disc_weights = torch.cat(disc_weights, dim=-1) + disc_weight_decay = torch.sum(torch.square(disc_weights)) + + loss = self.disc_coef * loss_disc + self.disc_grad_penalty * grad_pen_loss + \ + self.disc_logit_reg * disc_logit_loss + self.disc_weight_decay * disc_weight_decay + + self.optim_d.zero_grad() + loss.backward() + self.optim_d.step() + + if self.learning_steps_disc % self.epoch_disc == 0: + writer.add_scalar('Loss/disc', loss_disc.item(), self.learning_steps) + + # Discriminator's accuracies. + with torch.no_grad(): + acc_pi = (logits_pi < 0).float().mean().item() + acc_exp = (logits_exp > 0).float().mean().item() + + writer.add_scalar('Acc/acc_pi', acc_pi, self.learning_steps) + writer.add_scalar('Acc/acc_exp', acc_exp, self.learning_steps) + + def update(self, writer, total_steps): + self.learning_steps += 1 + for _ in range(self.epoch_disc): + self.learning_steps_disc += 1 + + # Samples from current policy trajectories. + samples_policy = self.buffer.sample(self.batch_size) + states, next_states = samples_policy[1], samples_policy[-3] + states = torch.cat([states, next_states], dim=-1) + + # Samples from expert demonstrations. + samples_expert = self.buffer_exp.sample(self.batch_size) + states_exp, next_states_exp = samples_expert[0], samples_expert[1] + states_exp = torch.cat([states_exp, next_states_exp], dim=-1) + + if self.normalizer is not None: + with torch.no_grad(): + states = self.normalizer.normalize_torch(states, self.device) + states_exp = self.normalizer.normalize_torch(states_exp, self.device) + + # Update discriminator and us encoder. + self.update_disc(states, states_exp, writer) + + # Calulates the running mean and std of a data stream + if self.normalizer is not None: + self.normalizer.update(states.cpu().numpy()) + self.normalizer.update(states_exp.cpu().numpy()) + + states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas = self.buffer.get() + + # Calculate rewards. + rewards, rewards_t, rewards_i = self.disc.calculate_reward(states, next_states, rewards) + + writer.add_scalar('Reward/rewards', rewards_t.mean().item() + rewards_i.mean().item(), + self.learning_steps) + writer.add_scalar('Reward/rewards_t', rewards_t.mean().item(), self.learning_steps) + writer.add_scalar('Reward/rewards_i', rewards_i.mean().item(), self.learning_steps) + + # Update PPO using estimated rewards. + self.update_ppo(states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas, writer, + total_steps) + self.buffer.clear() + return rewards_t.mean().item() + rewards_i.mean().item() + + def save_models(self, path): + torch.save({ + 'actor': self.actor.state_dict(), + 'critic': self.critic.state_dict(), + 'disc': self.disc.state_dict(), + 'optim_actor': self.optim_actor.state_dict(), + 'optim_critic': self.optim_critic.state_dict(), + 'optim_d': self.optim_d.state_dict() + }, os.path.join(path, 'model.pth')) + + def load_models(self, path, load_optimizer=True): + loaded_dict = torch.load(path, map_location='cuda:0') + self.actor.load_state_dict(loaded_dict['actor']) + self.critic.load_state_dict(loaded_dict['critic']) + self.disc.load_state_dict(loaded_dict['disc']) + if load_optimizer: + self.optim_actor.load_state_dict(loaded_dict['optim_actor']) + self.optim_critic.load_state_dict(loaded_dict['optim_critic']) + self.optim_d.load_state_dict(loaded_dict['optim_d']) diff --git a/Algorithm/policy.py b/Algorithm/policy.py new file mode 100644 index 0000000..f3035a9 --- /dev/null +++ b/Algorithm/policy.py @@ -0,0 +1,31 @@ +import torch +import numpy as np +from torch import nn +from .utils import build_mlp, reparameterize, evaluate_lop_pi + +class StateIndependentPolicy(nn.Module): + + def __init__(self, state_shape, action_shape, hidden_units=(64, 64), + hidden_activation=nn.Tanh()): + super().__init__() + + self.net = build_mlp( + input_dim=state_shape[0], + output_dim=action_shape[0], + hidden_units=hidden_units, + hidden_activation=hidden_activation + ) + self.log_stds = nn.Parameter(torch.zeros(1, action_shape[0])) + self.means = None + + def forward(self, states): + return torch.tanh(self.net(states)) + + def sample(self, states): + self.means = self.net(states) + actions, log_pis = reparameterize(self.means, self.log_stds) + return actions, log_pis + + def evaluate_log_pi(self, states, actions): + self.means = self.net(states) + return evaluate_lop_pi(self.means, self.log_stds, actions) \ No newline at end of file diff --git a/Algorithm/ppo.py b/Algorithm/ppo.py new file mode 100644 index 0000000..64278c5 --- /dev/null +++ b/Algorithm/ppo.py @@ -0,0 +1,267 @@ +import os +import torch +import numpy as np +from torch import nn +from torch.optim import Adam +from buffer import RolloutBuffer +from bert import Bert +from policy import StateIndependentPolicy +from abc import ABC, abstractmethod + + +class Algorithm(ABC): + + def __init__(self, state_shape, device, gamma): + + self.learning_steps = 0 + self.state_shape = state_shape + self.device = device + self.gamma = gamma + + def explore(self, state_list): + action_list = [] + log_pi_list = [] + if type(state_list).__module__ != "torch": + state_list = torch.tensor(state_list, dtype=torch.float, device=self.device) + with torch.no_grad(): + for state in state_list: + action, log_pi = self.actor.sample(state.unsqueeze(0)) + action_list.append(action.cpu().numpy()[0]) + log_pi_list.append(log_pi.item()) + return action_list, log_pi_list + + def exploit(self, state_list): + action_list = [] + state_list = torch.tensor(state_list, dtype=torch.float, device=self.device) + with torch.no_grad(): + for state in state_list: + action = self.actor(state.unsqueeze(0)) + action_list.append(action.cpu().numpy()[0]) + return action_list + + @abstractmethod + def is_update(self, step): + pass + + @abstractmethod + def update(self, writer, total_steps): + pass + + @abstractmethod + def save_models(self, save_dir): + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + +class PPO(Algorithm): + + def __init__(self, state_shape, device, gamma=0.995, rollout_length=2048, + units_actor=(64, 64), epoch_ppo=10, clip_eps=0.2, + lambd=0.97, max_grad_norm=1.0, desired_kl=0.01, surrogate_loss_coef=2., + value_loss_coef=5., entropy_coef=0., bounds_loss_coef=10., lr_actor=1e-3, lr_critic=1e-3, + lr_disc=1e-3, auto_lr=True, use_adv_norm=True, max_steps=10000000): + super().__init__(state_shape, device, gamma) + + self.lr_actor = lr_actor + self.lr_critic = lr_critic + self.lr_disc = lr_disc + self.auto_lr = auto_lr + + self.use_adv_norm = use_adv_norm + + # Rollout buffer. + self.buffer = RolloutBuffer( + buffer_size=rollout_length, + state_shape=state_shape, + action_shape=action_shape, + device=device + ) + + # Actor. + self.actor = StateIndependentPolicy( + state_shape=state_shape, + action_shape=action_shape, + hidden_units=units_actor, + hidden_activation=nn.Tanh() + ).to(device) + + # Critic. + self.critic = Bert( + input_dim=state_shape, + output_dim=1 + ).to(device) + + self.learning_steps_ppo = 0 + self.rollout_length = rollout_length + self.epoch_ppo = epoch_ppo + self.clip_eps = clip_eps + self.lambd = lambd + self.max_grad_norm = max_grad_norm + self.desired_kl = desired_kl + self.surrogate_loss_coef = surrogate_loss_coef + self.value_loss_coef = value_loss_coef + self.entropy_coef = entropy_coef + self.bounds_loss_coef = bounds_loss_coef + self.max_steps = max_steps + + self.optim_actor = Adam([{'params': self.actor.parameters()}], lr=lr_actor) + # self.optim_actor = Adam([ + # {'params': self.actor.net.f_net.parameters(), 'lr': lr_actor}, + # {'params': self.actor.net.k_net.parameters(), 'lr': lr_actor/3}]) + self.optim_critic = Adam([{'params': self.critic.parameters()}], lr=lr_critic) + + def is_update(self, step): + return step % self.rollout_length == 0 + + def step(self, env, state_list, state_gail): + state_list = torch.tensor(state_list, dtype=torch.float, device=self.device) + state_gail = torch.tensor(state_gail, dtype=torch.float, device=self.device) + action_list, log_pi_list = self.explore(state_list) + next_state, reward, terminated, truncated, info = env.step(np.array(action_list)) + next_state_gail = env.state_gail + done = terminated or truncated + + means = self.actor.means.detach().cpu().numpy()[0] + stds = (self.actor.log_stds.exp()).detach().cpu().numpy()[0] + + self.buffer.append(state_list, state_gail, action_list, reward, done, terminated, log_pi_list, + next_state, next_state_gail, means, stds) + + if done: + next_state = env.reset() + next_state_gail = env.state_gail + + return next_state, next_state_gail, info + + def update(self, writer, total_steps): + pass + + def update_ppo(self, states, actions, rewards, dones, tm_dones, log_pi_list, next_states, mus, sigmas, writer, + total_steps): + with torch.no_grad(): + values = self.critic(states.detach()) + next_values = self.critic(next_states.detach()) + + targets, gaes = self.calculate_gae( + values, rewards, dones, tm_dones, next_values, self.gamma, self.lambd) + + state_list = states.permute(1, 0, 2) + action_list = actions.permute(1, 0, 2) + + for i in range(self.epoch_ppo): + self.learning_steps_ppo += 1 + self.update_critic(states, targets, writer) + for state, action, log_pi in state_list, action_list, log_pi_list: + self.update_actor(state, action, log_pi, gaes, mus, sigmas, writer) + + # self.lr_decay(total_steps, writer) + + def update_critic(self, states, targets, writer): + loss_critic = (self.critic(states) - targets).pow_(2).mean() + loss_critic = loss_critic * self.value_loss_coef + + self.optim_critic.zero_grad() + loss_critic.backward(retain_graph=False) + nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) + self.optim_critic.step() + + if self.learning_steps_ppo % self.epoch_ppo == 0: + writer.add_scalar( + 'Loss/critic', loss_critic.item(), self.learning_steps) + + def update_actor(self, states, actions, log_pis_old, gaes, mus_old, sigmas_old, writer): + self.optim_actor.zero_grad() + log_pis = self.actor.evaluate_log_pi(states, actions) + mus = self.actor.means + sigmas = (self.actor.log_stds.exp()).repeat(mus.shape[0], 1) + entropy = -log_pis.mean() + + ratios = (log_pis - log_pis_old).exp_() + loss_actor1 = -ratios * gaes + loss_actor2 = -torch.clamp( + ratios, + 1.0 - self.clip_eps, + 1.0 + self.clip_eps + ) * gaes + loss_actor = torch.max(loss_actor1, loss_actor2).mean() + loss_actor = loss_actor * self.surrogate_loss_coef + if self.auto_lr: + # desired_kl: 0.01 + with torch.inference_mode(): + kl = torch.sum(torch.log(sigmas / sigmas_old + 1.e-5) + + (torch.square(sigmas_old) + torch.square(mus_old - mus)) / + (2.0 * torch.square(sigmas)) - 0.5, axis=-1) + + kl_mean = torch.mean(kl) + + if kl_mean > self.desired_kl * 2.0: + self.lr_actor = max(1e-5, self.lr_actor / 1.5) + self.lr_critic = max(1e-5, self.lr_critic / 1.5) + self.lr_disc = max(1e-5, self.lr_disc / 1.5) + elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: + self.lr_actor = min(1e-2, self.lr_actor * 1.5) + self.lr_critic = min(1e-2, self.lr_critic * 1.5) + self.lr_disc = min(1e-2, self.lr_disc * 1.5) + + for param_group in self.optim_actor.param_groups: + param_group['lr'] = self.lr_actor + for param_group in self.optim_critic.param_groups: + param_group['lr'] = self.lr_critic + for param_group in self.optim_d.param_groups: + param_group['lr'] = self.lr_disc + + loss = loss_actor # + b_loss * 0 - self.entropy_coef * entropy * 0 + + loss.backward() + nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) + self.optim_actor.step() + + if self.learning_steps_ppo % self.epoch_ppo == 0: + writer.add_scalar( + 'Loss/actor', loss_actor.item(), self.learning_steps) + writer.add_scalar( + 'Loss/entropy', entropy.item(), self.learning_steps) + writer.add_scalar( + 'Loss/learning_rate', self.lr_actor, self.learning_steps) + + def lr_decay(self, total_steps, writer): + lr_a_now = max(1e-5, self.lr_actor * (1 - total_steps / self.max_steps)) + lr_c_now = max(1e-5, self.lr_critic * (1 - total_steps / self.max_steps)) + lr_d_now = max(1e-5, self.lr_disc * (1 - total_steps / self.max_steps)) + for p in self.optim_actor.param_groups: + p['lr'] = lr_a_now + for p in self.optim_critic.param_groups: + p['lr'] = lr_c_now + for p in self.optim_d.param_groups: + p['lr'] = lr_d_now + + writer.add_scalar( + 'Loss/learning_rate', lr_a_now, self.learning_steps) + + def calculate_gae(self, values, rewards, dones, tm_dones, next_values, gamma, lambd): + """ + Calculate the advantage using GAE + 'tm_dones=True' means dead or win, there is no next state s' + 'dones=True' represents the terminal of an episode(dead or win or reaching the max_episode_steps). + When calculating the adv, if dones=True, gae=0 + Reference: https://github.com/Lizhi-sjtu/DRL-code-pytorch/blob/main/5.PPO-continuous/ppo_continuous.py + """ + with torch.no_grad(): + # Calculate TD errors. + deltas = rewards + gamma * next_values * (1 - tm_dones) - values + # Initialize gae. + gaes = torch.empty_like(rewards) + + # Calculate gae recursively from behind. + gaes[-1] = deltas[-1] + for t in reversed(range(rewards.size(0) - 1)): + gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1] + + v_target = gaes + values + if self.use_adv_norm: + gaes = (gaes - gaes.mean()) / (gaes.std(dim=0) + 1e-8) + + return v_target, gaes + + def save_models(self, save_dir): + pass diff --git a/Algorithm/utils.py b/Algorithm/utils.py new file mode 100644 index 0000000..724eb13 --- /dev/null +++ b/Algorithm/utils.py @@ -0,0 +1,108 @@ +import math +import torch +import numpy as np +from torch import nn +from typing import Tuple + + +class RunningMeanStd(object): + def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): + """ + Calulates the running mean and std of a data stream + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + :param epsilon: helps with arithmetic issues + :param shape: the shape of the data stream's output + """ + self.mean = np.zeros(shape, np.float64) + self.var = np.ones(shape, np.float64) + self.count = epsilon + + def update(self, arr: np.ndarray) -> None: + batch_mean = np.mean(arr, axis=0) + batch_var = np.var(arr, axis=0) + batch_count = arr.shape[0] + self.update_from_moments(batch_mean, batch_var, batch_count) + + def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None: + delta = batch_mean - self.mean + tot_count = self.count + batch_count + + new_mean = self.mean + delta * batch_count / tot_count + m_a = self.var * self.count + m_b = batch_var * batch_count + m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) + new_var = m_2 / (self.count + batch_count) + + new_count = batch_count + self.count + + self.mean = new_mean + self.var = new_var + self.count = new_count + + +class Normalizer(RunningMeanStd): + def __init__(self, input_dim, epsilon=1e-4, clip_obs=10.0): + super().__init__(shape=input_dim) + self.epsilon = epsilon + self.clip_obs = clip_obs + + def normalize(self, input): + return np.clip( + (input - self.mean) / np.sqrt(self.var + self.epsilon), + -self.clip_obs, self.clip_obs) + + def normalize_torch(self, input, device): + mean_torch = torch.tensor( + self.mean, device=device, dtype=torch.float32) + std_torch = torch.sqrt(torch.tensor( + self.var + self.epsilon, device=device, dtype=torch.float32)) + return torch.clamp( + (input - mean_torch) / std_torch, -self.clip_obs, self.clip_obs) + + def update_normalizer(self, rollouts, expert_loader): + policy_data_generator = rollouts.feed_forward_generator_amp( + None, mini_batch_size=expert_loader.batch_size) + expert_data_generator = expert_loader.dataset.feed_forward_generator_amp( + expert_loader.batch_size) + + for expert_batch, policy_batch in zip(expert_data_generator, policy_data_generator): + self.update( + torch.vstack(tuple(policy_batch) + tuple(expert_batch)).cpu().numpy()) + + +def build_mlp(input_dim, output_dim, hidden_units=[64, 64], + hidden_activation=nn.Tanh(), output_activation=None): + layers = [] + units = input_dim + for next_units in hidden_units: + layers.append(nn.Linear(units, next_units)) + layers.append(hidden_activation) + units = next_units + layers.append(nn.Linear(units, output_dim)) + if output_activation is not None: + layers.append(output_activation) + return nn.Sequential(*layers) + + +def calculate_log_pi(log_stds, noises, actions): + gaussian_log_probs = (-0.5 * noises.pow(2) - log_stds).sum( + dim=-1, keepdim=True) - 0.5 * math.log(2 * math.pi) * log_stds.size(-1) + + return gaussian_log_probs - torch.log( + 1 - actions.pow(2) + 1e-6).sum(dim=-1, keepdim=True) + + +def reparameterize(means, log_stds): + noises = torch.randn_like(means) + us = means + noises * log_stds.exp() + actions = torch.tanh(us) + return actions, calculate_log_pi(log_stds, noises, actions) + + +def atanh(x): + return 0.5 * (torch.log(1 + x + 1e-6) - torch.log(1 - x + 1e-6)) + + +def evaluate_lop_pi(means, log_stds, actions): + noises = (atanh(actions) - means) / (log_stds.exp() + 1e-8) + return calculate_log_pi(log_stds, noises, actions) diff --git a/Env/exp_converted/dataset_mapping.pkl b/Env/exp_converted/dataset_mapping.pkl new file mode 100644 index 0000000..3573b0b Binary files /dev/null and b/Env/exp_converted/dataset_mapping.pkl differ diff --git a/Env/exp_converted/dataset_summary.pkl b/Env/exp_converted/dataset_summary.pkl new file mode 100644 index 0000000..d44611f Binary files /dev/null and b/Env/exp_converted/dataset_summary.pkl differ diff --git a/Env/exp_converted/exp_converted_0/dataset_mapping.pkl b/Env/exp_converted/exp_converted_0/dataset_mapping.pkl new file mode 100644 index 0000000..00940bb Binary files /dev/null and b/Env/exp_converted/exp_converted_0/dataset_mapping.pkl differ diff --git a/Env/exp_converted/exp_converted_0/dataset_summary.pkl b/Env/exp_converted/exp_converted_0/dataset_summary.pkl new file mode 100644 index 0000000..52dc960 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/dataset_summary.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_104202f4f2590dff.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_104202f4f2590dff.pkl new file mode 100644 index 0000000..1cba11b Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_104202f4f2590dff.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_11e688db089d222.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_11e688db089d222.pkl new file mode 100644 index 0000000..4aafd79 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_11e688db089d222.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_197433a84d86f4b6.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_197433a84d86f4b6.pkl new file mode 100644 index 0000000..3666929 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_197433a84d86f4b6.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_21f584ad2dd5d7b8.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_21f584ad2dd5d7b8.pkl new file mode 100644 index 0000000..af20392 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_21f584ad2dd5d7b8.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2364a51095c69102.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2364a51095c69102.pkl new file mode 100644 index 0000000..fd4831d Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2364a51095c69102.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_25f57f7ef66cdfe6.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_25f57f7ef66cdfe6.pkl new file mode 100644 index 0000000..2dbdef6 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_25f57f7ef66cdfe6.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_27e52c5f34743a32.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_27e52c5f34743a32.pkl new file mode 100644 index 0000000..0fb5b15 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_27e52c5f34743a32.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2aa43fad083efbf3.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2aa43fad083efbf3.pkl new file mode 100644 index 0000000..d7db9b5 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2aa43fad083efbf3.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2bc07893b2abbb07.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2bc07893b2abbb07.pkl new file mode 100644 index 0000000..de1d0c1 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2bc07893b2abbb07.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2e0e37f5efeb70af.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2e0e37f5efeb70af.pkl new file mode 100644 index 0000000..804e713 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_2e0e37f5efeb70af.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_3114f7fbaa8cc086.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_3114f7fbaa8cc086.pkl new file mode 100644 index 0000000..2b1234d Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_3114f7fbaa8cc086.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_32f0ee473bcb2854.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_32f0ee473bcb2854.pkl new file mode 100644 index 0000000..d1977b9 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_32f0ee473bcb2854.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_3946229358696c01.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_3946229358696c01.pkl new file mode 100644 index 0000000..94674ac Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_3946229358696c01.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_3ec9f6dfb2b48d65.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_3ec9f6dfb2b48d65.pkl new file mode 100644 index 0000000..8689fe2 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_3ec9f6dfb2b48d65.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_41600af30ab8cc55.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_41600af30ab8cc55.pkl new file mode 100644 index 0000000..920fe40 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_41600af30ab8cc55.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_430a2693b92ba127.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_430a2693b92ba127.pkl new file mode 100644 index 0000000..713f549 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_430a2693b92ba127.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_44dc56e65fc65a82.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_44dc56e65fc65a82.pkl new file mode 100644 index 0000000..aabe86e Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_44dc56e65fc65a82.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_4a116dacc9ccc4df.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_4a116dacc9ccc4df.pkl new file mode 100644 index 0000000..c750fec Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_4a116dacc9ccc4df.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_4bf1d627f1771287.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_4bf1d627f1771287.pkl new file mode 100644 index 0000000..8218643 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_4bf1d627f1771287.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_546259711161a341.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_546259711161a341.pkl new file mode 100644 index 0000000..e57e43e Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_546259711161a341.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_5510d6a966ccc52f.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_5510d6a966ccc52f.pkl new file mode 100644 index 0000000..def48c5 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_5510d6a966ccc52f.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_568458c3148c034.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_568458c3148c034.pkl new file mode 100644 index 0000000..289c881 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_568458c3148c034.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_5915f8cd44872858.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_5915f8cd44872858.pkl new file mode 100644 index 0000000..03a00ca Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_5915f8cd44872858.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_5cac897a524d2f40.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_5cac897a524d2f40.pkl new file mode 100644 index 0000000..bc02dac Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_5cac897a524d2f40.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_66854d30a65d1216.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_66854d30a65d1216.pkl new file mode 100644 index 0000000..001d42d Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_66854d30a65d1216.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_671dec7d5e2fa9fb.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_671dec7d5e2fa9fb.pkl new file mode 100644 index 0000000..4d84f6b Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_671dec7d5e2fa9fb.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_6ecd6ab6d573b137.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_6ecd6ab6d573b137.pkl new file mode 100644 index 0000000..10adc48 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_6ecd6ab6d573b137.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_707f27ea3927b4f5.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_707f27ea3927b4f5.pkl new file mode 100644 index 0000000..e9fe575 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_707f27ea3927b4f5.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_715dfdaa4cf40df5.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_715dfdaa4cf40df5.pkl new file mode 100644 index 0000000..f126e65 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_715dfdaa4cf40df5.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_79776dd1931a3d26.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_79776dd1931a3d26.pkl new file mode 100644 index 0000000..f0280c1 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_79776dd1931a3d26.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7ac22c9e42d05c79.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7ac22c9e42d05c79.pkl new file mode 100644 index 0000000..eafaeb2 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7ac22c9e42d05c79.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7c003f7b2af6419e.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7c003f7b2af6419e.pkl new file mode 100644 index 0000000..48aba21 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7c003f7b2af6419e.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7ce76dd8013b8b9e.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7ce76dd8013b8b9e.pkl new file mode 100644 index 0000000..84cebdf Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7ce76dd8013b8b9e.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7f94d7eac202a8f6.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7f94d7eac202a8f6.pkl new file mode 100644 index 0000000..d143cbe Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_7f94d7eac202a8f6.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_852bf8201e701c22.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_852bf8201e701c22.pkl new file mode 100644 index 0000000..58235e8 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_852bf8201e701c22.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8998493f69081ab0.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8998493f69081ab0.pkl new file mode 100644 index 0000000..5ee1b34 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8998493f69081ab0.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8a720808ec3c0864.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8a720808ec3c0864.pkl new file mode 100644 index 0000000..91b87ed Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8a720808ec3c0864.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8bd2c7c34e2a7e91.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8bd2c7c34e2a7e91.pkl new file mode 100644 index 0000000..d51063b Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8bd2c7c34e2a7e91.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8c4eaec3edd72d1b.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8c4eaec3edd72d1b.pkl new file mode 100644 index 0000000..2e45887 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_8c4eaec3edd72d1b.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_90cdf8a7cb0e097f.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_90cdf8a7cb0e097f.pkl new file mode 100644 index 0000000..b11db38 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_90cdf8a7cb0e097f.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_90f674b6f7dad649.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_90f674b6f7dad649.pkl new file mode 100644 index 0000000..f0c4b98 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_90f674b6f7dad649.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_914a764cef3668a2.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_914a764cef3668a2.pkl new file mode 100644 index 0000000..e7389c0 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_914a764cef3668a2.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9570610abd87b982.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9570610abd87b982.pkl new file mode 100644 index 0000000..c36aa34 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9570610abd87b982.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_96d0ccbfef0829e5.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_96d0ccbfef0829e5.pkl new file mode 100644 index 0000000..970e339 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_96d0ccbfef0829e5.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9859cd1b4315b7de.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9859cd1b4315b7de.pkl new file mode 100644 index 0000000..ca351d1 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9859cd1b4315b7de.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_98cd28f72a641e8f.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_98cd28f72a641e8f.pkl new file mode 100644 index 0000000..a4fe969 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_98cd28f72a641e8f.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9d7563e9b6486022.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9d7563e9b6486022.pkl new file mode 100644 index 0000000..25d81d7 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9d7563e9b6486022.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9ec70bf90d6fe529.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9ec70bf90d6fe529.pkl new file mode 100644 index 0000000..639eed5 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_9ec70bf90d6fe529.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_a7a3a82d61f0e91e.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_a7a3a82d61f0e91e.pkl new file mode 100644 index 0000000..3956417 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_a7a3a82d61f0e91e.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_a7eb07f173d68ce5.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_a7eb07f173d68ce5.pkl new file mode 100644 index 0000000..55c0652 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_a7eb07f173d68ce5.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_a88f67a4ee877e62.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_a88f67a4ee877e62.pkl new file mode 100644 index 0000000..c64389b Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_a88f67a4ee877e62.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_ab06fe2bdd70dee8.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_ab06fe2bdd70dee8.pkl new file mode 100644 index 0000000..f11145b Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_ab06fe2bdd70dee8.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_b29b20e997b76ea3.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_b29b20e997b76ea3.pkl new file mode 100644 index 0000000..de17f19 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_b29b20e997b76ea3.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_b35779bf7e37ece9.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_b35779bf7e37ece9.pkl new file mode 100644 index 0000000..7aa5083 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_b35779bf7e37ece9.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_b6152ff56baf6817.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_b6152ff56baf6817.pkl new file mode 100644 index 0000000..bd27536 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_b6152ff56baf6817.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_beeee363eeb3f708.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_beeee363eeb3f708.pkl new file mode 100644 index 0000000..55ececc Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_beeee363eeb3f708.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_bffa6dd429936879.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_bffa6dd429936879.pkl new file mode 100644 index 0000000..01af3d3 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_bffa6dd429936879.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c354985f8a63a390.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c354985f8a63a390.pkl new file mode 100644 index 0000000..626eefc Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c354985f8a63a390.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c453b2059c68c41c.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c453b2059c68c41c.pkl new file mode 100644 index 0000000..1c41f0c Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c453b2059c68c41c.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c692808f8d63a7ec.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c692808f8d63a7ec.pkl new file mode 100644 index 0000000..ad79328 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c692808f8d63a7ec.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c93b188ee1c507d5.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c93b188ee1c507d5.pkl new file mode 100644 index 0000000..3d48fe4 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_c93b188ee1c507d5.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_cb3413b9e69ae5ab.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_cb3413b9e69ae5ab.pkl new file mode 100644 index 0000000..17f9030 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_cb3413b9e69ae5ab.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_cbcf4099dfd4f9fb.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_cbcf4099dfd4f9fb.pkl new file mode 100644 index 0000000..2b8b562 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_cbcf4099dfd4f9fb.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_cebdf28156152fd6.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_cebdf28156152fd6.pkl new file mode 100644 index 0000000..142f353 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_cebdf28156152fd6.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_d5327587f925c58e.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_d5327587f925c58e.pkl new file mode 100644 index 0000000..9240fb8 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_d5327587f925c58e.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_d6d0e0ed8c763a8a.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_d6d0e0ed8c763a8a.pkl new file mode 100644 index 0000000..6b03bf5 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_d6d0e0ed8c763a8a.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_d8036e6d7e2a86a2.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_d8036e6d7e2a86a2.pkl new file mode 100644 index 0000000..08cea8d Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_d8036e6d7e2a86a2.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_df74ad0cc6823304.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_df74ad0cc6823304.pkl new file mode 100644 index 0000000..4b5cfd8 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_df74ad0cc6823304.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_dfbfcbfc1b6f7f7a.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_dfbfcbfc1b6f7f7a.pkl new file mode 100644 index 0000000..afc6707 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_dfbfcbfc1b6f7f7a.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_e12f0928016d6956.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_e12f0928016d6956.pkl new file mode 100644 index 0000000..700f5e4 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_e12f0928016d6956.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_e2030d66ebfe7b6b.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_e2030d66ebfe7b6b.pkl new file mode 100644 index 0000000..4dc7558 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_e2030d66ebfe7b6b.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_eb073968f66914c7.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_eb073968f66914c7.pkl new file mode 100644 index 0000000..9c92a9e Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_eb073968f66914c7.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_f346701fdc8818d1.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_f346701fdc8818d1.pkl new file mode 100644 index 0000000..49cff18 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_f346701fdc8818d1.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_f6e89ee29a5f20b2.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_f6e89ee29a5f20b2.pkl new file mode 100644 index 0000000..cd36807 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_f6e89ee29a5f20b2.pkl differ diff --git a/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_f84a2c81fec0b16.pkl b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_f84a2c81fec0b16.pkl new file mode 100644 index 0000000..952e254 Binary files /dev/null and b/Env/exp_converted/exp_converted_0/sd_waymo_v1.2_f84a2c81fec0b16.pkl differ diff --git a/Env/run_multiagent_env.py b/Env/run_multiagent_env.py new file mode 100644 index 0000000..a6a21bc --- /dev/null +++ b/Env/run_multiagent_env.py @@ -0,0 +1,41 @@ +from scenario_env import MultiAgentScenarioEnv +from Env.simple_idm_policy import ConstantVelocityPolicy +from metadrive.engine.asset_loader import AssetLoader + +WAYMO_DATA_DIR = r"/home/zhy/桌面/MAGAIL_TR/Env" + +def main(): + env = MultiAgentScenarioEnv( + config={ + # "data_directory": AssetLoader.file_path(AssetLoader.asset_path, "waymo", unix_style=False), + "data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False), + "is_multi_agent": True, + "num_controlled_agents": 3, + "horizon": 300, + "use_render": True, + "sequential_seed": True, + "reactive_traffic": True, + "manual_control": True, + }, + agent2policy=ConstantVelocityPolicy(target_speed=50) + ) + + obs = env.reset(0 + ) + for step in range(10000): + actions = { + aid: env.controlled_agents[aid].policy.act() + for aid in env.controlled_agents + } + + obs, rewards, dones, infos = env.step(actions) + env.render(mode="topdown") + + if dones["__all__"]: + break + + env.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Env/scenario_env.py b/Env/scenario_env.py new file mode 100644 index 0000000..ec1e3b9 --- /dev/null +++ b/Env/scenario_env.py @@ -0,0 +1,204 @@ +import numpy as np +from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation +from metadrive.envs.scenario_env import ScenarioEnv +from metadrive.component.vehicle.vehicle_type import DefaultVehicle, vehicle_class_to_type +import math +import logging +from collections import defaultdict +from typing import Union, Dict, AnyStr +from metadrive.engine.logger import get_logger, set_log_level +from metadrive.type import MetaDriveType + + +class PolicyVehicle(DefaultVehicle): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.policy = None + self.destination = None + + def set_policy(self, policy): + self.policy = policy + + def set_destination(self, des): + self.destination = des + + def act(self, observation, policy=None): + if self.policy is not None: + return self.policy.act(observation) + else: + return self.action_space.sample() + + def before_step(self, action): + self.last_position = self.position # 2D vector + self.last_velocity = self.velocity # 2D vector + self.last_speed = self.speed # Scalar + self.last_heading_dir = self.heading + if action is not None: + self.last_current_action.append(action) + self._set_action(action) + + def is_done(self): + # arrive or crash + pass + + +vehicle_class_to_type[PolicyVehicle] = "default" + + +class MultiAgentScenarioEnv(ScenarioEnv): + @classmethod + def default_config(cls): + config = super().default_config() + config.update(dict( + data_directory=None, + num_controlled_agents=3, + horizon=1000, + )) + return config + + def __init__(self, config, agent2policy): + self.policy = agent2policy + self.controlled_agents = {} + self.controlled_agent_ids = [] + self.obs_list = [] + self.round = 0 + super().__init__(config) + + def reset(self, seed: Union[None, int] = None): + self.round = 0 + if self.logger is None: + self.logger = get_logger() + log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO) + set_log_level(log_level) + + self.lazy_init() + self._reset_global_seed(seed) + if self.engine is None: + raise ValueError("Broken MetaDrive instance.") + + # 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成 + _obj_to_clean_this_frame = [] + self.car_birth_info_list = [] + for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items(): + if scenario_id == self.engine.traffic_manager.sdc_scenario_id: + continue + else: + if track["type"] == MetaDriveType.VEHICLE: + _obj_to_clean_this_frame.append(scenario_id) + valid = track['state']['valid'] + first_show = np.argmax(valid) if valid.any() else -1 + last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1 + # id,出现时间,出生点坐标,出生朝向,目的地 + self.car_birth_info_list.append({ + 'id': track['metadata']['object_id'], + 'show_time': first_show, + 'begin': (track['state']['position'][first_show, 0], track['state']['position'][first_show, 1]), + 'heading': track['state']['heading'][first_show], + 'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1]) + }) + + for scenario_id in _obj_to_clean_this_frame: + self.engine.traffic_manager.current_traffic_data.pop(scenario_id) + + self.engine.reset() + self.reset_sensors() + self.engine.taskMgr.step() + + self.lanes = self.engine.map_manager.current_map.road_network.graph + + if self.top_down_renderer is not None: + self.top_down_renderer.clear() + self.engine.top_down_renderer = None + + self.dones = {} + self.episode_rewards = defaultdict(float) + self.episode_lengths = defaultdict(int) + + self.controlled_agents.clear() + self.controlled_agent_ids.clear() + + super().reset(seed) # 初始化场景 + self._spawn_controlled_agents() + + return self._get_all_obs() + + def _spawn_controlled_agents(self): + # ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent") + # ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0]) + for car in self.car_birth_info_list: + if car['show_time'] == self.round: + agent_id = f"controlled_{car['id']}" + + vehicle = self.engine.spawn_object( + PolicyVehicle, + vehicle_config={}, + position=car['begin'], + heading=car['heading'] + ) + vehicle.reset(position=car['begin'], heading=car['heading']) + + vehicle.set_policy(self.policy) + vehicle.set_destination(car['end']) + + self.controlled_agents[agent_id] = vehicle + self.controlled_agent_ids.append(agent_id) + + # ✅ 关键:注册到引擎的 active_agents,才能参与物理更新 + self.engine.agent_manager.active_agents[agent_id] = vehicle + + def _get_all_obs(self): + # position, velocity, heading, lidar, navigation, TODO: trafficlight -> list + self.obs_list = [] + for agent_id, vehicle in self.controlled_agents.items(): + state = vehicle.get_state() + + traffic_light = 0 + for lane in self.lanes.values(): + if lane.lane.point_on_lane(state['position'][:2]): + if self.engine.light_manager.has_traffic_light(lane.lane.index): + traffic_light = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status + if traffic_light == 'TRAFFIC_LIGHT_GREEN': + traffic_light = 1 + elif traffic_light == 'TRAFFIC_LIGHT_YELLOW': + traffic_light = 2 + elif traffic_light == 'TRAFFIC_LIGHT_RED': + traffic_light = 3 + else: + traffic_light = 0 + break + + lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle, + physics_world=self.engine.physics_world.dynamic_world) + side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8, + base_vehicle=vehicle, + physics_world=self.engine.physics_world.static_world) + lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3, + base_vehicle=vehicle, + physics_world=self.engine.physics_world.static_world) + + obs = (state['position'][:2] + list(state['velocity']) + [state['heading_theta']] + + lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light] + + list(vehicle.destination)) + self.obs_list.append(obs) + return self.obs_list + + def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]): + self.round += 1 + + for agent_id, action in action_dict.items(): + if agent_id in self.controlled_agents: + self.controlled_agents[agent_id].before_step(action) + + self.engine.step() + + for agent_id in action_dict: + if agent_id in self.controlled_agents: + self.controlled_agents[agent_id].after_step() + + self._spawn_controlled_agents() + obs = self._get_all_obs() + rewards = {aid: 0.0 for aid in self.controlled_agents} + dones = {aid: False for aid in self.controlled_agents} + dones["__all__"] = self.episode_step >= self.config["horizon"] + infos = {aid: {} for aid in self.controlled_agents} + return obs, rewards, dones, infos diff --git a/Env/simple_idm_policy.py b/Env/simple_idm_policy.py new file mode 100644 index 0000000..2f50be6 --- /dev/null +++ b/Env/simple_idm_policy.py @@ -0,0 +1,18 @@ +import numpy as np + +class ConstantVelocityPolicy: + def __init__(self, target_speed=50): + self.step_num = 0 + + def act(self): + self.step_num += 1 + if self.step_num % 30 < 15: + throttle = 1.0 + else: + throttle = 1.0 + + steering = 0.1 + + # return [steering, throttle] + + return [0.0,0.05] diff --git a/Env/utils.py b/Env/utils.py new file mode 100644 index 0000000..c19bf24 --- /dev/null +++ b/Env/utils.py @@ -0,0 +1,14 @@ +import numpy as np +import torch +import random + +def set_seed(seed): + if seed == -1: + seed = np.random.randint(0, 10000) + print('Random seed: {}'.format(seed)) + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) \ No newline at end of file