diff --git a/scenarionet/converter/waymo/utils.py b/scenarionet/converter/waymo/utils.py index d055cb2..c2e8bba 100644 --- a/scenarionet/converter/waymo/utils.py +++ b/scenarionet/converter/waymo/utils.py @@ -20,6 +20,8 @@ except ImportError as e: from metadrive.scenario import ScenarioDescription as SD from metadrive.type import MetaDriveType +SPLIT_KEY = "|" + def extract_poly(message): x = [i.x for i in message] @@ -352,31 +354,31 @@ def compute_width(map): def convert_waymo_scenario(scenario, version): - scenario_pb2 = Scenario() - scenario_pb2.ParseFromString(scenario) - + scenario = scenario md_scenario = SD() - md_scenario[SD.ID] = scenario_pb2.scenario_id + id_end = scenario.scenario_id.find(SPLIT_KEY) + + md_scenario[SD.ID] = scenario.scenario_id[:id_end] md_scenario[SD.VERSION] = version # Please note that SDC track index is not identical to sdc_id. # sdc_id is a unique indicator to a track, while sdc_track_index is only the index of the sdc track # in the tracks datastructure. - track_length = len(list(scenario_pb2.timestamps_seconds)) + track_length = len(list(scenario.timestamps_seconds)) - tracks, sdc_id = extract_tracks(scenario_pb2.tracks, scenario_pb2.sdc_track_index, track_length) + tracks, sdc_id = extract_tracks(scenario.tracks, scenario.sdc_track_index, track_length) md_scenario[SD.LENGTH] = track_length md_scenario[SD.TRACKS] = tracks - dynamic_states = extract_dynamic_map_states(scenario_pb2.dynamic_map_states, track_length) + dynamic_states = extract_dynamic_map_states(scenario.dynamic_map_states, track_length) md_scenario[SD.DYNAMIC_MAP_STATES] = dynamic_states - map_features = extract_map_features(scenario_pb2.map_features) + map_features = extract_map_features(scenario.map_features) md_scenario[SD.MAP_FEATURES] = map_features compute_width(md_scenario[SD.MAP_FEATURES]) @@ -384,25 +386,25 @@ def convert_waymo_scenario(scenario, version): md_scenario[SD.METADATA] = {} md_scenario[SD.METADATA][SD.ID] = md_scenario[SD.ID] md_scenario[SD.METADATA][SD.COORDINATE] = MetaDriveType.COORDINATE_WAYMO - md_scenario[SD.METADATA][SD.TIMESTEP] = np.asarray(list(scenario_pb2.timestamps_seconds), dtype=np.float32) + md_scenario[SD.METADATA][SD.TIMESTEP] = np.asarray(list(scenario.timestamps_seconds), dtype=np.float32) md_scenario[SD.METADATA][SD.METADRIVE_PROCESSED] = False md_scenario[SD.METADATA][SD.SDC_ID] = str(sdc_id) md_scenario[SD.METADATA]["dataset"] = "waymo" - md_scenario[SD.METADATA]["scenario_id"] = scenario_pb2.scenario_id + md_scenario[SD.METADATA]["scenario_id"] = scenario.scenario_id[:id_end] # TODO LQY Can we infer it? - # md_scenario[SD.METADATA]["source_file"] = str(file) + md_scenario[SD.METADATA]["source_file"] = scenario.scenario_id[id_end + 1:] md_scenario[SD.METADATA]["track_length"] = track_length # === Waymo specific data. Storing them here === - md_scenario[SD.METADATA]["current_time_index"] = scenario_pb2.current_time_index - md_scenario[SD.METADATA]["sdc_track_index"] = scenario_pb2.sdc_track_index + md_scenario[SD.METADATA]["current_time_index"] = scenario.current_time_index + md_scenario[SD.METADATA]["sdc_track_index"] = scenario.sdc_track_index # obj id - md_scenario[SD.METADATA]["objects_of_interest"] = [str(obj) for obj in scenario_pb2.objects_of_interest] + md_scenario[SD.METADATA]["objects_of_interest"] = [str(obj) for obj in scenario.objects_of_interest] - track_index = [obj.track_index for obj in scenario_pb2.tracks_to_predict] - track_id = [str(scenario_pb2.tracks[ind].id) for ind in track_index] - track_difficulty = [obj.difficulty for obj in scenario_pb2.tracks_to_predict] + track_index = [obj.track_index for obj in scenario.tracks_to_predict] + track_id = [str(scenario.tracks[ind].id) for ind in track_index] + track_difficulty = [obj.difficulty for obj in scenario.tracks_to_predict] track_obj_type = [tracks[id]["type"] for id in track_id] md_scenario[SD.METADATA]["tracks_to_predict"] = { id: { @@ -426,5 +428,10 @@ def get_waymo_scenarios(waymo_data_direction): file_path = os.path.join(waymo_data_direction, file) if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)): continue - scenarios += [s for s in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator()] + for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator(): + scenario = Scenario() + scenario.ParseFromString(data) + # a trick for loging file name + scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file + scenarios.append(scenario) return scenarios