新增scripts工具

This commit is contained in:
2025-10-25 21:44:11 +08:00
parent 62e638c4d2
commit c94571ddaa
17 changed files with 1193 additions and 66 deletions

View File

@@ -320,6 +320,7 @@ class MultiAgentScenarioEnv(ScenarioEnv):
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
_obj_to_clean_this_frame = [] # 需要清理的对象ID列表
self.car_birth_info_list = [] # 车辆生成信息列表
self.expert_trajectories = {} # 专家数据轨迹字典
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
# 跳过自车SDC - Self Driving Car
@@ -334,6 +335,53 @@ class MultiAgentScenarioEnv(ScenarioEnv):
first_show = np.argmax(valid) if valid.any() else -1
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
if first_show == -1 or last_show == -1:
continue
object_id = track["metadata"]["object_id"]
# 提取完整轨迹数据(只使用确认存在的字段)
trajectory_data = {
"object_id": object_id,
"scenario_id": scenario_id,
"valid_mask": valid[first_show:last_show+1].copy(), # 有效性掩码
"positions": track["state"]["position"][first_show:last_show+1].copy(), # (T, 3)
"headings": track["state"]["heading"][first_show:last_show+1].copy(), # (T,)
"velocities": track["state"]["velocity"][first_show:last_show+1].copy(), # (T, 2)
"timesteps": np.arange(first_show, last_show+1), # 时间戳
"start_timestep": first_show,
"end_timestep": last_show,
"length": last_show - first_show + 1
}
# 可选:如果数据中有车辆尺寸信息,则添加
# 方法1: 尝试从state中获取
if "length" in track["state"]:
trajectory_data["vehicle_length"] = track["state"]["length"][first_show]
if "width" in track["state"]:
trajectory_data["vehicle_width"] = track["state"]["width"][first_show]
if "height" in track["state"]:
trajectory_data["vehicle_height"] = track["state"]["height"][first_show]
# 方法2: 尝试从metadata中获取
if "vehicle_length" not in trajectory_data and "length" in track.get("metadata", {}):
trajectory_data["vehicle_length"] = track["metadata"]["length"]
if "vehicle_width" not in trajectory_data and "width" in track.get("metadata", {}):
trajectory_data["vehicle_width"] = track["metadata"]["width"]
if "vehicle_height" not in trajectory_data and "height" in track.get("metadata", {}):
trajectory_data["vehicle_height"] = track["metadata"]["height"]
# 方法3: 使用默认值(如果以上都没有)
if "vehicle_length" not in trajectory_data:
trajectory_data["vehicle_length"] = 4.5 # MetaDrive默认车长
if "vehicle_width" not in trajectory_data:
trajectory_data["vehicle_width"] = 2.0 # MetaDrive默认车宽
if "vehicle_height" not in trajectory_data:
trajectory_data["vehicle_height"] = 1.5 # MetaDrive默认车高
# 存储到专家轨迹字典
self.expert_trajectories[object_id] = trajectory_data
# 提取车辆关键信息
car_info = {
'id': track['metadata']['object_id'],