Files
MAGAIL4AutoDrive/Algorithm/magail.py

164 lines
7.2 KiB
Python
Raw Normal View History

2025-09-28 18:57:04 +08:00
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from .disc import GAILDiscrim
from .ppo import PPO
from .utils import Normalizer
except ImportError:
from disc import GAILDiscrim
from ppo import PPO
from utils import Normalizer
2025-09-28 18:57:04 +08:00
class MAGAIL(PPO):
def __init__(self, buffer_exp, input_dim, device, action_shape=(2,),
2025-09-28 18:57:04 +08:00
disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005,
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True,
**kwargs # 接受其他PPO参数
2025-09-28 18:57:04 +08:00
):
super().__init__(state_shape=input_dim, device=device, action_shape=action_shape, **kwargs)
2025-09-28 18:57:04 +08:00
self.learning_steps = 0
self.learning_steps_disc = 0
# 如果input_dim是元组提取第一个元素
state_dim = input_dim[0] if isinstance(input_dim, tuple) else input_dim
# 判别器输入是state+next_state拼接所以维度是state_dim*2
self.disc = GAILDiscrim(input_dim=state_dim*2).to(device) # 移动到指定设备
2025-09-28 18:57:04 +08:00
self.disc_grad_penalty = disc_grad_penalty
self.disc_coef = disc_coef
self.disc_logit_reg = disc_logit_reg
self.disc_weight_decay = disc_weight_decay
self.lr_disc = lr_disc
self.epoch_disc = epoch_disc
self.optim_d = torch.optim.Adam(self.disc.parameters(), lr=self.lr_disc)
self.normalizer = None
if use_gail_norm:
# state_shape已经是元组形式
state_dim = self.state_shape[0] if isinstance(self.state_shape, tuple) else self.state_shape
self.normalizer = Normalizer(state_dim*2)
2025-09-28 18:57:04 +08:00
self.batch_size = batch_size
self.buffer_exp = buffer_exp
def update_disc(self, states, states_exp, writer):
states_cp = states.clone()
states_exp_cp = states_exp.clone()
# Output of discriminator is (-inf, inf), not [0, 1].
logits_pi = self.disc(states_cp)
logits_exp = self.disc(states_exp_cp)
# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].
loss_pi = -F.logsigmoid(-logits_pi).mean()
loss_exp = -F.logsigmoid(logits_exp).mean()
loss_disc = 0.5 * (loss_pi + loss_exp)
# logit reg
logit_weights = self.disc.get_disc_logit_weights()
disc_logit_loss = torch.sum(torch.square(logit_weights))
# grad penalty
sample_expert = states_exp_cp
sample_expert.requires_grad = True
disc = self.disc(sample_expert) # 直接调用forward方法
2025-09-28 18:57:04 +08:00
ones = torch.ones(disc.size(), device=disc.device)
disc_demo_grad = torch.autograd.grad(disc, sample_expert,
grad_outputs=ones,
create_graph=True, retain_graph=True, only_inputs=True)
disc_demo_grad = disc_demo_grad[0]
disc_demo_grad = torch.sum(torch.square(disc_demo_grad), dim=-1)
grad_pen_loss = torch.mean(disc_demo_grad)
# weight decay
disc_weights = self.disc.get_disc_weights()
disc_weights = torch.cat(disc_weights, dim=-1)
disc_weight_decay = torch.sum(torch.square(disc_weights))
loss = self.disc_coef * loss_disc + self.disc_grad_penalty * grad_pen_loss + \
self.disc_logit_reg * disc_logit_loss + self.disc_weight_decay * disc_weight_decay
self.optim_d.zero_grad()
loss.backward()
self.optim_d.step()
if self.learning_steps_disc % self.epoch_disc == 0:
writer.add_scalar('Loss/disc', loss_disc.item(), self.learning_steps)
# Discriminator's accuracies.
with torch.no_grad():
acc_pi = (logits_pi < 0).float().mean().item()
acc_exp = (logits_exp > 0).float().mean().item()
writer.add_scalar('Acc/acc_pi', acc_pi, self.learning_steps)
writer.add_scalar('Acc/acc_exp', acc_exp, self.learning_steps)
def update(self, writer, total_steps):
self.learning_steps += 1
for _ in range(self.epoch_disc):
self.learning_steps_disc += 1
# Samples from current policy trajectories.
samples_policy = self.buffer.sample(self.batch_size)
# samples_policy返回: (states, actions, rewards, dones, tm_dones, log_pis, next_states, means, stds)
states, next_states = samples_policy[0], samples_policy[6] # 修正: 使用states而不是actions
2025-09-28 18:57:04 +08:00
states = torch.cat([states, next_states], dim=-1)
# Samples from expert demonstrations.
samples_expert = self.buffer_exp.sample(self.batch_size)
states_exp, next_states_exp = samples_expert[0], samples_expert[1]
states_exp = torch.cat([states_exp, next_states_exp], dim=-1)
if self.normalizer is not None:
with torch.no_grad():
states = self.normalizer.normalize_torch(states, self.device)
states_exp = self.normalizer.normalize_torch(states_exp, self.device)
# Update discriminator and us encoder.
self.update_disc(states, states_exp, writer)
# Calulates the running mean and std of a data stream
if self.normalizer is not None:
self.normalizer.update(states.cpu().numpy())
self.normalizer.update(states_exp.cpu().numpy())
states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas = self.buffer.get()
# Calculate rewards.
rewards, rewards_t, rewards_i = self.disc.calculate_reward(states, next_states, rewards)
writer.add_scalar('Reward/rewards', rewards_t.mean().item() + rewards_i.mean().item(),
self.learning_steps)
writer.add_scalar('Reward/rewards_t', rewards_t.mean().item(), self.learning_steps)
writer.add_scalar('Reward/rewards_i', rewards_i.mean().item(), self.learning_steps)
# Update PPO using estimated rewards.
self.update_ppo(states, actions, rewards, dones, tm_dones, log_pis, next_states, mus, sigmas, writer,
total_steps)
self.buffer.clear()
return rewards_t.mean().item() + rewards_i.mean().item()
def save_models(self, path):
# 确保目录存在
os.makedirs(path, exist_ok=True)
2025-09-28 18:57:04 +08:00
torch.save({
'actor': self.actor.state_dict(),
'critic': self.critic.state_dict(),
'disc': self.disc.state_dict(),
'optim_actor': self.optim_actor.state_dict(),
'optim_critic': self.optim_critic.state_dict(),
'optim_d': self.optim_d.state_dict()
}, os.path.join(path, 'model.pth'))
def load_models(self, path, load_optimizer=True):
loaded_dict = torch.load(path, map_location='cuda:0')
self.actor.load_state_dict(loaded_dict['actor'])
self.critic.load_state_dict(loaded_dict['critic'])
self.disc.load_state_dict(loaded_dict['disc'])
if load_optimizer:
self.optim_actor.load_state_dict(loaded_dict['optim_actor'])
self.optim_critic.load_state_dict(loaded_dict['optim_critic'])
self.optim_d.load_state_dict(loaded_dict['optim_d'])