修改了算法代码,并建立了一个简单的训练脚本.修改bert处理二维输入,移除PPO的permute参数

This commit is contained in:
2025-10-22 16:56:12 +08:00
parent b626702cbb
commit 3f7e183c4b
101 changed files with 3837 additions and 39 deletions

View File

@@ -3,9 +3,14 @@ import torch
import numpy as np
from torch import nn
from torch.optim import Adam
from buffer import RolloutBuffer
from bert import Bert
from policy import StateIndependentPolicy
try:
from .buffer import RolloutBuffer
from .bert import Bert
from .policy import StateIndependentPolicy
except ImportError:
from buffer import RolloutBuffer
from bert import Bert
from policy import StateIndependentPolicy
from abc import ABC, abstractmethod
@@ -55,7 +60,7 @@ class Algorithm(ABC):
class PPO(Algorithm):
def __init__(self, state_shape, device, gamma=0.995, rollout_length=2048,
def __init__(self, state_shape, device, action_shape=(2,), gamma=0.995, rollout_length=2048,
units_actor=(64, 64), epoch_ppo=10, clip_eps=0.2,
lambd=0.97, max_grad_norm=1.0, desired_kl=0.01, surrogate_loss_coef=2.,
value_loss_coef=5., entropy_coef=0., bounds_loss_coef=10., lr_actor=1e-3, lr_critic=1e-3,
@@ -66,6 +71,7 @@ class PPO(Algorithm):
self.lr_critic = lr_critic
self.lr_disc = lr_disc
self.auto_lr = auto_lr
self.action_shape = action_shape
self.use_adv_norm = use_adv_norm
@@ -86,8 +92,10 @@ class PPO(Algorithm):
).to(device)
# Critic.
# 如果state_shape是元组提取第一个元素
state_dim = state_shape[0] if isinstance(state_shape, tuple) else state_shape
self.critic = Bert(
input_dim=state_shape,
input_dim=state_dim,
output_dim=1
).to(device)
@@ -145,14 +153,12 @@ class PPO(Algorithm):
targets, gaes = self.calculate_gae(
values, rewards, dones, tm_dones, next_values, self.gamma, self.lambd)
state_list = states.permute(1, 0, 2)
action_list = actions.permute(1, 0, 2)
# 处理批量数据不需要按智能体分组因为buffer中已经混合了所有智能体的数据
for i in range(self.epoch_ppo):
self.learning_steps_ppo += 1
self.update_critic(states, targets, writer)
for state, action, log_pi in state_list, action_list, log_pi_list:
self.update_actor(state, action, log_pi, gaes, mus, sigmas, writer)
# 直接使用整个batch进行actor更新
self.update_actor(states, actions, log_pi_list, gaes, mus, sigmas, writer)
# self.lr_decay(total_steps, writer)