新增scripts工具
This commit is contained in:
BIN
Env/__pycache__/expert_replay_policy.cpython-310.pyc
Normal file
BIN
Env/__pycache__/expert_replay_policy.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -320,6 +320,7 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
|
||||
_obj_to_clean_this_frame = [] # 需要清理的对象ID列表
|
||||
self.car_birth_info_list = [] # 车辆生成信息列表
|
||||
self.expert_trajectories = {} # 专家数据轨迹字典
|
||||
|
||||
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
||||
# 跳过自车(SDC - Self Driving Car)
|
||||
@@ -334,6 +335,53 @@ class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
first_show = np.argmax(valid) if valid.any() else -1
|
||||
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
||||
|
||||
if first_show == -1 or last_show == -1:
|
||||
continue
|
||||
object_id = track["metadata"]["object_id"]
|
||||
|
||||
# 提取完整轨迹数据(只使用确认存在的字段)
|
||||
trajectory_data = {
|
||||
"object_id": object_id,
|
||||
"scenario_id": scenario_id,
|
||||
"valid_mask": valid[first_show:last_show+1].copy(), # 有效性掩码
|
||||
"positions": track["state"]["position"][first_show:last_show+1].copy(), # (T, 3)
|
||||
"headings": track["state"]["heading"][first_show:last_show+1].copy(), # (T,)
|
||||
"velocities": track["state"]["velocity"][first_show:last_show+1].copy(), # (T, 2)
|
||||
"timesteps": np.arange(first_show, last_show+1), # 时间戳
|
||||
"start_timestep": first_show,
|
||||
"end_timestep": last_show,
|
||||
"length": last_show - first_show + 1
|
||||
}
|
||||
|
||||
# 可选:如果数据中有车辆尺寸信息,则添加
|
||||
# 方法1: 尝试从state中获取
|
||||
if "length" in track["state"]:
|
||||
trajectory_data["vehicle_length"] = track["state"]["length"][first_show]
|
||||
if "width" in track["state"]:
|
||||
trajectory_data["vehicle_width"] = track["state"]["width"][first_show]
|
||||
if "height" in track["state"]:
|
||||
trajectory_data["vehicle_height"] = track["state"]["height"][first_show]
|
||||
|
||||
# 方法2: 尝试从metadata中获取
|
||||
if "vehicle_length" not in trajectory_data and "length" in track.get("metadata", {}):
|
||||
trajectory_data["vehicle_length"] = track["metadata"]["length"]
|
||||
if "vehicle_width" not in trajectory_data and "width" in track.get("metadata", {}):
|
||||
trajectory_data["vehicle_width"] = track["metadata"]["width"]
|
||||
if "vehicle_height" not in trajectory_data and "height" in track.get("metadata", {}):
|
||||
trajectory_data["vehicle_height"] = track["metadata"]["height"]
|
||||
|
||||
# 方法3: 使用默认值(如果以上都没有)
|
||||
if "vehicle_length" not in trajectory_data:
|
||||
trajectory_data["vehicle_length"] = 4.5 # MetaDrive默认车长
|
||||
if "vehicle_width" not in trajectory_data:
|
||||
trajectory_data["vehicle_width"] = 2.0 # MetaDrive默认车宽
|
||||
if "vehicle_height" not in trajectory_data:
|
||||
trajectory_data["vehicle_height"] = 1.5 # MetaDrive默认车高
|
||||
|
||||
|
||||
# 存储到专家轨迹字典
|
||||
self.expert_trajectories[object_id] = trajectory_data
|
||||
|
||||
# 提取车辆关键信息
|
||||
car_info = {
|
||||
'id': track['metadata']['object_id'],
|
||||
|
||||
Reference in New Issue
Block a user