564 lines
21 KiB
Python
564 lines
21 KiB
Python
"""
|
||
MAGAIL训练脚本
|
||
|
||
将Algorithm模块中的MAGAIL算法应用到多智能体环境中进行训练
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import torch
|
||
import numpy as np
|
||
import pickle
|
||
import time
|
||
from datetime import datetime
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
|
||
# 添加路径
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), 'Algorithm'))
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), 'Env'))
|
||
|
||
from Algorithm.magail import MAGAIL
|
||
from Algorithm.buffer import RolloutBuffer
|
||
from Env.scenario_env import MultiAgentScenarioEnv
|
||
from metadrive.engine.asset_loader import AssetLoader
|
||
|
||
|
||
class ExpertBuffer:
|
||
"""
|
||
专家数据缓冲区
|
||
|
||
从Waymo数据集中加载专家轨迹,用于GAIL判别器训练
|
||
"""
|
||
|
||
def __init__(self, data_dir, device, max_samples=100000):
|
||
"""
|
||
初始化专家缓冲区
|
||
|
||
Args:
|
||
data_dir: 专家数据目录
|
||
device: 计算设备
|
||
max_samples: 最大样本数
|
||
"""
|
||
self.device = device
|
||
self.max_samples = max_samples
|
||
self.states = []
|
||
self.next_states = []
|
||
|
||
print(f"📚 加载专家数据从: {data_dir}")
|
||
self._load_expert_data(data_dir)
|
||
|
||
# 数据已经在_extract_trajectories中转换为tensor并放到设备上了
|
||
if len(self.states) > 0:
|
||
print(f"✅ 加载完成: {len(self.states)} 条专家轨迹")
|
||
else:
|
||
print(f"⚠️ 警告: 未找到专家数据")
|
||
|
||
def _load_expert_data(self, data_dir):
|
||
"""
|
||
从pkl文件加载专家数据
|
||
|
||
注意: 这里需要根据实际的数据格式进行调整
|
||
"""
|
||
# 查找所有pkl文件
|
||
pkl_files = []
|
||
for root, dirs, files in os.walk(data_dir):
|
||
for file in files:
|
||
if file.endswith('.pkl') and 'sd_waymo' in file:
|
||
pkl_files.append(os.path.join(root, file))
|
||
|
||
if not pkl_files:
|
||
print(f"⚠️ 在 {data_dir} 中未找到专家数据文件")
|
||
return
|
||
|
||
print(f" 找到 {len(pkl_files)} 个数据文件")
|
||
|
||
# 只加载第一个文件作为示例
|
||
# 实际使用时可以加载多个文件
|
||
for pkl_file in pkl_files[:1]: # 只加载第一个文件
|
||
try:
|
||
with open(pkl_file, 'rb') as f:
|
||
data = pickle.load(f)
|
||
print(f" 正在处理: {os.path.basename(pkl_file)}")
|
||
self._extract_trajectories(data)
|
||
|
||
if len(self.states) >= self.max_samples:
|
||
break
|
||
except Exception as e:
|
||
print(f" ⚠️ 加载 {pkl_file} 失败: {e}")
|
||
|
||
def _extract_trajectories(self, scenario_data):
|
||
"""
|
||
从MetaDrive场景数据中提取车辆轨迹
|
||
|
||
Args:
|
||
scenario_data: MetaDrive格式的场景数据字典
|
||
"""
|
||
try:
|
||
# 方法1: 如果是字典且有'tracks'键
|
||
if isinstance(scenario_data, dict) and 'tracks' in scenario_data:
|
||
for vehicle_id, track_data in scenario_data['tracks'].items():
|
||
if track_data.get('type') == 'VEHICLE':
|
||
states = track_data.get('state', {})
|
||
|
||
# 获取有效帧
|
||
valid = states.get('valid', [])
|
||
if not hasattr(valid, 'any') or not valid.any():
|
||
continue
|
||
|
||
# 提取位置、速度、朝向等
|
||
positions = states.get('position', [])
|
||
velocities = states.get('velocity', [])
|
||
headings = states.get('heading', [])
|
||
|
||
# 构建state序列
|
||
for t in range(len(positions) - 1):
|
||
if valid[t] and valid[t+1]:
|
||
# 当前状态
|
||
state = np.concatenate([
|
||
positions[t][:2], # x, y
|
||
velocities[t], # vx, vy
|
||
[headings[t]], # heading
|
||
# ... 其他观测维度(激光雷达等暂时用0填充)
|
||
])
|
||
|
||
# 下一状态
|
||
next_state = np.concatenate([
|
||
positions[t+1][:2],
|
||
velocities[t+1],
|
||
[headings[t+1]],
|
||
])
|
||
|
||
# 补齐到108维(匹配实际观测维度)
|
||
state = np.pad(state, (0, 108 - len(state)))
|
||
next_state = np.pad(next_state, (0, 108 - len(next_state)))
|
||
|
||
# 转换为tensor并移到指定设备
|
||
self.states.append(torch.tensor(state, dtype=torch.float32, device=self.device))
|
||
self.next_states.append(torch.tensor(next_state, dtype=torch.float32, device=self.device))
|
||
|
||
if len(self.states) >= self.max_samples:
|
||
return
|
||
|
||
# 方法2: 其他可能的格式
|
||
# ...
|
||
|
||
except Exception as e:
|
||
print(f" ⚠️ 提取轨迹失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
def sample(self, batch_size):
|
||
"""
|
||
随机采样一批专家数据
|
||
|
||
Returns:
|
||
(states, next_states)
|
||
"""
|
||
if len(self.states) == 0:
|
||
# 如果没有专家数据,返回零张量
|
||
return (torch.zeros(batch_size, 108, device=self.device),
|
||
torch.zeros(batch_size, 108, device=self.device))
|
||
|
||
# 使用numpy随机采样(避免索引问题)
|
||
indices = np.random.randint(0, len(self.states), size=batch_size)
|
||
# 将list中的tensor堆叠成batch
|
||
states_batch = torch.stack([self.states[i] for i in indices])
|
||
next_states_batch = torch.stack([self.next_states[i] for i in indices])
|
||
return states_batch, next_states_batch
|
||
|
||
|
||
class MAGAILPolicy:
|
||
"""
|
||
MAGAIL策略包装器
|
||
|
||
将MAGAIL算法包装成环境可用的策略接口
|
||
"""
|
||
|
||
def __init__(self, magail_agent, device):
|
||
self.magail = magail_agent
|
||
self.device = device
|
||
|
||
def act(self, observation=None):
|
||
"""
|
||
执行动作(与环境兼容的接口)
|
||
|
||
注意: 由于环境调用方式的限制,这里采用简化处理
|
||
实际训练时需要通过主循环统一调用
|
||
"""
|
||
# 这个方法在训练时不会被使用
|
||
# 训练时统一通过 magail.explore() 获取动作
|
||
return [0.0, 0.0]
|
||
|
||
|
||
def collect_observations(env):
|
||
"""
|
||
收集所有智能体的观测
|
||
|
||
Args:
|
||
env: 多智能体环境
|
||
|
||
Returns:
|
||
obs_array: numpy数组 (n_agents, obs_dim)
|
||
"""
|
||
obs_list = env.obs_list
|
||
if len(obs_list) == 0:
|
||
return np.array([])
|
||
return np.array(obs_list)
|
||
|
||
|
||
def train_magail(
|
||
data_dir,
|
||
output_dir="./outputs",
|
||
num_episodes=1000,
|
||
horizon=300,
|
||
rollout_length=512, # 改为512,更适合300步的episode
|
||
batch_size=128, # 减小batch_size
|
||
lr_actor=3e-4,
|
||
lr_critic=3e-4,
|
||
lr_disc=3e-4,
|
||
epoch_disc=5,
|
||
epoch_ppo=10,
|
||
render=False,
|
||
device="cuda",
|
||
):
|
||
"""
|
||
MAGAIL训练主函数
|
||
|
||
Args:
|
||
data_dir: Waymo数据目录
|
||
output_dir: 输出目录(模型、日志)
|
||
num_episodes: 训练轮数
|
||
horizon: 每轮最大步数
|
||
rollout_length: PPO更新间隔
|
||
batch_size: 批次大小
|
||
lr_actor: Actor学习率
|
||
lr_critic: Critic学习率
|
||
lr_disc: 判别器学习率
|
||
epoch_disc: 判别器更新轮数
|
||
epoch_ppo: PPO更新轮数
|
||
render: 是否渲染
|
||
device: 计算设备
|
||
"""
|
||
|
||
# 创建输出目录
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
run_dir = os.path.join(output_dir, f"magail_{timestamp}")
|
||
os.makedirs(run_dir, exist_ok=True)
|
||
os.makedirs(os.path.join(run_dir, "models"), exist_ok=True)
|
||
|
||
# TensorBoard
|
||
writer = SummaryWriter(os.path.join(run_dir, "logs"))
|
||
|
||
# 设备
|
||
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||
print(f"🖥️ 使用设备: {device}")
|
||
|
||
# 观测和动作维度
|
||
# 根据scenario_env.py的_get_all_obs():
|
||
# 位置(2) + 速度(2) + 朝向(1) + 激光雷达(80) + 侧向(10) + 车道线(10) + 红绿灯(1) + 目标点(2) = 108
|
||
obs_dim = 108
|
||
action_dim = 2 # [转向, 油门/刹车]
|
||
|
||
print(f"📊 观测维度: {obs_dim}, 动作维度: {action_dim}")
|
||
|
||
# 加载专家数据
|
||
expert_buffer = ExpertBuffer(
|
||
data_dir=os.path.join(data_dir, "exp_converted"),
|
||
device=device,
|
||
max_samples=50000
|
||
)
|
||
|
||
# 初始化MAGAIL
|
||
print(f"🤖 初始化MAGAIL算法...")
|
||
magail = MAGAIL(
|
||
buffer_exp=expert_buffer,
|
||
input_dim=(obs_dim,),
|
||
device=device,
|
||
action_shape=(action_dim,),
|
||
rollout_length=rollout_length,
|
||
disc_coef=20.0,
|
||
disc_grad_penalty=0.1,
|
||
disc_logit_reg=0.25,
|
||
disc_weight_decay=0.0005,
|
||
lr_disc=lr_disc,
|
||
lr_actor=lr_actor,
|
||
lr_critic=lr_critic,
|
||
epoch_disc=epoch_disc,
|
||
epoch_ppo=epoch_ppo,
|
||
batch_size=batch_size,
|
||
use_gail_norm=True,
|
||
gamma=0.995,
|
||
lambd=0.97,
|
||
)
|
||
|
||
# 创建策略包装器
|
||
policy = MAGAILPolicy(magail, device)
|
||
|
||
# 环境配置(稍后为每个episode创建)
|
||
env_config = {
|
||
"data_directory": AssetLoader.file_path(data_dir, "exp_converted", unix_style=False),
|
||
"is_multi_agent": True,
|
||
"num_controlled_agents": 5,
|
||
"horizon": horizon,
|
||
"use_render": render,
|
||
"sequential_seed": True,
|
||
"reactive_traffic": True,
|
||
"manual_control": False,
|
||
"filter_offroad_vehicles": True,
|
||
"lane_tolerance": 3.0,
|
||
"max_controlled_vehicles": 5,
|
||
"debug_lane_filter": False,
|
||
"debug_traffic_light": False,
|
||
}
|
||
|
||
env = None # 每个episode创建新环境
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"🚀 开始训练 MAGAIL")
|
||
print(f"{'='*60}")
|
||
print(f"训练轮数: {num_episodes}")
|
||
print(f"每轮步数: {horizon}")
|
||
print(f"更新间隔: {rollout_length}")
|
||
print(f"输出目录: {run_dir}")
|
||
print(f"{'='*60}\n")
|
||
|
||
# 训练循环
|
||
total_steps = 0
|
||
best_reward = -float('inf')
|
||
|
||
for episode in range(num_episodes):
|
||
# 为每个episode创建新环境(避免MetaDrive的对象清理问题)
|
||
if env is not None:
|
||
env.close()
|
||
|
||
print(f"🌍 初始化Episode {episode + 1}环境...")
|
||
env = MultiAgentScenarioEnv(config=env_config, agent2policy=policy)
|
||
|
||
# 重置环境(场景索引要在范围内,循环使用场景)
|
||
scenario_index = episode % 3 # 只有3个场景,循环使用
|
||
obs_list = env.reset(scenario_index)
|
||
episode_reward = 0
|
||
episode_length = 0
|
||
|
||
# 检查是否有车辆
|
||
if len(env.controlled_agents) == 0:
|
||
print(f"⚠️ Episode {episode}: 没有可控车辆,跳过")
|
||
continue
|
||
|
||
print(f"\n📍 Episode {episode + 1}/{num_episodes}")
|
||
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
||
|
||
for step in range(horizon):
|
||
# 收集观测
|
||
obs_array = collect_observations(env)
|
||
|
||
if len(obs_array) == 0:
|
||
break
|
||
|
||
# 策略采样动作
|
||
actions, log_pis = magail.explore(obs_array)
|
||
|
||
# 调试:打印第一步的动作(查看动作范围)
|
||
if step == 0 and episode == 0:
|
||
print(f"\n🔍 调试信息 - 第一个动作:")
|
||
print(f" 动作数量: {len(actions)}")
|
||
if len(actions) > 0:
|
||
print(f" 第一个动作: {actions[0]}")
|
||
print(f" 动作范围: [{np.min(actions):.3f}, {np.max(actions):.3f}]")
|
||
# 检查车辆初始位置
|
||
first_vehicle = list(env.controlled_agents.values())[0]
|
||
print(f" 第一辆车初始位置: {first_vehicle.position}")
|
||
print(f" 第一辆车初始速度: {first_vehicle.speed:.2f} m/s")
|
||
|
||
# 每50步打印一次位置变化
|
||
if step % 50 == 0 and step > 0 and episode == 0:
|
||
if len(env.controlled_agents) > 0:
|
||
first_vehicle = list(env.controlled_agents.values())[0]
|
||
print(f" 步数{step}: 位置={first_vehicle.position}, 速度={first_vehicle.speed:.2f}m/s")
|
||
|
||
# 构建动作字典
|
||
action_dict = {}
|
||
for i, agent_id in enumerate(env.controlled_agents.keys()):
|
||
if i < len(actions):
|
||
action_dict[agent_id] = actions[i]
|
||
|
||
# 环境步进
|
||
next_obs_list, rewards, dones, infos = env.step(action_dict)
|
||
next_obs_array = collect_observations(env)
|
||
|
||
# 渲染
|
||
if render:
|
||
env.render(mode="topdown")
|
||
time.sleep(0.02) # 20ms延迟,让渲染更平滑,约50fps
|
||
|
||
# 存储经验到buffer(为每个智能体存储)
|
||
for i, agent_id in enumerate(env.controlled_agents.keys()):
|
||
if i < len(obs_array) and i < len(actions) and i < len(next_obs_array):
|
||
# 获取该智能体的数据
|
||
state = obs_array[i]
|
||
action = actions[i]
|
||
reward = rewards.get(agent_id, 0.0)
|
||
done = dones.get(agent_id, False)
|
||
tm_done = done # 暂时使用相同的done标志
|
||
log_pi = log_pis[i]
|
||
next_state = next_obs_array[i]
|
||
|
||
# 获取策略参数
|
||
mean = magail.actor.means[i].detach().cpu().numpy() if i < len(magail.actor.means) else np.zeros(action_dim)
|
||
std = magail.actor.log_stds.exp()[0].detach().cpu().numpy()
|
||
|
||
# 存储到buffer
|
||
try:
|
||
magail.buffer.append(
|
||
state=torch.tensor(state, dtype=torch.float32, device=device),
|
||
action=action,
|
||
reward=reward,
|
||
done=done,
|
||
tm_dones=tm_done, # 修正参数名
|
||
log_pi=log_pi,
|
||
next_state=next_state,
|
||
next_state_gail=next_state,
|
||
means=mean,
|
||
stds=std
|
||
)
|
||
# 调试:只在第一个episode打印一次
|
||
if episode == 0 and step == 0 and i == 0:
|
||
print(f" ✅ 成功存入第一条数据 (buffer._n={magail.buffer._n})")
|
||
except Exception as e:
|
||
# 打印错误信息
|
||
if episode == 0 and step == 0:
|
||
print(f" ❌ buffer存储失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# 计算平均奖励
|
||
avg_reward = np.mean(list(rewards.values())) if rewards else 0.0
|
||
episode_reward += avg_reward
|
||
episode_length += 1
|
||
total_steps += 1
|
||
|
||
# 检查是否结束
|
||
if dones.get("__all__", False):
|
||
break
|
||
|
||
# 定期更新
|
||
if total_steps % rollout_length == 0 and total_steps > 0:
|
||
print(f"\n 🔄 步数 {total_steps}: 更新模型...")
|
||
print(f" Buffer状态: _n={magail.buffer._n}, _p={magail.buffer._p}, buffer_size={magail.buffer.buffer_size}")
|
||
|
||
# 检查buffer是否有足够数据
|
||
if magail.buffer._n < batch_size:
|
||
print(f" ⚠️ Buffer数据不足: _n={magail.buffer._n} < batch_size={batch_size}, 跳过本次更新")
|
||
continue
|
||
|
||
try:
|
||
# 调用MAGAIL更新
|
||
gail_reward = magail.update(writer, total_steps)
|
||
print(f" GAIL奖励: {gail_reward:.4f}")
|
||
|
||
writer.add_scalar('Training/GAILReward', gail_reward, total_steps)
|
||
except Exception as e:
|
||
print(f" ⚠️ 更新失败: {e}")
|
||
if episode < 5: # 只在前几个episode打印详细信息
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# 记录训练指标
|
||
writer.add_scalar('Training/EpisodeReward', episode_reward, episode)
|
||
writer.add_scalar('Training/EpisodeLength', episode_length, episode)
|
||
|
||
# Episode结束
|
||
avg_episode_reward = episode_reward / max(episode_length, 1)
|
||
print(f" ✅ Episode {episode + 1} 完成:")
|
||
print(f" 步数: {episode_length}")
|
||
print(f" 总奖励: {episode_reward:.2f}")
|
||
print(f" 平均奖励: {avg_episode_reward:.4f}")
|
||
print(f" 车辆数: {len(env.controlled_agents)}")
|
||
|
||
# 记录到TensorBoard
|
||
writer.add_scalar('Episode/Reward', episode_reward, episode)
|
||
writer.add_scalar('Episode/Length', episode_length, episode)
|
||
writer.add_scalar('Episode/AvgReward', avg_episode_reward, episode)
|
||
writer.add_scalar('Episode/NumVehicles', len(env.controlled_agents), episode)
|
||
writer.add_scalar('Training/TotalSteps', total_steps, episode)
|
||
|
||
# 保存最佳模型
|
||
if episode_reward > best_reward:
|
||
best_reward = episode_reward
|
||
save_path = os.path.join(run_dir, "models", "best_model")
|
||
magail.save_models(save_path)
|
||
print(f" 💾 保存最佳模型 (奖励: {best_reward:.2f})")
|
||
|
||
# 定期保存
|
||
if (episode + 1) % 50 == 0:
|
||
save_path = os.path.join(run_dir, "models", f"checkpoint_{episode + 1}")
|
||
magail.save_models(save_path)
|
||
print(f" 💾 保存检查点: {save_path}")
|
||
|
||
# 训练结束
|
||
print(f"\n{'='*60}")
|
||
print(f"✅ 训练完成!")
|
||
print(f"{'='*60}")
|
||
print(f"总步数: {total_steps}")
|
||
print(f"最佳奖励: {best_reward:.2f}")
|
||
print(f"模型保存位置: {run_dir}/models")
|
||
print(f"日志位置: {run_dir}/logs")
|
||
|
||
# 关闭
|
||
env.close()
|
||
writer.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="训练MAGAIL算法")
|
||
parser.add_argument("--data-dir", type=str,
|
||
default="/home/huangfukk/MAGAIL4AutoDrive/Env",
|
||
help="Waymo数据目录")
|
||
parser.add_argument("--output-dir", type=str,
|
||
default="./outputs",
|
||
help="输出目录")
|
||
parser.add_argument("--episodes", type=int, default=1000,
|
||
help="训练轮数")
|
||
parser.add_argument("--horizon", type=int, default=300,
|
||
help="每轮最大步数")
|
||
parser.add_argument("--rollout-length", type=int, default=2048,
|
||
help="PPO更新间隔")
|
||
parser.add_argument("--batch-size", type=int, default=256,
|
||
help="批次大小")
|
||
parser.add_argument("--lr-actor", type=float, default=3e-4,
|
||
help="Actor学习率")
|
||
parser.add_argument("--lr-critic", type=float, default=3e-4,
|
||
help="Critic学习率")
|
||
parser.add_argument("--lr-disc", type=float, default=3e-4,
|
||
help="判别器学习率")
|
||
parser.add_argument("--epoch-disc", type=int, default=5,
|
||
help="判别器更新轮数")
|
||
parser.add_argument("--epoch-ppo", type=int, default=10,
|
||
help="PPO更新轮数")
|
||
parser.add_argument("--render", action="store_true",
|
||
help="是否渲染")
|
||
parser.add_argument("--device", type=str, default="cuda",
|
||
choices=["cuda", "cpu"],
|
||
help="计算设备")
|
||
|
||
args = parser.parse_args()
|
||
|
||
train_magail(
|
||
data_dir=args.data_dir,
|
||
output_dir=args.output_dir,
|
||
num_episodes=args.episodes,
|
||
horizon=args.horizon,
|
||
rollout_length=args.rollout_length,
|
||
batch_size=args.batch_size,
|
||
lr_actor=args.lr_actor,
|
||
lr_critic=args.lr_critic,
|
||
lr_disc=args.lr_disc,
|
||
epoch_disc=args.epoch_disc,
|
||
epoch_ppo=args.epoch_ppo,
|
||
render=args.render,
|
||
device=args.device,
|
||
)
|
||
|