新增scripts工具
This commit is contained in:
177
scripts/check_track_fields.py
Normal file
177
scripts/check_track_fields.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(current_dir)
|
||||
env_dir = os.path.join(project_root, "Env")
|
||||
|
||||
sys.path.insert(0, project_root)
|
||||
sys.path.insert(0, env_dir)
|
||||
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
import numpy as np
|
||||
|
||||
class DummyPolicy:
|
||||
"""
|
||||
占位策略,用于数据检查时初始化环境
|
||||
不需要实际执行动作,只是为了满足环境初始化要求
|
||||
"""
|
||||
def act(self, *args, **kwargs):
|
||||
# 返回零动作 [throttle, steering]
|
||||
return np.array([0.0, 0.0])
|
||||
|
||||
def check_available_fields():
|
||||
"""
|
||||
检查Waymo转MetaDrive数据中实际可用的字段
|
||||
"""
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
||||
data_dir = AssetLoader.file_path(WAYMO_DATA_DIR, "exp_filtered", unix_style=False)
|
||||
|
||||
# 创建占位策略
|
||||
dummy_policy = DummyPolicy()
|
||||
|
||||
# 初始化环境,传入必需的agent2policy参数
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": data_dir,
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"use_render": False,
|
||||
"sequential_seed": True,
|
||||
},
|
||||
agent2policy=dummy_policy # 添加这个必需参数
|
||||
)
|
||||
|
||||
print("✓ 环境初始化成功")
|
||||
|
||||
# 重置环境以加载数据
|
||||
print("正在加载场景数据...")
|
||||
env.reset()
|
||||
|
||||
# 检查是否有expert_trajectories属性
|
||||
if hasattr(env, 'expert_trajectories'):
|
||||
print(f"✓ expert_trajectories属性存在,包含 {len(env.expert_trajectories)} 条轨迹")
|
||||
else:
|
||||
print("⚠️ expert_trajectories属性不存在,请先修改scenario_env.py添加轨迹存储功能")
|
||||
|
||||
# 获取一个track样本
|
||||
sample_track = None
|
||||
for scenario_id, track in env.engine.traffic_manager.current_traffic_data.items():
|
||||
if track["type"] == "VEHICLE":
|
||||
sample_track = track
|
||||
print(f"\n找到样本车辆: scenario_id = {scenario_id}")
|
||||
break
|
||||
|
||||
if sample_track is None:
|
||||
print("未找到车辆轨迹数据")
|
||||
env.close()
|
||||
return
|
||||
|
||||
print("="*60)
|
||||
print("Track数据结构分析")
|
||||
print("="*60)
|
||||
|
||||
# 1. 顶层字段
|
||||
print("\n1. Track顶层字段:")
|
||||
for key in sample_track.keys():
|
||||
print(f" - {key}: {type(sample_track[key])}")
|
||||
|
||||
# 2. metadata字段
|
||||
print("\n2. track['metadata']字段:")
|
||||
if "metadata" in sample_track:
|
||||
for key, value in sample_track["metadata"].items():
|
||||
if isinstance(value, (str, int, float, bool)):
|
||||
print(f" - {key}: {type(value).__name__} = {value}")
|
||||
else:
|
||||
print(f" - {key}: {type(value).__name__}")
|
||||
|
||||
# 3. state字段
|
||||
print("\n3. track['state']字段:")
|
||||
if "state" in sample_track:
|
||||
for key, value in sample_track["state"].items():
|
||||
if isinstance(value, np.ndarray):
|
||||
print(f" - {key}: shape={value.shape}, dtype={value.dtype}")
|
||||
# 打印第一个有效值
|
||||
if "valid" in sample_track["state"]:
|
||||
valid_idx = np.argmax(sample_track["state"]["valid"])
|
||||
if valid_idx >= 0 and valid_idx < len(value):
|
||||
print(f" 示例值 (index {valid_idx}): {value[valid_idx]}")
|
||||
else:
|
||||
print(f" - {key}: {type(value)} = {value}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("建议存储的字段:")
|
||||
print("="*60)
|
||||
|
||||
# 检查必需字段
|
||||
required_fields = ["position", "heading", "velocity", "valid"]
|
||||
print("\n必需字段:")
|
||||
all_required_exist = True
|
||||
for field in required_fields:
|
||||
if "state" in sample_track and field in sample_track["state"]:
|
||||
print(f" ✓ {field} (存在)")
|
||||
else:
|
||||
print(f" ✗ {field} (缺失)")
|
||||
all_required_exist = False
|
||||
|
||||
# 检查可选字段
|
||||
optional_fields = ["length", "width", "height", "bbox"]
|
||||
print("\n可选字段:")
|
||||
available_optional = []
|
||||
for field in optional_fields:
|
||||
if "state" in sample_track and field in sample_track["state"]:
|
||||
print(f" + {field} (在state中)")
|
||||
available_optional.append(field)
|
||||
elif "metadata" in sample_track and field in sample_track["metadata"]:
|
||||
print(f" + {field} (在metadata中)")
|
||||
available_optional.append(field)
|
||||
else:
|
||||
print(f" - {field} (不存在)")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("推荐的trajectory_data结构:")
|
||||
print("="*60)
|
||||
|
||||
if all_required_exist:
|
||||
print("""
|
||||
trajectory_data = {
|
||||
"object_id": object_id,
|
||||
"scenario_id": scenario_id,
|
||||
"valid_mask": valid[first_show:last_show+1].copy(),
|
||||
"positions": track["state"]["position"][first_show:last_show+1].copy(),
|
||||
"headings": track["state"]["heading"][first_show:last_show+1].copy(),
|
||||
"velocities": track["state"]["velocity"][first_show:last_show+1].copy(),
|
||||
"timesteps": np.arange(first_show, last_show+1),
|
||||
"start_timestep": first_show,
|
||||
"end_timestep": last_show,
|
||||
"length": last_show - first_show + 1
|
||||
}
|
||||
""")
|
||||
|
||||
if available_optional:
|
||||
print("如果需要车辆尺寸,可选添加:")
|
||||
for field in available_optional:
|
||||
if field in ["length", "width", "height"]:
|
||||
print(f' trajectory_data["vehicle_{field}"] = track["state" or "metadata"]["{field}"][first_show]')
|
||||
else:
|
||||
print("⚠️ 缺少必需字段,请检查数据转换流程")
|
||||
|
||||
# 如果有expert_trajectories,展示一个样本
|
||||
if hasattr(env, 'expert_trajectories') and len(env.expert_trajectories) > 0:
|
||||
print("\n" + "="*60)
|
||||
print("expert_trajectories样本:")
|
||||
print("="*60)
|
||||
sample_traj = list(env.expert_trajectories.values())[0]
|
||||
for key, value in sample_traj.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
print(f" {key}: shape={value.shape}, dtype={value.dtype}")
|
||||
else:
|
||||
print(f" {key}: {type(value).__name__} = {value}")
|
||||
|
||||
env.close()
|
||||
print("\n✓ 分析完成")
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_available_fields()
|
||||
Reference in New Issue
Block a user