7.3 KiB
7.3 KiB
问题解决记录
运行 train_magail.py 遇到的问题及解决方案
❌ 问题1: 模块导入错误
错误信息:
ImportError: attempted relative import with no known parent package
原因:
Algorithm/文件夹中的文件使用了不一致的导入方式- 有些文件使用相对导入(
.ppo,.utils),有些使用绝对导入
解决方案:
- 修改所有导入语句为兼容模式(try相对导入,except绝对导入)
- 创建
Algorithm/__init__.py和Env/__init__.py使其成为Python包
修改的文件:
- ✅
Algorithm/ppo.py- 添加try-except导入 - ✅
Algorithm/policy.py- 添加try-except导入 - ✅
Algorithm/disc.py- 添加try-except导入 - ✅
Algorithm/magail.py- 添加try-except导入 - ✅
Algorithm/__init__.py- 新建包初始化文件 - ✅
Env/__init__.py- 新建包初始化文件
❌ 问题2: action_shape 参数缺失
错误信息:
TypeError: MAGAIL.__init__() got an unexpected keyword argument 'action_shape'
原因:
MAGAIL类的__init__()方法没有定义action_shape参数- 但训练脚本传递了这个参数
解决方案:
在 MAGAIL.__init__() 中添加 action_shape 参数并传递给父类
修改代码:
# 修改前
class MAGAIL(PPO):
def __init__(self, buffer_exp, input_dim, device, ...):
super().__init__(state_shape=input_dim, device=device)
# 修改后
class MAGAIL(PPO):
def __init__(self, buffer_exp, input_dim, device, action_shape=(2,), **kwargs):
super().__init__(state_shape=input_dim, device=device,
action_shape=action_shape, **kwargs)
❌ 问题3: 维度类型不匹配
错误信息:
TypeError: empty(): argument 'size' must be tuple of ints,
but found element of type tuple at pos 2
原因:
input_dim传入时是元组(108,)- 但
Bert的nn.Linear(input_dim, embed_dim)期望input_dim是整数
根本原因:
# 训练脚本中
magail = MAGAIL(
input_dim=(108,), # 传入的是元组
...
)
# 在 PPO.__init__ 中
self.critic = Bert(
input_dim=state_shape, # state_shape = (108,) 仍是元组
output_dim=1
)
# 在 Bert.__init__ 中
self.projection = nn.Linear(input_dim, embed_dim) # 期望 input_dim=108
解决方案:
在使用 state_shape 或 input_dim 创建 Bert 时,提取整数值
修改代码:
- 在
Algorithm/ppo.py中:
# 修改前
self.critic = Bert(
input_dim=state_shape,
output_dim=1
).to(device)
# 修改后
state_dim = state_shape[0] if isinstance(state_shape, tuple) else state_shape
self.critic = Bert(
input_dim=state_dim,
output_dim=1
).to(device)
- 在
Algorithm/magail.py中:
# 修改前
self.disc = GAILDiscrim(input_dim=input_dim)
# 修改后
state_dim = input_dim[0] if isinstance(input_dim, tuple) else input_dim
self.disc = GAILDiscrim(input_dim=state_dim)
- 在
Algorithm/magail.py的Normalizer中:
# 修改前
self.normalizer = Normalizer(self.state_shape[0]*2)
# 修改后
state_dim = self.state_shape[0] if isinstance(self.state_shape, tuple) else self.state_shape
self.normalizer = Normalizer(state_dim*2)
所有修改总结
修改的文件列表
| 文件 | 修改内容 |
|---|---|
Algorithm/ppo.py |
1. 添加try-except导入 2. 添加action_shape参数 3. 修复state_shape元组问题 |
Algorithm/magail.py |
1. 添加try-except导入 2. 添加action_shape参数 3. 修复input_dim元组问题 |
Algorithm/policy.py |
添加try-except导入 |
Algorithm/disc.py |
添加try-except导入 |
Algorithm/__init__.py |
新建包初始化文件 |
Env/__init__.py |
新建包初始化文件 |
新增的文件
| 文件 | 用途 |
|---|---|
train_magail.py |
完整的MAGAIL训练脚本 |
MAGAIL算法应用指南.md |
详细的使用文档 |
快速测试.sh |
快速测试脚本 |
问题解决记录.md |
本文件 |
如何运行
方法1: 直接运行(推荐)
# 确保在conda环境中
conda activate metadrive
# 运行训练(少量episode测试)
python train_magail.py --episodes 5 --horizon 100
# 带可视化
python train_magail.py --episodes 5 --render
# 完整训练
python train_magail.py --episodes 1000 --horizon 300
方法2: 使用测试脚本
conda activate metadrive
bash 快速测试.sh
方法3: 查看帮助
python train_magail.py --help
验证是否修复成功
运行以下命令测试:
# 测试1: 验证导入
python -c "from Algorithm.magail import MAGAIL; print('✅ 导入成功!')"
# 测试2: 验证初始化
python -c "
from Algorithm.magail import MAGAIL
import torch
class DummyBuffer:
def __init__(self, device):
self.device = device
def sample(self, batch_size):
return torch.randn(batch_size, 108, device=self.device), \
torch.randn(batch_size, 108, device=self.device)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
buffer = DummyBuffer(device)
magail = MAGAIL(
buffer_exp=buffer,
input_dim=(108,),
device=device,
action_shape=(2,)
)
print('✅ MAGAIL初始化成功!')
"
# 测试3: 运行1个episode
python train_magail.py --episodes 1 --horizon 50
❌ 问题4: 保存模型时目录不存在
错误信息:
FileNotFoundError: [Errno 2] 没有那个文件或目录: './outputs/magail_xxx/models/best_model/model.pth'
原因:
save_models()方法直接保存文件,但没有创建目录
解决方案:
在保存前使用 os.makedirs(path, exist_ok=True) 创建目录
修改代码:
# 修改前
def save_models(self, path):
torch.save({...}, os.path.join(path, 'model.pth'))
# 修改后
def save_models(self, path):
os.makedirs(path, exist_ok=True) # 确保目录存在
torch.save({...}, os.path.join(path, 'model.pth'))
剩余的已知问题
1. 专家数据加载未实现
现象:
⚠️ 警告: 未找到专家数据
原因:
ExpertBuffer._extract_trajectories() 方法是空的,需要根据实际Waymo数据格式实现
TODO:
- 分析Waymo pkl文件的实际结构
- 实现轨迹提取逻辑
2. 多智能体buffer存储
现象: 当前buffer是为单智能体设计的
TODO:
- 实现多智能体经验存储逻辑
- 或者修改为将所有智能体作为一个batch处理
3. 环境奖励为0
现象: 环境返回的奖励全是0
TODO:
- 设计合理的任务奖励函数
- 或者完全依赖GAIL的内在奖励
下一步计划
- ✅ 已完成: 修复导入问题
- ✅ 已完成: 修复参数问题
- ✅ 已完成: 修复维度问题
- ✅ 已完成: 修复保存模型目录问题
- ✅ 已完成: 测试完整训练流程 - 训练脚本可以正常运行!
- 📝 待办: 实现专家数据加载
- 📝 待办: 完善多智能体buffer
- 📝 待办: 设计奖励函数
- 📝 待办: 超参数调优
参考文档
技术说明文档.md- 完整的技术实现细节MAGAIL算法应用指南.md- 使用指南和示例README.md- 项目概述
最后更新: 2024 状态: ✅ 主要问题已解决,可以开始训练