magail4autodrive: first commit
This commit is contained in:
149
Algorithm/magail.py
Normal file
149
Algorithm/magail.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .disc import GAILDiscrim
|
||||
from .ppo import PPO
|
||||
from .utils import Normalizer
|
||||
|
||||
|
||||
class MAGAIL(PPO):
|
||||
def __init__(self, buffer_exp, input_dim, device,
|
||||
disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005,
|
||||
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True
|
||||
):
|
||||
super().__init__(state_shape=input_dim, device=device)
|
||||
self.learning_steps = 0
|
||||
self.learning_steps_disc = 0
|
||||
|
||||
self.disc = GAILDiscrim(input_dim=input_dim)
|
||||
self.disc_grad_penalty = disc_grad_penalty
|
||||
self.disc_coef = disc_coef
|
||||
self.disc_logit_reg = disc_logit_reg
|
||||
self.disc_weight_decay = disc_weight_decay
|
||||
self.lr_disc = lr_disc
|
||||
self.epoch_disc = epoch_disc
|
||||
self.optim_d = torch.optim.Adam(self.disc.parameters(), lr=self.lr_disc)
|
||||
|
||||
self.normalizer = None
|
||||
if use_gail_norm:
|
||||
self.normalizer = Normalizer(self.state_shape[0]*2)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.buffer_exp = buffer_exp
|
||||
|
||||
def update_disc(self, states, states_exp, writer):
|
||||
states_cp = states.clone()
|
||||
states_exp_cp = states_exp.clone()
|
||||
|
||||
# Output of discriminator is (-inf, inf), not [0, 1].
|
||||
logits_pi = self.disc(states_cp)
|
||||
logits_exp = self.disc(states_exp_cp)
|
||||
|
||||
# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].
|
||||
loss_pi = -F.logsigmoid(-logits_pi).mean()
|
||||
loss_exp = -F.logsigmoid(logits_exp).mean()
|
||||
loss_disc = 0.5 * (loss_pi + loss_exp)
|
||||
|
||||
# logit reg
|
||||
logit_weights = self.disc.get_disc_logit_weights()
|
||||
disc_logit_loss = torch.sum(torch.square(logit_weights))
|
||||
|
||||
# grad penalty
|
||||
sample_expert = states_exp_cp
|
||||
sample_expert.requires_grad = True
|
||||
disc = self.disc.linear(self.disc.trunk(sample_expert))
|
||||
ones = torch.ones(disc.size(), device=disc.device)
|
||||
disc_demo_grad = torch.autograd.grad(disc, sample_expert,
|
||||
grad_outputs=ones,
|
||||
create_graph=True, retain_graph=True, only_inputs=True)
|
||||
disc_demo_grad = disc_demo_grad[0]
|
||||
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
|
||||
grad_pen_loss = torch.mean(disc_demo_grad)
|
||||
|
||||
# weight decay
|
||||
disc_weights = self.disc.get_disc_weights()
|
||||
disc_weights = torch.cat(disc_weights, dim=-1)
|
||||
disc_weight_decay = torch.sum(torch.square(disc_weights))
|
||||
|
||||
loss = self.disc_coef * loss_disc + self.disc_grad_penalty * grad_pen_loss + \
|
||||
self.disc_logit_reg * disc_logit_loss + self.disc_weight_decay * disc_weight_decay
|
||||
|
||||
self.optim_d.zero_grad()
|
||||
loss.backward()
|
||||
self.optim_d.step()
|
||||
|
||||
if self.learning_steps_disc % self.epoch_disc == 0:
|
||||
writer.add_scalar('Loss/disc', loss_disc.item(), self.learning_steps)
|
||||
|
||||
# Discriminator's accuracies.
|
||||
with torch.no_grad():
|
||||
acc_pi = (logits_pi < 0).float().mean().item()
|
||||
acc_exp = (logits_exp > 0).float().mean().item()
|
||||
|
||||
writer.add_scalar('Acc/acc_pi', acc_pi, self.learning_steps)
|
||||
writer.add_scalar('Acc/acc_exp', acc_exp, self.learning_steps)
|
||||
|
||||
def update(self, writer, total_steps):
|
||||
self.learning_steps += 1
|
||||
for _ in range(self.epoch_disc):
|
||||
self.learning_steps_disc += 1
|
||||
|
||||
# Samples from current policy trajectories.
|
||||
samples_policy = self.buffer.sample(self.batch_size)
|
||||
states, next_states = samples_policy[1], samples_policy[-3]
|
||||
states = torch.cat([states, next_states], dim=-1)
|
||||
|
||||
# Samples from expert demonstrations.
|
||||
samples_expert = self.buffer_exp.sample(self.batch_size)
|
||||
states_exp, next_states_exp = samples_expert[0], samples_expert[1]
|
||||
states_exp = torch.cat([states_exp, next_states_exp], dim=-1)
|
||||
|
||||
if self.normalizer is not None:
|
||||
with torch.no_grad():
|
||||
states = self.normalizer.normalize_torch(states, self.device)
|
||||
states_exp = self.normalizer.normalize_torch(states_exp, self.device)
|
||||
|
||||
# Update discriminator and us encoder.
|
||||
self.update_disc(states, states_exp, writer)
|
||||
|
||||
# Calulates the running mean and std of a data stream
|
||||
if self.normalizer is not None:
|
||||
self.normalizer.update(states.cpu().numpy())
|
||||
self.normalizer.update(states_exp.cpu().numpy())
|
||||
|
||||
states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas = self.buffer.get()
|
||||
|
||||
# Calculate rewards.
|
||||
rewards, rewards_t, rewards_i = self.disc.calculate_reward(states, next_states, rewards)
|
||||
|
||||
writer.add_scalar('Reward/rewards', rewards_t.mean().item() + rewards_i.mean().item(),
|
||||
self.learning_steps)
|
||||
writer.add_scalar('Reward/rewards_t', rewards_t.mean().item(), self.learning_steps)
|
||||
writer.add_scalar('Reward/rewards_i', rewards_i.mean().item(), self.learning_steps)
|
||||
|
||||
# Update PPO using estimated rewards.
|
||||
self.update_ppo(states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas, writer,
|
||||
total_steps)
|
||||
self.buffer.clear()
|
||||
return rewards_t.mean().item() + rewards_i.mean().item()
|
||||
|
||||
def save_models(self, path):
|
||||
torch.save({
|
||||
'actor': self.actor.state_dict(),
|
||||
'critic': self.critic.state_dict(),
|
||||
'disc': self.disc.state_dict(),
|
||||
'optim_actor': self.optim_actor.state_dict(),
|
||||
'optim_critic': self.optim_critic.state_dict(),
|
||||
'optim_d': self.optim_d.state_dict()
|
||||
}, os.path.join(path, 'model.pth'))
|
||||
|
||||
def load_models(self, path, load_optimizer=True):
|
||||
loaded_dict = torch.load(path, map_location='cuda:0')
|
||||
self.actor.load_state_dict(loaded_dict['actor'])
|
||||
self.critic.load_state_dict(loaded_dict['critic'])
|
||||
self.disc.load_state_dict(loaded_dict['disc'])
|
||||
if load_optimizer:
|
||||
self.optim_actor.load_state_dict(loaded_dict['optim_actor'])
|
||||
self.optim_critic.load_state_dict(loaded_dict['optim_critic'])
|
||||
self.optim_d.load_state_dict(loaded_dict['optim_d'])
|
||||
Reference in New Issue
Block a user