Files
MAGAIL4AutoDrive/train_magail.py

564 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
MAGAIL训练脚本
将Algorithm模块中的MAGAIL算法应用到多智能体环境中进行训练
"""
import os
import sys
import torch
import numpy as np
import pickle
import time
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
# 添加路径
sys.path.append(os.path.join(os.path.dirname(__file__), 'Algorithm'))
sys.path.append(os.path.join(os.path.dirname(__file__), 'Env'))
from Algorithm.magail import MAGAIL
from Algorithm.buffer import RolloutBuffer
from Env.scenario_env import MultiAgentScenarioEnv
from metadrive.engine.asset_loader import AssetLoader
class ExpertBuffer:
"""
专家数据缓冲区
从Waymo数据集中加载专家轨迹用于GAIL判别器训练
"""
def __init__(self, data_dir, device, max_samples=100000):
"""
初始化专家缓冲区
Args:
data_dir: 专家数据目录
device: 计算设备
max_samples: 最大样本数
"""
self.device = device
self.max_samples = max_samples
self.states = []
self.next_states = []
print(f"📚 加载专家数据从: {data_dir}")
self._load_expert_data(data_dir)
# 数据已经在_extract_trajectories中转换为tensor并放到设备上了
if len(self.states) > 0:
print(f"✅ 加载完成: {len(self.states)} 条专家轨迹")
else:
print(f"⚠️ 警告: 未找到专家数据")
def _load_expert_data(self, data_dir):
"""
从pkl文件加载专家数据
注意: 这里需要根据实际的数据格式进行调整
"""
# 查找所有pkl文件
pkl_files = []
for root, dirs, files in os.walk(data_dir):
for file in files:
if file.endswith('.pkl') and 'sd_waymo' in file:
pkl_files.append(os.path.join(root, file))
if not pkl_files:
print(f"⚠️ 在 {data_dir} 中未找到专家数据文件")
return
print(f" 找到 {len(pkl_files)} 个数据文件")
# 只加载第一个文件作为示例
# 实际使用时可以加载多个文件
for pkl_file in pkl_files[:1]: # 只加载第一个文件
try:
with open(pkl_file, 'rb') as f:
data = pickle.load(f)
print(f" 正在处理: {os.path.basename(pkl_file)}")
self._extract_trajectories(data)
if len(self.states) >= self.max_samples:
break
except Exception as e:
print(f" ⚠️ 加载 {pkl_file} 失败: {e}")
def _extract_trajectories(self, scenario_data):
"""
从MetaDrive场景数据中提取车辆轨迹
Args:
scenario_data: MetaDrive格式的场景数据字典
"""
try:
# 方法1: 如果是字典且有'tracks'键
if isinstance(scenario_data, dict) and 'tracks' in scenario_data:
for vehicle_id, track_data in scenario_data['tracks'].items():
if track_data.get('type') == 'VEHICLE':
states = track_data.get('state', {})
# 获取有效帧
valid = states.get('valid', [])
if not hasattr(valid, 'any') or not valid.any():
continue
# 提取位置、速度、朝向等
positions = states.get('position', [])
velocities = states.get('velocity', [])
headings = states.get('heading', [])
# 构建state序列
for t in range(len(positions) - 1):
if valid[t] and valid[t+1]:
# 当前状态
state = np.concatenate([
positions[t][:2], # x, y
velocities[t], # vx, vy
[headings[t]], # heading
# ... 其他观测维度激光雷达等暂时用0填充
])
# 下一状态
next_state = np.concatenate([
positions[t+1][:2],
velocities[t+1],
[headings[t+1]],
])
# 补齐到108维匹配实际观测维度
state = np.pad(state, (0, 108 - len(state)))
next_state = np.pad(next_state, (0, 108 - len(next_state)))
# 转换为tensor并移到指定设备
self.states.append(torch.tensor(state, dtype=torch.float32, device=self.device))
self.next_states.append(torch.tensor(next_state, dtype=torch.float32, device=self.device))
if len(self.states) >= self.max_samples:
return
# 方法2: 其他可能的格式
# ...
except Exception as e:
print(f" ⚠️ 提取轨迹失败: {e}")
import traceback
traceback.print_exc()
def sample(self, batch_size):
"""
随机采样一批专家数据
Returns:
(states, next_states)
"""
if len(self.states) == 0:
# 如果没有专家数据,返回零张量
return (torch.zeros(batch_size, 108, device=self.device),
torch.zeros(batch_size, 108, device=self.device))
# 使用numpy随机采样避免索引问题
indices = np.random.randint(0, len(self.states), size=batch_size)
# 将list中的tensor堆叠成batch
states_batch = torch.stack([self.states[i] for i in indices])
next_states_batch = torch.stack([self.next_states[i] for i in indices])
return states_batch, next_states_batch
class MAGAILPolicy:
"""
MAGAIL策略包装器
将MAGAIL算法包装成环境可用的策略接口
"""
def __init__(self, magail_agent, device):
self.magail = magail_agent
self.device = device
def act(self, observation=None):
"""
执行动作(与环境兼容的接口)
注意: 由于环境调用方式的限制,这里采用简化处理
实际训练时需要通过主循环统一调用
"""
# 这个方法在训练时不会被使用
# 训练时统一通过 magail.explore() 获取动作
return [0.0, 0.0]
def collect_observations(env):
"""
收集所有智能体的观测
Args:
env: 多智能体环境
Returns:
obs_array: numpy数组 (n_agents, obs_dim)
"""
obs_list = env.obs_list
if len(obs_list) == 0:
return np.array([])
return np.array(obs_list)
def train_magail(
data_dir,
output_dir="./outputs",
num_episodes=1000,
horizon=300,
rollout_length=512, # 改为512更适合300步的episode
batch_size=128, # 减小batch_size
lr_actor=3e-4,
lr_critic=3e-4,
lr_disc=3e-4,
epoch_disc=5,
epoch_ppo=10,
render=False,
device="cuda",
):
"""
MAGAIL训练主函数
Args:
data_dir: Waymo数据目录
output_dir: 输出目录(模型、日志)
num_episodes: 训练轮数
horizon: 每轮最大步数
rollout_length: PPO更新间隔
batch_size: 批次大小
lr_actor: Actor学习率
lr_critic: Critic学习率
lr_disc: 判别器学习率
epoch_disc: 判别器更新轮数
epoch_ppo: PPO更新轮数
render: 是否渲染
device: 计算设备
"""
# 创建输出目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join(output_dir, f"magail_{timestamp}")
os.makedirs(run_dir, exist_ok=True)
os.makedirs(os.path.join(run_dir, "models"), exist_ok=True)
# TensorBoard
writer = SummaryWriter(os.path.join(run_dir, "logs"))
# 设备
device = torch.device(device if torch.cuda.is_available() else "cpu")
print(f"🖥️ 使用设备: {device}")
# 观测和动作维度
# 根据scenario_env.py的_get_all_obs():
# 位置(2) + 速度(2) + 朝向(1) + 激光雷达(80) + 侧向(10) + 车道线(10) + 红绿灯(1) + 目标点(2) = 108
obs_dim = 108
action_dim = 2 # [转向, 油门/刹车]
print(f"📊 观测维度: {obs_dim}, 动作维度: {action_dim}")
# 加载专家数据
expert_buffer = ExpertBuffer(
data_dir=os.path.join(data_dir, "exp_converted"),
device=device,
max_samples=50000
)
# 初始化MAGAIL
print(f"🤖 初始化MAGAIL算法...")
magail = MAGAIL(
buffer_exp=expert_buffer,
input_dim=(obs_dim,),
device=device,
action_shape=(action_dim,),
rollout_length=rollout_length,
disc_coef=20.0,
disc_grad_penalty=0.1,
disc_logit_reg=0.25,
disc_weight_decay=0.0005,
lr_disc=lr_disc,
lr_actor=lr_actor,
lr_critic=lr_critic,
epoch_disc=epoch_disc,
epoch_ppo=epoch_ppo,
batch_size=batch_size,
use_gail_norm=True,
gamma=0.995,
lambd=0.97,
)
# 创建策略包装器
policy = MAGAILPolicy(magail, device)
# 环境配置稍后为每个episode创建
env_config = {
"data_directory": AssetLoader.file_path(data_dir, "exp_converted", unix_style=False),
"is_multi_agent": True,
"num_controlled_agents": 5,
"horizon": horizon,
"use_render": render,
"sequential_seed": True,
"reactive_traffic": True,
"manual_control": False,
"filter_offroad_vehicles": True,
"lane_tolerance": 3.0,
"max_controlled_vehicles": 5,
"debug_lane_filter": False,
"debug_traffic_light": False,
}
env = None # 每个episode创建新环境
print(f"\n{'='*60}")
print(f"🚀 开始训练 MAGAIL")
print(f"{'='*60}")
print(f"训练轮数: {num_episodes}")
print(f"每轮步数: {horizon}")
print(f"更新间隔: {rollout_length}")
print(f"输出目录: {run_dir}")
print(f"{'='*60}\n")
# 训练循环
total_steps = 0
best_reward = -float('inf')
for episode in range(num_episodes):
# 为每个episode创建新环境避免MetaDrive的对象清理问题
if env is not None:
env.close()
print(f"🌍 初始化Episode {episode + 1}环境...")
env = MultiAgentScenarioEnv(config=env_config, agent2policy=policy)
# 重置环境(场景索引要在范围内,循环使用场景)
scenario_index = episode % 3 # 只有3个场景循环使用
obs_list = env.reset(scenario_index)
episode_reward = 0
episode_length = 0
# 检查是否有车辆
if len(env.controlled_agents) == 0:
print(f"⚠️ Episode {episode}: 没有可控车辆,跳过")
continue
print(f"\n📍 Episode {episode + 1}/{num_episodes}")
print(f" 可控车辆数: {len(env.controlled_agents)}")
for step in range(horizon):
# 收集观测
obs_array = collect_observations(env)
if len(obs_array) == 0:
break
# 策略采样动作
actions, log_pis = magail.explore(obs_array)
# 调试:打印第一步的动作(查看动作范围)
if step == 0 and episode == 0:
print(f"\n🔍 调试信息 - 第一个动作:")
print(f" 动作数量: {len(actions)}")
if len(actions) > 0:
print(f" 第一个动作: {actions[0]}")
print(f" 动作范围: [{np.min(actions):.3f}, {np.max(actions):.3f}]")
# 检查车辆初始位置
first_vehicle = list(env.controlled_agents.values())[0]
print(f" 第一辆车初始位置: {first_vehicle.position}")
print(f" 第一辆车初始速度: {first_vehicle.speed:.2f} m/s")
# 每50步打印一次位置变化
if step % 50 == 0 and step > 0 and episode == 0:
if len(env.controlled_agents) > 0:
first_vehicle = list(env.controlled_agents.values())[0]
print(f" 步数{step}: 位置={first_vehicle.position}, 速度={first_vehicle.speed:.2f}m/s")
# 构建动作字典
action_dict = {}
for i, agent_id in enumerate(env.controlled_agents.keys()):
if i < len(actions):
action_dict[agent_id] = actions[i]
# 环境步进
next_obs_list, rewards, dones, infos = env.step(action_dict)
next_obs_array = collect_observations(env)
# 渲染
if render:
env.render(mode="topdown")
time.sleep(0.02) # 20ms延迟让渲染更平滑约50fps
# 存储经验到buffer为每个智能体存储
for i, agent_id in enumerate(env.controlled_agents.keys()):
if i < len(obs_array) and i < len(actions) and i < len(next_obs_array):
# 获取该智能体的数据
state = obs_array[i]
action = actions[i]
reward = rewards.get(agent_id, 0.0)
done = dones.get(agent_id, False)
tm_done = done # 暂时使用相同的done标志
log_pi = log_pis[i]
next_state = next_obs_array[i]
# 获取策略参数
mean = magail.actor.means[i].detach().cpu().numpy() if i < len(magail.actor.means) else np.zeros(action_dim)
std = magail.actor.log_stds.exp()[0].detach().cpu().numpy()
# 存储到buffer
try:
magail.buffer.append(
state=torch.tensor(state, dtype=torch.float32, device=device),
action=action,
reward=reward,
done=done,
tm_dones=tm_done, # 修正参数名
log_pi=log_pi,
next_state=next_state,
next_state_gail=next_state,
means=mean,
stds=std
)
# 调试只在第一个episode打印一次
if episode == 0 and step == 0 and i == 0:
print(f" ✅ 成功存入第一条数据 (buffer._n={magail.buffer._n})")
except Exception as e:
# 打印错误信息
if episode == 0 and step == 0:
print(f" ❌ buffer存储失败: {e}")
import traceback
traceback.print_exc()
# 计算平均奖励
avg_reward = np.mean(list(rewards.values())) if rewards else 0.0
episode_reward += avg_reward
episode_length += 1
total_steps += 1
# 检查是否结束
if dones.get("__all__", False):
break
# 定期更新
if total_steps % rollout_length == 0 and total_steps > 0:
print(f"\n 🔄 步数 {total_steps}: 更新模型...")
print(f" Buffer状态: _n={magail.buffer._n}, _p={magail.buffer._p}, buffer_size={magail.buffer.buffer_size}")
# 检查buffer是否有足够数据
if magail.buffer._n < batch_size:
print(f" ⚠️ Buffer数据不足: _n={magail.buffer._n} < batch_size={batch_size}, 跳过本次更新")
continue
try:
# 调用MAGAIL更新
gail_reward = magail.update(writer, total_steps)
print(f" GAIL奖励: {gail_reward:.4f}")
writer.add_scalar('Training/GAILReward', gail_reward, total_steps)
except Exception as e:
print(f" ⚠️ 更新失败: {e}")
if episode < 5: # 只在前几个episode打印详细信息
import traceback
traceback.print_exc()
# 记录训练指标
writer.add_scalar('Training/EpisodeReward', episode_reward, episode)
writer.add_scalar('Training/EpisodeLength', episode_length, episode)
# Episode结束
avg_episode_reward = episode_reward / max(episode_length, 1)
print(f" ✅ Episode {episode + 1} 完成:")
print(f" 步数: {episode_length}")
print(f" 总奖励: {episode_reward:.2f}")
print(f" 平均奖励: {avg_episode_reward:.4f}")
print(f" 车辆数: {len(env.controlled_agents)}")
# 记录到TensorBoard
writer.add_scalar('Episode/Reward', episode_reward, episode)
writer.add_scalar('Episode/Length', episode_length, episode)
writer.add_scalar('Episode/AvgReward', avg_episode_reward, episode)
writer.add_scalar('Episode/NumVehicles', len(env.controlled_agents), episode)
writer.add_scalar('Training/TotalSteps', total_steps, episode)
# 保存最佳模型
if episode_reward > best_reward:
best_reward = episode_reward
save_path = os.path.join(run_dir, "models", "best_model")
magail.save_models(save_path)
print(f" 💾 保存最佳模型 (奖励: {best_reward:.2f})")
# 定期保存
if (episode + 1) % 50 == 0:
save_path = os.path.join(run_dir, "models", f"checkpoint_{episode + 1}")
magail.save_models(save_path)
print(f" 💾 保存检查点: {save_path}")
# 训练结束
print(f"\n{'='*60}")
print(f"✅ 训练完成!")
print(f"{'='*60}")
print(f"总步数: {total_steps}")
print(f"最佳奖励: {best_reward:.2f}")
print(f"模型保存位置: {run_dir}/models")
print(f"日志位置: {run_dir}/logs")
# 关闭
env.close()
writer.close()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="训练MAGAIL算法")
parser.add_argument("--data-dir", type=str,
default="/home/huangfukk/MAGAIL4AutoDrive/Env",
help="Waymo数据目录")
parser.add_argument("--output-dir", type=str,
default="./outputs",
help="输出目录")
parser.add_argument("--episodes", type=int, default=1000,
help="训练轮数")
parser.add_argument("--horizon", type=int, default=300,
help="每轮最大步数")
parser.add_argument("--rollout-length", type=int, default=2048,
help="PPO更新间隔")
parser.add_argument("--batch-size", type=int, default=256,
help="批次大小")
parser.add_argument("--lr-actor", type=float, default=3e-4,
help="Actor学习率")
parser.add_argument("--lr-critic", type=float, default=3e-4,
help="Critic学习率")
parser.add_argument("--lr-disc", type=float, default=3e-4,
help="判别器学习率")
parser.add_argument("--epoch-disc", type=int, default=5,
help="判别器更新轮数")
parser.add_argument("--epoch-ppo", type=int, default=10,
help="PPO更新轮数")
parser.add_argument("--render", action="store_true",
help="是否渲染")
parser.add_argument("--device", type=str, default="cuda",
choices=["cuda", "cpu"],
help="计算设备")
args = parser.parse_args()
train_magail(
data_dir=args.data_dir,
output_dir=args.output_dir,
num_episodes=args.episodes,
horizon=args.horizon,
rollout_length=args.rollout_length,
batch_size=args.batch_size,
lr_actor=args.lr_actor,
lr_critic=args.lr_critic,
lr_disc=args.lr_disc,
epoch_disc=args.epoch_disc,
epoch_ppo=args.epoch_ppo,
render=args.render,
device=args.device,
)