""" MAGAIL训练脚本 将Algorithm模块中的MAGAIL算法应用到多智能体环境中进行训练 """ import os import sys import torch import numpy as np import pickle import time from datetime import datetime from torch.utils.tensorboard import SummaryWriter # 添加路径 sys.path.append(os.path.join(os.path.dirname(__file__), 'Algorithm')) sys.path.append(os.path.join(os.path.dirname(__file__), 'Env')) from Algorithm.magail import MAGAIL from Algorithm.buffer import RolloutBuffer from Env.scenario_env import MultiAgentScenarioEnv from metadrive.engine.asset_loader import AssetLoader class ExpertBuffer: """ 专家数据缓冲区 从Waymo数据集中加载专家轨迹,用于GAIL判别器训练 """ def __init__(self, data_dir, device, max_samples=100000): """ 初始化专家缓冲区 Args: data_dir: 专家数据目录 device: 计算设备 max_samples: 最大样本数 """ self.device = device self.max_samples = max_samples self.states = [] self.next_states = [] print(f"📚 加载专家数据从: {data_dir}") self._load_expert_data(data_dir) # 数据已经在_extract_trajectories中转换为tensor并放到设备上了 if len(self.states) > 0: print(f"✅ 加载完成: {len(self.states)} 条专家轨迹") else: print(f"⚠️ 警告: 未找到专家数据") def _load_expert_data(self, data_dir): """ 从pkl文件加载专家数据 注意: 这里需要根据实际的数据格式进行调整 """ # 查找所有pkl文件 pkl_files = [] for root, dirs, files in os.walk(data_dir): for file in files: if file.endswith('.pkl') and 'sd_waymo' in file: pkl_files.append(os.path.join(root, file)) if not pkl_files: print(f"⚠️ 在 {data_dir} 中未找到专家数据文件") return print(f" 找到 {len(pkl_files)} 个数据文件") # 只加载第一个文件作为示例 # 实际使用时可以加载多个文件 for pkl_file in pkl_files[:1]: # 只加载第一个文件 try: with open(pkl_file, 'rb') as f: data = pickle.load(f) print(f" 正在处理: {os.path.basename(pkl_file)}") self._extract_trajectories(data) if len(self.states) >= self.max_samples: break except Exception as e: print(f" ⚠️ 加载 {pkl_file} 失败: {e}") def _extract_trajectories(self, scenario_data): """ 从MetaDrive场景数据中提取车辆轨迹 Args: scenario_data: MetaDrive格式的场景数据字典 """ try: # 方法1: 如果是字典且有'tracks'键 if isinstance(scenario_data, dict) and 'tracks' in scenario_data: for vehicle_id, track_data in scenario_data['tracks'].items(): if track_data.get('type') == 'VEHICLE': states = track_data.get('state', {}) # 获取有效帧 valid = states.get('valid', []) if not hasattr(valid, 'any') or not valid.any(): continue # 提取位置、速度、朝向等 positions = states.get('position', []) velocities = states.get('velocity', []) headings = states.get('heading', []) # 构建state序列 for t in range(len(positions) - 1): if valid[t] and valid[t+1]: # 当前状态 state = np.concatenate([ positions[t][:2], # x, y velocities[t], # vx, vy [headings[t]], # heading # ... 其他观测维度(激光雷达等暂时用0填充) ]) # 下一状态 next_state = np.concatenate([ positions[t+1][:2], velocities[t+1], [headings[t+1]], ]) # 补齐到108维(匹配实际观测维度) state = np.pad(state, (0, 108 - len(state))) next_state = np.pad(next_state, (0, 108 - len(next_state))) # 转换为tensor并移到指定设备 self.states.append(torch.tensor(state, dtype=torch.float32, device=self.device)) self.next_states.append(torch.tensor(next_state, dtype=torch.float32, device=self.device)) if len(self.states) >= self.max_samples: return # 方法2: 其他可能的格式 # ... except Exception as e: print(f" ⚠️ 提取轨迹失败: {e}") import traceback traceback.print_exc() def sample(self, batch_size): """ 随机采样一批专家数据 Returns: (states, next_states) """ if len(self.states) == 0: # 如果没有专家数据,返回零张量 return (torch.zeros(batch_size, 108, device=self.device), torch.zeros(batch_size, 108, device=self.device)) # 使用numpy随机采样(避免索引问题) indices = np.random.randint(0, len(self.states), size=batch_size) # 将list中的tensor堆叠成batch states_batch = torch.stack([self.states[i] for i in indices]) next_states_batch = torch.stack([self.next_states[i] for i in indices]) return states_batch, next_states_batch class MAGAILPolicy: """ MAGAIL策略包装器 将MAGAIL算法包装成环境可用的策略接口 """ def __init__(self, magail_agent, device): self.magail = magail_agent self.device = device def act(self, observation=None): """ 执行动作(与环境兼容的接口) 注意: 由于环境调用方式的限制,这里采用简化处理 实际训练时需要通过主循环统一调用 """ # 这个方法在训练时不会被使用 # 训练时统一通过 magail.explore() 获取动作 return [0.0, 0.0] def collect_observations(env): """ 收集所有智能体的观测 Args: env: 多智能体环境 Returns: obs_array: numpy数组 (n_agents, obs_dim) """ obs_list = env.obs_list if len(obs_list) == 0: return np.array([]) return np.array(obs_list) def train_magail( data_dir, output_dir="./outputs", num_episodes=1000, horizon=300, rollout_length=512, # 改为512,更适合300步的episode batch_size=128, # 减小batch_size lr_actor=3e-4, lr_critic=3e-4, lr_disc=3e-4, epoch_disc=5, epoch_ppo=10, render=False, device="cuda", ): """ MAGAIL训练主函数 Args: data_dir: Waymo数据目录 output_dir: 输出目录(模型、日志) num_episodes: 训练轮数 horizon: 每轮最大步数 rollout_length: PPO更新间隔 batch_size: 批次大小 lr_actor: Actor学习率 lr_critic: Critic学习率 lr_disc: 判别器学习率 epoch_disc: 判别器更新轮数 epoch_ppo: PPO更新轮数 render: 是否渲染 device: 计算设备 """ # 创建输出目录 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") run_dir = os.path.join(output_dir, f"magail_{timestamp}") os.makedirs(run_dir, exist_ok=True) os.makedirs(os.path.join(run_dir, "models"), exist_ok=True) # TensorBoard writer = SummaryWriter(os.path.join(run_dir, "logs")) # 设备 device = torch.device(device if torch.cuda.is_available() else "cpu") print(f"🖥️ 使用设备: {device}") # 观测和动作维度 # 根据scenario_env.py的_get_all_obs(): # 位置(2) + 速度(2) + 朝向(1) + 激光雷达(80) + 侧向(10) + 车道线(10) + 红绿灯(1) + 目标点(2) = 108 obs_dim = 108 action_dim = 2 # [转向, 油门/刹车] print(f"📊 观测维度: {obs_dim}, 动作维度: {action_dim}") # 加载专家数据 expert_buffer = ExpertBuffer( data_dir=os.path.join(data_dir, "exp_converted"), device=device, max_samples=50000 ) # 初始化MAGAIL print(f"🤖 初始化MAGAIL算法...") magail = MAGAIL( buffer_exp=expert_buffer, input_dim=(obs_dim,), device=device, action_shape=(action_dim,), rollout_length=rollout_length, disc_coef=20.0, disc_grad_penalty=0.1, disc_logit_reg=0.25, disc_weight_decay=0.0005, lr_disc=lr_disc, lr_actor=lr_actor, lr_critic=lr_critic, epoch_disc=epoch_disc, epoch_ppo=epoch_ppo, batch_size=batch_size, use_gail_norm=True, gamma=0.995, lambd=0.97, ) # 创建策略包装器 policy = MAGAILPolicy(magail, device) # 环境配置(稍后为每个episode创建) env_config = { "data_directory": AssetLoader.file_path(data_dir, "exp_converted", unix_style=False), "is_multi_agent": True, "num_controlled_agents": 5, "horizon": horizon, "use_render": render, "sequential_seed": True, "reactive_traffic": True, "manual_control": False, "filter_offroad_vehicles": True, "lane_tolerance": 3.0, "max_controlled_vehicles": 5, "debug_lane_filter": False, "debug_traffic_light": False, } env = None # 每个episode创建新环境 print(f"\n{'='*60}") print(f"🚀 开始训练 MAGAIL") print(f"{'='*60}") print(f"训练轮数: {num_episodes}") print(f"每轮步数: {horizon}") print(f"更新间隔: {rollout_length}") print(f"输出目录: {run_dir}") print(f"{'='*60}\n") # 训练循环 total_steps = 0 best_reward = -float('inf') for episode in range(num_episodes): # 为每个episode创建新环境(避免MetaDrive的对象清理问题) if env is not None: env.close() print(f"🌍 初始化Episode {episode + 1}环境...") env = MultiAgentScenarioEnv(config=env_config, agent2policy=policy) # 重置环境(场景索引要在范围内,循环使用场景) scenario_index = episode % 3 # 只有3个场景,循环使用 obs_list = env.reset(scenario_index) episode_reward = 0 episode_length = 0 # 检查是否有车辆 if len(env.controlled_agents) == 0: print(f"⚠️ Episode {episode}: 没有可控车辆,跳过") continue print(f"\n📍 Episode {episode + 1}/{num_episodes}") print(f" 可控车辆数: {len(env.controlled_agents)}") for step in range(horizon): # 收集观测 obs_array = collect_observations(env) if len(obs_array) == 0: break # 策略采样动作 actions, log_pis = magail.explore(obs_array) # 调试:打印第一步的动作(查看动作范围) if step == 0 and episode == 0: print(f"\n🔍 调试信息 - 第一个动作:") print(f" 动作数量: {len(actions)}") if len(actions) > 0: print(f" 第一个动作: {actions[0]}") print(f" 动作范围: [{np.min(actions):.3f}, {np.max(actions):.3f}]") # 检查车辆初始位置 first_vehicle = list(env.controlled_agents.values())[0] print(f" 第一辆车初始位置: {first_vehicle.position}") print(f" 第一辆车初始速度: {first_vehicle.speed:.2f} m/s") # 每50步打印一次位置变化 if step % 50 == 0 and step > 0 and episode == 0: if len(env.controlled_agents) > 0: first_vehicle = list(env.controlled_agents.values())[0] print(f" 步数{step}: 位置={first_vehicle.position}, 速度={first_vehicle.speed:.2f}m/s") # 构建动作字典 action_dict = {} for i, agent_id in enumerate(env.controlled_agents.keys()): if i < len(actions): action_dict[agent_id] = actions[i] # 环境步进 next_obs_list, rewards, dones, infos = env.step(action_dict) next_obs_array = collect_observations(env) # 渲染 if render: env.render(mode="topdown") time.sleep(0.02) # 20ms延迟,让渲染更平滑,约50fps # 存储经验到buffer(为每个智能体存储) for i, agent_id in enumerate(env.controlled_agents.keys()): if i < len(obs_array) and i < len(actions) and i < len(next_obs_array): # 获取该智能体的数据 state = obs_array[i] action = actions[i] reward = rewards.get(agent_id, 0.0) done = dones.get(agent_id, False) tm_done = done # 暂时使用相同的done标志 log_pi = log_pis[i] next_state = next_obs_array[i] # 获取策略参数 mean = magail.actor.means[i].detach().cpu().numpy() if i < len(magail.actor.means) else np.zeros(action_dim) std = magail.actor.log_stds.exp()[0].detach().cpu().numpy() # 存储到buffer try: magail.buffer.append( state=torch.tensor(state, dtype=torch.float32, device=device), action=action, reward=reward, done=done, tm_dones=tm_done, # 修正参数名 log_pi=log_pi, next_state=next_state, next_state_gail=next_state, means=mean, stds=std ) # 调试:只在第一个episode打印一次 if episode == 0 and step == 0 and i == 0: print(f" ✅ 成功存入第一条数据 (buffer._n={magail.buffer._n})") except Exception as e: # 打印错误信息 if episode == 0 and step == 0: print(f" ❌ buffer存储失败: {e}") import traceback traceback.print_exc() # 计算平均奖励 avg_reward = np.mean(list(rewards.values())) if rewards else 0.0 episode_reward += avg_reward episode_length += 1 total_steps += 1 # 检查是否结束 if dones.get("__all__", False): break # 定期更新 if total_steps % rollout_length == 0 and total_steps > 0: print(f"\n 🔄 步数 {total_steps}: 更新模型...") print(f" Buffer状态: _n={magail.buffer._n}, _p={magail.buffer._p}, buffer_size={magail.buffer.buffer_size}") # 检查buffer是否有足够数据 if magail.buffer._n < batch_size: print(f" ⚠️ Buffer数据不足: _n={magail.buffer._n} < batch_size={batch_size}, 跳过本次更新") continue try: # 调用MAGAIL更新 gail_reward = magail.update(writer, total_steps) print(f" GAIL奖励: {gail_reward:.4f}") writer.add_scalar('Training/GAILReward', gail_reward, total_steps) except Exception as e: print(f" ⚠️ 更新失败: {e}") if episode < 5: # 只在前几个episode打印详细信息 import traceback traceback.print_exc() # 记录训练指标 writer.add_scalar('Training/EpisodeReward', episode_reward, episode) writer.add_scalar('Training/EpisodeLength', episode_length, episode) # Episode结束 avg_episode_reward = episode_reward / max(episode_length, 1) print(f" ✅ Episode {episode + 1} 完成:") print(f" 步数: {episode_length}") print(f" 总奖励: {episode_reward:.2f}") print(f" 平均奖励: {avg_episode_reward:.4f}") print(f" 车辆数: {len(env.controlled_agents)}") # 记录到TensorBoard writer.add_scalar('Episode/Reward', episode_reward, episode) writer.add_scalar('Episode/Length', episode_length, episode) writer.add_scalar('Episode/AvgReward', avg_episode_reward, episode) writer.add_scalar('Episode/NumVehicles', len(env.controlled_agents), episode) writer.add_scalar('Training/TotalSteps', total_steps, episode) # 保存最佳模型 if episode_reward > best_reward: best_reward = episode_reward save_path = os.path.join(run_dir, "models", "best_model") magail.save_models(save_path) print(f" 💾 保存最佳模型 (奖励: {best_reward:.2f})") # 定期保存 if (episode + 1) % 50 == 0: save_path = os.path.join(run_dir, "models", f"checkpoint_{episode + 1}") magail.save_models(save_path) print(f" 💾 保存检查点: {save_path}") # 训练结束 print(f"\n{'='*60}") print(f"✅ 训练完成!") print(f"{'='*60}") print(f"总步数: {total_steps}") print(f"最佳奖励: {best_reward:.2f}") print(f"模型保存位置: {run_dir}/models") print(f"日志位置: {run_dir}/logs") # 关闭 env.close() writer.close() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="训练MAGAIL算法") parser.add_argument("--data-dir", type=str, default="/home/huangfukk/MAGAIL4AutoDrive/Env", help="Waymo数据目录") parser.add_argument("--output-dir", type=str, default="./outputs", help="输出目录") parser.add_argument("--episodes", type=int, default=1000, help="训练轮数") parser.add_argument("--horizon", type=int, default=300, help="每轮最大步数") parser.add_argument("--rollout-length", type=int, default=2048, help="PPO更新间隔") parser.add_argument("--batch-size", type=int, default=256, help="批次大小") parser.add_argument("--lr-actor", type=float, default=3e-4, help="Actor学习率") parser.add_argument("--lr-critic", type=float, default=3e-4, help="Critic学习率") parser.add_argument("--lr-disc", type=float, default=3e-4, help="判别器学习率") parser.add_argument("--epoch-disc", type=int, default=5, help="判别器更新轮数") parser.add_argument("--epoch-ppo", type=int, default=10, help="PPO更新轮数") parser.add_argument("--render", action="store_true", help="是否渲染") parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="计算设备") args = parser.parse_args() train_magail( data_dir=args.data_dir, output_dir=args.output_dir, num_episodes=args.episodes, horizon=args.horizon, rollout_length=args.rollout_length, batch_size=args.batch_size, lr_actor=args.lr_actor, lr_critic=args.lr_critic, lr_disc=args.lr_disc, epoch_disc=args.epoch_disc, epoch_ppo=args.epoch_ppo, render=args.render, device=args.device, )