Add support to AV2 (#48)
* add support to av2 --------- Co-authored-by: Alan-LanFeng <fenglan18@outook.com>
This commit is contained in:
77
scenarionet/convert_argoverse2.py
Normal file
77
scenarionet/convert_argoverse2.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
desc = "Build database from Argoverse scenarios"
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import shutil
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_REPO_PATH
|
||||||
|
from scenarionet.converter.utils import write_to_directory
|
||||||
|
from scenarionet.converter.argoverse2.utils import convert_av2_scenario, get_av2_scenarios, preprocess_av2_scenarios
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=desc)
|
||||||
|
parser.add_argument(
|
||||||
|
"--database_path",
|
||||||
|
"-d",
|
||||||
|
default=os.path.join(SCENARIONET_DATASET_PATH, "av2"),
|
||||||
|
help="A directory, the path to place the converted data"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_name", "-n", default="av2", help="Dataset name, will be used to generate scenario files"
|
||||||
|
)
|
||||||
|
parser.add_argument("--version", "-v", default='v2', help="version")
|
||||||
|
parser.add_argument("--overwrite", action="store_true", help="If the database_path exists, whether to overwrite it")
|
||||||
|
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw_data_path",
|
||||||
|
default=os.path.join(SCENARIONET_REPO_PATH, "waymo_origin"),
|
||||||
|
help="The directory stores all waymo tfrecord"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_file_index",
|
||||||
|
default=0,
|
||||||
|
type=int,
|
||||||
|
help="Control how many files to use. We will list all files in the raw data folder "
|
||||||
|
"and select files[start_file_index: start_file_index+num_files]"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_files",
|
||||||
|
default=1000,
|
||||||
|
type=int,
|
||||||
|
help="Control how many files to use. We will list all files in the raw data folder "
|
||||||
|
"and select files[start_file_index: start_file_index+num_files]"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
overwrite = args.overwrite
|
||||||
|
dataset_name = args.dataset_name
|
||||||
|
output_path = args.database_path
|
||||||
|
version = args.version
|
||||||
|
|
||||||
|
save_path = output_path
|
||||||
|
if os.path.exists(output_path):
|
||||||
|
if not overwrite:
|
||||||
|
raise ValueError(
|
||||||
|
"Directory {} already exists! Abort. "
|
||||||
|
"\n Try setting overwrite=True or adding --overwrite".format(output_path)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shutil.rmtree(output_path)
|
||||||
|
|
||||||
|
av2_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path)
|
||||||
|
|
||||||
|
scenarios = get_av2_scenarios(av2_data_directory, args.start_file_index, args.num_files)
|
||||||
|
|
||||||
|
write_to_directory(
|
||||||
|
convert_func=convert_av2_scenario,
|
||||||
|
scenarios=scenarios,
|
||||||
|
output_path=output_path,
|
||||||
|
dataset_version=version,
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
overwrite=overwrite,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
preprocess=preprocess_av2_scenarios
|
||||||
|
)
|
||||||
68
scenarionet/converter/argoverse2/type.py
Normal file
68
scenarionet/converter/argoverse2/type.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from metadrive.type import MetaDriveType
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from av2.datasets.motion_forecasting.data_schema import ObjectType
|
||||||
|
from av2.map.lane_segment import LaneType, LaneMarkType
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning("Can not import av2-devkit: {}".format(e))
|
||||||
|
|
||||||
|
|
||||||
|
def get_traffic_obj_type(av2_obj_type):
|
||||||
|
if av2_obj_type == ObjectType.VEHICLE or av2_obj_type == ObjectType.BUS:
|
||||||
|
return MetaDriveType.VEHICLE
|
||||||
|
# elif av2_obj_type == ObjectType.MOTORCYCLIST:
|
||||||
|
# return MetaDriveType.MOTORCYCLIST
|
||||||
|
elif av2_obj_type == ObjectType.PEDESTRIAN:
|
||||||
|
return MetaDriveType.PEDESTRIAN
|
||||||
|
elif av2_obj_type == ObjectType.CYCLIST:
|
||||||
|
return MetaDriveType.CYCLIST
|
||||||
|
# elif av2_obj_type == ObjectType.BUS:
|
||||||
|
# return MetaDriveType.BUS
|
||||||
|
# elif av2_obj_type == ObjectType.STATIC:
|
||||||
|
# return MetaDriveType.STATIC
|
||||||
|
# elif av2_obj_type == ObjectType.CONSTRUCTION:
|
||||||
|
# return MetaDriveType.CONSTRUCTION
|
||||||
|
# elif av2_obj_type == ObjectType.BACKGROUND:
|
||||||
|
# return MetaDriveType.BACKGROUND
|
||||||
|
# elif av2_obj_type == ObjectType.RIDERLESS_BICYCLE:
|
||||||
|
# return MetaDriveType.RIDERLESS_BICYCLE
|
||||||
|
# elif av2_obj_type == ObjectType.UNKNOWN:
|
||||||
|
# return MetaDriveType.UNKNOWN
|
||||||
|
else:
|
||||||
|
return MetaDriveType.OTHER
|
||||||
|
|
||||||
|
|
||||||
|
def get_lane_type(av2_lane_type):
|
||||||
|
if av2_lane_type == LaneType.VEHICLE or av2_lane_type == LaneType.BUS:
|
||||||
|
return MetaDriveType.LANE_SURFACE_STREET
|
||||||
|
elif av2_lane_type == LaneType.BIKE:
|
||||||
|
return MetaDriveType.LANE_BIKE_LANE
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown nuplan lane type: {}".format(av2_lane_type))
|
||||||
|
|
||||||
|
|
||||||
|
def get_lane_mark_type(av2_mark_type):
|
||||||
|
conversion_dict = {
|
||||||
|
LaneMarkType.DOUBLE_SOLID_YELLOW: "ROAD_LINE_SOLID_DOUBLE_YELLOW",
|
||||||
|
LaneMarkType.DOUBLE_SOLID_WHITE: "ROAD_LINE_SOLID_DOUBLE_WHITE",
|
||||||
|
LaneMarkType.SOLID_YELLOW: "ROAD_LINE_SOLID_SINGLE_YELLOW",
|
||||||
|
LaneMarkType.SOLID_WHITE: "ROAD_LINE_SOLID_SINGLE_WHITE",
|
||||||
|
LaneMarkType.DASHED_WHITE: "ROAD_LINE_BROKEN_SINGLE_WHITE",
|
||||||
|
LaneMarkType.DASHED_YELLOW: "ROAD_LINE_BROKEN_SINGLE_YELLOW",
|
||||||
|
LaneMarkType.DASH_SOLID_YELLOW: "ROAD_LINE_SOLID_DOUBLE_YELLOW",
|
||||||
|
LaneMarkType.DASH_SOLID_WHITE: "ROAD_LINE_SOLID_DOUBLE_WHITE",
|
||||||
|
LaneMarkType.DOUBLE_DASH_YELLOW: "ROAD_LINE_BROKEN_SINGLE_YELLOW",
|
||||||
|
LaneMarkType.DOUBLE_DASH_WHITE: "ROAD_LINE_BROKEN_SINGLE_WHITE",
|
||||||
|
LaneMarkType.SOLID_DASH_WHITE: "ROAD_LINE_BROKEN_SINGLE_WHITE",
|
||||||
|
LaneMarkType.SOLID_DASH_YELLOW: "ROAD_LINE_BROKEN_SINGLE_YELLOW",
|
||||||
|
LaneMarkType.SOLID_BLUE: "UNKNOWN_LINE",
|
||||||
|
LaneMarkType.NONE: "UNKNOWN_LINE",
|
||||||
|
LaneMarkType.UNKNOWN: "UNKNOWN_LINE"
|
||||||
|
}
|
||||||
|
|
||||||
|
return conversion_dict.get(av2_mark_type, "UNKNOWN_LINE")
|
||||||
257
scenarionet/converter/argoverse2/utils.py
Normal file
257
scenarionet/converter/argoverse2/utils.py
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from scenarionet.converter.utils import mph_to_kmh
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
from metadrive.scenario import ScenarioDescription as SD
|
||||||
|
from metadrive.type import MetaDriveType
|
||||||
|
|
||||||
|
from scenarionet.converter.argoverse2.type import get_traffic_obj_type, get_lane_type, get_lane_mark_type
|
||||||
|
from av2.datasets.motion_forecasting import scenario_serialization
|
||||||
|
from av2.map.map_api import ArgoverseStaticMap
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
_ESTIMATED_VEHICLE_LENGTH_M: Final[float] = 4.0
|
||||||
|
_ESTIMATED_VEHICLE_WIDTH_M: Final[float] = 2.0
|
||||||
|
_ESTIMATED_CYCLIST_LENGTH_M: Final[float] = 2.0
|
||||||
|
_ESTIMATED_CYCLIST_WIDTH_M: Final[float] = 0.7
|
||||||
|
_ESTIMATED_PEDESTRIAN_LENGTH_M: Final[float] = 0.5
|
||||||
|
_ESTIMATED_PEDESTRIAN_WIDTH_M: Final[float] = 0.5
|
||||||
|
_ESTIMATED_BUS_LENGTH_M: Final[float] = 12.0
|
||||||
|
_ESTIMATED_BUS_WIDTH_M: Final[float] = 2.5
|
||||||
|
|
||||||
|
_HIGHWAY_SPEED_LIMIT_MPH: Final[float] = 85.0
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tracks(tracks, sdc_idx, track_length):
|
||||||
|
ret = dict()
|
||||||
|
|
||||||
|
def _object_state_template(object_id):
|
||||||
|
return dict(type=None, state=dict(# Never add extra dim if the value is scalar.
|
||||||
|
position=np.zeros([track_length, 3], dtype=np.float32), length=np.zeros([track_length], dtype=np.float32),
|
||||||
|
width=np.zeros([track_length], dtype=np.float32), height=np.zeros([track_length], dtype=np.float32),
|
||||||
|
heading=np.zeros([track_length], dtype=np.float32), velocity=np.zeros([track_length, 2], dtype=np.float32),
|
||||||
|
valid=np.zeros([track_length], dtype=bool), ),
|
||||||
|
metadata=dict(track_length=track_length, type=None, object_id=object_id, dataset="av2"))
|
||||||
|
|
||||||
|
track_category = []
|
||||||
|
|
||||||
|
for obj in tracks:
|
||||||
|
object_id = obj.track_id
|
||||||
|
track_category.append(obj.category.value)
|
||||||
|
obj_state = _object_state_template(object_id)
|
||||||
|
# Transform it to Waymo type string
|
||||||
|
obj_state["type"] = get_traffic_obj_type(obj.object_type)
|
||||||
|
if obj_state["type"] == MetaDriveType.VEHICLE:
|
||||||
|
length = _ESTIMATED_VEHICLE_LENGTH_M
|
||||||
|
width = _ESTIMATED_VEHICLE_WIDTH_M
|
||||||
|
elif obj_state["type"] == MetaDriveType.PEDESTRIAN:
|
||||||
|
length = _ESTIMATED_PEDESTRIAN_LENGTH_M
|
||||||
|
width = _ESTIMATED_PEDESTRIAN_WIDTH_M
|
||||||
|
elif obj_state["type"] == MetaDriveType.CYCLIST:
|
||||||
|
length = _ESTIMATED_CYCLIST_LENGTH_M
|
||||||
|
width = _ESTIMATED_CYCLIST_WIDTH_M
|
||||||
|
# elif obj_state["type"] == MetaDriveType.BUS:
|
||||||
|
# length = _ESTIMATED_BUS_LENGTH_M
|
||||||
|
# width = _ESTIMATED_BUS_WIDTH_M
|
||||||
|
else:
|
||||||
|
length = 1
|
||||||
|
width = 1
|
||||||
|
|
||||||
|
for _, state in enumerate(obj.object_states):
|
||||||
|
step_count = state.timestep
|
||||||
|
obj_state["state"]["position"][step_count][0] = state.position[0]
|
||||||
|
obj_state["state"]["position"][step_count][1] = state.position[1]
|
||||||
|
obj_state["state"]["position"][step_count][2] = 0
|
||||||
|
|
||||||
|
# l = [state.length for state in obj.states]
|
||||||
|
# w = [state.width for state in obj.states]
|
||||||
|
# h = [state.height for state in obj.states]
|
||||||
|
# obj_state["state"]["size"] = np.stack([l, w, h], 1).astype("float32")
|
||||||
|
obj_state["state"]["length"][step_count] = length
|
||||||
|
obj_state["state"]["width"][step_count] = width
|
||||||
|
obj_state["state"]["height"][step_count] = 1
|
||||||
|
|
||||||
|
# heading = [state.heading for state in obj.states]
|
||||||
|
obj_state["state"]["heading"][step_count] = state.heading
|
||||||
|
|
||||||
|
obj_state["state"]["velocity"][step_count][0] = state.velocity[0]
|
||||||
|
obj_state["state"]["velocity"][step_count][1] = state.velocity[1]
|
||||||
|
|
||||||
|
obj_state["state"]["valid"][step_count] = True
|
||||||
|
|
||||||
|
obj_state["metadata"]["type"] = obj_state["type"]
|
||||||
|
|
||||||
|
ret[object_id] = obj_state
|
||||||
|
|
||||||
|
return ret, track_category
|
||||||
|
|
||||||
|
|
||||||
|
def extract_lane_mark(lane_mark):
|
||||||
|
line = dict()
|
||||||
|
line["type"] = get_lane_mark_type(lane_mark.mark_type)
|
||||||
|
line["polyline"] = lane_mark.polyline.astype(np.float32)
|
||||||
|
return line
|
||||||
|
|
||||||
|
|
||||||
|
def extract_map_features(map_features):
|
||||||
|
# with open(
|
||||||
|
# "/Users/fenglan/Desktop/vita-group/code/mdsn/scenarionet/data_sample/waymo_converted_0/sd_waymo_v1.2_7e8422433c66cc13.pkl",
|
||||||
|
# 'rb') as f:
|
||||||
|
# waymo_sample = pickle.load(f)
|
||||||
|
ret = {}
|
||||||
|
vector_lane_segments = map_features.get_scenario_lane_segments()
|
||||||
|
vector_drivable_areas = map_features.get_scenario_vector_drivable_areas()
|
||||||
|
ped_crossings = map_features.get_scenario_ped_crossings()
|
||||||
|
|
||||||
|
ids = map_features.get_scenario_lane_segment_ids()
|
||||||
|
|
||||||
|
max_id = max(ids)
|
||||||
|
for seg in vector_lane_segments:
|
||||||
|
center = {}
|
||||||
|
lane_id = str(seg.id)
|
||||||
|
|
||||||
|
left_id = str(seg.id + max_id + 1)
|
||||||
|
right_id = str(seg.id + max_id + 2)
|
||||||
|
left_marking = extract_lane_mark(seg.left_lane_marking)
|
||||||
|
right_marking = extract_lane_mark(seg.right_lane_marking)
|
||||||
|
|
||||||
|
ret[left_id] = left_marking
|
||||||
|
ret[right_id] = right_marking
|
||||||
|
|
||||||
|
center["speed_limit_mph"] = _HIGHWAY_SPEED_LIMIT_MPH
|
||||||
|
|
||||||
|
center["speed_limit_kmh"] = mph_to_kmh(_HIGHWAY_SPEED_LIMIT_MPH)
|
||||||
|
|
||||||
|
center["type"] = get_lane_type(seg.lane_type)
|
||||||
|
|
||||||
|
polyline = map_features.get_lane_segment_centerline(seg.id)
|
||||||
|
center["polyline"] = polyline.astype(np.float32)
|
||||||
|
|
||||||
|
center["interpolating"] = True
|
||||||
|
|
||||||
|
center["entry_lanes"] = [str(id) for id in seg.predecessors]
|
||||||
|
|
||||||
|
center["exit_lanes"] = [str(id) for id in seg.successors]
|
||||||
|
|
||||||
|
center["left_boundaries"] = []
|
||||||
|
|
||||||
|
center["right_boundaries"] = []
|
||||||
|
|
||||||
|
center["left_neighbor"] = []
|
||||||
|
|
||||||
|
center["right_neighbor"] = []
|
||||||
|
center['width'] = np.zeros([len(polyline), 2], dtype=np.float32)
|
||||||
|
|
||||||
|
ret[lane_id] = center
|
||||||
|
|
||||||
|
# for edge in vector_drivable_areas:
|
||||||
|
# bound = dict()
|
||||||
|
# bound["type"] = MetaDriveType.BOUNDARY_LINE
|
||||||
|
# bound["polyline"] = edge.xyz.astype(np.float32)
|
||||||
|
# ret[str(edge.id)] = bound
|
||||||
|
|
||||||
|
for cross in ped_crossings:
|
||||||
|
bound = dict()
|
||||||
|
bound["type"] = MetaDriveType.CROSSWALK
|
||||||
|
bound["polygon"] = cross.polygon.astype(np.float32)
|
||||||
|
ret[str(cross.id)] = bound
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_av2_scenarios(av2_data_directory, start_index, num):
|
||||||
|
# parse raw data from input path to output path,
|
||||||
|
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
|
||||||
|
logger.info("\nReading raw data")
|
||||||
|
|
||||||
|
all_scenario_files = sorted(Path(av2_data_directory).rglob("*.parquet"))
|
||||||
|
|
||||||
|
return all_scenario_files
|
||||||
|
|
||||||
|
|
||||||
|
def convert_av2_scenario(scenario, version):
|
||||||
|
md_scenario = SD()
|
||||||
|
|
||||||
|
md_scenario[SD.ID] = scenario.scenario_id
|
||||||
|
md_scenario[SD.VERSION] = version
|
||||||
|
|
||||||
|
# Please note that SDC track index is not identical to sdc_id.
|
||||||
|
# sdc_id is a unique indicator to a track, while sdc_track_index is only the index of the sdc track
|
||||||
|
# in the tracks datastructure.
|
||||||
|
|
||||||
|
track_length = scenario.timestamps_ns.shape[0]
|
||||||
|
|
||||||
|
tracks, category = extract_tracks(scenario.tracks, scenario.focal_track_id, track_length)
|
||||||
|
|
||||||
|
md_scenario[SD.LENGTH] = track_length
|
||||||
|
|
||||||
|
md_scenario[SD.TRACKS] = tracks
|
||||||
|
|
||||||
|
md_scenario[SD.DYNAMIC_MAP_STATES] = {}
|
||||||
|
|
||||||
|
map_features = extract_map_features(scenario.static_map)
|
||||||
|
md_scenario[SD.MAP_FEATURES] = map_features
|
||||||
|
|
||||||
|
# compute_width(md_scenario[SD.MAP_FEATURES])
|
||||||
|
|
||||||
|
md_scenario[SD.METADATA] = {}
|
||||||
|
md_scenario[SD.METADATA][SD.ID] = md_scenario[SD.ID]
|
||||||
|
md_scenario[SD.METADATA][SD.COORDINATE] = MetaDriveType.COORDINATE_WAYMO
|
||||||
|
md_scenario[SD.METADATA][SD.TIMESTEP] = np.array(list(range(track_length))) / 10
|
||||||
|
md_scenario[SD.METADATA][SD.METADRIVE_PROCESSED] = False
|
||||||
|
md_scenario[SD.METADATA][SD.SDC_ID] = 'AV'
|
||||||
|
md_scenario[SD.METADATA]["dataset"] = "av2"
|
||||||
|
md_scenario[SD.METADATA]["scenario_id"] = scenario.scenario_id
|
||||||
|
md_scenario[SD.METADATA]["source_file"] = scenario.scenario_id
|
||||||
|
md_scenario[SD.METADATA]["track_length"] = track_length
|
||||||
|
|
||||||
|
# === Waymo specific data. Storing them here ===
|
||||||
|
md_scenario[SD.METADATA]["current_time_index"] = 49
|
||||||
|
md_scenario[SD.METADATA]["sdc_track_index"] = scenario.focal_track_id
|
||||||
|
|
||||||
|
# obj id
|
||||||
|
obj_keys = list(tracks.keys())
|
||||||
|
md_scenario[SD.METADATA]["objects_of_interest"] = [obj_keys[idx] for idx, cat in enumerate(category) if cat == 2]
|
||||||
|
|
||||||
|
track_index = [obj_keys.index(scenario.focal_track_id)]
|
||||||
|
track_id = [scenario.focal_track_id]
|
||||||
|
track_difficulty = [0]
|
||||||
|
track_obj_type = [tracks[id]["type"] for id in track_id]
|
||||||
|
md_scenario[SD.METADATA]["tracks_to_predict"] = {
|
||||||
|
id: {
|
||||||
|
"track_index": track_index[count],
|
||||||
|
"track_id": id,
|
||||||
|
"difficulty": track_difficulty[count],
|
||||||
|
"object_type": track_obj_type[count]
|
||||||
|
}
|
||||||
|
for count, id in enumerate(track_id)
|
||||||
|
}
|
||||||
|
# clean memory
|
||||||
|
del scenario
|
||||||
|
return md_scenario
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_av2_scenarios(files, worker_index):
|
||||||
|
"""
|
||||||
|
Convert the waymo files into scenario_pb2. This happens in each worker.
|
||||||
|
:param files: a list of file path
|
||||||
|
:param worker_index, the index for the worker
|
||||||
|
:return: a list of scenario_pb2
|
||||||
|
"""
|
||||||
|
|
||||||
|
for scenario_path in tqdm(files, desc="Process av2 scenarios for worker {}".format(worker_index)):
|
||||||
|
scenario_id = scenario_path.stem.split("_")[-1]
|
||||||
|
static_map_path = (scenario_path.parents[0] / f"log_map_archive_{scenario_id}.json")
|
||||||
|
scenario = scenario_serialization.load_argoverse_scenario_parquet(scenario_path)
|
||||||
|
static_map = ArgoverseStaticMap.from_json(static_map_path)
|
||||||
|
scenario.static_map = static_map
|
||||||
|
yield scenario
|
||||||
|
|
||||||
|
# logger.info("Worker {}: Process {} waymo scenarios".format(worker_index, len(scenarios))) # return scenarios
|
||||||
@@ -19,7 +19,7 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.info(e)
|
logger.info(e)
|
||||||
|
|
||||||
import scenarionet.converter.waymo.waymo_protos.scenario_pb2
|
import scenarionet.converter.waymo.waymo_protos.scenario_pb2 as scenario_pb2
|
||||||
|
|
||||||
from metadrive.scenario import ScenarioDescription as SD
|
from metadrive.scenario import ScenarioDescription as SD
|
||||||
from metadrive.type import MetaDriveType
|
from metadrive.type import MetaDriveType
|
||||||
@@ -453,7 +453,7 @@ def preprocess_waymo_scenarios(files, worker_index):
|
|||||||
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
||||||
continue
|
continue
|
||||||
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
|
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
|
||||||
scenario = scenarionet.converter.waymo.waymo_protos.scenario_pb2.Scenario()
|
scenario = scenario_pb2.Scenario()
|
||||||
scenario.ParseFromString(data)
|
scenario.ParseFromString(data)
|
||||||
# a trick for loging file name
|
# a trick for loging file name
|
||||||
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file
|
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file
|
||||||
|
|||||||
Reference in New Issue
Block a user