source file to data
This commit is contained in:
@@ -20,6 +20,8 @@ except ImportError as e:
|
||||
from metadrive.scenario import ScenarioDescription as SD
|
||||
from metadrive.type import MetaDriveType
|
||||
|
||||
SPLIT_KEY = "|"
|
||||
|
||||
|
||||
def extract_poly(message):
|
||||
x = [i.x for i in message]
|
||||
@@ -352,31 +354,31 @@ def compute_width(map):
|
||||
|
||||
|
||||
def convert_waymo_scenario(scenario, version):
|
||||
scenario_pb2 = Scenario()
|
||||
scenario_pb2.ParseFromString(scenario)
|
||||
|
||||
scenario = scenario
|
||||
md_scenario = SD()
|
||||
|
||||
md_scenario[SD.ID] = scenario_pb2.scenario_id
|
||||
id_end = scenario.scenario_id.find(SPLIT_KEY)
|
||||
|
||||
md_scenario[SD.ID] = scenario.scenario_id[:id_end]
|
||||
md_scenario[SD.VERSION] = version
|
||||
|
||||
# Please note that SDC track index is not identical to sdc_id.
|
||||
# sdc_id is a unique indicator to a track, while sdc_track_index is only the index of the sdc track
|
||||
# in the tracks datastructure.
|
||||
|
||||
track_length = len(list(scenario_pb2.timestamps_seconds))
|
||||
track_length = len(list(scenario.timestamps_seconds))
|
||||
|
||||
tracks, sdc_id = extract_tracks(scenario_pb2.tracks, scenario_pb2.sdc_track_index, track_length)
|
||||
tracks, sdc_id = extract_tracks(scenario.tracks, scenario.sdc_track_index, track_length)
|
||||
|
||||
md_scenario[SD.LENGTH] = track_length
|
||||
|
||||
md_scenario[SD.TRACKS] = tracks
|
||||
|
||||
dynamic_states = extract_dynamic_map_states(scenario_pb2.dynamic_map_states, track_length)
|
||||
dynamic_states = extract_dynamic_map_states(scenario.dynamic_map_states, track_length)
|
||||
|
||||
md_scenario[SD.DYNAMIC_MAP_STATES] = dynamic_states
|
||||
|
||||
map_features = extract_map_features(scenario_pb2.map_features)
|
||||
map_features = extract_map_features(scenario.map_features)
|
||||
md_scenario[SD.MAP_FEATURES] = map_features
|
||||
|
||||
compute_width(md_scenario[SD.MAP_FEATURES])
|
||||
@@ -384,25 +386,25 @@ def convert_waymo_scenario(scenario, version):
|
||||
md_scenario[SD.METADATA] = {}
|
||||
md_scenario[SD.METADATA][SD.ID] = md_scenario[SD.ID]
|
||||
md_scenario[SD.METADATA][SD.COORDINATE] = MetaDriveType.COORDINATE_WAYMO
|
||||
md_scenario[SD.METADATA][SD.TIMESTEP] = np.asarray(list(scenario_pb2.timestamps_seconds), dtype=np.float32)
|
||||
md_scenario[SD.METADATA][SD.TIMESTEP] = np.asarray(list(scenario.timestamps_seconds), dtype=np.float32)
|
||||
md_scenario[SD.METADATA][SD.METADRIVE_PROCESSED] = False
|
||||
md_scenario[SD.METADATA][SD.SDC_ID] = str(sdc_id)
|
||||
md_scenario[SD.METADATA]["dataset"] = "waymo"
|
||||
md_scenario[SD.METADATA]["scenario_id"] = scenario_pb2.scenario_id
|
||||
md_scenario[SD.METADATA]["scenario_id"] = scenario.scenario_id[:id_end]
|
||||
# TODO LQY Can we infer it?
|
||||
# md_scenario[SD.METADATA]["source_file"] = str(file)
|
||||
md_scenario[SD.METADATA]["source_file"] = scenario.scenario_id[id_end + 1:]
|
||||
md_scenario[SD.METADATA]["track_length"] = track_length
|
||||
|
||||
# === Waymo specific data. Storing them here ===
|
||||
md_scenario[SD.METADATA]["current_time_index"] = scenario_pb2.current_time_index
|
||||
md_scenario[SD.METADATA]["sdc_track_index"] = scenario_pb2.sdc_track_index
|
||||
md_scenario[SD.METADATA]["current_time_index"] = scenario.current_time_index
|
||||
md_scenario[SD.METADATA]["sdc_track_index"] = scenario.sdc_track_index
|
||||
|
||||
# obj id
|
||||
md_scenario[SD.METADATA]["objects_of_interest"] = [str(obj) for obj in scenario_pb2.objects_of_interest]
|
||||
md_scenario[SD.METADATA]["objects_of_interest"] = [str(obj) for obj in scenario.objects_of_interest]
|
||||
|
||||
track_index = [obj.track_index for obj in scenario_pb2.tracks_to_predict]
|
||||
track_id = [str(scenario_pb2.tracks[ind].id) for ind in track_index]
|
||||
track_difficulty = [obj.difficulty for obj in scenario_pb2.tracks_to_predict]
|
||||
track_index = [obj.track_index for obj in scenario.tracks_to_predict]
|
||||
track_id = [str(scenario.tracks[ind].id) for ind in track_index]
|
||||
track_difficulty = [obj.difficulty for obj in scenario.tracks_to_predict]
|
||||
track_obj_type = [tracks[id]["type"] for id in track_id]
|
||||
md_scenario[SD.METADATA]["tracks_to_predict"] = {
|
||||
id: {
|
||||
@@ -426,5 +428,10 @@ def get_waymo_scenarios(waymo_data_direction):
|
||||
file_path = os.path.join(waymo_data_direction, file)
|
||||
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
||||
continue
|
||||
scenarios += [s for s in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator()]
|
||||
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
|
||||
scenario = Scenario()
|
||||
scenario.ParseFromString(data)
|
||||
# a trick for loging file name
|
||||
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file
|
||||
scenarios.append(scenario)
|
||||
return scenarios
|
||||
|
||||
Reference in New Issue
Block a user