Files
MAGAIL4AutoDrive/Algorithm/ppo.py

274 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import numpy as np
from torch import nn
from torch.optim import Adam
try:
from .buffer import RolloutBuffer
from .bert import Bert
from .policy import StateIndependentPolicy
except ImportError:
from buffer import RolloutBuffer
from bert import Bert
from policy import StateIndependentPolicy
from abc import ABC, abstractmethod
class Algorithm(ABC):
def __init__(self, state_shape, device, gamma):
self.learning_steps = 0
self.state_shape = state_shape
self.device = device
self.gamma = gamma
def explore(self, state_list):
action_list = []
log_pi_list = []
if type(state_list).__module__ != "torch":
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
with torch.no_grad():
for state in state_list:
action, log_pi = self.actor.sample(state.unsqueeze(0))
action_list.append(action.cpu().numpy()[0])
log_pi_list.append(log_pi.item())
return action_list, log_pi_list
def exploit(self, state_list):
action_list = []
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
with torch.no_grad():
for state in state_list:
action = self.actor(state.unsqueeze(0))
action_list.append(action.cpu().numpy()[0])
return action_list
@abstractmethod
def is_update(self, step):
pass
@abstractmethod
def update(self, writer, total_steps):
pass
@abstractmethod
def save_models(self, save_dir):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
class PPO(Algorithm):
def __init__(self, state_shape, device, action_shape=(2,), gamma=0.995, rollout_length=2048,
units_actor=(64, 64), epoch_ppo=10, clip_eps=0.2,
lambd=0.97, max_grad_norm=1.0, desired_kl=0.01, surrogate_loss_coef=2.,
value_loss_coef=5., entropy_coef=0., bounds_loss_coef=10., lr_actor=1e-3, lr_critic=1e-3,
lr_disc=1e-3, auto_lr=True, use_adv_norm=True, max_steps=10000000):
super().__init__(state_shape, device, gamma)
self.lr_actor = lr_actor
self.lr_critic = lr_critic
self.lr_disc = lr_disc
self.auto_lr = auto_lr
self.action_shape = action_shape
self.use_adv_norm = use_adv_norm
# Rollout buffer.
self.buffer = RolloutBuffer(
buffer_size=rollout_length,
state_shape=state_shape,
action_shape=action_shape,
device=device
)
# Actor.
self.actor = StateIndependentPolicy(
state_shape=state_shape,
action_shape=action_shape,
hidden_units=units_actor,
hidden_activation=nn.Tanh()
).to(device)
# Critic.
# 如果state_shape是元组提取第一个元素
state_dim = state_shape[0] if isinstance(state_shape, tuple) else state_shape
self.critic = Bert(
input_dim=state_dim,
output_dim=1
).to(device)
self.learning_steps_ppo = 0
self.rollout_length = rollout_length
self.epoch_ppo = epoch_ppo
self.clip_eps = clip_eps
self.lambd = lambd
self.max_grad_norm = max_grad_norm
self.desired_kl = desired_kl
self.surrogate_loss_coef = surrogate_loss_coef
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.bounds_loss_coef = bounds_loss_coef
self.max_steps = max_steps
self.optim_actor = Adam([{'params': self.actor.parameters()}], lr=lr_actor)
# self.optim_actor = Adam([
# {'params': self.actor.net.f_net.parameters(), 'lr': lr_actor},
# {'params': self.actor.net.k_net.parameters(), 'lr': lr_actor/3}])
self.optim_critic = Adam([{'params': self.critic.parameters()}], lr=lr_critic)
def is_update(self, step):
return step % self.rollout_length == 0
def step(self, env, state_list, state_gail):
state_list = torch.tensor(state_list, dtype=torch.float, device=self.device)
state_gail = torch.tensor(state_gail, dtype=torch.float, device=self.device)
action_list, log_pi_list = self.explore(state_list)
next_state, reward, terminated, truncated, info = env.step(np.array(action_list))
next_state_gail = env.state_gail
done = terminated or truncated
means = self.actor.means.detach().cpu().numpy()[0]
stds = (self.actor.log_stds.exp()).detach().cpu().numpy()[0]
self.buffer.append(state_list, state_gail, action_list, reward, done, terminated, log_pi_list,
next_state, next_state_gail, means, stds)
if done:
next_state = env.reset()
next_state_gail = env.state_gail
return next_state, next_state_gail, info
def update(self, writer, total_steps):
pass
def update_ppo(self, states, actions, rewards, dones, tm_dones, log_pi_list, next_states, mus, sigmas, writer,
total_steps):
with torch.no_grad():
values = self.critic(states.detach())
next_values = self.critic(next_states.detach())
targets, gaes = self.calculate_gae(
values, rewards, dones, tm_dones, next_values, self.gamma, self.lambd)
# 处理批量数据不需要按智能体分组因为buffer中已经混合了所有智能体的数据
for i in range(self.epoch_ppo):
self.learning_steps_ppo += 1
self.update_critic(states, targets, writer)
# 直接使用整个batch进行actor更新
self.update_actor(states, actions, log_pi_list, gaes, mus, sigmas, writer)
# self.lr_decay(total_steps, writer)
def update_critic(self, states, targets, writer):
loss_critic = (self.critic(states) - targets).pow_(2).mean()
loss_critic = loss_critic * self.value_loss_coef
self.optim_critic.zero_grad()
loss_critic.backward(retain_graph=False)
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
self.optim_critic.step()
if self.learning_steps_ppo % self.epoch_ppo == 0:
writer.add_scalar(
'Loss/critic', loss_critic.item(), self.learning_steps)
def update_actor(self, states, actions, log_pis_old, gaes, mus_old, sigmas_old, writer):
self.optim_actor.zero_grad()
log_pis = self.actor.evaluate_log_pi(states, actions)
mus = self.actor.means
sigmas = (self.actor.log_stds.exp()).repeat(mus.shape[0], 1)
entropy = -log_pis.mean()
ratios = (log_pis - log_pis_old).exp_()
loss_actor1 = -ratios * gaes
loss_actor2 = -torch.clamp(
ratios,
1.0 - self.clip_eps,
1.0 + self.clip_eps
) * gaes
loss_actor = torch.max(loss_actor1, loss_actor2).mean()
loss_actor = loss_actor * self.surrogate_loss_coef
if self.auto_lr:
# desired_kl: 0.01
with torch.inference_mode():
kl = torch.sum(torch.log(sigmas / sigmas_old + 1.e-5) +
(torch.square(sigmas_old) + torch.square(mus_old - mus)) /
(2.0 * torch.square(sigmas)) - 0.5, axis=-1)
kl_mean = torch.mean(kl)
if kl_mean > self.desired_kl * 2.0:
self.lr_actor = max(1e-5, self.lr_actor / 1.5)
self.lr_critic = max(1e-5, self.lr_critic / 1.5)
self.lr_disc = max(1e-5, self.lr_disc / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
self.lr_actor = min(1e-2, self.lr_actor * 1.5)
self.lr_critic = min(1e-2, self.lr_critic * 1.5)
self.lr_disc = min(1e-2, self.lr_disc * 1.5)
for param_group in self.optim_actor.param_groups:
param_group['lr'] = self.lr_actor
for param_group in self.optim_critic.param_groups:
param_group['lr'] = self.lr_critic
for param_group in self.optim_d.param_groups:
param_group['lr'] = self.lr_disc
loss = loss_actor # + b_loss * 0 - self.entropy_coef * entropy * 0
loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
self.optim_actor.step()
if self.learning_steps_ppo % self.epoch_ppo == 0:
writer.add_scalar(
'Loss/actor', loss_actor.item(), self.learning_steps)
writer.add_scalar(
'Loss/entropy', entropy.item(), self.learning_steps)
writer.add_scalar(
'Loss/learning_rate', self.lr_actor, self.learning_steps)
def lr_decay(self, total_steps, writer):
lr_a_now = max(1e-5, self.lr_actor * (1 - total_steps / self.max_steps))
lr_c_now = max(1e-5, self.lr_critic * (1 - total_steps / self.max_steps))
lr_d_now = max(1e-5, self.lr_disc * (1 - total_steps / self.max_steps))
for p in self.optim_actor.param_groups:
p['lr'] = lr_a_now
for p in self.optim_critic.param_groups:
p['lr'] = lr_c_now
for p in self.optim_d.param_groups:
p['lr'] = lr_d_now
writer.add_scalar(
'Loss/learning_rate', lr_a_now, self.learning_steps)
def calculate_gae(self, values, rewards, dones, tm_dones, next_values, gamma, lambd):
"""
Calculate the advantage using GAE
'tm_dones=True' means dead or win, there is no next state s'
'dones=True' represents the terminal of an episode(dead or win or reaching the max_episode_steps).
When calculating the adv, if dones=True, gae=0
Reference: https://github.com/Lizhi-sjtu/DRL-code-pytorch/blob/main/5.PPO-continuous/ppo_continuous.py
"""
with torch.no_grad():
# Calculate TD errors.
deltas = rewards + gamma * next_values * (1 - tm_dones) - values
# Initialize gae.
gaes = torch.empty_like(rewards)
# Calculate gae recursively from behind.
gaes[-1] = deltas[-1]
for t in reversed(range(rewards.size(0) - 1)):
gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1]
v_target = gaes + values
if self.use_adv_norm:
gaes = (gaes - gaes.mean()) / (gaes.std(dim=0) + 1e-8)
return v_target, gaes
def save_models(self, save_dir):
pass