564 lines
21 KiB
Python
564 lines
21 KiB
Python
|
|
"""
|
|||
|
|
MAGAIL训练脚本
|
|||
|
|
|
|||
|
|
将Algorithm模块中的MAGAIL算法应用到多智能体环境中进行训练
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import torch
|
|||
|
|
import numpy as np
|
|||
|
|
import pickle
|
|||
|
|
import time
|
|||
|
|
from datetime import datetime
|
|||
|
|
from torch.utils.tensorboard import SummaryWriter
|
|||
|
|
|
|||
|
|
# 添加路径
|
|||
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'Algorithm'))
|
|||
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'Env'))
|
|||
|
|
|
|||
|
|
from Algorithm.magail import MAGAIL
|
|||
|
|
from Algorithm.buffer import RolloutBuffer
|
|||
|
|
from Env.scenario_env import MultiAgentScenarioEnv
|
|||
|
|
from metadrive.engine.asset_loader import AssetLoader
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ExpertBuffer:
|
|||
|
|
"""
|
|||
|
|
专家数据缓冲区
|
|||
|
|
|
|||
|
|
从Waymo数据集中加载专家轨迹,用于GAIL判别器训练
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, data_dir, device, max_samples=100000):
|
|||
|
|
"""
|
|||
|
|
初始化专家缓冲区
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data_dir: 专家数据目录
|
|||
|
|
device: 计算设备
|
|||
|
|
max_samples: 最大样本数
|
|||
|
|
"""
|
|||
|
|
self.device = device
|
|||
|
|
self.max_samples = max_samples
|
|||
|
|
self.states = []
|
|||
|
|
self.next_states = []
|
|||
|
|
|
|||
|
|
print(f"📚 加载专家数据从: {data_dir}")
|
|||
|
|
self._load_expert_data(data_dir)
|
|||
|
|
|
|||
|
|
# 数据已经在_extract_trajectories中转换为tensor并放到设备上了
|
|||
|
|
if len(self.states) > 0:
|
|||
|
|
print(f"✅ 加载完成: {len(self.states)} 条专家轨迹")
|
|||
|
|
else:
|
|||
|
|
print(f"⚠️ 警告: 未找到专家数据")
|
|||
|
|
|
|||
|
|
def _load_expert_data(self, data_dir):
|
|||
|
|
"""
|
|||
|
|
从pkl文件加载专家数据
|
|||
|
|
|
|||
|
|
注意: 这里需要根据实际的数据格式进行调整
|
|||
|
|
"""
|
|||
|
|
# 查找所有pkl文件
|
|||
|
|
pkl_files = []
|
|||
|
|
for root, dirs, files in os.walk(data_dir):
|
|||
|
|
for file in files:
|
|||
|
|
if file.endswith('.pkl') and 'sd_waymo' in file:
|
|||
|
|
pkl_files.append(os.path.join(root, file))
|
|||
|
|
|
|||
|
|
if not pkl_files:
|
|||
|
|
print(f"⚠️ 在 {data_dir} 中未找到专家数据文件")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
print(f" 找到 {len(pkl_files)} 个数据文件")
|
|||
|
|
|
|||
|
|
# 只加载第一个文件作为示例
|
|||
|
|
# 实际使用时可以加载多个文件
|
|||
|
|
for pkl_file in pkl_files[:1]: # 只加载第一个文件
|
|||
|
|
try:
|
|||
|
|
with open(pkl_file, 'rb') as f:
|
|||
|
|
data = pickle.load(f)
|
|||
|
|
print(f" 正在处理: {os.path.basename(pkl_file)}")
|
|||
|
|
self._extract_trajectories(data)
|
|||
|
|
|
|||
|
|
if len(self.states) >= self.max_samples:
|
|||
|
|
break
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f" ⚠️ 加载 {pkl_file} 失败: {e}")
|
|||
|
|
|
|||
|
|
def _extract_trajectories(self, scenario_data):
|
|||
|
|
"""
|
|||
|
|
从MetaDrive场景数据中提取车辆轨迹
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
scenario_data: MetaDrive格式的场景数据字典
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 方法1: 如果是字典且有'tracks'键
|
|||
|
|
if isinstance(scenario_data, dict) and 'tracks' in scenario_data:
|
|||
|
|
for vehicle_id, track_data in scenario_data['tracks'].items():
|
|||
|
|
if track_data.get('type') == 'VEHICLE':
|
|||
|
|
states = track_data.get('state', {})
|
|||
|
|
|
|||
|
|
# 获取有效帧
|
|||
|
|
valid = states.get('valid', [])
|
|||
|
|
if not hasattr(valid, 'any') or not valid.any():
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 提取位置、速度、朝向等
|
|||
|
|
positions = states.get('position', [])
|
|||
|
|
velocities = states.get('velocity', [])
|
|||
|
|
headings = states.get('heading', [])
|
|||
|
|
|
|||
|
|
# 构建state序列
|
|||
|
|
for t in range(len(positions) - 1):
|
|||
|
|
if valid[t] and valid[t+1]:
|
|||
|
|
# 当前状态
|
|||
|
|
state = np.concatenate([
|
|||
|
|
positions[t][:2], # x, y
|
|||
|
|
velocities[t], # vx, vy
|
|||
|
|
[headings[t]], # heading
|
|||
|
|
# ... 其他观测维度(激光雷达等暂时用0填充)
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# 下一状态
|
|||
|
|
next_state = np.concatenate([
|
|||
|
|
positions[t+1][:2],
|
|||
|
|
velocities[t+1],
|
|||
|
|
[headings[t+1]],
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# 补齐到108维(匹配实际观测维度)
|
|||
|
|
state = np.pad(state, (0, 108 - len(state)))
|
|||
|
|
next_state = np.pad(next_state, (0, 108 - len(next_state)))
|
|||
|
|
|
|||
|
|
# 转换为tensor并移到指定设备
|
|||
|
|
self.states.append(torch.tensor(state, dtype=torch.float32, device=self.device))
|
|||
|
|
self.next_states.append(torch.tensor(next_state, dtype=torch.float32, device=self.device))
|
|||
|
|
|
|||
|
|
if len(self.states) >= self.max_samples:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 方法2: 其他可能的格式
|
|||
|
|
# ...
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f" ⚠️ 提取轨迹失败: {e}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
|
|||
|
|
def sample(self, batch_size):
|
|||
|
|
"""
|
|||
|
|
随机采样一批专家数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(states, next_states)
|
|||
|
|
"""
|
|||
|
|
if len(self.states) == 0:
|
|||
|
|
# 如果没有专家数据,返回零张量
|
|||
|
|
return (torch.zeros(batch_size, 108, device=self.device),
|
|||
|
|
torch.zeros(batch_size, 108, device=self.device))
|
|||
|
|
|
|||
|
|
# 使用numpy随机采样(避免索引问题)
|
|||
|
|
indices = np.random.randint(0, len(self.states), size=batch_size)
|
|||
|
|
# 将list中的tensor堆叠成batch
|
|||
|
|
states_batch = torch.stack([self.states[i] for i in indices])
|
|||
|
|
next_states_batch = torch.stack([self.next_states[i] for i in indices])
|
|||
|
|
return states_batch, next_states_batch
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MAGAILPolicy:
|
|||
|
|
"""
|
|||
|
|
MAGAIL策略包装器
|
|||
|
|
|
|||
|
|
将MAGAIL算法包装成环境可用的策略接口
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, magail_agent, device):
|
|||
|
|
self.magail = magail_agent
|
|||
|
|
self.device = device
|
|||
|
|
|
|||
|
|
def act(self, observation=None):
|
|||
|
|
"""
|
|||
|
|
执行动作(与环境兼容的接口)
|
|||
|
|
|
|||
|
|
注意: 由于环境调用方式的限制,这里采用简化处理
|
|||
|
|
实际训练时需要通过主循环统一调用
|
|||
|
|
"""
|
|||
|
|
# 这个方法在训练时不会被使用
|
|||
|
|
# 训练时统一通过 magail.explore() 获取动作
|
|||
|
|
return [0.0, 0.0]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def collect_observations(env):
|
|||
|
|
"""
|
|||
|
|
收集所有智能体的观测
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
env: 多智能体环境
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
obs_array: numpy数组 (n_agents, obs_dim)
|
|||
|
|
"""
|
|||
|
|
obs_list = env.obs_list
|
|||
|
|
if len(obs_list) == 0:
|
|||
|
|
return np.array([])
|
|||
|
|
return np.array(obs_list)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def train_magail(
|
|||
|
|
data_dir,
|
|||
|
|
output_dir="./outputs",
|
|||
|
|
num_episodes=1000,
|
|||
|
|
horizon=300,
|
|||
|
|
rollout_length=512, # 改为512,更适合300步的episode
|
|||
|
|
batch_size=128, # 减小batch_size
|
|||
|
|
lr_actor=3e-4,
|
|||
|
|
lr_critic=3e-4,
|
|||
|
|
lr_disc=3e-4,
|
|||
|
|
epoch_disc=5,
|
|||
|
|
epoch_ppo=10,
|
|||
|
|
render=False,
|
|||
|
|
device="cuda",
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
MAGAIL训练主函数
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data_dir: Waymo数据目录
|
|||
|
|
output_dir: 输出目录(模型、日志)
|
|||
|
|
num_episodes: 训练轮数
|
|||
|
|
horizon: 每轮最大步数
|
|||
|
|
rollout_length: PPO更新间隔
|
|||
|
|
batch_size: 批次大小
|
|||
|
|
lr_actor: Actor学习率
|
|||
|
|
lr_critic: Critic学习率
|
|||
|
|
lr_disc: 判别器学习率
|
|||
|
|
epoch_disc: 判别器更新轮数
|
|||
|
|
epoch_ppo: PPO更新轮数
|
|||
|
|
render: 是否渲染
|
|||
|
|
device: 计算设备
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 创建输出目录
|
|||
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|||
|
|
run_dir = os.path.join(output_dir, f"magail_{timestamp}")
|
|||
|
|
os.makedirs(run_dir, exist_ok=True)
|
|||
|
|
os.makedirs(os.path.join(run_dir, "models"), exist_ok=True)
|
|||
|
|
|
|||
|
|
# TensorBoard
|
|||
|
|
writer = SummaryWriter(os.path.join(run_dir, "logs"))
|
|||
|
|
|
|||
|
|
# 设备
|
|||
|
|
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|||
|
|
print(f"🖥️ 使用设备: {device}")
|
|||
|
|
|
|||
|
|
# 观测和动作维度
|
|||
|
|
# 根据scenario_env.py的_get_all_obs():
|
|||
|
|
# 位置(2) + 速度(2) + 朝向(1) + 激光雷达(80) + 侧向(10) + 车道线(10) + 红绿灯(1) + 目标点(2) = 108
|
|||
|
|
obs_dim = 108
|
|||
|
|
action_dim = 2 # [转向, 油门/刹车]
|
|||
|
|
|
|||
|
|
print(f"📊 观测维度: {obs_dim}, 动作维度: {action_dim}")
|
|||
|
|
|
|||
|
|
# 加载专家数据
|
|||
|
|
expert_buffer = ExpertBuffer(
|
|||
|
|
data_dir=os.path.join(data_dir, "exp_converted"),
|
|||
|
|
device=device,
|
|||
|
|
max_samples=50000
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 初始化MAGAIL
|
|||
|
|
print(f"🤖 初始化MAGAIL算法...")
|
|||
|
|
magail = MAGAIL(
|
|||
|
|
buffer_exp=expert_buffer,
|
|||
|
|
input_dim=(obs_dim,),
|
|||
|
|
device=device,
|
|||
|
|
action_shape=(action_dim,),
|
|||
|
|
rollout_length=rollout_length,
|
|||
|
|
disc_coef=20.0,
|
|||
|
|
disc_grad_penalty=0.1,
|
|||
|
|
disc_logit_reg=0.25,
|
|||
|
|
disc_weight_decay=0.0005,
|
|||
|
|
lr_disc=lr_disc,
|
|||
|
|
lr_actor=lr_actor,
|
|||
|
|
lr_critic=lr_critic,
|
|||
|
|
epoch_disc=epoch_disc,
|
|||
|
|
epoch_ppo=epoch_ppo,
|
|||
|
|
batch_size=batch_size,
|
|||
|
|
use_gail_norm=True,
|
|||
|
|
gamma=0.995,
|
|||
|
|
lambd=0.97,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建策略包装器
|
|||
|
|
policy = MAGAILPolicy(magail, device)
|
|||
|
|
|
|||
|
|
# 环境配置(稍后为每个episode创建)
|
|||
|
|
env_config = {
|
|||
|
|
"data_directory": AssetLoader.file_path(data_dir, "exp_converted", unix_style=False),
|
|||
|
|
"is_multi_agent": True,
|
|||
|
|
"num_controlled_agents": 5,
|
|||
|
|
"horizon": horizon,
|
|||
|
|
"use_render": render,
|
|||
|
|
"sequential_seed": True,
|
|||
|
|
"reactive_traffic": True,
|
|||
|
|
"manual_control": False,
|
|||
|
|
"filter_offroad_vehicles": True,
|
|||
|
|
"lane_tolerance": 3.0,
|
|||
|
|
"max_controlled_vehicles": 5,
|
|||
|
|
"debug_lane_filter": False,
|
|||
|
|
"debug_traffic_light": False,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
env = None # 每个episode创建新环境
|
|||
|
|
|
|||
|
|
print(f"\n{'='*60}")
|
|||
|
|
print(f"🚀 开始训练 MAGAIL")
|
|||
|
|
print(f"{'='*60}")
|
|||
|
|
print(f"训练轮数: {num_episodes}")
|
|||
|
|
print(f"每轮步数: {horizon}")
|
|||
|
|
print(f"更新间隔: {rollout_length}")
|
|||
|
|
print(f"输出目录: {run_dir}")
|
|||
|
|
print(f"{'='*60}\n")
|
|||
|
|
|
|||
|
|
# 训练循环
|
|||
|
|
total_steps = 0
|
|||
|
|
best_reward = -float('inf')
|
|||
|
|
|
|||
|
|
for episode in range(num_episodes):
|
|||
|
|
# 为每个episode创建新环境(避免MetaDrive的对象清理问题)
|
|||
|
|
if env is not None:
|
|||
|
|
env.close()
|
|||
|
|
|
|||
|
|
print(f"🌍 初始化Episode {episode + 1}环境...")
|
|||
|
|
env = MultiAgentScenarioEnv(config=env_config, agent2policy=policy)
|
|||
|
|
|
|||
|
|
# 重置环境(场景索引要在范围内,循环使用场景)
|
|||
|
|
scenario_index = episode % 3 # 只有3个场景,循环使用
|
|||
|
|
obs_list = env.reset(scenario_index)
|
|||
|
|
episode_reward = 0
|
|||
|
|
episode_length = 0
|
|||
|
|
|
|||
|
|
# 检查是否有车辆
|
|||
|
|
if len(env.controlled_agents) == 0:
|
|||
|
|
print(f"⚠️ Episode {episode}: 没有可控车辆,跳过")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print(f"\n📍 Episode {episode + 1}/{num_episodes}")
|
|||
|
|
print(f" 可控车辆数: {len(env.controlled_agents)}")
|
|||
|
|
|
|||
|
|
for step in range(horizon):
|
|||
|
|
# 收集观测
|
|||
|
|
obs_array = collect_observations(env)
|
|||
|
|
|
|||
|
|
if len(obs_array) == 0:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 策略采样动作
|
|||
|
|
actions, log_pis = magail.explore(obs_array)
|
|||
|
|
|
|||
|
|
# 调试:打印第一步的动作(查看动作范围)
|
|||
|
|
if step == 0 and episode == 0:
|
|||
|
|
print(f"\n🔍 调试信息 - 第一个动作:")
|
|||
|
|
print(f" 动作数量: {len(actions)}")
|
|||
|
|
if len(actions) > 0:
|
|||
|
|
print(f" 第一个动作: {actions[0]}")
|
|||
|
|
print(f" 动作范围: [{np.min(actions):.3f}, {np.max(actions):.3f}]")
|
|||
|
|
# 检查车辆初始位置
|
|||
|
|
first_vehicle = list(env.controlled_agents.values())[0]
|
|||
|
|
print(f" 第一辆车初始位置: {first_vehicle.position}")
|
|||
|
|
print(f" 第一辆车初始速度: {first_vehicle.speed:.2f} m/s")
|
|||
|
|
|
|||
|
|
# 每50步打印一次位置变化
|
|||
|
|
if step % 50 == 0 and step > 0 and episode == 0:
|
|||
|
|
if len(env.controlled_agents) > 0:
|
|||
|
|
first_vehicle = list(env.controlled_agents.values())[0]
|
|||
|
|
print(f" 步数{step}: 位置={first_vehicle.position}, 速度={first_vehicle.speed:.2f}m/s")
|
|||
|
|
|
|||
|
|
# 构建动作字典
|
|||
|
|
action_dict = {}
|
|||
|
|
for i, agent_id in enumerate(env.controlled_agents.keys()):
|
|||
|
|
if i < len(actions):
|
|||
|
|
action_dict[agent_id] = actions[i]
|
|||
|
|
|
|||
|
|
# 环境步进
|
|||
|
|
next_obs_list, rewards, dones, infos = env.step(action_dict)
|
|||
|
|
next_obs_array = collect_observations(env)
|
|||
|
|
|
|||
|
|
# 渲染
|
|||
|
|
if render:
|
|||
|
|
env.render(mode="topdown")
|
|||
|
|
time.sleep(0.02) # 20ms延迟,让渲染更平滑,约50fps
|
|||
|
|
|
|||
|
|
# 存储经验到buffer(为每个智能体存储)
|
|||
|
|
for i, agent_id in enumerate(env.controlled_agents.keys()):
|
|||
|
|
if i < len(obs_array) and i < len(actions) and i < len(next_obs_array):
|
|||
|
|
# 获取该智能体的数据
|
|||
|
|
state = obs_array[i]
|
|||
|
|
action = actions[i]
|
|||
|
|
reward = rewards.get(agent_id, 0.0)
|
|||
|
|
done = dones.get(agent_id, False)
|
|||
|
|
tm_done = done # 暂时使用相同的done标志
|
|||
|
|
log_pi = log_pis[i]
|
|||
|
|
next_state = next_obs_array[i]
|
|||
|
|
|
|||
|
|
# 获取策略参数
|
|||
|
|
mean = magail.actor.means[i].detach().cpu().numpy() if i < len(magail.actor.means) else np.zeros(action_dim)
|
|||
|
|
std = magail.actor.log_stds.exp()[0].detach().cpu().numpy()
|
|||
|
|
|
|||
|
|
# 存储到buffer
|
|||
|
|
try:
|
|||
|
|
magail.buffer.append(
|
|||
|
|
state=torch.tensor(state, dtype=torch.float32, device=device),
|
|||
|
|
action=action,
|
|||
|
|
reward=reward,
|
|||
|
|
done=done,
|
|||
|
|
tm_dones=tm_done, # 修正参数名
|
|||
|
|
log_pi=log_pi,
|
|||
|
|
next_state=next_state,
|
|||
|
|
next_state_gail=next_state,
|
|||
|
|
means=mean,
|
|||
|
|
stds=std
|
|||
|
|
)
|
|||
|
|
# 调试:只在第一个episode打印一次
|
|||
|
|
if episode == 0 and step == 0 and i == 0:
|
|||
|
|
print(f" ✅ 成功存入第一条数据 (buffer._n={magail.buffer._n})")
|
|||
|
|
except Exception as e:
|
|||
|
|
# 打印错误信息
|
|||
|
|
if episode == 0 and step == 0:
|
|||
|
|
print(f" ❌ buffer存储失败: {e}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
|
|||
|
|
# 计算平均奖励
|
|||
|
|
avg_reward = np.mean(list(rewards.values())) if rewards else 0.0
|
|||
|
|
episode_reward += avg_reward
|
|||
|
|
episode_length += 1
|
|||
|
|
total_steps += 1
|
|||
|
|
|
|||
|
|
# 检查是否结束
|
|||
|
|
if dones.get("__all__", False):
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 定期更新
|
|||
|
|
if total_steps % rollout_length == 0 and total_steps > 0:
|
|||
|
|
print(f"\n 🔄 步数 {total_steps}: 更新模型...")
|
|||
|
|
print(f" Buffer状态: _n={magail.buffer._n}, _p={magail.buffer._p}, buffer_size={magail.buffer.buffer_size}")
|
|||
|
|
|
|||
|
|
# 检查buffer是否有足够数据
|
|||
|
|
if magail.buffer._n < batch_size:
|
|||
|
|
print(f" ⚠️ Buffer数据不足: _n={magail.buffer._n} < batch_size={batch_size}, 跳过本次更新")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 调用MAGAIL更新
|
|||
|
|
gail_reward = magail.update(writer, total_steps)
|
|||
|
|
print(f" GAIL奖励: {gail_reward:.4f}")
|
|||
|
|
|
|||
|
|
writer.add_scalar('Training/GAILReward', gail_reward, total_steps)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f" ⚠️ 更新失败: {e}")
|
|||
|
|
if episode < 5: # 只在前几个episode打印详细信息
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
|
|||
|
|
# 记录训练指标
|
|||
|
|
writer.add_scalar('Training/EpisodeReward', episode_reward, episode)
|
|||
|
|
writer.add_scalar('Training/EpisodeLength', episode_length, episode)
|
|||
|
|
|
|||
|
|
# Episode结束
|
|||
|
|
avg_episode_reward = episode_reward / max(episode_length, 1)
|
|||
|
|
print(f" ✅ Episode {episode + 1} 完成:")
|
|||
|
|
print(f" 步数: {episode_length}")
|
|||
|
|
print(f" 总奖励: {episode_reward:.2f}")
|
|||
|
|
print(f" 平均奖励: {avg_episode_reward:.4f}")
|
|||
|
|
print(f" 车辆数: {len(env.controlled_agents)}")
|
|||
|
|
|
|||
|
|
# 记录到TensorBoard
|
|||
|
|
writer.add_scalar('Episode/Reward', episode_reward, episode)
|
|||
|
|
writer.add_scalar('Episode/Length', episode_length, episode)
|
|||
|
|
writer.add_scalar('Episode/AvgReward', avg_episode_reward, episode)
|
|||
|
|
writer.add_scalar('Episode/NumVehicles', len(env.controlled_agents), episode)
|
|||
|
|
writer.add_scalar('Training/TotalSteps', total_steps, episode)
|
|||
|
|
|
|||
|
|
# 保存最佳模型
|
|||
|
|
if episode_reward > best_reward:
|
|||
|
|
best_reward = episode_reward
|
|||
|
|
save_path = os.path.join(run_dir, "models", "best_model")
|
|||
|
|
magail.save_models(save_path)
|
|||
|
|
print(f" 💾 保存最佳模型 (奖励: {best_reward:.2f})")
|
|||
|
|
|
|||
|
|
# 定期保存
|
|||
|
|
if (episode + 1) % 50 == 0:
|
|||
|
|
save_path = os.path.join(run_dir, "models", f"checkpoint_{episode + 1}")
|
|||
|
|
magail.save_models(save_path)
|
|||
|
|
print(f" 💾 保存检查点: {save_path}")
|
|||
|
|
|
|||
|
|
# 训练结束
|
|||
|
|
print(f"\n{'='*60}")
|
|||
|
|
print(f"✅ 训练完成!")
|
|||
|
|
print(f"{'='*60}")
|
|||
|
|
print(f"总步数: {total_steps}")
|
|||
|
|
print(f"最佳奖励: {best_reward:.2f}")
|
|||
|
|
print(f"模型保存位置: {run_dir}/models")
|
|||
|
|
print(f"日志位置: {run_dir}/logs")
|
|||
|
|
|
|||
|
|
# 关闭
|
|||
|
|
env.close()
|
|||
|
|
writer.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
import argparse
|
|||
|
|
|
|||
|
|
parser = argparse.ArgumentParser(description="训练MAGAIL算法")
|
|||
|
|
parser.add_argument("--data-dir", type=str,
|
|||
|
|
default="/home/huangfukk/MAGAIL4AutoDrive/Env",
|
|||
|
|
help="Waymo数据目录")
|
|||
|
|
parser.add_argument("--output-dir", type=str,
|
|||
|
|
default="./outputs",
|
|||
|
|
help="输出目录")
|
|||
|
|
parser.add_argument("--episodes", type=int, default=1000,
|
|||
|
|
help="训练轮数")
|
|||
|
|
parser.add_argument("--horizon", type=int, default=300,
|
|||
|
|
help="每轮最大步数")
|
|||
|
|
parser.add_argument("--rollout-length", type=int, default=2048,
|
|||
|
|
help="PPO更新间隔")
|
|||
|
|
parser.add_argument("--batch-size", type=int, default=256,
|
|||
|
|
help="批次大小")
|
|||
|
|
parser.add_argument("--lr-actor", type=float, default=3e-4,
|
|||
|
|
help="Actor学习率")
|
|||
|
|
parser.add_argument("--lr-critic", type=float, default=3e-4,
|
|||
|
|
help="Critic学习率")
|
|||
|
|
parser.add_argument("--lr-disc", type=float, default=3e-4,
|
|||
|
|
help="判别器学习率")
|
|||
|
|
parser.add_argument("--epoch-disc", type=int, default=5,
|
|||
|
|
help="判别器更新轮数")
|
|||
|
|
parser.add_argument("--epoch-ppo", type=int, default=10,
|
|||
|
|
help="PPO更新轮数")
|
|||
|
|
parser.add_argument("--render", action="store_true",
|
|||
|
|
help="是否渲染")
|
|||
|
|
parser.add_argument("--device", type=str, default="cuda",
|
|||
|
|
choices=["cuda", "cpu"],
|
|||
|
|
help="计算设备")
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
train_magail(
|
|||
|
|
data_dir=args.data_dir,
|
|||
|
|
output_dir=args.output_dir,
|
|||
|
|
num_episodes=args.episodes,
|
|||
|
|
horizon=args.horizon,
|
|||
|
|
rollout_length=args.rollout_length,
|
|||
|
|
batch_size=args.batch_size,
|
|||
|
|
lr_actor=args.lr_actor,
|
|||
|
|
lr_critic=args.lr_critic,
|
|||
|
|
lr_disc=args.lr_disc,
|
|||
|
|
epoch_disc=args.epoch_disc,
|
|||
|
|
epoch_ppo=args.epoch_ppo,
|
|||
|
|
render=args.render,
|
|||
|
|
device=args.device,
|
|||
|
|
)
|
|||
|
|
|