新增scripts工具
This commit is contained in:
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
256
scripts/analyze_expert_data.py
Normal file
256
scripts/analyze_expert_data.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(current_dir)
|
||||
env_dir = os.path.join(project_root, "Env")
|
||||
sys.path.insert(0, project_root)
|
||||
sys.path.insert(0, env_dir)
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import defaultdict
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
import pickle
|
||||
import os
|
||||
|
||||
class DummyPolicy:
|
||||
"""占位策略"""
|
||||
def act(self, *args, **kwargs):
|
||||
return np.array([0.0, 0.0])
|
||||
|
||||
class ExpertDataAnalyzer:
|
||||
def __init__(self, data_directory):
|
||||
self.data_directory = data_directory
|
||||
self.env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": data_directory,
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"use_render": False,
|
||||
"sequential_seed": True,
|
||||
},
|
||||
agent2policy=DummyPolicy() # 添加必需参数
|
||||
)
|
||||
|
||||
self.statistics = {
|
||||
"num_scenarios": 0,
|
||||
"num_trajectories": 0,
|
||||
"trajectory_lengths": [],
|
||||
"velocities": [],
|
||||
"speeds": [], # 速度大小
|
||||
"accelerations": [],
|
||||
"heading_changes": [],
|
||||
"inter_vehicle_distances": [],
|
||||
"num_vehicles_per_scenario": [],
|
||||
"static_vehicles": 0, # 统计静止车辆
|
||||
}
|
||||
|
||||
def analyze_all_scenarios(self, num_scenarios=None):
|
||||
"""遍历所有场景并收集统计信息"""
|
||||
scenario_count = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
obs = self.env.reset()
|
||||
|
||||
if not hasattr(self.env, 'expert_trajectories'):
|
||||
print("⚠️ 环境缺少expert_trajectories属性")
|
||||
break
|
||||
|
||||
expert_trajs = self.env.expert_trajectories
|
||||
|
||||
if len(expert_trajs) == 0:
|
||||
continue
|
||||
|
||||
scenario_count += 1
|
||||
self.statistics["num_scenarios"] += 1
|
||||
self.statistics["num_vehicles_per_scenario"].append(len(expert_trajs))
|
||||
|
||||
# 分析每条轨迹
|
||||
for obj_id, traj in expert_trajs.items():
|
||||
self.analyze_single_trajectory(traj)
|
||||
|
||||
# 分析车辆间交互
|
||||
self.analyze_vehicle_interactions(expert_trajs)
|
||||
|
||||
print(f"已分析场景 {scenario_count}/{num_scenarios}, 车辆数: {len(expert_trajs)}")
|
||||
|
||||
if num_scenarios and scenario_count >= num_scenarios:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"场景 {scenario_count} 处理失败: {e}")
|
||||
break
|
||||
|
||||
self.env.close()
|
||||
|
||||
def analyze_single_trajectory(self, traj):
|
||||
"""分析单条轨迹"""
|
||||
self.statistics["num_trajectories"] += 1
|
||||
|
||||
length = traj["length"]
|
||||
self.statistics["trajectory_lengths"].append(length)
|
||||
|
||||
# 速度分析
|
||||
velocities = traj["velocities"]
|
||||
speeds = np.linalg.norm(velocities, axis=1)
|
||||
self.statistics["velocities"].extend(velocities.tolist())
|
||||
self.statistics["speeds"].extend(speeds.tolist())
|
||||
|
||||
# 检查是否为静止车辆
|
||||
if np.max(speeds) < 0.5: # 最大速度小于0.5m/s视为静止
|
||||
self.statistics["static_vehicles"] += 1
|
||||
|
||||
# 加速度分析
|
||||
if length > 1:
|
||||
accelerations = np.diff(speeds) * 10 # 10Hz数据
|
||||
self.statistics["accelerations"].extend(accelerations.tolist())
|
||||
|
||||
# 航向角变化
|
||||
headings = traj["headings"]
|
||||
if length > 1:
|
||||
heading_changes = np.diff(headings)
|
||||
heading_changes = np.arctan2(np.sin(heading_changes), np.cos(heading_changes))
|
||||
self.statistics["heading_changes"].extend(heading_changes.tolist())
|
||||
|
||||
def analyze_vehicle_interactions(self, expert_trajs):
|
||||
"""分析车辆间的距离"""
|
||||
if len(expert_trajs) < 2:
|
||||
return
|
||||
|
||||
traj_list = list(expert_trajs.values())
|
||||
|
||||
for i in range(len(traj_list)):
|
||||
for j in range(i+1, len(traj_list)):
|
||||
traj_i = traj_list[i]
|
||||
traj_j = traj_list[j]
|
||||
|
||||
start_time = max(traj_i["start_timestep"], traj_j["start_timestep"])
|
||||
end_time = min(traj_i["end_timestep"], traj_j["end_timestep"])
|
||||
|
||||
if start_time >= end_time:
|
||||
continue
|
||||
|
||||
idx_i_start = start_time - traj_i["start_timestep"]
|
||||
idx_i_end = end_time - traj_i["start_timestep"]
|
||||
idx_j_start = start_time - traj_j["start_timestep"]
|
||||
idx_j_end = end_time - traj_j["start_timestep"]
|
||||
|
||||
pos_i = traj_i["positions"][idx_i_start:idx_i_end, :2]
|
||||
pos_j = traj_j["positions"][idx_j_start:idx_j_end, :2]
|
||||
|
||||
distances = np.linalg.norm(pos_i - pos_j, axis=1)
|
||||
self.statistics["inter_vehicle_distances"].extend(distances.tolist())
|
||||
|
||||
def generate_report(self, save_dir="./analysis_results"):
|
||||
"""生成统计报告"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
stats = self.statistics
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("专家数据集统计报告")
|
||||
print("="*60)
|
||||
print(f"总场景数: {stats['num_scenarios']}")
|
||||
print(f"总轨迹数: {stats['num_trajectories']}")
|
||||
print(f"静止车辆数: {stats['static_vehicles']} ({stats['static_vehicles']/stats['num_trajectories']*100:.1f}%)")
|
||||
print(f"平均每场景车辆数: {np.mean(stats['num_vehicles_per_scenario']):.2f} ± {np.std(stats['num_vehicles_per_scenario']):.2f}")
|
||||
|
||||
print(f"\n轨迹长度统计 (帧数 @ 10Hz):")
|
||||
print(f" 平均: {np.mean(stats['trajectory_lengths']):.2f} 帧 ({np.mean(stats['trajectory_lengths'])*0.1:.2f}秒)")
|
||||
print(f" 中位数: {np.median(stats['trajectory_lengths']):.2f} 帧")
|
||||
print(f" 最小/最大: {np.min(stats['trajectory_lengths'])} / {np.max(stats['trajectory_lengths'])} 帧")
|
||||
|
||||
print(f"\n速度统计 (m/s):")
|
||||
speeds = np.array(stats['speeds'])
|
||||
print(f" 平均: {np.mean(speeds):.2f} ± {np.std(speeds):.2f}")
|
||||
print(f" 中位数: {np.median(speeds):.2f}")
|
||||
print(f" 最小/最大: {np.min(speeds):.2f} / {np.max(speeds):.2f}")
|
||||
print(f" 静止帧(<0.5m/s): {np.sum(speeds < 0.5)} ({np.sum(speeds < 0.5)/len(speeds)*100:.1f}%)")
|
||||
|
||||
print(f"\n加速度统计 (m/s²):")
|
||||
accs = np.array(stats['accelerations'])
|
||||
print(f" 平均: {np.mean(accs):.4f} ± {np.std(accs):.2f}")
|
||||
print(f" 最小/最大: {np.min(accs):.2f} / {np.max(accs):.2f}")
|
||||
|
||||
if len(stats['inter_vehicle_distances']) > 0:
|
||||
dists = np.array(stats['inter_vehicle_distances'])
|
||||
print(f"\n车辆间距离统计 (m):")
|
||||
print(f" 平均: {np.mean(dists):.2f} ± {np.std(dists):.2f}")
|
||||
print(f" 最小: {np.min(dists):.2f}")
|
||||
print(f" 近距离交互(<5m): {np.sum(dists < 5.0)} ({np.sum(dists < 5.0)/len(dists)*100:.2f}%)")
|
||||
|
||||
# 保存数据
|
||||
with open(os.path.join(save_dir, "statistics.pkl"), "wb") as f:
|
||||
pickle.dump(stats, f)
|
||||
|
||||
# 绘制可视化
|
||||
self.plot_distributions(save_dir)
|
||||
|
||||
print(f"\n✓ 报告已保存到: {save_dir}")
|
||||
|
||||
def plot_distributions(self, save_dir):
|
||||
"""绘制分布图"""
|
||||
stats = self.statistics
|
||||
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
||||
|
||||
# 1. 轨迹长度分布
|
||||
axes[0, 0].hist(stats['trajectory_lengths'], bins=50, edgecolor='black')
|
||||
axes[0, 0].set_xlabel('Trajectory Length (frames @ 10Hz)')
|
||||
axes[0, 0].set_ylabel('Frequency')
|
||||
axes[0, 0].set_title('Trajectory Length Distribution')
|
||||
axes[0, 0].axvline(np.mean(stats['trajectory_lengths']), color='red',
|
||||
linestyle='--', label=f'Mean: {np.mean(stats["trajectory_lengths"]):.1f}')
|
||||
axes[0, 0].legend()
|
||||
|
||||
# 2. 速度分布
|
||||
axes[0, 1].hist(stats['speeds'], bins=50, edgecolor='black')
|
||||
axes[0, 1].set_xlabel('Speed (m/s)')
|
||||
axes[0, 1].set_ylabel('Frequency')
|
||||
axes[0, 1].set_title('Speed Distribution')
|
||||
axes[0, 1].axvline(np.mean(stats['speeds']), color='red',
|
||||
linestyle='--', label=f'Mean: {np.mean(stats["speeds"]):.2f}')
|
||||
axes[0, 1].legend()
|
||||
|
||||
# 3. 加速度分布
|
||||
axes[0, 2].hist(stats['accelerations'], bins=50, edgecolor='black')
|
||||
axes[0, 2].set_xlabel('Acceleration (m/s²)')
|
||||
axes[0, 2].set_ylabel('Frequency')
|
||||
axes[0, 2].set_title('Acceleration Distribution')
|
||||
|
||||
# 4. 每场景车辆数
|
||||
axes[1, 0].hist(stats['num_vehicles_per_scenario'], bins=30, edgecolor='black')
|
||||
axes[1, 0].set_xlabel('Vehicles per Scenario')
|
||||
axes[1, 0].set_ylabel('Frequency')
|
||||
axes[1, 0].set_title('Vehicles per Scenario')
|
||||
|
||||
# 5. 航向角变化
|
||||
axes[1, 1].hist(stats['heading_changes'], bins=50, edgecolor='black')
|
||||
axes[1, 1].set_xlabel('Heading Change (rad)')
|
||||
axes[1, 1].set_ylabel('Frequency')
|
||||
axes[1, 1].set_title('Heading Change Distribution')
|
||||
|
||||
# 6. 车辆间距离
|
||||
if len(stats['inter_vehicle_distances']) > 0:
|
||||
axes[1, 2].hist(stats['inter_vehicle_distances'], bins=50,
|
||||
range=(0, 50), edgecolor='black')
|
||||
axes[1, 2].set_xlabel('Inter-vehicle Distance (m)')
|
||||
axes[1, 2].set_ylabel('Frequency')
|
||||
axes[1, 2].set_title('Distance Distribution')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(save_dir, "distributions.png"), dpi=300)
|
||||
print(f" ✓ 分布图已保存")
|
||||
|
||||
if __name__ == "__main__":
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
||||
data_dir = AssetLoader.file_path(WAYMO_DATA_DIR, "exp_filtered", unix_style=False)
|
||||
|
||||
print("开始分析专家数据...")
|
||||
analyzer = ExpertDataAnalyzer(data_dir)
|
||||
analyzer.analyze_all_scenarios(num_scenarios=100) # 分析100个场景
|
||||
analyzer.generate_report()
|
||||
47
scripts/check_database_info.py
Normal file
47
scripts/check_database_info.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pickle
|
||||
import os
|
||||
|
||||
# 检查过滤后的数据库
|
||||
filtered_db = "/home/huangfukk/mdsn/exp_filtered"
|
||||
|
||||
print("="*60)
|
||||
print("过滤后数据库信息")
|
||||
print("="*60)
|
||||
|
||||
# 读取summary
|
||||
summary_path = os.path.join(filtered_db, "dataset_summary.pkl")
|
||||
with open(summary_path, 'rb') as f:
|
||||
summary = pickle.load(f)
|
||||
|
||||
print(f"\n总场景数: {len(summary)}")
|
||||
print(f"场景ID列表(前10个): {list(summary.keys())[:10]}")
|
||||
|
||||
# 读取mapping
|
||||
mapping_path = os.path.join(filtered_db, "dataset_mapping.pkl")
|
||||
with open(mapping_path, 'rb') as f:
|
||||
mapping = pickle.load(f)
|
||||
|
||||
print(f"\n映射关系数量: {len(mapping)}")
|
||||
|
||||
# 检查第一个场景的详细信息
|
||||
first_scenario_id = list(summary.keys())[0]
|
||||
first_scenario_info = summary[first_scenario_id]
|
||||
print(f"\n第一个场景详细信息:")
|
||||
print(f" 场景ID: {first_scenario_id}")
|
||||
print(f" 元数据: {first_scenario_info}")
|
||||
|
||||
# 检查映射的文件路径
|
||||
first_scenario_path = mapping[first_scenario_id]
|
||||
print(f" 场景文件路径(相对): {first_scenario_path}")
|
||||
|
||||
# 检查文件是否存在
|
||||
abs_path = os.path.join(filtered_db, first_scenario_path)
|
||||
print(f" 场景文件路径(绝对): {abs_path}")
|
||||
print(f" 文件存在: {os.path.exists(abs_path)}")
|
||||
|
||||
# 统计源数据库的场景文件
|
||||
converted_db = "/home/huangfukk/mdsn/exp_converted"
|
||||
converted_files = [f for f in os.listdir(converted_db) if f.endswith('.pkl') and f.startswith('sd_')]
|
||||
print(f"\n源数据库 exp_converted:")
|
||||
print(f" 场景文件数量: {len(converted_files)}")
|
||||
print(f" 示例文件: {converted_files[:5]}")
|
||||
177
scripts/check_track_fields.py
Normal file
177
scripts/check_track_fields.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(current_dir)
|
||||
env_dir = os.path.join(project_root, "Env")
|
||||
|
||||
sys.path.insert(0, project_root)
|
||||
sys.path.insert(0, env_dir)
|
||||
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
import numpy as np
|
||||
|
||||
class DummyPolicy:
|
||||
"""
|
||||
占位策略,用于数据检查时初始化环境
|
||||
不需要实际执行动作,只是为了满足环境初始化要求
|
||||
"""
|
||||
def act(self, *args, **kwargs):
|
||||
# 返回零动作 [throttle, steering]
|
||||
return np.array([0.0, 0.0])
|
||||
|
||||
def check_available_fields():
|
||||
"""
|
||||
检查Waymo转MetaDrive数据中实际可用的字段
|
||||
"""
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
||||
data_dir = AssetLoader.file_path(WAYMO_DATA_DIR, "exp_filtered", unix_style=False)
|
||||
|
||||
# 创建占位策略
|
||||
dummy_policy = DummyPolicy()
|
||||
|
||||
# 初始化环境,传入必需的agent2policy参数
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": data_dir,
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"use_render": False,
|
||||
"sequential_seed": True,
|
||||
},
|
||||
agent2policy=dummy_policy # 添加这个必需参数
|
||||
)
|
||||
|
||||
print("✓ 环境初始化成功")
|
||||
|
||||
# 重置环境以加载数据
|
||||
print("正在加载场景数据...")
|
||||
env.reset()
|
||||
|
||||
# 检查是否有expert_trajectories属性
|
||||
if hasattr(env, 'expert_trajectories'):
|
||||
print(f"✓ expert_trajectories属性存在,包含 {len(env.expert_trajectories)} 条轨迹")
|
||||
else:
|
||||
print("⚠️ expert_trajectories属性不存在,请先修改scenario_env.py添加轨迹存储功能")
|
||||
|
||||
# 获取一个track样本
|
||||
sample_track = None
|
||||
for scenario_id, track in env.engine.traffic_manager.current_traffic_data.items():
|
||||
if track["type"] == "VEHICLE":
|
||||
sample_track = track
|
||||
print(f"\n找到样本车辆: scenario_id = {scenario_id}")
|
||||
break
|
||||
|
||||
if sample_track is None:
|
||||
print("未找到车辆轨迹数据")
|
||||
env.close()
|
||||
return
|
||||
|
||||
print("="*60)
|
||||
print("Track数据结构分析")
|
||||
print("="*60)
|
||||
|
||||
# 1. 顶层字段
|
||||
print("\n1. Track顶层字段:")
|
||||
for key in sample_track.keys():
|
||||
print(f" - {key}: {type(sample_track[key])}")
|
||||
|
||||
# 2. metadata字段
|
||||
print("\n2. track['metadata']字段:")
|
||||
if "metadata" in sample_track:
|
||||
for key, value in sample_track["metadata"].items():
|
||||
if isinstance(value, (str, int, float, bool)):
|
||||
print(f" - {key}: {type(value).__name__} = {value}")
|
||||
else:
|
||||
print(f" - {key}: {type(value).__name__}")
|
||||
|
||||
# 3. state字段
|
||||
print("\n3. track['state']字段:")
|
||||
if "state" in sample_track:
|
||||
for key, value in sample_track["state"].items():
|
||||
if isinstance(value, np.ndarray):
|
||||
print(f" - {key}: shape={value.shape}, dtype={value.dtype}")
|
||||
# 打印第一个有效值
|
||||
if "valid" in sample_track["state"]:
|
||||
valid_idx = np.argmax(sample_track["state"]["valid"])
|
||||
if valid_idx >= 0 and valid_idx < len(value):
|
||||
print(f" 示例值 (index {valid_idx}): {value[valid_idx]}")
|
||||
else:
|
||||
print(f" - {key}: {type(value)} = {value}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("建议存储的字段:")
|
||||
print("="*60)
|
||||
|
||||
# 检查必需字段
|
||||
required_fields = ["position", "heading", "velocity", "valid"]
|
||||
print("\n必需字段:")
|
||||
all_required_exist = True
|
||||
for field in required_fields:
|
||||
if "state" in sample_track and field in sample_track["state"]:
|
||||
print(f" ✓ {field} (存在)")
|
||||
else:
|
||||
print(f" ✗ {field} (缺失)")
|
||||
all_required_exist = False
|
||||
|
||||
# 检查可选字段
|
||||
optional_fields = ["length", "width", "height", "bbox"]
|
||||
print("\n可选字段:")
|
||||
available_optional = []
|
||||
for field in optional_fields:
|
||||
if "state" in sample_track and field in sample_track["state"]:
|
||||
print(f" + {field} (在state中)")
|
||||
available_optional.append(field)
|
||||
elif "metadata" in sample_track and field in sample_track["metadata"]:
|
||||
print(f" + {field} (在metadata中)")
|
||||
available_optional.append(field)
|
||||
else:
|
||||
print(f" - {field} (不存在)")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("推荐的trajectory_data结构:")
|
||||
print("="*60)
|
||||
|
||||
if all_required_exist:
|
||||
print("""
|
||||
trajectory_data = {
|
||||
"object_id": object_id,
|
||||
"scenario_id": scenario_id,
|
||||
"valid_mask": valid[first_show:last_show+1].copy(),
|
||||
"positions": track["state"]["position"][first_show:last_show+1].copy(),
|
||||
"headings": track["state"]["heading"][first_show:last_show+1].copy(),
|
||||
"velocities": track["state"]["velocity"][first_show:last_show+1].copy(),
|
||||
"timesteps": np.arange(first_show, last_show+1),
|
||||
"start_timestep": first_show,
|
||||
"end_timestep": last_show,
|
||||
"length": last_show - first_show + 1
|
||||
}
|
||||
""")
|
||||
|
||||
if available_optional:
|
||||
print("如果需要车辆尺寸,可选添加:")
|
||||
for field in available_optional:
|
||||
if field in ["length", "width", "height"]:
|
||||
print(f' trajectory_data["vehicle_{field}"] = track["state" or "metadata"]["{field}"][first_show]')
|
||||
else:
|
||||
print("⚠️ 缺少必需字段,请检查数据转换流程")
|
||||
|
||||
# 如果有expert_trajectories,展示一个样本
|
||||
if hasattr(env, 'expert_trajectories') and len(env.expert_trajectories) > 0:
|
||||
print("\n" + "="*60)
|
||||
print("expert_trajectories样本:")
|
||||
print("="*60)
|
||||
sample_traj = list(env.expert_trajectories.values())[0]
|
||||
for key, value in sample_traj.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
print(f" {key}: shape={value.shape}, dtype={value.dtype}")
|
||||
else:
|
||||
print(f" {key}: {type(value).__name__} = {value}")
|
||||
|
||||
env.close()
|
||||
print("\n✓ 分析完成")
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_available_fields()
|
||||
105
scripts/visualize_expert_trajectory.py
Normal file
105
scripts/visualize_expert_trajectory.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(current_dir)
|
||||
env_dir = os.path.join(project_root, "Env")
|
||||
sys.path.insert(0, project_root)
|
||||
sys.path.insert(0, env_dir)
|
||||
|
||||
# 现在可以导入了
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.animation import FuncAnimation
|
||||
|
||||
class DummyPolicy:
|
||||
"""
|
||||
占位策略,用于数据检查时初始化环境
|
||||
不需要实际执行动作,只是为了满足环境初始化要求
|
||||
"""
|
||||
def act(self, *args, **kwargs):
|
||||
# 返回零动作 [throttle, steering]
|
||||
return np.array([0.0, 0.0])
|
||||
|
||||
def visualize_expert_trajectory(env, scenario_idx=0):
|
||||
"""
|
||||
可视化专家轨迹的俯视图动画
|
||||
"""
|
||||
env.reset()
|
||||
expert_trajs = env.expert_trajectories
|
||||
|
||||
if len(expert_trajs) == 0:
|
||||
print("当前场景无专家轨迹")
|
||||
return
|
||||
|
||||
# 设置绘图
|
||||
fig, ax = plt.subplots(figsize=(12, 12))
|
||||
|
||||
# 获取所有轨迹的最大时间长度
|
||||
max_timestep = max(traj["end_timestep"] for traj in expert_trajs.values())
|
||||
min_timestep = min(traj["start_timestep"] for traj in expert_trajs.values())
|
||||
|
||||
# 绘制完整轨迹(淡色)
|
||||
colors = plt.cm.tab10(np.linspace(0, 1, len(expert_trajs)))
|
||||
for idx, (obj_id, traj) in enumerate(expert_trajs.items()):
|
||||
positions = traj["positions"][:, :2]
|
||||
ax.plot(positions[:, 0], positions[:, 1],
|
||||
color=colors[idx], alpha=0.3, linewidth=1,
|
||||
label=f'Vehicle {obj_id[:6]}')
|
||||
|
||||
# 初始化当前位置标记
|
||||
scatter = ax.scatter([], [], s=200, c='red', marker='o', edgecolors='black', linewidths=2)
|
||||
time_text = ax.text(0.02, 0.95, '', transform=ax.transAxes, fontsize=14)
|
||||
|
||||
ax.set_xlabel('X (m)')
|
||||
ax.set_ylabel('Y (m)')
|
||||
ax.set_title(f'Expert Trajectory Visualization - Scenario {scenario_idx}')
|
||||
ax.legend(loc='upper right', fontsize=8)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.axis('equal')
|
||||
|
||||
def update(frame):
|
||||
current_time = min_timestep + frame
|
||||
|
||||
# 收集当前时间所有车辆的位置
|
||||
current_positions = []
|
||||
for traj in expert_trajs.values():
|
||||
if traj["start_timestep"] <= current_time <= traj["end_timestep"]:
|
||||
idx = current_time - traj["start_timestep"]
|
||||
pos = traj["positions"][idx, :2]
|
||||
current_positions.append(pos)
|
||||
|
||||
if len(current_positions) > 0:
|
||||
current_positions = np.array(current_positions)
|
||||
scatter.set_offsets(current_positions)
|
||||
|
||||
time_text.set_text(f'Time: {frame * 0.1:.1f}s (Frame {frame})')
|
||||
return scatter, time_text
|
||||
|
||||
anim = FuncAnimation(fig, update, frames=max_timestep-min_timestep+1,
|
||||
interval=100, blit=True, repeat=True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
return anim
|
||||
|
||||
if __name__ == "__main__":
|
||||
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
|
||||
data_dir = AssetLoader.file_path(WAYMO_DATA_DIR, "exp_filtered", unix_style=False)
|
||||
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
"data_directory": data_dir,
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"use_render": False,
|
||||
},
|
||||
agent2policy=DummyPolicy()
|
||||
)
|
||||
|
||||
# 可视化第一个场景
|
||||
anim = visualize_expert_trajectory(env, scenario_idx=0)
|
||||
Reference in New Issue
Block a user