Files
MAGAIL4AutoDrive/Algorithm/magail.py

164 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from .disc import GAILDiscrim
from .ppo import PPO
from .utils import Normalizer
except ImportError:
from disc import GAILDiscrim
from ppo import PPO
from utils import Normalizer
class MAGAIL(PPO):
def __init__(self, buffer_exp, input_dim, device, action_shape=(2,),
disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005,
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True,
**kwargs # 接受其他PPO参数
):
super().__init__(state_shape=input_dim, device=device, action_shape=action_shape, **kwargs)
self.learning_steps = 0
self.learning_steps_disc = 0
# 如果input_dim是元组提取第一个元素
state_dim = input_dim[0] if isinstance(input_dim, tuple) else input_dim
# 判别器输入是state+next_state拼接所以维度是state_dim*2
self.disc = GAILDiscrim(input_dim=state_dim*2).to(device) # 移动到指定设备
self.disc_grad_penalty = disc_grad_penalty
self.disc_coef = disc_coef
self.disc_logit_reg = disc_logit_reg
self.disc_weight_decay = disc_weight_decay
self.lr_disc = lr_disc
self.epoch_disc = epoch_disc
self.optim_d = torch.optim.Adam(self.disc.parameters(), lr=self.lr_disc)
self.normalizer = None
if use_gail_norm:
# state_shape已经是元组形式
state_dim = self.state_shape[0] if isinstance(self.state_shape, tuple) else self.state_shape
self.normalizer = Normalizer(state_dim*2)
self.batch_size = batch_size
self.buffer_exp = buffer_exp
def update_disc(self, states, states_exp, writer):
states_cp = states.clone()
states_exp_cp = states_exp.clone()
# Output of discriminator is (-inf, inf), not [0, 1].
logits_pi = self.disc(states_cp)
logits_exp = self.disc(states_exp_cp)
# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].
loss_pi = -F.logsigmoid(-logits_pi).mean()
loss_exp = -F.logsigmoid(logits_exp).mean()
loss_disc = 0.5 * (loss_pi + loss_exp)
# logit reg
logit_weights = self.disc.get_disc_logit_weights()
disc_logit_loss = torch.sum(torch.square(logit_weights))
# grad penalty
sample_expert = states_exp_cp
sample_expert.requires_grad = True
disc = self.disc(sample_expert) # 直接调用forward方法
ones = torch.ones(disc.size(), device=disc.device)
disc_demo_grad = torch.autograd.grad(disc, sample_expert,
grad_outputs=ones,
create_graph=True, retain_graph=True, only_inputs=True)
disc_demo_grad = disc_demo_grad[0]
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
grad_pen_loss = torch.mean(disc_demo_grad)
# weight decay
disc_weights = self.disc.get_disc_weights()
disc_weights = torch.cat(disc_weights, dim=-1)
disc_weight_decay = torch.sum(torch.square(disc_weights))
loss = self.disc_coef * loss_disc + self.disc_grad_penalty * grad_pen_loss + \
self.disc_logit_reg * disc_logit_loss + self.disc_weight_decay * disc_weight_decay
self.optim_d.zero_grad()
loss.backward()
self.optim_d.step()
if self.learning_steps_disc % self.epoch_disc == 0:
writer.add_scalar('Loss/disc', loss_disc.item(), self.learning_steps)
# Discriminator's accuracies.
with torch.no_grad():
acc_pi = (logits_pi < 0).float().mean().item()
acc_exp = (logits_exp > 0).float().mean().item()
writer.add_scalar('Acc/acc_pi', acc_pi, self.learning_steps)
writer.add_scalar('Acc/acc_exp', acc_exp, self.learning_steps)
def update(self, writer, total_steps):
self.learning_steps += 1
for _ in range(self.epoch_disc):
self.learning_steps_disc += 1
# Samples from current policy trajectories.
samples_policy = self.buffer.sample(self.batch_size)
# samples_policy返回: (states, actions, rewards, dones, tm_dones, log_pis, next_states, means, stds)
states, next_states = samples_policy[0], samples_policy[6] # 修正: 使用states而不是actions
states = torch.cat([states, next_states], dim=-1)
# Samples from expert demonstrations.
samples_expert = self.buffer_exp.sample(self.batch_size)
states_exp, next_states_exp = samples_expert[0], samples_expert[1]
states_exp = torch.cat([states_exp, next_states_exp], dim=-1)
if self.normalizer is not None:
with torch.no_grad():
states = self.normalizer.normalize_torch(states, self.device)
states_exp = self.normalizer.normalize_torch(states_exp, self.device)
# Update discriminator and us encoder.
self.update_disc(states, states_exp, writer)
# Calulates the running mean and std of a data stream
if self.normalizer is not None:
self.normalizer.update(states.cpu().numpy())
self.normalizer.update(states_exp.cpu().numpy())
states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas = self.buffer.get()
# Calculate rewards.
rewards, rewards_t, rewards_i = self.disc.calculate_reward(states, next_states, rewards)
writer.add_scalar('Reward/rewards', rewards_t.mean().item() + rewards_i.mean().item(),
self.learning_steps)
writer.add_scalar('Reward/rewards_t', rewards_t.mean().item(), self.learning_steps)
writer.add_scalar('Reward/rewards_i', rewards_i.mean().item(), self.learning_steps)
# Update PPO using estimated rewards.
self.update_ppo(states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas, writer,
total_steps)
self.buffer.clear()
return rewards_t.mean().item() + rewards_i.mean().item()
def save_models(self, path):
# 确保目录存在
os.makedirs(path, exist_ok=True)
torch.save({
'actor': self.actor.state_dict(),
'critic': self.critic.state_dict(),
'disc': self.disc.state_dict(),
'optim_actor': self.optim_actor.state_dict(),
'optim_critic': self.optim_critic.state_dict(),
'optim_d': self.optim_d.state_dict()
}, os.path.join(path, 'model.pth'))
def load_models(self, path, load_optimizer=True):
loaded_dict = torch.load(path, map_location='cuda:0')
self.actor.load_state_dict(loaded_dict['actor'])
self.critic.load_state_dict(loaded_dict['critic'])
self.disc.load_state_dict(loaded_dict['disc'])
if load_optimizer:
self.optim_actor.load_state_dict(loaded_dict['optim_actor'])
self.optim_critic.load_state_dict(loaded_dict['optim_critic'])
self.optim_d.load_state_dict(loaded_dict['optim_d'])