import os import torch import numpy as np from torch import nn from torch.optim import Adam try: from .buffer import RolloutBuffer from .bert import Bert from .policy import StateIndependentPolicy except ImportError: from buffer import RolloutBuffer from bert import Bert from policy import StateIndependentPolicy from abc import ABC, abstractmethod class Algorithm(ABC): def __init__(self, state_shape, device, gamma): self.learning_steps = 0 self.state_shape = state_shape self.device = device self.gamma = gamma def explore(self, state_list): action_list = [] log_pi_list = [] if type(state_list).__module__ != "torch": state_list = torch.tensor(state_list, dtype=torch.float, device=self.device) with torch.no_grad(): for state in state_list: action, log_pi = self.actor.sample(state.unsqueeze(0)) action_list.append(action.cpu().numpy()[0]) log_pi_list.append(log_pi.item()) return action_list, log_pi_list def exploit(self, state_list): action_list = [] state_list = torch.tensor(state_list, dtype=torch.float, device=self.device) with torch.no_grad(): for state in state_list: action = self.actor(state.unsqueeze(0)) action_list.append(action.cpu().numpy()[0]) return action_list @abstractmethod def is_update(self, step): pass @abstractmethod def update(self, writer, total_steps): pass @abstractmethod def save_models(self, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) class PPO(Algorithm): def __init__(self, state_shape, device, action_shape=(2,), gamma=0.995, rollout_length=2048, units_actor=(64, 64), epoch_ppo=10, clip_eps=0.2, lambd=0.97, max_grad_norm=1.0, desired_kl=0.01, surrogate_loss_coef=2., value_loss_coef=5., entropy_coef=0., bounds_loss_coef=10., lr_actor=1e-3, lr_critic=1e-3, lr_disc=1e-3, auto_lr=True, use_adv_norm=True, max_steps=10000000): super().__init__(state_shape, device, gamma) self.lr_actor = lr_actor self.lr_critic = lr_critic self.lr_disc = lr_disc self.auto_lr = auto_lr self.action_shape = action_shape self.use_adv_norm = use_adv_norm # Rollout buffer. self.buffer = RolloutBuffer( buffer_size=rollout_length, state_shape=state_shape, action_shape=action_shape, device=device ) # Actor. self.actor = StateIndependentPolicy( state_shape=state_shape, action_shape=action_shape, hidden_units=units_actor, hidden_activation=nn.Tanh() ).to(device) # Critic. # 如果state_shape是元组,提取第一个元素 state_dim = state_shape[0] if isinstance(state_shape, tuple) else state_shape self.critic = Bert( input_dim=state_dim, output_dim=1 ).to(device) self.learning_steps_ppo = 0 self.rollout_length = rollout_length self.epoch_ppo = epoch_ppo self.clip_eps = clip_eps self.lambd = lambd self.max_grad_norm = max_grad_norm self.desired_kl = desired_kl self.surrogate_loss_coef = surrogate_loss_coef self.value_loss_coef = value_loss_coef self.entropy_coef = entropy_coef self.bounds_loss_coef = bounds_loss_coef self.max_steps = max_steps self.optim_actor = Adam([{'params': self.actor.parameters()}], lr=lr_actor) # self.optim_actor = Adam([ # {'params': self.actor.net.f_net.parameters(), 'lr': lr_actor}, # {'params': self.actor.net.k_net.parameters(), 'lr': lr_actor/3}]) self.optim_critic = Adam([{'params': self.critic.parameters()}], lr=lr_critic) def is_update(self, step): return step % self.rollout_length == 0 def step(self, env, state_list, state_gail): state_list = torch.tensor(state_list, dtype=torch.float, device=self.device) state_gail = torch.tensor(state_gail, dtype=torch.float, device=self.device) action_list, log_pi_list = self.explore(state_list) next_state, reward, terminated, truncated, info = env.step(np.array(action_list)) next_state_gail = env.state_gail done = terminated or truncated means = self.actor.means.detach().cpu().numpy()[0] stds = (self.actor.log_stds.exp()).detach().cpu().numpy()[0] self.buffer.append(state_list, state_gail, action_list, reward, done, terminated, log_pi_list, next_state, next_state_gail, means, stds) if done: next_state = env.reset() next_state_gail = env.state_gail return next_state, next_state_gail, info def update(self, writer, total_steps): pass def update_ppo(self, states, actions, rewards, dones, tm_dones, log_pi_list, next_states, mus, sigmas, writer, total_steps): with torch.no_grad(): values = self.critic(states.detach()) next_values = self.critic(next_states.detach()) targets, gaes = self.calculate_gae( values, rewards, dones, tm_dones, next_values, self.gamma, self.lambd) # 处理批量数据(不需要按智能体分组,因为buffer中已经混合了所有智能体的数据) for i in range(self.epoch_ppo): self.learning_steps_ppo += 1 self.update_critic(states, targets, writer) # 直接使用整个batch进行actor更新 self.update_actor(states, actions, log_pi_list, gaes, mus, sigmas, writer) # self.lr_decay(total_steps, writer) def update_critic(self, states, targets, writer): loss_critic = (self.critic(states) - targets).pow_(2).mean() loss_critic = loss_critic * self.value_loss_coef self.optim_critic.zero_grad() loss_critic.backward(retain_graph=False) nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) self.optim_critic.step() if self.learning_steps_ppo % self.epoch_ppo == 0: writer.add_scalar( 'Loss/critic', loss_critic.item(), self.learning_steps) def update_actor(self, states, actions, log_pis_old, gaes, mus_old, sigmas_old, writer): self.optim_actor.zero_grad() log_pis = self.actor.evaluate_log_pi(states, actions) mus = self.actor.means sigmas = (self.actor.log_stds.exp()).repeat(mus.shape[0], 1) entropy = -log_pis.mean() ratios = (log_pis - log_pis_old).exp_() loss_actor1 = -ratios * gaes loss_actor2 = -torch.clamp( ratios, 1.0 - self.clip_eps, 1.0 + self.clip_eps ) * gaes loss_actor = torch.max(loss_actor1, loss_actor2).mean() loss_actor = loss_actor * self.surrogate_loss_coef if self.auto_lr: # desired_kl: 0.01 with torch.inference_mode(): kl = torch.sum(torch.log(sigmas / sigmas_old + 1.e-5) + (torch.square(sigmas_old) + torch.square(mus_old - mus)) / (2.0 * torch.square(sigmas)) - 0.5, axis=-1) kl_mean = torch.mean(kl) if kl_mean > self.desired_kl * 2.0: self.lr_actor = max(1e-5, self.lr_actor / 1.5) self.lr_critic = max(1e-5, self.lr_critic / 1.5) self.lr_disc = max(1e-5, self.lr_disc / 1.5) elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: self.lr_actor = min(1e-2, self.lr_actor * 1.5) self.lr_critic = min(1e-2, self.lr_critic * 1.5) self.lr_disc = min(1e-2, self.lr_disc * 1.5) for param_group in self.optim_actor.param_groups: param_group['lr'] = self.lr_actor for param_group in self.optim_critic.param_groups: param_group['lr'] = self.lr_critic for param_group in self.optim_d.param_groups: param_group['lr'] = self.lr_disc loss = loss_actor # + b_loss * 0 - self.entropy_coef * entropy * 0 loss.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) self.optim_actor.step() if self.learning_steps_ppo % self.epoch_ppo == 0: writer.add_scalar( 'Loss/actor', loss_actor.item(), self.learning_steps) writer.add_scalar( 'Loss/entropy', entropy.item(), self.learning_steps) writer.add_scalar( 'Loss/learning_rate', self.lr_actor, self.learning_steps) def lr_decay(self, total_steps, writer): lr_a_now = max(1e-5, self.lr_actor * (1 - total_steps / self.max_steps)) lr_c_now = max(1e-5, self.lr_critic * (1 - total_steps / self.max_steps)) lr_d_now = max(1e-5, self.lr_disc * (1 - total_steps / self.max_steps)) for p in self.optim_actor.param_groups: p['lr'] = lr_a_now for p in self.optim_critic.param_groups: p['lr'] = lr_c_now for p in self.optim_d.param_groups: p['lr'] = lr_d_now writer.add_scalar( 'Loss/learning_rate', lr_a_now, self.learning_steps) def calculate_gae(self, values, rewards, dones, tm_dones, next_values, gamma, lambd): """ Calculate the advantage using GAE 'tm_dones=True' means dead or win, there is no next state s' 'dones=True' represents the terminal of an episode(dead or win or reaching the max_episode_steps). When calculating the adv, if dones=True, gae=0 Reference: https://github.com/Lizhi-sjtu/DRL-code-pytorch/blob/main/5.PPO-continuous/ppo_continuous.py """ with torch.no_grad(): # Calculate TD errors. deltas = rewards + gamma * next_values * (1 - tm_dones) - values # Initialize gae. gaes = torch.empty_like(rewards) # Calculate gae recursively from behind. gaes[-1] = deltas[-1] for t in reversed(range(rewards.size(0) - 1)): gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1] v_target = gaes + values if self.use_adv_norm: gaes = (gaes - gaes.mean()) / (gaes.std(dim=0) + 1e-8) return v_target, gaes def save_models(self, save_dir): pass