Files
MAGAIL4AutoDrive/Algorithm/ppo.py

274 lines
11 KiB
Python
Raw Normal View History

2025-09-28 18:57:04 +08:00
import os
import torch
import numpy as np
from torch import nn
from torch.optim import Adam
try:
from .buffer import RolloutBuffer
from .bert import Bert
from .policy import StateIndependentPolicy
except ImportError:
from buffer import RolloutBuffer
from bert import Bert
from policy import StateIndependentPolicy
2025-09-28 18:57:04 +08:00
from abc import ABC, abstractmethod
class Algorithm(ABC):
def __init__(self, state_shape, device, gamma):
self.learning_steps = 0
self.state_shape = state_shape
self.device = device
self.gamma = gamma
def explore(self, state_list):
action_list = []
log_pi_list = []
if type(state_list).__module__ != "torch":
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
with torch.no_grad():
for state in state_list:
action, log_pi = self.actor.sample(state.unsqueeze(0))
action_list.append(action.cpu().numpy()[0])
log_pi_list.append(log_pi.item())
return action_list, log_pi_list
def exploit(self, state_list):
action_list = []
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
with torch.no_grad():
for state in state_list:
action = self.actor(state.unsqueeze(0))
action_list.append(action.cpu().numpy()[0])
return action_list
@abstractmethod
def is_update(self, step):
pass
@abstractmethod
def update(self, writer, total_steps):
pass
@abstractmethod
def save_models(self, save_dir):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
class PPO(Algorithm):
def __init__(self, state_shape, device, action_shape=(2,), gamma=0.995, rollout_length=2048,
2025-09-28 18:57:04 +08:00
units_actor=(64, 64), epoch_ppo=10, clip_eps=0.2,
lambd=0.97, max_grad_norm=1.0, desired_kl=0.01, surrogate_loss_coef=2.,
value_loss_coef=5., entropy_coef=0., bounds_loss_coef=10., lr_actor=1e-3, lr_critic=1e-3,
lr_disc=1e-3, auto_lr=True, use_adv_norm=True, max_steps=10000000):
super().__init__(state_shape, device, gamma)
self.lr_actor = lr_actor
self.lr_critic = lr_critic
self.lr_disc = lr_disc
self.auto_lr = auto_lr
self.action_shape = action_shape
2025-09-28 18:57:04 +08:00
self.use_adv_norm = use_adv_norm
# Rollout buffer.
self.buffer = RolloutBuffer(
buffer_size=rollout_length,
state_shape=state_shape,
action_shape=action_shape,
device=device
)
# Actor.
self.actor = StateIndependentPolicy(
state_shape=state_shape,
action_shape=action_shape,
hidden_units=units_actor,
hidden_activation=nn.Tanh()
).to(device)
# Critic.
# 如果state_shape是元组提取第一个元素
state_dim = state_shape[0] if isinstance(state_shape, tuple) else state_shape
2025-09-28 18:57:04 +08:00
self.critic = Bert(
input_dim=state_dim,
2025-09-28 18:57:04 +08:00
output_dim=1
).to(device)
self.learning_steps_ppo = 0
self.rollout_length = rollout_length
self.epoch_ppo = epoch_ppo
self.clip_eps = clip_eps
self.lambd = lambd
self.max_grad_norm = max_grad_norm
self.desired_kl = desired_kl
self.surrogate_loss_coef = surrogate_loss_coef
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.bounds_loss_coef = bounds_loss_coef
self.max_steps = max_steps
self.optim_actor = Adam([{'params': self.actor.parameters()}], lr=lr_actor)
# self.optim_actor = Adam([
# {'params': self.actor.net.f_net.parameters(), 'lr': lr_actor},
# {'params': self.actor.net.k_net.parameters(), 'lr': lr_actor/3}])
self.optim_critic = Adam([{'params': self.critic.parameters()}], lr=lr_critic)
def is_update(self, step):
return step % self.rollout_length == 0
def step(self, env, state_list, state_gail):
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
state_gail = torch.tensor(state_gail, dtype=torch.float, device=self.device)
action_list, log_pi_list = self.explore(state_list)
next_state, reward, terminated, truncated, info = env.step(np.array(action_list))
next_state_gail = env.state_gail
done = terminated or truncated
means = self.actor.means.detach().cpu().numpy()[0]
stds = (self.actor.log_stds.exp()).detach().cpu().numpy()[0]
self.buffer.append(state_list, state_gail, action_list, reward, done, terminated, log_pi_list,
next_state, next_state_gail, means, stds)
if done:
next_state = env.reset()
next_state_gail = env.state_gail
return next_state, next_state_gail, info
def update(self, writer, total_steps):
pass
def update_ppo(self, states, actions, rewards, dones, tm_dones, log_pi_list, next_states, mus, sigmas, writer,
total_steps):
with torch.no_grad():
values = self.critic(states.detach())
next_values = self.critic(next_states.detach())
targets, gaes = self.calculate_gae(
values, rewards, dones, tm_dones, next_values, self.gamma, self.lambd)
# 处理批量数据不需要按智能体分组因为buffer中已经混合了所有智能体的数据
2025-09-28 18:57:04 +08:00
for i in range(self.epoch_ppo):
self.learning_steps_ppo += 1
self.update_critic(states, targets, writer)
# 直接使用整个batch进行actor更新
self.update_actor(states, actions, log_pi_list, gaes, mus, sigmas, writer)
2025-09-28 18:57:04 +08:00
# self.lr_decay(total_steps, writer)
def update_critic(self, states, targets, writer):
loss_critic = (self.critic(states) - targets).pow_(2).mean()
loss_critic = loss_critic * self.value_loss_coef
self.optim_critic.zero_grad()
loss_critic.backward(retain_graph=False)
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
self.optim_critic.step()
if self.learning_steps_ppo % self.epoch_ppo == 0:
writer.add_scalar(
'Loss/critic', loss_critic.item(), self.learning_steps)
def update_actor(self, states, actions, log_pis_old, gaes, mus_old, sigmas_old, writer):
self.optim_actor.zero_grad()
log_pis = self.actor.evaluate_log_pi(states, actions)
mus = self.actor.means
sigmas = (self.actor.log_stds.exp()).repeat(mus.shape[0], 1)
entropy = -log_pis.mean()
ratios = (log_pis - log_pis_old).exp_()
loss_actor1 = -ratios * gaes
loss_actor2 = -torch.clamp(
ratios,
1.0 - self.clip_eps,
1.0 + self.clip_eps
) * gaes
loss_actor = torch.max(loss_actor1, loss_actor2).mean()
loss_actor = loss_actor * self.surrogate_loss_coef
if self.auto_lr:
# desired_kl: 0.01
with torch.inference_mode():
kl = torch.sum(torch.log(sigmas / sigmas_old + 1.e-5) +
(torch.square(sigmas_old) + torch.square(mus_old - mus)) /
(2.0 * torch.square(sigmas)) - 0.5, axis=-1)
kl_mean = torch.mean(kl)
if kl_mean > self.desired_kl * 2.0:
self.lr_actor = max(1e-5, self.lr_actor / 1.5)
self.lr_critic = max(1e-5, self.lr_critic / 1.5)
self.lr_disc = max(1e-5, self.lr_disc / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
self.lr_actor = min(1e-2, self.lr_actor * 1.5)
self.lr_critic = min(1e-2, self.lr_critic * 1.5)
self.lr_disc = min(1e-2, self.lr_disc * 1.5)
for param_group in self.optim_actor.param_groups:
param_group['lr'] = self.lr_actor
for param_group in self.optim_critic.param_groups:
param_group['lr'] = self.lr_critic
for param_group in self.optim_d.param_groups:
param_group['lr'] = self.lr_disc
loss = loss_actor # + b_loss * 0 - self.entropy_coef * entropy * 0
loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
self.optim_actor.step()
if self.learning_steps_ppo % self.epoch_ppo == 0:
writer.add_scalar(
'Loss/actor', loss_actor.item(), self.learning_steps)
writer.add_scalar(
'Loss/entropy', entropy.item(), self.learning_steps)
writer.add_scalar(
'Loss/learning_rate', self.lr_actor, self.learning_steps)
def lr_decay(self, total_steps, writer):
lr_a_now = max(1e-5, self.lr_actor * (1 - total_steps / self.max_steps))
lr_c_now = max(1e-5, self.lr_critic * (1 - total_steps / self.max_steps))
lr_d_now = max(1e-5, self.lr_disc * (1 - total_steps / self.max_steps))
for p in self.optim_actor.param_groups:
p['lr'] = lr_a_now
for p in self.optim_critic.param_groups:
p['lr'] = lr_c_now
for p in self.optim_d.param_groups:
p['lr'] = lr_d_now
writer.add_scalar(
'Loss/learning_rate', lr_a_now, self.learning_steps)
def calculate_gae(self, values, rewards, dones, tm_dones, next_values, gamma, lambd):
"""
Calculate the advantage using GAE
'tm_dones=True' means dead or win, there is no next state s'
'dones=True' represents the terminal of an episode(dead or win or reaching the max_episode_steps).
When calculating the adv, if dones=True, gae=0
Reference: https://github.com/Lizhi-sjtu/DRL-code-pytorch/blob/main/5.PPO-continuous/ppo_continuous.py
"""
with torch.no_grad():
# Calculate TD errors.
deltas = rewards + gamma * next_values * (1 - tm_dones) - values
# Initialize gae.
gaes = torch.empty_like(rewards)
# Calculate gae recursively from behind.
gaes[-1] = deltas[-1]
for t in reversed(range(rewards.size(0) - 1)):
gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1]
v_target = gaes + values
if self.use_adv_norm:
gaes = (gaes - gaes.mean()) / (gaes.std(dim=0) + 1e-8)
return v_target, gaes
def save_models(self, save_dir):
pass