source file to data

This commit is contained in:
QuanyiLi
2023-05-06 21:22:49 +01:00
parent 7501a04618
commit 3963c88e41

View File

@@ -20,6 +20,8 @@ except ImportError as e:
from metadrive.scenario import ScenarioDescription as SD
from metadrive.type import MetaDriveType
SPLIT_KEY = "|"
def extract_poly(message):
x = [i.x for i in message]
@@ -352,31 +354,31 @@ def compute_width(map):
def convert_waymo_scenario(scenario, version):
scenario_pb2 = Scenario()
scenario_pb2.ParseFromString(scenario)
scenario = scenario
md_scenario = SD()
md_scenario[SD.ID] = scenario_pb2.scenario_id
id_end = scenario.scenario_id.find(SPLIT_KEY)
md_scenario[SD.ID] = scenario.scenario_id[:id_end]
md_scenario[SD.VERSION] = version
# Please note that SDC track index is not identical to sdc_id.
# sdc_id is a unique indicator to a track, while sdc_track_index is only the index of the sdc track
# in the tracks datastructure.
track_length = len(list(scenario_pb2.timestamps_seconds))
track_length = len(list(scenario.timestamps_seconds))
tracks, sdc_id = extract_tracks(scenario_pb2.tracks, scenario_pb2.sdc_track_index, track_length)
tracks, sdc_id = extract_tracks(scenario.tracks, scenario.sdc_track_index, track_length)
md_scenario[SD.LENGTH] = track_length
md_scenario[SD.TRACKS] = tracks
dynamic_states = extract_dynamic_map_states(scenario_pb2.dynamic_map_states, track_length)
dynamic_states = extract_dynamic_map_states(scenario.dynamic_map_states, track_length)
md_scenario[SD.DYNAMIC_MAP_STATES] = dynamic_states
map_features = extract_map_features(scenario_pb2.map_features)
map_features = extract_map_features(scenario.map_features)
md_scenario[SD.MAP_FEATURES] = map_features
compute_width(md_scenario[SD.MAP_FEATURES])
@@ -384,25 +386,25 @@ def convert_waymo_scenario(scenario, version):
md_scenario[SD.METADATA] = {}
md_scenario[SD.METADATA][SD.ID] = md_scenario[SD.ID]
md_scenario[SD.METADATA][SD.COORDINATE] = MetaDriveType.COORDINATE_WAYMO
md_scenario[SD.METADATA][SD.TIMESTEP] = np.asarray(list(scenario_pb2.timestamps_seconds), dtype=np.float32)
md_scenario[SD.METADATA][SD.TIMESTEP] = np.asarray(list(scenario.timestamps_seconds), dtype=np.float32)
md_scenario[SD.METADATA][SD.METADRIVE_PROCESSED] = False
md_scenario[SD.METADATA][SD.SDC_ID] = str(sdc_id)
md_scenario[SD.METADATA]["dataset"] = "waymo"
md_scenario[SD.METADATA]["scenario_id"] = scenario_pb2.scenario_id
md_scenario[SD.METADATA]["scenario_id"] = scenario.scenario_id[:id_end]
# TODO LQY Can we infer it?
# md_scenario[SD.METADATA]["source_file"] = str(file)
md_scenario[SD.METADATA]["source_file"] = scenario.scenario_id[id_end + 1:]
md_scenario[SD.METADATA]["track_length"] = track_length
# === Waymo specific data. Storing them here ===
md_scenario[SD.METADATA]["current_time_index"] = scenario_pb2.current_time_index
md_scenario[SD.METADATA]["sdc_track_index"] = scenario_pb2.sdc_track_index
md_scenario[SD.METADATA]["current_time_index"] = scenario.current_time_index
md_scenario[SD.METADATA]["sdc_track_index"] = scenario.sdc_track_index
# obj id
md_scenario[SD.METADATA]["objects_of_interest"] = [str(obj) for obj in scenario_pb2.objects_of_interest]
md_scenario[SD.METADATA]["objects_of_interest"] = [str(obj) for obj in scenario.objects_of_interest]
track_index = [obj.track_index for obj in scenario_pb2.tracks_to_predict]
track_id = [str(scenario_pb2.tracks[ind].id) for ind in track_index]
track_difficulty = [obj.difficulty for obj in scenario_pb2.tracks_to_predict]
track_index = [obj.track_index for obj in scenario.tracks_to_predict]
track_id = [str(scenario.tracks[ind].id) for ind in track_index]
track_difficulty = [obj.difficulty for obj in scenario.tracks_to_predict]
track_obj_type = [tracks[id]["type"] for id in track_id]
md_scenario[SD.METADATA]["tracks_to_predict"] = {
id: {
@@ -426,5 +428,10 @@ def get_waymo_scenarios(waymo_data_direction):
file_path = os.path.join(waymo_data_direction, file)
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
continue
scenarios += [s for s in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator()]
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
scenario = Scenario()
scenario.ParseFromString(data)
# a trick for loging file name
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file
scenarios.append(scenario)
return scenarios