修改了算法代码,并建立了一个简单的训练脚本.修改bert处理二维输入,移除PPO的permute参数

This commit is contained in:
2025-10-22 16:56:12 +08:00
parent b626702cbb
commit 3f7e183c4b
101 changed files with 3837 additions and 39 deletions

View File

@@ -2,21 +2,30 @@ import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from .disc import GAILDiscrim
from .ppo import PPO
from .utils import Normalizer
try:
from .disc import GAILDiscrim
from .ppo import PPO
from .utils import Normalizer
except ImportError:
from disc import GAILDiscrim
from ppo import PPO
from utils import Normalizer
class MAGAIL(PPO):
def __init__(self, buffer_exp, input_dim, device,
def __init__(self, buffer_exp, input_dim, device, action_shape=(2,),
disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005,
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True
lr_disc=1e-3, epoch_disc=50, batch_size=1000, use_gail_norm=True,
**kwargs # 接受其他PPO参数
):
super().__init__(state_shape=input_dim, device=device)
super().__init__(state_shape=input_dim, device=device, action_shape=action_shape, **kwargs)
self.learning_steps = 0
self.learning_steps_disc = 0
self.disc = GAILDiscrim(input_dim=input_dim)
# 如果input_dim是元组提取第一个元素
state_dim = input_dim[0] if isinstance(input_dim, tuple) else input_dim
# 判别器输入是state+next_state拼接所以维度是state_dim*2
self.disc = GAILDiscrim(input_dim=state_dim*2).to(device) # 移动到指定设备
self.disc_grad_penalty = disc_grad_penalty
self.disc_coef = disc_coef
self.disc_logit_reg = disc_logit_reg
@@ -27,7 +36,9 @@ class MAGAIL(PPO):
self.normalizer = None
if use_gail_norm:
self.normalizer = Normalizer(self.state_shape[0]*2)
# state_shape已经是元组形式
state_dim = self.state_shape[0] if isinstance(self.state_shape, tuple) else self.state_shape
self.normalizer = Normalizer(state_dim*2)
self.batch_size = batch_size
self.buffer_exp = buffer_exp
@@ -52,7 +63,7 @@ class MAGAIL(PPO):
# grad penalty
sample_expert = states_exp_cp
sample_expert.requires_grad = True
disc = self.disc.linear(self.disc.trunk(sample_expert))
disc = self.disc(sample_expert) # 直接调用forward方法
ones = torch.ones(disc.size(), device=disc.device)
disc_demo_grad = torch.autograd.grad(disc, sample_expert,
grad_outputs=ones,
@@ -91,7 +102,8 @@ class MAGAIL(PPO):
# Samples from current policy trajectories.
samples_policy = self.buffer.sample(self.batch_size)
states, next_states = samples_policy[1], samples_policy[-3]
# samples_policy返回: (states, actions, rewards, dones, tm_dones, log_pis, next_states, means, stds)
states, next_states = samples_policy[0], samples_policy[6] # 修正: 使用states而不是actions
states = torch.cat([states, next_states], dim=-1)
# Samples from expert demonstrations.
@@ -129,6 +141,8 @@ class MAGAIL(PPO):
return rewards_t.mean().item() + rewards_i.mean().item()
def save_models(self, path):
# 确保目录存在
os.makedirs(path, exist_ok=True)
torch.save({
'actor': self.actor.state_dict(),
'critic': self.critic.state_dict(),