修改了算法代码,并建立了一个简单的训练脚本.修改bert处理二维输入,移除PPO的permute参数
This commit is contained in:
@@ -3,9 +3,14 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
from buffer import RolloutBuffer
|
||||
from bert import Bert
|
||||
from policy import StateIndependentPolicy
|
||||
try:
|
||||
from .buffer import RolloutBuffer
|
||||
from .bert import Bert
|
||||
from .policy import StateIndependentPolicy
|
||||
except ImportError:
|
||||
from buffer import RolloutBuffer
|
||||
from bert import Bert
|
||||
from policy import StateIndependentPolicy
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
@@ -55,7 +60,7 @@ class Algorithm(ABC):
|
||||
|
||||
class PPO(Algorithm):
|
||||
|
||||
def __init__(self, state_shape, device, gamma=0.995, rollout_length=2048,
|
||||
def __init__(self, state_shape, device, action_shape=(2,), gamma=0.995, rollout_length=2048,
|
||||
units_actor=(64, 64), epoch_ppo=10, clip_eps=0.2,
|
||||
lambd=0.97, max_grad_norm=1.0, desired_kl=0.01, surrogate_loss_coef=2.,
|
||||
value_loss_coef=5., entropy_coef=0., bounds_loss_coef=10., lr_actor=1e-3, lr_critic=1e-3,
|
||||
@@ -66,6 +71,7 @@ class PPO(Algorithm):
|
||||
self.lr_critic = lr_critic
|
||||
self.lr_disc = lr_disc
|
||||
self.auto_lr = auto_lr
|
||||
self.action_shape = action_shape
|
||||
|
||||
self.use_adv_norm = use_adv_norm
|
||||
|
||||
@@ -86,8 +92,10 @@ class PPO(Algorithm):
|
||||
).to(device)
|
||||
|
||||
# Critic.
|
||||
# 如果state_shape是元组,提取第一个元素
|
||||
state_dim = state_shape[0] if isinstance(state_shape, tuple) else state_shape
|
||||
self.critic = Bert(
|
||||
input_dim=state_shape,
|
||||
input_dim=state_dim,
|
||||
output_dim=1
|
||||
).to(device)
|
||||
|
||||
@@ -145,14 +153,12 @@ class PPO(Algorithm):
|
||||
targets, gaes = self.calculate_gae(
|
||||
values, rewards, dones, tm_dones, next_values, self.gamma, self.lambd)
|
||||
|
||||
state_list = states.permute(1, 0, 2)
|
||||
action_list = actions.permute(1, 0, 2)
|
||||
|
||||
# 处理批量数据(不需要按智能体分组,因为buffer中已经混合了所有智能体的数据)
|
||||
for i in range(self.epoch_ppo):
|
||||
self.learning_steps_ppo += 1
|
||||
self.update_critic(states, targets, writer)
|
||||
for state, action, log_pi in state_list, action_list, log_pi_list:
|
||||
self.update_actor(state, action, log_pi, gaes, mus, sigmas, writer)
|
||||
# 直接使用整个batch进行actor更新
|
||||
self.update_actor(states, actions, log_pi_list, gaes, mus, sigmas, writer)
|
||||
|
||||
# self.lr_decay(total_steps, writer)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user