source file to data
This commit is contained in:
@@ -20,6 +20,8 @@ except ImportError as e:
|
|||||||
from metadrive.scenario import ScenarioDescription as SD
|
from metadrive.scenario import ScenarioDescription as SD
|
||||||
from metadrive.type import MetaDriveType
|
from metadrive.type import MetaDriveType
|
||||||
|
|
||||||
|
SPLIT_KEY = "|"
|
||||||
|
|
||||||
|
|
||||||
def extract_poly(message):
|
def extract_poly(message):
|
||||||
x = [i.x for i in message]
|
x = [i.x for i in message]
|
||||||
@@ -352,31 +354,31 @@ def compute_width(map):
|
|||||||
|
|
||||||
|
|
||||||
def convert_waymo_scenario(scenario, version):
|
def convert_waymo_scenario(scenario, version):
|
||||||
scenario_pb2 = Scenario()
|
scenario = scenario
|
||||||
scenario_pb2.ParseFromString(scenario)
|
|
||||||
|
|
||||||
md_scenario = SD()
|
md_scenario = SD()
|
||||||
|
|
||||||
md_scenario[SD.ID] = scenario_pb2.scenario_id
|
id_end = scenario.scenario_id.find(SPLIT_KEY)
|
||||||
|
|
||||||
|
md_scenario[SD.ID] = scenario.scenario_id[:id_end]
|
||||||
md_scenario[SD.VERSION] = version
|
md_scenario[SD.VERSION] = version
|
||||||
|
|
||||||
# Please note that SDC track index is not identical to sdc_id.
|
# Please note that SDC track index is not identical to sdc_id.
|
||||||
# sdc_id is a unique indicator to a track, while sdc_track_index is only the index of the sdc track
|
# sdc_id is a unique indicator to a track, while sdc_track_index is only the index of the sdc track
|
||||||
# in the tracks datastructure.
|
# in the tracks datastructure.
|
||||||
|
|
||||||
track_length = len(list(scenario_pb2.timestamps_seconds))
|
track_length = len(list(scenario.timestamps_seconds))
|
||||||
|
|
||||||
tracks, sdc_id = extract_tracks(scenario_pb2.tracks, scenario_pb2.sdc_track_index, track_length)
|
tracks, sdc_id = extract_tracks(scenario.tracks, scenario.sdc_track_index, track_length)
|
||||||
|
|
||||||
md_scenario[SD.LENGTH] = track_length
|
md_scenario[SD.LENGTH] = track_length
|
||||||
|
|
||||||
md_scenario[SD.TRACKS] = tracks
|
md_scenario[SD.TRACKS] = tracks
|
||||||
|
|
||||||
dynamic_states = extract_dynamic_map_states(scenario_pb2.dynamic_map_states, track_length)
|
dynamic_states = extract_dynamic_map_states(scenario.dynamic_map_states, track_length)
|
||||||
|
|
||||||
md_scenario[SD.DYNAMIC_MAP_STATES] = dynamic_states
|
md_scenario[SD.DYNAMIC_MAP_STATES] = dynamic_states
|
||||||
|
|
||||||
map_features = extract_map_features(scenario_pb2.map_features)
|
map_features = extract_map_features(scenario.map_features)
|
||||||
md_scenario[SD.MAP_FEATURES] = map_features
|
md_scenario[SD.MAP_FEATURES] = map_features
|
||||||
|
|
||||||
compute_width(md_scenario[SD.MAP_FEATURES])
|
compute_width(md_scenario[SD.MAP_FEATURES])
|
||||||
@@ -384,25 +386,25 @@ def convert_waymo_scenario(scenario, version):
|
|||||||
md_scenario[SD.METADATA] = {}
|
md_scenario[SD.METADATA] = {}
|
||||||
md_scenario[SD.METADATA][SD.ID] = md_scenario[SD.ID]
|
md_scenario[SD.METADATA][SD.ID] = md_scenario[SD.ID]
|
||||||
md_scenario[SD.METADATA][SD.COORDINATE] = MetaDriveType.COORDINATE_WAYMO
|
md_scenario[SD.METADATA][SD.COORDINATE] = MetaDriveType.COORDINATE_WAYMO
|
||||||
md_scenario[SD.METADATA][SD.TIMESTEP] = np.asarray(list(scenario_pb2.timestamps_seconds), dtype=np.float32)
|
md_scenario[SD.METADATA][SD.TIMESTEP] = np.asarray(list(scenario.timestamps_seconds), dtype=np.float32)
|
||||||
md_scenario[SD.METADATA][SD.METADRIVE_PROCESSED] = False
|
md_scenario[SD.METADATA][SD.METADRIVE_PROCESSED] = False
|
||||||
md_scenario[SD.METADATA][SD.SDC_ID] = str(sdc_id)
|
md_scenario[SD.METADATA][SD.SDC_ID] = str(sdc_id)
|
||||||
md_scenario[SD.METADATA]["dataset"] = "waymo"
|
md_scenario[SD.METADATA]["dataset"] = "waymo"
|
||||||
md_scenario[SD.METADATA]["scenario_id"] = scenario_pb2.scenario_id
|
md_scenario[SD.METADATA]["scenario_id"] = scenario.scenario_id[:id_end]
|
||||||
# TODO LQY Can we infer it?
|
# TODO LQY Can we infer it?
|
||||||
# md_scenario[SD.METADATA]["source_file"] = str(file)
|
md_scenario[SD.METADATA]["source_file"] = scenario.scenario_id[id_end + 1:]
|
||||||
md_scenario[SD.METADATA]["track_length"] = track_length
|
md_scenario[SD.METADATA]["track_length"] = track_length
|
||||||
|
|
||||||
# === Waymo specific data. Storing them here ===
|
# === Waymo specific data. Storing them here ===
|
||||||
md_scenario[SD.METADATA]["current_time_index"] = scenario_pb2.current_time_index
|
md_scenario[SD.METADATA]["current_time_index"] = scenario.current_time_index
|
||||||
md_scenario[SD.METADATA]["sdc_track_index"] = scenario_pb2.sdc_track_index
|
md_scenario[SD.METADATA]["sdc_track_index"] = scenario.sdc_track_index
|
||||||
|
|
||||||
# obj id
|
# obj id
|
||||||
md_scenario[SD.METADATA]["objects_of_interest"] = [str(obj) for obj in scenario_pb2.objects_of_interest]
|
md_scenario[SD.METADATA]["objects_of_interest"] = [str(obj) for obj in scenario.objects_of_interest]
|
||||||
|
|
||||||
track_index = [obj.track_index for obj in scenario_pb2.tracks_to_predict]
|
track_index = [obj.track_index for obj in scenario.tracks_to_predict]
|
||||||
track_id = [str(scenario_pb2.tracks[ind].id) for ind in track_index]
|
track_id = [str(scenario.tracks[ind].id) for ind in track_index]
|
||||||
track_difficulty = [obj.difficulty for obj in scenario_pb2.tracks_to_predict]
|
track_difficulty = [obj.difficulty for obj in scenario.tracks_to_predict]
|
||||||
track_obj_type = [tracks[id]["type"] for id in track_id]
|
track_obj_type = [tracks[id]["type"] for id in track_id]
|
||||||
md_scenario[SD.METADATA]["tracks_to_predict"] = {
|
md_scenario[SD.METADATA]["tracks_to_predict"] = {
|
||||||
id: {
|
id: {
|
||||||
@@ -426,5 +428,10 @@ def get_waymo_scenarios(waymo_data_direction):
|
|||||||
file_path = os.path.join(waymo_data_direction, file)
|
file_path = os.path.join(waymo_data_direction, file)
|
||||||
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
||||||
continue
|
continue
|
||||||
scenarios += [s for s in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator()]
|
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
|
||||||
|
scenario = Scenario()
|
||||||
|
scenario.ParseFromString(data)
|
||||||
|
# a trick for loging file name
|
||||||
|
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file
|
||||||
|
scenarios.append(scenario)
|
||||||
return scenarios
|
return scenarios
|
||||||
|
|||||||
Reference in New Issue
Block a user