magail4autodrive: first commit
This commit is contained in:
81
Algorithm/bert.py
Normal file
81
Algorithm/bert.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Bert(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, embed_dim=128,
|
||||
num_layers=4, ff_dim=512, num_heads=4, dropout=0.1, CLS=False, TANH=False):
|
||||
super().__init__()
|
||||
self.CLS = CLS
|
||||
self.projection = nn.Linear(input_dim, embed_dim)
|
||||
if self.CLS:
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, input_dim + 1, embed_dim))
|
||||
else:
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, input_dim, embed_dim))
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerLayer(embed_dim, num_heads, ff_dim, dropout)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
if TANH:
|
||||
self.classifier = nn.Sequential(nn.Linear(embed_dim, output_dim), nn.Tanh())
|
||||
else:
|
||||
self.classifier = nn.Linear(embed_dim, output_dim)
|
||||
|
||||
self.layers.train()
|
||||
self.classifier.train()
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
# x: (batch_size, seq_len, input_dim)
|
||||
# 线性投影
|
||||
x = self.projection(x) # (batch_size, input_dim, embed_dim)
|
||||
|
||||
batch_size = x.size(0)
|
||||
if self.CLS:
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
x = torch.cat([cls_tokens, x], dim=1) # (batch_size, 29, embed_dim)
|
||||
|
||||
# 添加位置编码
|
||||
x = x + self.pos_embed
|
||||
|
||||
# 转置为(seq_len, batch_size, embed_dim)
|
||||
x = x.permute(1, 0, 2)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask=mask)
|
||||
|
||||
if self.CLS:
|
||||
return self.classifier(x[0, :, :])
|
||||
else:
|
||||
pooled = x.mean(dim=0) # (batch_size, embed_dim)
|
||||
return self.classifier(pooled)
|
||||
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
|
||||
self.linear1 = nn.Linear(embed_dim, ff_dim)
|
||||
self.linear2 = nn.Linear(ff_dim, embed_dim)
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
# 使用GELU激活函数
|
||||
self.activation = nn.GELU()
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
# Post-LN 结构 (残差连接后归一化)
|
||||
# 注意力部分
|
||||
attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
|
||||
x = x + self.dropout(attn_output)
|
||||
x = self.norm1(x)
|
||||
|
||||
# FFN部分
|
||||
ff_output = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||
x = x + self.dropout(ff_output)
|
||||
x = self.norm2(x)
|
||||
|
||||
return x
|
||||
80
Algorithm/buffer.py
Normal file
80
Algorithm/buffer.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class RolloutBuffer:
|
||||
# TODO: state and action are list
|
||||
def __init__(self, buffer_size, state_shape, action_shape, device):
|
||||
self._n = 0
|
||||
self._p = 0
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
self.states = torch.empty((self.buffer_size, *state_shape), dtype=torch.float, device=device)
|
||||
# self.states_gail = torch.empty((self.buffer_size, *state_gail_shape), dtype=torch.float, device=device)
|
||||
self.actions = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device)
|
||||
self.rewards = torch.empty((self.buffer_size, 1), dtype=torch.float, device=device)
|
||||
self.dones = torch.empty((self.buffer_size, 1), dtype=torch.int, device=device)
|
||||
self.tm_dones = torch.empty((self.buffer_size, 1), dtype=torch.int, device=device)
|
||||
self.log_pis = torch.empty((self.buffer_size, 1), dtype=torch.float, device=device)
|
||||
self.next_states = torch.empty((self.buffer_size, *state_shape), dtype=torch.float, device=device)
|
||||
# self.next_states_gail = torch.empty((self.buffer_size, *state_gail_shape), dtype=torch.float, device=device)
|
||||
self.means = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device)
|
||||
self.stds = torch.empty((self.buffer_size, *action_shape), dtype=torch.float, device=device)
|
||||
|
||||
def append(self, state, action, reward, done, tm_dones, log_pi, next_state, next_state_gail, means, stds):
|
||||
self.states[self._p].copy_(state)
|
||||
# self.states_gail[self._p].copy_(state_gail)
|
||||
self.actions[self._p].copy_(torch.from_numpy(action))
|
||||
self.rewards[self._p] = float(reward)
|
||||
self.dones[self._p] = int(done)
|
||||
self.tm_dones[self._p] = int(tm_dones)
|
||||
self.log_pis[self._p] = float(log_pi)
|
||||
self.next_states[self._p].copy_(torch.from_numpy(next_state))
|
||||
# self.next_states_gail[self._p].copy_(torch.from_numpy(next_state_gail))
|
||||
self.means[self._p].copy_(torch.from_numpy(means))
|
||||
self.stds[self._p].copy_(torch.from_numpy(stds))
|
||||
|
||||
self._p = (self._p + 1) % self.buffer_size
|
||||
self._n = min(self._n + 1, self.buffer_size)
|
||||
|
||||
def get(self):
|
||||
assert self._p % self.buffer_size == 0
|
||||
idxes = slice(0, self.buffer_size)
|
||||
return (
|
||||
self.states[idxes],
|
||||
self.actions[idxes],
|
||||
self.rewards[idxes],
|
||||
self.dones[idxes],
|
||||
self.tm_dones[idxes],
|
||||
self.log_pis[idxes],
|
||||
self.next_states[idxes],
|
||||
self.means[idxes],
|
||||
self.stds[idxes]
|
||||
)
|
||||
|
||||
def sample(self, batch_size):
|
||||
assert self._p % self.buffer_size == 0
|
||||
idxes = np.random.randint(low=0, high=self._n, size=batch_size)
|
||||
return (
|
||||
self.states[idxes],
|
||||
self.actions[idxes],
|
||||
self.rewards[idxes],
|
||||
self.dones[idxes],
|
||||
self.tm_dones[idxes],
|
||||
self.log_pis[idxes],
|
||||
self.next_states[idxes],
|
||||
self.means[idxes],
|
||||
self.stds[idxes]
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
self.states[:, :] = 0
|
||||
self.actions[:, :] = 0
|
||||
self.rewards[:, :] = 0
|
||||
self.dones[:, :] = 0
|
||||
self.tm_dones[:, :] = 0
|
||||
self.log_pis[:, :] = 0
|
||||
self.next_states[:, :] = 0
|
||||
self.means[:, :] = 0
|
||||
self.stds[:, :] = 0
|
||||
44
Algorithm/disc.py
Normal file
44
Algorithm/disc.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from .bert import Bert
|
||||
|
||||
|
||||
DISC_LOGIT_INIT_SCALE = 1.0
|
||||
|
||||
|
||||
class GAILDiscrim(Bert):
|
||||
|
||||
def __init__(self, input_dim, reward_i_coef=1.0, reward_t_coef=1.0, normalizer=None, device=None):
|
||||
super().__init__(input_dim=input_dim, output_dim=1, TANH=False)
|
||||
self.device = device
|
||||
self.reward_t_coef = reward_t_coef
|
||||
self.reward_i_coef = reward_i_coef
|
||||
self.normalizer = normalizer
|
||||
|
||||
def calculate_reward(self, states_gail, next_states_gail, rewards_t):
|
||||
# PPO(GAIL) is to maximize E_{\pi} [-log(1 - D)].
|
||||
states_gail = states_gail.clone()
|
||||
next_states_gail = next_states_gail.clone()
|
||||
states = torch.cat([states_gail, next_states_gail], dim=-1)
|
||||
with torch.no_grad():
|
||||
if self.normalizer is not None:
|
||||
states = self.normalizer.normalize_torch(states, self.device)
|
||||
rewards_t = self.reward_t_coef * rewards_t
|
||||
d = self.forward(states)
|
||||
prob = 1 / (1 + torch.exp(-d))
|
||||
rewards_i = self.reward_i_coef * (
|
||||
-torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.device))))
|
||||
rewards = rewards_t + rewards_i
|
||||
return rewards, rewards_t / (self.reward_t_coef + 1e-10), rewards_i / (self.reward_i_coef + 1e-10)
|
||||
|
||||
def get_disc_logit_weights(self):
|
||||
return torch.flatten(self.classifier.weight)
|
||||
|
||||
def get_disc_weights(self):
|
||||
weights = []
|
||||
for m in self.layers.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
weights.append(torch.flatten(m.weight))
|
||||
|
||||
weights.append(torch.flatten(self.classifier.weight))
|
||||
return weights
|
||||
149
Algorithm/magail.py
Normal file
149
Algorithm/magail.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .disc import GAILDiscrim
|
||||
from .ppo import PPO
|
||||
from .utils import Normalizer
|
||||
|
||||
|
||||
class MAGAIL(PPO):
|
||||
def __init__(self, buffer_exp, input_dim, device,
|
||||
disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005,
|
||||
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True
|
||||
):
|
||||
super().__init__(state_shape=input_dim, device=device)
|
||||
self.learning_steps = 0
|
||||
self.learning_steps_disc = 0
|
||||
|
||||
self.disc = GAILDiscrim(input_dim=input_dim)
|
||||
self.disc_grad_penalty = disc_grad_penalty
|
||||
self.disc_coef = disc_coef
|
||||
self.disc_logit_reg = disc_logit_reg
|
||||
self.disc_weight_decay = disc_weight_decay
|
||||
self.lr_disc = lr_disc
|
||||
self.epoch_disc = epoch_disc
|
||||
self.optim_d = torch.optim.Adam(self.disc.parameters(), lr=self.lr_disc)
|
||||
|
||||
self.normalizer = None
|
||||
if use_gail_norm:
|
||||
self.normalizer = Normalizer(self.state_shape[0]*2)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.buffer_exp = buffer_exp
|
||||
|
||||
def update_disc(self, states, states_exp, writer):
|
||||
states_cp = states.clone()
|
||||
states_exp_cp = states_exp.clone()
|
||||
|
||||
# Output of discriminator is (-inf, inf), not [0, 1].
|
||||
logits_pi = self.disc(states_cp)
|
||||
logits_exp = self.disc(states_exp_cp)
|
||||
|
||||
# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].
|
||||
loss_pi = -F.logsigmoid(-logits_pi).mean()
|
||||
loss_exp = -F.logsigmoid(logits_exp).mean()
|
||||
loss_disc = 0.5 * (loss_pi + loss_exp)
|
||||
|
||||
# logit reg
|
||||
logit_weights = self.disc.get_disc_logit_weights()
|
||||
disc_logit_loss = torch.sum(torch.square(logit_weights))
|
||||
|
||||
# grad penalty
|
||||
sample_expert = states_exp_cp
|
||||
sample_expert.requires_grad = True
|
||||
disc = self.disc.linear(self.disc.trunk(sample_expert))
|
||||
ones = torch.ones(disc.size(), device=disc.device)
|
||||
disc_demo_grad = torch.autograd.grad(disc, sample_expert,
|
||||
grad_outputs=ones,
|
||||
create_graph=True, retain_graph=True, only_inputs=True)
|
||||
disc_demo_grad = disc_demo_grad[0]
|
||||
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
|
||||
grad_pen_loss = torch.mean(disc_demo_grad)
|
||||
|
||||
# weight decay
|
||||
disc_weights = self.disc.get_disc_weights()
|
||||
disc_weights = torch.cat(disc_weights, dim=-1)
|
||||
disc_weight_decay = torch.sum(torch.square(disc_weights))
|
||||
|
||||
loss = self.disc_coef * loss_disc + self.disc_grad_penalty * grad_pen_loss + \
|
||||
self.disc_logit_reg * disc_logit_loss + self.disc_weight_decay * disc_weight_decay
|
||||
|
||||
self.optim_d.zero_grad()
|
||||
loss.backward()
|
||||
self.optim_d.step()
|
||||
|
||||
if self.learning_steps_disc % self.epoch_disc == 0:
|
||||
writer.add_scalar('Loss/disc', loss_disc.item(), self.learning_steps)
|
||||
|
||||
# Discriminator's accuracies.
|
||||
with torch.no_grad():
|
||||
acc_pi = (logits_pi < 0).float().mean().item()
|
||||
acc_exp = (logits_exp > 0).float().mean().item()
|
||||
|
||||
writer.add_scalar('Acc/acc_pi', acc_pi, self.learning_steps)
|
||||
writer.add_scalar('Acc/acc_exp', acc_exp, self.learning_steps)
|
||||
|
||||
def update(self, writer, total_steps):
|
||||
self.learning_steps += 1
|
||||
for _ in range(self.epoch_disc):
|
||||
self.learning_steps_disc += 1
|
||||
|
||||
# Samples from current policy trajectories.
|
||||
samples_policy = self.buffer.sample(self.batch_size)
|
||||
states, next_states = samples_policy[1], samples_policy[-3]
|
||||
states = torch.cat([states, next_states], dim=-1)
|
||||
|
||||
# Samples from expert demonstrations.
|
||||
samples_expert = self.buffer_exp.sample(self.batch_size)
|
||||
states_exp, next_states_exp = samples_expert[0], samples_expert[1]
|
||||
states_exp = torch.cat([states_exp, next_states_exp], dim=-1)
|
||||
|
||||
if self.normalizer is not None:
|
||||
with torch.no_grad():
|
||||
states = self.normalizer.normalize_torch(states, self.device)
|
||||
states_exp = self.normalizer.normalize_torch(states_exp, self.device)
|
||||
|
||||
# Update discriminator and us encoder.
|
||||
self.update_disc(states, states_exp, writer)
|
||||
|
||||
# Calulates the running mean and std of a data stream
|
||||
if self.normalizer is not None:
|
||||
self.normalizer.update(states.cpu().numpy())
|
||||
self.normalizer.update(states_exp.cpu().numpy())
|
||||
|
||||
states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas = self.buffer.get()
|
||||
|
||||
# Calculate rewards.
|
||||
rewards, rewards_t, rewards_i = self.disc.calculate_reward(states, next_states, rewards)
|
||||
|
||||
writer.add_scalar('Reward/rewards', rewards_t.mean().item() + rewards_i.mean().item(),
|
||||
self.learning_steps)
|
||||
writer.add_scalar('Reward/rewards_t', rewards_t.mean().item(), self.learning_steps)
|
||||
writer.add_scalar('Reward/rewards_i', rewards_i.mean().item(), self.learning_steps)
|
||||
|
||||
# Update PPO using estimated rewards.
|
||||
self.update_ppo(states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas, writer,
|
||||
total_steps)
|
||||
self.buffer.clear()
|
||||
return rewards_t.mean().item() + rewards_i.mean().item()
|
||||
|
||||
def save_models(self, path):
|
||||
torch.save({
|
||||
'actor': self.actor.state_dict(),
|
||||
'critic': self.critic.state_dict(),
|
||||
'disc': self.disc.state_dict(),
|
||||
'optim_actor': self.optim_actor.state_dict(),
|
||||
'optim_critic': self.optim_critic.state_dict(),
|
||||
'optim_d': self.optim_d.state_dict()
|
||||
}, os.path.join(path, 'model.pth'))
|
||||
|
||||
def load_models(self, path, load_optimizer=True):
|
||||
loaded_dict = torch.load(path, map_location='cuda:0')
|
||||
self.actor.load_state_dict(loaded_dict['actor'])
|
||||
self.critic.load_state_dict(loaded_dict['critic'])
|
||||
self.disc.load_state_dict(loaded_dict['disc'])
|
||||
if load_optimizer:
|
||||
self.optim_actor.load_state_dict(loaded_dict['optim_actor'])
|
||||
self.optim_critic.load_state_dict(loaded_dict['optim_critic'])
|
||||
self.optim_d.load_state_dict(loaded_dict['optim_d'])
|
||||
31
Algorithm/policy.py
Normal file
31
Algorithm/policy.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from .utils import build_mlp, reparameterize, evaluate_lop_pi
|
||||
|
||||
class StateIndependentPolicy(nn.Module):
|
||||
|
||||
def __init__(self, state_shape, action_shape, hidden_units=(64, 64),
|
||||
hidden_activation=nn.Tanh()):
|
||||
super().__init__()
|
||||
|
||||
self.net = build_mlp(
|
||||
input_dim=state_shape[0],
|
||||
output_dim=action_shape[0],
|
||||
hidden_units=hidden_units,
|
||||
hidden_activation=hidden_activation
|
||||
)
|
||||
self.log_stds = nn.Parameter(torch.zeros(1, action_shape[0]))
|
||||
self.means = None
|
||||
|
||||
def forward(self, states):
|
||||
return torch.tanh(self.net(states))
|
||||
|
||||
def sample(self, states):
|
||||
self.means = self.net(states)
|
||||
actions, log_pis = reparameterize(self.means, self.log_stds)
|
||||
return actions, log_pis
|
||||
|
||||
def evaluate_log_pi(self, states, actions):
|
||||
self.means = self.net(states)
|
||||
return evaluate_lop_pi(self.means, self.log_stds, actions)
|
||||
267
Algorithm/ppo.py
Normal file
267
Algorithm/ppo.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
from buffer import RolloutBuffer
|
||||
from bert import Bert
|
||||
from policy import StateIndependentPolicy
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Algorithm(ABC):
|
||||
|
||||
def __init__(self, state_shape, device, gamma):
|
||||
|
||||
self.learning_steps = 0
|
||||
self.state_shape = state_shape
|
||||
self.device = device
|
||||
self.gamma = gamma
|
||||
|
||||
def explore(self, state_list):
|
||||
action_list = []
|
||||
log_pi_list = []
|
||||
if type(state_list).__module__ != "torch":
|
||||
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
|
||||
with torch.no_grad():
|
||||
for state in state_list:
|
||||
action, log_pi = self.actor.sample(state.unsqueeze(0))
|
||||
action_list.append(action.cpu().numpy()[0])
|
||||
log_pi_list.append(log_pi.item())
|
||||
return action_list, log_pi_list
|
||||
|
||||
def exploit(self, state_list):
|
||||
action_list = []
|
||||
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
|
||||
with torch.no_grad():
|
||||
for state in state_list:
|
||||
action = self.actor(state.unsqueeze(0))
|
||||
action_list.append(action.cpu().numpy()[0])
|
||||
return action_list
|
||||
|
||||
@abstractmethod
|
||||
def is_update(self, step):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, writer, total_steps):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_models(self, save_dir):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
|
||||
class PPO(Algorithm):
|
||||
|
||||
def __init__(self, state_shape, device, gamma=0.995, rollout_length=2048,
|
||||
units_actor=(64, 64), epoch_ppo=10, clip_eps=0.2,
|
||||
lambd=0.97, max_grad_norm=1.0, desired_kl=0.01, surrogate_loss_coef=2.,
|
||||
value_loss_coef=5., entropy_coef=0., bounds_loss_coef=10., lr_actor=1e-3, lr_critic=1e-3,
|
||||
lr_disc=1e-3, auto_lr=True, use_adv_norm=True, max_steps=10000000):
|
||||
super().__init__(state_shape, device, gamma)
|
||||
|
||||
self.lr_actor = lr_actor
|
||||
self.lr_critic = lr_critic
|
||||
self.lr_disc = lr_disc
|
||||
self.auto_lr = auto_lr
|
||||
|
||||
self.use_adv_norm = use_adv_norm
|
||||
|
||||
# Rollout buffer.
|
||||
self.buffer = RolloutBuffer(
|
||||
buffer_size=rollout_length,
|
||||
state_shape=state_shape,
|
||||
action_shape=action_shape,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Actor.
|
||||
self.actor = StateIndependentPolicy(
|
||||
state_shape=state_shape,
|
||||
action_shape=action_shape,
|
||||
hidden_units=units_actor,
|
||||
hidden_activation=nn.Tanh()
|
||||
).to(device)
|
||||
|
||||
# Critic.
|
||||
self.critic = Bert(
|
||||
input_dim=state_shape,
|
||||
output_dim=1
|
||||
).to(device)
|
||||
|
||||
self.learning_steps_ppo = 0
|
||||
self.rollout_length = rollout_length
|
||||
self.epoch_ppo = epoch_ppo
|
||||
self.clip_eps = clip_eps
|
||||
self.lambd = lambd
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.desired_kl = desired_kl
|
||||
self.surrogate_loss_coef = surrogate_loss_coef
|
||||
self.value_loss_coef = value_loss_coef
|
||||
self.entropy_coef = entropy_coef
|
||||
self.bounds_loss_coef = bounds_loss_coef
|
||||
self.max_steps = max_steps
|
||||
|
||||
self.optim_actor = Adam([{'params': self.actor.parameters()}], lr=lr_actor)
|
||||
# self.optim_actor = Adam([
|
||||
# {'params': self.actor.net.f_net.parameters(), 'lr': lr_actor},
|
||||
# {'params': self.actor.net.k_net.parameters(), 'lr': lr_actor/3}])
|
||||
self.optim_critic = Adam([{'params': self.critic.parameters()}], lr=lr_critic)
|
||||
|
||||
def is_update(self, step):
|
||||
return step % self.rollout_length == 0
|
||||
|
||||
def step(self, env, state_list, state_gail):
|
||||
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
|
||||
state_gail = torch.tensor(state_gail, dtype=torch.float, device=self.device)
|
||||
action_list, log_pi_list = self.explore(state_list)
|
||||
next_state, reward, terminated, truncated, info = env.step(np.array(action_list))
|
||||
next_state_gail = env.state_gail
|
||||
done = terminated or truncated
|
||||
|
||||
means = self.actor.means.detach().cpu().numpy()[0]
|
||||
stds = (self.actor.log_stds.exp()).detach().cpu().numpy()[0]
|
||||
|
||||
self.buffer.append(state_list, state_gail, action_list, reward, done, terminated, log_pi_list,
|
||||
next_state, next_state_gail, means, stds)
|
||||
|
||||
if done:
|
||||
next_state = env.reset()
|
||||
next_state_gail = env.state_gail
|
||||
|
||||
return next_state, next_state_gail, info
|
||||
|
||||
def update(self, writer, total_steps):
|
||||
pass
|
||||
|
||||
def update_ppo(self, states, actions, rewards, dones, tm_dones, log_pi_list, next_states, mus, sigmas, writer,
|
||||
total_steps):
|
||||
with torch.no_grad():
|
||||
values = self.critic(states.detach())
|
||||
next_values = self.critic(next_states.detach())
|
||||
|
||||
targets, gaes = self.calculate_gae(
|
||||
values, rewards, dones, tm_dones, next_values, self.gamma, self.lambd)
|
||||
|
||||
state_list = states.permute(1, 0, 2)
|
||||
action_list = actions.permute(1, 0, 2)
|
||||
|
||||
for i in range(self.epoch_ppo):
|
||||
self.learning_steps_ppo += 1
|
||||
self.update_critic(states, targets, writer)
|
||||
for state, action, log_pi in state_list, action_list, log_pi_list:
|
||||
self.update_actor(state, action, log_pi, gaes, mus, sigmas, writer)
|
||||
|
||||
# self.lr_decay(total_steps, writer)
|
||||
|
||||
def update_critic(self, states, targets, writer):
|
||||
loss_critic = (self.critic(states) - targets).pow_(2).mean()
|
||||
loss_critic = loss_critic * self.value_loss_coef
|
||||
|
||||
self.optim_critic.zero_grad()
|
||||
loss_critic.backward(retain_graph=False)
|
||||
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
|
||||
self.optim_critic.step()
|
||||
|
||||
if self.learning_steps_ppo % self.epoch_ppo == 0:
|
||||
writer.add_scalar(
|
||||
'Loss/critic', loss_critic.item(), self.learning_steps)
|
||||
|
||||
def update_actor(self, states, actions, log_pis_old, gaes, mus_old, sigmas_old, writer):
|
||||
self.optim_actor.zero_grad()
|
||||
log_pis = self.actor.evaluate_log_pi(states, actions)
|
||||
mus = self.actor.means
|
||||
sigmas = (self.actor.log_stds.exp()).repeat(mus.shape[0], 1)
|
||||
entropy = -log_pis.mean()
|
||||
|
||||
ratios = (log_pis - log_pis_old).exp_()
|
||||
loss_actor1 = -ratios * gaes
|
||||
loss_actor2 = -torch.clamp(
|
||||
ratios,
|
||||
1.0 - self.clip_eps,
|
||||
1.0 + self.clip_eps
|
||||
) * gaes
|
||||
loss_actor = torch.max(loss_actor1, loss_actor2).mean()
|
||||
loss_actor = loss_actor * self.surrogate_loss_coef
|
||||
if self.auto_lr:
|
||||
# desired_kl: 0.01
|
||||
with torch.inference_mode():
|
||||
kl = torch.sum(torch.log(sigmas / sigmas_old + 1.e-5) +
|
||||
(torch.square(sigmas_old) + torch.square(mus_old - mus)) /
|
||||
(2.0 * torch.square(sigmas)) - 0.5, axis=-1)
|
||||
|
||||
kl_mean = torch.mean(kl)
|
||||
|
||||
if kl_mean > self.desired_kl * 2.0:
|
||||
self.lr_actor = max(1e-5, self.lr_actor / 1.5)
|
||||
self.lr_critic = max(1e-5, self.lr_critic / 1.5)
|
||||
self.lr_disc = max(1e-5, self.lr_disc / 1.5)
|
||||
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
|
||||
self.lr_actor = min(1e-2, self.lr_actor * 1.5)
|
||||
self.lr_critic = min(1e-2, self.lr_critic * 1.5)
|
||||
self.lr_disc = min(1e-2, self.lr_disc * 1.5)
|
||||
|
||||
for param_group in self.optim_actor.param_groups:
|
||||
param_group['lr'] = self.lr_actor
|
||||
for param_group in self.optim_critic.param_groups:
|
||||
param_group['lr'] = self.lr_critic
|
||||
for param_group in self.optim_d.param_groups:
|
||||
param_group['lr'] = self.lr_disc
|
||||
|
||||
loss = loss_actor # + b_loss * 0 - self.entropy_coef * entropy * 0
|
||||
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
||||
self.optim_actor.step()
|
||||
|
||||
if self.learning_steps_ppo % self.epoch_ppo == 0:
|
||||
writer.add_scalar(
|
||||
'Loss/actor', loss_actor.item(), self.learning_steps)
|
||||
writer.add_scalar(
|
||||
'Loss/entropy', entropy.item(), self.learning_steps)
|
||||
writer.add_scalar(
|
||||
'Loss/learning_rate', self.lr_actor, self.learning_steps)
|
||||
|
||||
def lr_decay(self, total_steps, writer):
|
||||
lr_a_now = max(1e-5, self.lr_actor * (1 - total_steps / self.max_steps))
|
||||
lr_c_now = max(1e-5, self.lr_critic * (1 - total_steps / self.max_steps))
|
||||
lr_d_now = max(1e-5, self.lr_disc * (1 - total_steps / self.max_steps))
|
||||
for p in self.optim_actor.param_groups:
|
||||
p['lr'] = lr_a_now
|
||||
for p in self.optim_critic.param_groups:
|
||||
p['lr'] = lr_c_now
|
||||
for p in self.optim_d.param_groups:
|
||||
p['lr'] = lr_d_now
|
||||
|
||||
writer.add_scalar(
|
||||
'Loss/learning_rate', lr_a_now, self.learning_steps)
|
||||
|
||||
def calculate_gae(self, values, rewards, dones, tm_dones, next_values, gamma, lambd):
|
||||
"""
|
||||
Calculate the advantage using GAE
|
||||
'tm_dones=True' means dead or win, there is no next state s'
|
||||
'dones=True' represents the terminal of an episode(dead or win or reaching the max_episode_steps).
|
||||
When calculating the adv, if dones=True, gae=0
|
||||
Reference: https://github.com/Lizhi-sjtu/DRL-code-pytorch/blob/main/5.PPO-continuous/ppo_continuous.py
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Calculate TD errors.
|
||||
deltas = rewards + gamma * next_values * (1 - tm_dones) - values
|
||||
# Initialize gae.
|
||||
gaes = torch.empty_like(rewards)
|
||||
|
||||
# Calculate gae recursively from behind.
|
||||
gaes[-1] = deltas[-1]
|
||||
for t in reversed(range(rewards.size(0) - 1)):
|
||||
gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1]
|
||||
|
||||
v_target = gaes + values
|
||||
if self.use_adv_norm:
|
||||
gaes = (gaes - gaes.mean()) / (gaes.std(dim=0) + 1e-8)
|
||||
|
||||
return v_target, gaes
|
||||
|
||||
def save_models(self, save_dir):
|
||||
pass
|
||||
108
Algorithm/utils.py
Normal file
108
Algorithm/utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class RunningMeanStd(object):
|
||||
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
|
||||
"""
|
||||
Calulates the running mean and std of a data stream
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
:param epsilon: helps with arithmetic issues
|
||||
:param shape: the shape of the data stream's output
|
||||
"""
|
||||
self.mean = np.zeros(shape, np.float64)
|
||||
self.var = np.ones(shape, np.float64)
|
||||
self.count = epsilon
|
||||
|
||||
def update(self, arr: np.ndarray) -> None:
|
||||
batch_mean = np.mean(arr, axis=0)
|
||||
batch_var = np.var(arr, axis=0)
|
||||
batch_count = arr.shape[0]
|
||||
self.update_from_moments(batch_mean, batch_var, batch_count)
|
||||
|
||||
def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None:
|
||||
delta = batch_mean - self.mean
|
||||
tot_count = self.count + batch_count
|
||||
|
||||
new_mean = self.mean + delta * batch_count / tot_count
|
||||
m_a = self.var * self.count
|
||||
m_b = batch_var * batch_count
|
||||
m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
|
||||
new_var = m_2 / (self.count + batch_count)
|
||||
|
||||
new_count = batch_count + self.count
|
||||
|
||||
self.mean = new_mean
|
||||
self.var = new_var
|
||||
self.count = new_count
|
||||
|
||||
|
||||
class Normalizer(RunningMeanStd):
|
||||
def __init__(self, input_dim, epsilon=1e-4, clip_obs=10.0):
|
||||
super().__init__(shape=input_dim)
|
||||
self.epsilon = epsilon
|
||||
self.clip_obs = clip_obs
|
||||
|
||||
def normalize(self, input):
|
||||
return np.clip(
|
||||
(input - self.mean) / np.sqrt(self.var + self.epsilon),
|
||||
-self.clip_obs, self.clip_obs)
|
||||
|
||||
def normalize_torch(self, input, device):
|
||||
mean_torch = torch.tensor(
|
||||
self.mean, device=device, dtype=torch.float32)
|
||||
std_torch = torch.sqrt(torch.tensor(
|
||||
self.var + self.epsilon, device=device, dtype=torch.float32))
|
||||
return torch.clamp(
|
||||
(input - mean_torch) / std_torch, -self.clip_obs, self.clip_obs)
|
||||
|
||||
def update_normalizer(self, rollouts, expert_loader):
|
||||
policy_data_generator = rollouts.feed_forward_generator_amp(
|
||||
None, mini_batch_size=expert_loader.batch_size)
|
||||
expert_data_generator = expert_loader.dataset.feed_forward_generator_amp(
|
||||
expert_loader.batch_size)
|
||||
|
||||
for expert_batch, policy_batch in zip(expert_data_generator, policy_data_generator):
|
||||
self.update(
|
||||
torch.vstack(tuple(policy_batch) + tuple(expert_batch)).cpu().numpy())
|
||||
|
||||
|
||||
def build_mlp(input_dim, output_dim, hidden_units=[64, 64],
|
||||
hidden_activation=nn.Tanh(), output_activation=None):
|
||||
layers = []
|
||||
units = input_dim
|
||||
for next_units in hidden_units:
|
||||
layers.append(nn.Linear(units, next_units))
|
||||
layers.append(hidden_activation)
|
||||
units = next_units
|
||||
layers.append(nn.Linear(units, output_dim))
|
||||
if output_activation is not None:
|
||||
layers.append(output_activation)
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def calculate_log_pi(log_stds, noises, actions):
|
||||
gaussian_log_probs = (-0.5 * noises.pow(2) - log_stds).sum(
|
||||
dim=-1, keepdim=True) - 0.5 * math.log(2 * math.pi) * log_stds.size(-1)
|
||||
|
||||
return gaussian_log_probs - torch.log(
|
||||
1 - actions.pow(2) + 1e-6).sum(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
def reparameterize(means, log_stds):
|
||||
noises = torch.randn_like(means)
|
||||
us = means + noises * log_stds.exp()
|
||||
actions = torch.tanh(us)
|
||||
return actions, calculate_log_pi(log_stds, noises, actions)
|
||||
|
||||
|
||||
def atanh(x):
|
||||
return 0.5 * (torch.log(1 + x + 1e-6) - torch.log(1 - x + 1e-6))
|
||||
|
||||
|
||||
def evaluate_lop_pi(means, log_stds, actions):
|
||||
noises = (atanh(actions) - means) / (log_stds.exp() + 1e-8)
|
||||
return calculate_log_pi(log_stds, noises, actions)
|
||||
Reference in New Issue
Block a user