Files
MAGAIL4AutoDrive/analyze_expert_data.py

104 lines
3.8 KiB
Python
Raw Normal View History

"""
分析Waymo专家数据的结构
运行: python analyze_expert_data.py
"""
import pickle
import numpy as np
import os
def analyze_pkl_file(filepath):
"""分析单个pkl文件的结构"""
print(f"\n{'='*80}")
print(f"分析文件: {os.path.basename(filepath)}")
print(f"{'='*80}")
with open(filepath, 'rb') as f:
data = pickle.load(f)
print(f"\n1. 数据类型: {type(data)}")
print(f" 文件大小: {os.path.getsize(filepath) / 1024:.1f} KB")
if isinstance(data, dict):
print(f"\n2. 字典结构:")
print(f" 键数量: {len(data)}")
print(f" 键列表: {list(data.keys())[:10]}")
# 详细分析每个键
for i, (key, value) in enumerate(list(data.items())[:5]):
print(f"\n 键 [{i+1}]: '{key}'")
print(f" 类型: {type(value)}")
if isinstance(value, dict):
print(f" 子键: {list(value.keys())}")
# 分析子字典
for subkey, subvalue in list(value.items())[:3]:
print(f" - {subkey}: {type(subvalue)}", end="")
if isinstance(subvalue, np.ndarray):
print(f" shape={subvalue.shape}, dtype={subvalue.dtype}")
elif isinstance(subvalue, dict):
print(f" keys={list(subvalue.keys())[:5]}")
elif isinstance(subvalue, (list, tuple)):
print(f" len={len(subvalue)}")
else:
print(f" = {subvalue}")
elif isinstance(value, np.ndarray):
print(f" Shape: {value.shape}, dtype: {value.dtype}")
print(f" 示例: {value.flatten()[:5]}")
elif isinstance(value, (list, tuple)):
print(f" 长度: {len(value)}")
if len(value) > 0:
print(f" 第一个元素: {type(value[0])}")
elif isinstance(data, (list, tuple)):
print(f"\n2. 列表/元组结构:")
print(f" 长度: {len(data)}")
if len(data) > 0:
print(f" 第一个元素类型: {type(data[0])}")
if isinstance(data[0], dict):
print(f" 第一个元素的键: {list(data[0].keys())}")
return data
def find_trajectory_data(data, max_depth=3, current_depth=0, path=""):
"""递归查找可能包含轨迹数据的字段"""
if current_depth > max_depth:
return
if isinstance(data, dict):
for key, value in data.items():
new_path = f"{path}.{key}" if path else key
# 查找可能是轨迹的数据(通常是时间序列数组)
if isinstance(value, np.ndarray):
if len(value.shape) >= 2 and value.shape[0] > 10: # 可能是时间序列
print(f" 🎯 可能的轨迹数据: {new_path}")
print(f" Shape: {value.shape}, dtype: {value.dtype}")
print(f" 前3个值: {value[:3]}")
# 继续递归
elif isinstance(value, dict):
find_trajectory_data(value, max_depth, current_depth + 1, new_path)
if __name__ == "__main__":
# 分析第一个数据文件
data_dir = "Env/exp_converted/exp_converted_0"
pkl_files = [f for f in os.listdir(data_dir) if f.startswith('sd_waymo')]
if pkl_files:
filepath = os.path.join(data_dir, pkl_files[0])
data = analyze_pkl_file(filepath)
print(f"\n\n{'='*80}")
print("查找可能的轨迹数据...")
print(f"{'='*80}")
find_trajectory_data(data)
else:
print("未找到数据文件!")