150 lines
6.3 KiB
Python
150 lines
6.3 KiB
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from .disc import GAILDiscrim
|
|
from .ppo import PPO
|
|
from .utils import Normalizer
|
|
|
|
|
|
class MAGAIL(PPO):
|
|
def __init__(self, buffer_exp, input_dim, device,
|
|
disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005,
|
|
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True
|
|
):
|
|
super().__init__(state_shape=input_dim, device=device)
|
|
self.learning_steps = 0
|
|
self.learning_steps_disc = 0
|
|
|
|
self.disc = GAILDiscrim(input_dim=input_dim)
|
|
self.disc_grad_penalty = disc_grad_penalty
|
|
self.disc_coef = disc_coef
|
|
self.disc_logit_reg = disc_logit_reg
|
|
self.disc_weight_decay = disc_weight_decay
|
|
self.lr_disc = lr_disc
|
|
self.epoch_disc = epoch_disc
|
|
self.optim_d = torch.optim.Adam(self.disc.parameters(), lr=self.lr_disc)
|
|
|
|
self.normalizer = None
|
|
if use_gail_norm:
|
|
self.normalizer = Normalizer(self.state_shape[0]*2)
|
|
|
|
self.batch_size = batch_size
|
|
self.buffer_exp = buffer_exp
|
|
|
|
def update_disc(self, states, states_exp, writer):
|
|
states_cp = states.clone()
|
|
states_exp_cp = states_exp.clone()
|
|
|
|
# Output of discriminator is (-inf, inf), not [0, 1].
|
|
logits_pi = self.disc(states_cp)
|
|
logits_exp = self.disc(states_exp_cp)
|
|
|
|
# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].
|
|
loss_pi = -F.logsigmoid(-logits_pi).mean()
|
|
loss_exp = -F.logsigmoid(logits_exp).mean()
|
|
loss_disc = 0.5 * (loss_pi + loss_exp)
|
|
|
|
# logit reg
|
|
logit_weights = self.disc.get_disc_logit_weights()
|
|
disc_logit_loss = torch.sum(torch.square(logit_weights))
|
|
|
|
# grad penalty
|
|
sample_expert = states_exp_cp
|
|
sample_expert.requires_grad = True
|
|
disc = self.disc.linear(self.disc.trunk(sample_expert))
|
|
ones = torch.ones(disc.size(), device=disc.device)
|
|
disc_demo_grad = torch.autograd.grad(disc, sample_expert,
|
|
grad_outputs=ones,
|
|
create_graph=True, retain_graph=True, only_inputs=True)
|
|
disc_demo_grad = disc_demo_grad[0]
|
|
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
|
|
grad_pen_loss = torch.mean(disc_demo_grad)
|
|
|
|
# weight decay
|
|
disc_weights = self.disc.get_disc_weights()
|
|
disc_weights = torch.cat(disc_weights, dim=-1)
|
|
disc_weight_decay = torch.sum(torch.square(disc_weights))
|
|
|
|
loss = self.disc_coef * loss_disc + self.disc_grad_penalty * grad_pen_loss + \
|
|
self.disc_logit_reg * disc_logit_loss + self.disc_weight_decay * disc_weight_decay
|
|
|
|
self.optim_d.zero_grad()
|
|
loss.backward()
|
|
self.optim_d.step()
|
|
|
|
if self.learning_steps_disc % self.epoch_disc == 0:
|
|
writer.add_scalar('Loss/disc', loss_disc.item(), self.learning_steps)
|
|
|
|
# Discriminator's accuracies.
|
|
with torch.no_grad():
|
|
acc_pi = (logits_pi < 0).float().mean().item()
|
|
acc_exp = (logits_exp > 0).float().mean().item()
|
|
|
|
writer.add_scalar('Acc/acc_pi', acc_pi, self.learning_steps)
|
|
writer.add_scalar('Acc/acc_exp', acc_exp, self.learning_steps)
|
|
|
|
def update(self, writer, total_steps):
|
|
self.learning_steps += 1
|
|
for _ in range(self.epoch_disc):
|
|
self.learning_steps_disc += 1
|
|
|
|
# Samples from current policy trajectories.
|
|
samples_policy = self.buffer.sample(self.batch_size)
|
|
states, next_states = samples_policy[1], samples_policy[-3]
|
|
states = torch.cat([states, next_states], dim=-1)
|
|
|
|
# Samples from expert demonstrations.
|
|
samples_expert = self.buffer_exp.sample(self.batch_size)
|
|
states_exp, next_states_exp = samples_expert[0], samples_expert[1]
|
|
states_exp = torch.cat([states_exp, next_states_exp], dim=-1)
|
|
|
|
if self.normalizer is not None:
|
|
with torch.no_grad():
|
|
states = self.normalizer.normalize_torch(states, self.device)
|
|
states_exp = self.normalizer.normalize_torch(states_exp, self.device)
|
|
|
|
# Update discriminator and us encoder.
|
|
self.update_disc(states, states_exp, writer)
|
|
|
|
# Calulates the running mean and std of a data stream
|
|
if self.normalizer is not None:
|
|
self.normalizer.update(states.cpu().numpy())
|
|
self.normalizer.update(states_exp.cpu().numpy())
|
|
|
|
states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas = self.buffer.get()
|
|
|
|
# Calculate rewards.
|
|
rewards, rewards_t, rewards_i = self.disc.calculate_reward(states, next_states, rewards)
|
|
|
|
writer.add_scalar('Reward/rewards', rewards_t.mean().item() + rewards_i.mean().item(),
|
|
self.learning_steps)
|
|
writer.add_scalar('Reward/rewards_t', rewards_t.mean().item(), self.learning_steps)
|
|
writer.add_scalar('Reward/rewards_i', rewards_i.mean().item(), self.learning_steps)
|
|
|
|
# Update PPO using estimated rewards.
|
|
self.update_ppo(states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas, writer,
|
|
total_steps)
|
|
self.buffer.clear()
|
|
return rewards_t.mean().item() + rewards_i.mean().item()
|
|
|
|
def save_models(self, path):
|
|
torch.save({
|
|
'actor': self.actor.state_dict(),
|
|
'critic': self.critic.state_dict(),
|
|
'disc': self.disc.state_dict(),
|
|
'optim_actor': self.optim_actor.state_dict(),
|
|
'optim_critic': self.optim_critic.state_dict(),
|
|
'optim_d': self.optim_d.state_dict()
|
|
}, os.path.join(path, 'model.pth'))
|
|
|
|
def load_models(self, path, load_optimizer=True):
|
|
loaded_dict = torch.load(path, map_location='cuda:0')
|
|
self.actor.load_state_dict(loaded_dict['actor'])
|
|
self.critic.load_state_dict(loaded_dict['critic'])
|
|
self.disc.load_state_dict(loaded_dict['disc'])
|
|
if load_optimizer:
|
|
self.optim_actor.load_state_dict(loaded_dict['optim_actor'])
|
|
self.optim_critic.load_state_dict(loaded_dict['optim_critic'])
|
|
self.optim_d.load_state_dict(loaded_dict['optim_d'])
|