修改了算法代码,并建立了一个简单的训练脚本.修改bert处理二维输入,移除PPO的permute参数
This commit is contained in:
@@ -2,21 +2,30 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .disc import GAILDiscrim
|
||||
from .ppo import PPO
|
||||
from .utils import Normalizer
|
||||
try:
|
||||
from .disc import GAILDiscrim
|
||||
from .ppo import PPO
|
||||
from .utils import Normalizer
|
||||
except ImportError:
|
||||
from disc import GAILDiscrim
|
||||
from ppo import PPO
|
||||
from utils import Normalizer
|
||||
|
||||
|
||||
class MAGAIL(PPO):
|
||||
def __init__(self, buffer_exp, input_dim, device,
|
||||
def __init__(self, buffer_exp, input_dim, device, action_shape=(2,),
|
||||
disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005,
|
||||
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True
|
||||
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True,
|
||||
**kwargs # 接受其他PPO参数
|
||||
):
|
||||
super().__init__(state_shape=input_dim, device=device)
|
||||
super().__init__(state_shape=input_dim, device=device, action_shape=action_shape, **kwargs)
|
||||
self.learning_steps = 0
|
||||
self.learning_steps_disc = 0
|
||||
|
||||
self.disc = GAILDiscrim(input_dim=input_dim)
|
||||
# 如果input_dim是元组,提取第一个元素
|
||||
state_dim = input_dim[0] if isinstance(input_dim, tuple) else input_dim
|
||||
# 判别器输入是state+next_state拼接,所以维度是state_dim*2
|
||||
self.disc = GAILDiscrim(input_dim=state_dim*2).to(device) # 移动到指定设备
|
||||
self.disc_grad_penalty = disc_grad_penalty
|
||||
self.disc_coef = disc_coef
|
||||
self.disc_logit_reg = disc_logit_reg
|
||||
@@ -27,7 +36,9 @@ class MAGAIL(PPO):
|
||||
|
||||
self.normalizer = None
|
||||
if use_gail_norm:
|
||||
self.normalizer = Normalizer(self.state_shape[0]*2)
|
||||
# state_shape已经是元组形式
|
||||
state_dim = self.state_shape[0] if isinstance(self.state_shape, tuple) else self.state_shape
|
||||
self.normalizer = Normalizer(state_dim*2)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.buffer_exp = buffer_exp
|
||||
@@ -52,7 +63,7 @@ class MAGAIL(PPO):
|
||||
# grad penalty
|
||||
sample_expert = states_exp_cp
|
||||
sample_expert.requires_grad = True
|
||||
disc = self.disc.linear(self.disc.trunk(sample_expert))
|
||||
disc = self.disc(sample_expert) # 直接调用forward方法
|
||||
ones = torch.ones(disc.size(), device=disc.device)
|
||||
disc_demo_grad = torch.autograd.grad(disc, sample_expert,
|
||||
grad_outputs=ones,
|
||||
@@ -91,7 +102,8 @@ class MAGAIL(PPO):
|
||||
|
||||
# Samples from current policy trajectories.
|
||||
samples_policy = self.buffer.sample(self.batch_size)
|
||||
states, next_states = samples_policy[1], samples_policy[-3]
|
||||
# samples_policy返回: (states, actions, rewards, dones, tm_dones, log_pis, next_states, means, stds)
|
||||
states, next_states = samples_policy[0], samples_policy[6] # 修正: 使用states而不是actions
|
||||
states = torch.cat([states, next_states], dim=-1)
|
||||
|
||||
# Samples from expert demonstrations.
|
||||
@@ -129,6 +141,8 @@ class MAGAIL(PPO):
|
||||
return rewards_t.mean().item() + rewards_i.mean().item()
|
||||
|
||||
def save_models(self, path):
|
||||
# 确保目录存在
|
||||
os.makedirs(path, exist_ok=True)
|
||||
torch.save({
|
||||
'actor': self.actor.state_dict(),
|
||||
'critic': self.critic.state_dict(),
|
||||
|
||||
Reference in New Issue
Block a user