178 lines
6.2 KiB
Python
178 lines
6.2 KiB
Python
import sys
|
|
import os
|
|
|
|
# 添加路径
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
project_root = os.path.dirname(current_dir)
|
|
env_dir = os.path.join(project_root, "Env")
|
|
|
|
sys.path.insert(0, project_root)
|
|
sys.path.insert(0, env_dir)
|
|
|
|
from scenario_env import MultiAgentScenarioEnv
|
|
from metadrive.engine.asset_loader import AssetLoader
|
|
import numpy as np
|
|
|
|
class DummyPolicy:
|
|
"""
|
|
占位策略,用于数据检查时初始化环境
|
|
不需要实际执行动作,只是为了满足环境初始化要求
|
|
"""
|
|
def act(self, *args, **kwargs):
|
|
# 返回零动作 [throttle, steering]
|
|
return np.array([0.0, 0.0])
|
|
|
|
def check_available_fields():
|
|
"""
|
|
检查Waymo转MetaDrive数据中实际可用的字段
|
|
"""
|
|
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
|
data_dir = AssetLoader.file_path(WAYMO_DATA_DIR, "exp_filtered", unix_style=False)
|
|
|
|
# 创建占位策略
|
|
dummy_policy = DummyPolicy()
|
|
|
|
# 初始化环境,传入必需的agent2policy参数
|
|
env = MultiAgentScenarioEnv(
|
|
config={
|
|
"data_directory": data_dir,
|
|
"is_multi_agent": True,
|
|
"num_controlled_agents": 3,
|
|
"use_render": False,
|
|
"sequential_seed": True,
|
|
},
|
|
agent2policy=dummy_policy # 添加这个必需参数
|
|
)
|
|
|
|
print("✓ 环境初始化成功")
|
|
|
|
# 重置环境以加载数据
|
|
print("正在加载场景数据...")
|
|
env.reset()
|
|
|
|
# 检查是否有expert_trajectories属性
|
|
if hasattr(env, 'expert_trajectories'):
|
|
print(f"✓ expert_trajectories属性存在,包含 {len(env.expert_trajectories)} 条轨迹")
|
|
else:
|
|
print("⚠️ expert_trajectories属性不存在,请先修改scenario_env.py添加轨迹存储功能")
|
|
|
|
# 获取一个track样本
|
|
sample_track = None
|
|
for scenario_id, track in env.engine.traffic_manager.current_traffic_data.items():
|
|
if track["type"] == "VEHICLE":
|
|
sample_track = track
|
|
print(f"\n找到样本车辆: scenario_id = {scenario_id}")
|
|
break
|
|
|
|
if sample_track is None:
|
|
print("未找到车辆轨迹数据")
|
|
env.close()
|
|
return
|
|
|
|
print("="*60)
|
|
print("Track数据结构分析")
|
|
print("="*60)
|
|
|
|
# 1. 顶层字段
|
|
print("\n1. Track顶层字段:")
|
|
for key in sample_track.keys():
|
|
print(f" - {key}: {type(sample_track[key])}")
|
|
|
|
# 2. metadata字段
|
|
print("\n2. track['metadata']字段:")
|
|
if "metadata" in sample_track:
|
|
for key, value in sample_track["metadata"].items():
|
|
if isinstance(value, (str, int, float, bool)):
|
|
print(f" - {key}: {type(value).__name__} = {value}")
|
|
else:
|
|
print(f" - {key}: {type(value).__name__}")
|
|
|
|
# 3. state字段
|
|
print("\n3. track['state']字段:")
|
|
if "state" in sample_track:
|
|
for key, value in sample_track["state"].items():
|
|
if isinstance(value, np.ndarray):
|
|
print(f" - {key}: shape={value.shape}, dtype={value.dtype}")
|
|
# 打印第一个有效值
|
|
if "valid" in sample_track["state"]:
|
|
valid_idx = np.argmax(sample_track["state"]["valid"])
|
|
if valid_idx >= 0 and valid_idx < len(value):
|
|
print(f" 示例值 (index {valid_idx}): {value[valid_idx]}")
|
|
else:
|
|
print(f" - {key}: {type(value)} = {value}")
|
|
|
|
print("\n" + "="*60)
|
|
print("建议存储的字段:")
|
|
print("="*60)
|
|
|
|
# 检查必需字段
|
|
required_fields = ["position", "heading", "velocity", "valid"]
|
|
print("\n必需字段:")
|
|
all_required_exist = True
|
|
for field in required_fields:
|
|
if "state" in sample_track and field in sample_track["state"]:
|
|
print(f" ✓ {field} (存在)")
|
|
else:
|
|
print(f" ✗ {field} (缺失)")
|
|
all_required_exist = False
|
|
|
|
# 检查可选字段
|
|
optional_fields = ["length", "width", "height", "bbox"]
|
|
print("\n可选字段:")
|
|
available_optional = []
|
|
for field in optional_fields:
|
|
if "state" in sample_track and field in sample_track["state"]:
|
|
print(f" + {field} (在state中)")
|
|
available_optional.append(field)
|
|
elif "metadata" in sample_track and field in sample_track["metadata"]:
|
|
print(f" + {field} (在metadata中)")
|
|
available_optional.append(field)
|
|
else:
|
|
print(f" - {field} (不存在)")
|
|
|
|
print("\n" + "="*60)
|
|
print("推荐的trajectory_data结构:")
|
|
print("="*60)
|
|
|
|
if all_required_exist:
|
|
print("""
|
|
trajectory_data = {
|
|
"object_id": object_id,
|
|
"scenario_id": scenario_id,
|
|
"valid_mask": valid[first_show:last_show+1].copy(),
|
|
"positions": track["state"]["position"][first_show:last_show+1].copy(),
|
|
"headings": track["state"]["heading"][first_show:last_show+1].copy(),
|
|
"velocities": track["state"]["velocity"][first_show:last_show+1].copy(),
|
|
"timesteps": np.arange(first_show, last_show+1),
|
|
"start_timestep": first_show,
|
|
"end_timestep": last_show,
|
|
"length": last_show - first_show + 1
|
|
}
|
|
""")
|
|
|
|
if available_optional:
|
|
print("如果需要车辆尺寸,可选添加:")
|
|
for field in available_optional:
|
|
if field in ["length", "width", "height"]:
|
|
print(f' trajectory_data["vehicle_{field}"] = track["state" or "metadata"]["{field}"][first_show]')
|
|
else:
|
|
print("⚠️ 缺少必需字段,请检查数据转换流程")
|
|
|
|
# 如果有expert_trajectories,展示一个样本
|
|
if hasattr(env, 'expert_trajectories') and len(env.expert_trajectories) > 0:
|
|
print("\n" + "="*60)
|
|
print("expert_trajectories样本:")
|
|
print("="*60)
|
|
sample_traj = list(env.expert_trajectories.values())[0]
|
|
for key, value in sample_traj.items():
|
|
if isinstance(value, np.ndarray):
|
|
print(f" {key}: shape={value.shape}, dtype={value.dtype}")
|
|
else:
|
|
print(f" {key}: {type(value).__name__} = {value}")
|
|
|
|
env.close()
|
|
print("\n✓ 分析完成")
|
|
|
|
if __name__ == "__main__":
|
|
check_available_fields()
|