Add support to AV2 (#48)
* add support to av2 --------- Co-authored-by: Alan-LanFeng <fenglan18@outook.com>
This commit is contained in:
77
scenarionet/convert_argoverse2.py
Normal file
77
scenarionet/convert_argoverse2.py
Normal file
@@ -0,0 +1,77 @@
|
||||
desc = "Build database from Argoverse scenarios"
|
||||
|
||||
if __name__ == '__main__':
|
||||
import shutil
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_REPO_PATH
|
||||
from scenarionet.converter.utils import write_to_directory
|
||||
from scenarionet.converter.argoverse2.utils import convert_av2_scenario, get_av2_scenarios, preprocess_av2_scenarios
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
parser = argparse.ArgumentParser(description=desc)
|
||||
parser.add_argument(
|
||||
"--database_path",
|
||||
"-d",
|
||||
default=os.path.join(SCENARIONET_DATASET_PATH, "av2"),
|
||||
help="A directory, the path to place the converted data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name", "-n", default="av2", help="Dataset name, will be used to generate scenario files"
|
||||
)
|
||||
parser.add_argument("--version", "-v", default='v2', help="version")
|
||||
parser.add_argument("--overwrite", action="store_true", help="If the database_path exists, whether to overwrite it")
|
||||
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||
parser.add_argument(
|
||||
"--raw_data_path",
|
||||
default=os.path.join(SCENARIONET_REPO_PATH, "waymo_origin"),
|
||||
help="The directory stores all waymo tfrecord"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start_file_index",
|
||||
default=0,
|
||||
type=int,
|
||||
help="Control how many files to use. We will list all files in the raw data folder "
|
||||
"and select files[start_file_index: start_file_index+num_files]"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_files",
|
||||
default=1000,
|
||||
type=int,
|
||||
help="Control how many files to use. We will list all files in the raw data folder "
|
||||
"and select files[start_file_index: start_file_index+num_files]"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
overwrite = args.overwrite
|
||||
dataset_name = args.dataset_name
|
||||
output_path = args.database_path
|
||||
version = args.version
|
||||
|
||||
save_path = output_path
|
||||
if os.path.exists(output_path):
|
||||
if not overwrite:
|
||||
raise ValueError(
|
||||
"Directory {} already exists! Abort. "
|
||||
"\n Try setting overwrite=True or adding --overwrite".format(output_path)
|
||||
)
|
||||
else:
|
||||
shutil.rmtree(output_path)
|
||||
|
||||
av2_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path)
|
||||
|
||||
scenarios = get_av2_scenarios(av2_data_directory, args.start_file_index, args.num_files)
|
||||
|
||||
write_to_directory(
|
||||
convert_func=convert_av2_scenario,
|
||||
scenarios=scenarios,
|
||||
output_path=output_path,
|
||||
dataset_version=version,
|
||||
dataset_name=dataset_name,
|
||||
overwrite=overwrite,
|
||||
num_workers=args.num_workers,
|
||||
preprocess=preprocess_av2_scenarios
|
||||
)
|
||||
68
scenarionet/converter/argoverse2/type.py
Normal file
68
scenarionet/converter/argoverse2/type.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import logging
|
||||
|
||||
from metadrive.type import MetaDriveType
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from av2.datasets.motion_forecasting.data_schema import ObjectType
|
||||
from av2.map.lane_segment import LaneType, LaneMarkType
|
||||
except ImportError as e:
|
||||
logger.warning("Can not import av2-devkit: {}".format(e))
|
||||
|
||||
|
||||
def get_traffic_obj_type(av2_obj_type):
|
||||
if av2_obj_type == ObjectType.VEHICLE or av2_obj_type == ObjectType.BUS:
|
||||
return MetaDriveType.VEHICLE
|
||||
# elif av2_obj_type == ObjectType.MOTORCYCLIST:
|
||||
# return MetaDriveType.MOTORCYCLIST
|
||||
elif av2_obj_type == ObjectType.PEDESTRIAN:
|
||||
return MetaDriveType.PEDESTRIAN
|
||||
elif av2_obj_type == ObjectType.CYCLIST:
|
||||
return MetaDriveType.CYCLIST
|
||||
# elif av2_obj_type == ObjectType.BUS:
|
||||
# return MetaDriveType.BUS
|
||||
# elif av2_obj_type == ObjectType.STATIC:
|
||||
# return MetaDriveType.STATIC
|
||||
# elif av2_obj_type == ObjectType.CONSTRUCTION:
|
||||
# return MetaDriveType.CONSTRUCTION
|
||||
# elif av2_obj_type == ObjectType.BACKGROUND:
|
||||
# return MetaDriveType.BACKGROUND
|
||||
# elif av2_obj_type == ObjectType.RIDERLESS_BICYCLE:
|
||||
# return MetaDriveType.RIDERLESS_BICYCLE
|
||||
# elif av2_obj_type == ObjectType.UNKNOWN:
|
||||
# return MetaDriveType.UNKNOWN
|
||||
else:
|
||||
return MetaDriveType.OTHER
|
||||
|
||||
|
||||
def get_lane_type(av2_lane_type):
|
||||
if av2_lane_type == LaneType.VEHICLE or av2_lane_type == LaneType.BUS:
|
||||
return MetaDriveType.LANE_SURFACE_STREET
|
||||
elif av2_lane_type == LaneType.BIKE:
|
||||
return MetaDriveType.LANE_BIKE_LANE
|
||||
else:
|
||||
raise ValueError("Unknown nuplan lane type: {}".format(av2_lane_type))
|
||||
|
||||
|
||||
def get_lane_mark_type(av2_mark_type):
|
||||
conversion_dict = {
|
||||
LaneMarkType.DOUBLE_SOLID_YELLOW: "ROAD_LINE_SOLID_DOUBLE_YELLOW",
|
||||
LaneMarkType.DOUBLE_SOLID_WHITE: "ROAD_LINE_SOLID_DOUBLE_WHITE",
|
||||
LaneMarkType.SOLID_YELLOW: "ROAD_LINE_SOLID_SINGLE_YELLOW",
|
||||
LaneMarkType.SOLID_WHITE: "ROAD_LINE_SOLID_SINGLE_WHITE",
|
||||
LaneMarkType.DASHED_WHITE: "ROAD_LINE_BROKEN_SINGLE_WHITE",
|
||||
LaneMarkType.DASHED_YELLOW: "ROAD_LINE_BROKEN_SINGLE_YELLOW",
|
||||
LaneMarkType.DASH_SOLID_YELLOW: "ROAD_LINE_SOLID_DOUBLE_YELLOW",
|
||||
LaneMarkType.DASH_SOLID_WHITE: "ROAD_LINE_SOLID_DOUBLE_WHITE",
|
||||
LaneMarkType.DOUBLE_DASH_YELLOW: "ROAD_LINE_BROKEN_SINGLE_YELLOW",
|
||||
LaneMarkType.DOUBLE_DASH_WHITE: "ROAD_LINE_BROKEN_SINGLE_WHITE",
|
||||
LaneMarkType.SOLID_DASH_WHITE: "ROAD_LINE_BROKEN_SINGLE_WHITE",
|
||||
LaneMarkType.SOLID_DASH_YELLOW: "ROAD_LINE_BROKEN_SINGLE_YELLOW",
|
||||
LaneMarkType.SOLID_BLUE: "UNKNOWN_LINE",
|
||||
LaneMarkType.NONE: "UNKNOWN_LINE",
|
||||
LaneMarkType.UNKNOWN: "UNKNOWN_LINE"
|
||||
}
|
||||
|
||||
return conversion_dict.get(av2_mark_type, "UNKNOWN_LINE")
|
||||
257
scenarionet/converter/argoverse2/utils.py
Normal file
257
scenarionet/converter/argoverse2/utils.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import logging
|
||||
|
||||
import tqdm
|
||||
|
||||
from scenarionet.converter.utils import mph_to_kmh
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from metadrive.scenario import ScenarioDescription as SD
|
||||
from metadrive.type import MetaDriveType
|
||||
|
||||
from scenarionet.converter.argoverse2.type import get_traffic_obj_type, get_lane_type, get_lane_mark_type
|
||||
from av2.datasets.motion_forecasting import scenario_serialization
|
||||
from av2.map.map_api import ArgoverseStaticMap
|
||||
from typing import Final
|
||||
|
||||
_ESTIMATED_VEHICLE_LENGTH_M: Final[float] = 4.0
|
||||
_ESTIMATED_VEHICLE_WIDTH_M: Final[float] = 2.0
|
||||
_ESTIMATED_CYCLIST_LENGTH_M: Final[float] = 2.0
|
||||
_ESTIMATED_CYCLIST_WIDTH_M: Final[float] = 0.7
|
||||
_ESTIMATED_PEDESTRIAN_LENGTH_M: Final[float] = 0.5
|
||||
_ESTIMATED_PEDESTRIAN_WIDTH_M: Final[float] = 0.5
|
||||
_ESTIMATED_BUS_LENGTH_M: Final[float] = 12.0
|
||||
_ESTIMATED_BUS_WIDTH_M: Final[float] = 2.5
|
||||
|
||||
_HIGHWAY_SPEED_LIMIT_MPH: Final[float] = 85.0
|
||||
|
||||
|
||||
def extract_tracks(tracks, sdc_idx, track_length):
|
||||
ret = dict()
|
||||
|
||||
def _object_state_template(object_id):
|
||||
return dict(type=None, state=dict(# Never add extra dim if the value is scalar.
|
||||
position=np.zeros([track_length, 3], dtype=np.float32), length=np.zeros([track_length], dtype=np.float32),
|
||||
width=np.zeros([track_length], dtype=np.float32), height=np.zeros([track_length], dtype=np.float32),
|
||||
heading=np.zeros([track_length], dtype=np.float32), velocity=np.zeros([track_length, 2], dtype=np.float32),
|
||||
valid=np.zeros([track_length], dtype=bool), ),
|
||||
metadata=dict(track_length=track_length, type=None, object_id=object_id, dataset="av2"))
|
||||
|
||||
track_category = []
|
||||
|
||||
for obj in tracks:
|
||||
object_id = obj.track_id
|
||||
track_category.append(obj.category.value)
|
||||
obj_state = _object_state_template(object_id)
|
||||
# Transform it to Waymo type string
|
||||
obj_state["type"] = get_traffic_obj_type(obj.object_type)
|
||||
if obj_state["type"] == MetaDriveType.VEHICLE:
|
||||
length = _ESTIMATED_VEHICLE_LENGTH_M
|
||||
width = _ESTIMATED_VEHICLE_WIDTH_M
|
||||
elif obj_state["type"] == MetaDriveType.PEDESTRIAN:
|
||||
length = _ESTIMATED_PEDESTRIAN_LENGTH_M
|
||||
width = _ESTIMATED_PEDESTRIAN_WIDTH_M
|
||||
elif obj_state["type"] == MetaDriveType.CYCLIST:
|
||||
length = _ESTIMATED_CYCLIST_LENGTH_M
|
||||
width = _ESTIMATED_CYCLIST_WIDTH_M
|
||||
# elif obj_state["type"] == MetaDriveType.BUS:
|
||||
# length = _ESTIMATED_BUS_LENGTH_M
|
||||
# width = _ESTIMATED_BUS_WIDTH_M
|
||||
else:
|
||||
length = 1
|
||||
width = 1
|
||||
|
||||
for _, state in enumerate(obj.object_states):
|
||||
step_count = state.timestep
|
||||
obj_state["state"]["position"][step_count][0] = state.position[0]
|
||||
obj_state["state"]["position"][step_count][1] = state.position[1]
|
||||
obj_state["state"]["position"][step_count][2] = 0
|
||||
|
||||
# l = [state.length for state in obj.states]
|
||||
# w = [state.width for state in obj.states]
|
||||
# h = [state.height for state in obj.states]
|
||||
# obj_state["state"]["size"] = np.stack([l, w, h], 1).astype("float32")
|
||||
obj_state["state"]["length"][step_count] = length
|
||||
obj_state["state"]["width"][step_count] = width
|
||||
obj_state["state"]["height"][step_count] = 1
|
||||
|
||||
# heading = [state.heading for state in obj.states]
|
||||
obj_state["state"]["heading"][step_count] = state.heading
|
||||
|
||||
obj_state["state"]["velocity"][step_count][0] = state.velocity[0]
|
||||
obj_state["state"]["velocity"][step_count][1] = state.velocity[1]
|
||||
|
||||
obj_state["state"]["valid"][step_count] = True
|
||||
|
||||
obj_state["metadata"]["type"] = obj_state["type"]
|
||||
|
||||
ret[object_id] = obj_state
|
||||
|
||||
return ret, track_category
|
||||
|
||||
|
||||
def extract_lane_mark(lane_mark):
|
||||
line = dict()
|
||||
line["type"] = get_lane_mark_type(lane_mark.mark_type)
|
||||
line["polyline"] = lane_mark.polyline.astype(np.float32)
|
||||
return line
|
||||
|
||||
|
||||
def extract_map_features(map_features):
|
||||
# with open(
|
||||
# "/Users/fenglan/Desktop/vita-group/code/mdsn/scenarionet/data_sample/waymo_converted_0/sd_waymo_v1.2_7e8422433c66cc13.pkl",
|
||||
# 'rb') as f:
|
||||
# waymo_sample = pickle.load(f)
|
||||
ret = {}
|
||||
vector_lane_segments = map_features.get_scenario_lane_segments()
|
||||
vector_drivable_areas = map_features.get_scenario_vector_drivable_areas()
|
||||
ped_crossings = map_features.get_scenario_ped_crossings()
|
||||
|
||||
ids = map_features.get_scenario_lane_segment_ids()
|
||||
|
||||
max_id = max(ids)
|
||||
for seg in vector_lane_segments:
|
||||
center = {}
|
||||
lane_id = str(seg.id)
|
||||
|
||||
left_id = str(seg.id + max_id + 1)
|
||||
right_id = str(seg.id + max_id + 2)
|
||||
left_marking = extract_lane_mark(seg.left_lane_marking)
|
||||
right_marking = extract_lane_mark(seg.right_lane_marking)
|
||||
|
||||
ret[left_id] = left_marking
|
||||
ret[right_id] = right_marking
|
||||
|
||||
center["speed_limit_mph"] = _HIGHWAY_SPEED_LIMIT_MPH
|
||||
|
||||
center["speed_limit_kmh"] = mph_to_kmh(_HIGHWAY_SPEED_LIMIT_MPH)
|
||||
|
||||
center["type"] = get_lane_type(seg.lane_type)
|
||||
|
||||
polyline = map_features.get_lane_segment_centerline(seg.id)
|
||||
center["polyline"] = polyline.astype(np.float32)
|
||||
|
||||
center["interpolating"] = True
|
||||
|
||||
center["entry_lanes"] = [str(id) for id in seg.predecessors]
|
||||
|
||||
center["exit_lanes"] = [str(id) for id in seg.successors]
|
||||
|
||||
center["left_boundaries"] = []
|
||||
|
||||
center["right_boundaries"] = []
|
||||
|
||||
center["left_neighbor"] = []
|
||||
|
||||
center["right_neighbor"] = []
|
||||
center['width'] = np.zeros([len(polyline), 2], dtype=np.float32)
|
||||
|
||||
ret[lane_id] = center
|
||||
|
||||
# for edge in vector_drivable_areas:
|
||||
# bound = dict()
|
||||
# bound["type"] = MetaDriveType.BOUNDARY_LINE
|
||||
# bound["polyline"] = edge.xyz.astype(np.float32)
|
||||
# ret[str(edge.id)] = bound
|
||||
|
||||
for cross in ped_crossings:
|
||||
bound = dict()
|
||||
bound["type"] = MetaDriveType.CROSSWALK
|
||||
bound["polygon"] = cross.polygon.astype(np.float32)
|
||||
ret[str(cross.id)] = bound
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def get_av2_scenarios(av2_data_directory, start_index, num):
|
||||
# parse raw data from input path to output path,
|
||||
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
|
||||
logger.info("\nReading raw data")
|
||||
|
||||
all_scenario_files = sorted(Path(av2_data_directory).rglob("*.parquet"))
|
||||
|
||||
return all_scenario_files
|
||||
|
||||
|
||||
def convert_av2_scenario(scenario, version):
|
||||
md_scenario = SD()
|
||||
|
||||
md_scenario[SD.ID] = scenario.scenario_id
|
||||
md_scenario[SD.VERSION] = version
|
||||
|
||||
# Please note that SDC track index is not identical to sdc_id.
|
||||
# sdc_id is a unique indicator to a track, while sdc_track_index is only the index of the sdc track
|
||||
# in the tracks datastructure.
|
||||
|
||||
track_length = scenario.timestamps_ns.shape[0]
|
||||
|
||||
tracks, category = extract_tracks(scenario.tracks, scenario.focal_track_id, track_length)
|
||||
|
||||
md_scenario[SD.LENGTH] = track_length
|
||||
|
||||
md_scenario[SD.TRACKS] = tracks
|
||||
|
||||
md_scenario[SD.DYNAMIC_MAP_STATES] = {}
|
||||
|
||||
map_features = extract_map_features(scenario.static_map)
|
||||
md_scenario[SD.MAP_FEATURES] = map_features
|
||||
|
||||
# compute_width(md_scenario[SD.MAP_FEATURES])
|
||||
|
||||
md_scenario[SD.METADATA] = {}
|
||||
md_scenario[SD.METADATA][SD.ID] = md_scenario[SD.ID]
|
||||
md_scenario[SD.METADATA][SD.COORDINATE] = MetaDriveType.COORDINATE_WAYMO
|
||||
md_scenario[SD.METADATA][SD.TIMESTEP] = np.array(list(range(track_length))) / 10
|
||||
md_scenario[SD.METADATA][SD.METADRIVE_PROCESSED] = False
|
||||
md_scenario[SD.METADATA][SD.SDC_ID] = 'AV'
|
||||
md_scenario[SD.METADATA]["dataset"] = "av2"
|
||||
md_scenario[SD.METADATA]["scenario_id"] = scenario.scenario_id
|
||||
md_scenario[SD.METADATA]["source_file"] = scenario.scenario_id
|
||||
md_scenario[SD.METADATA]["track_length"] = track_length
|
||||
|
||||
# === Waymo specific data. Storing them here ===
|
||||
md_scenario[SD.METADATA]["current_time_index"] = 49
|
||||
md_scenario[SD.METADATA]["sdc_track_index"] = scenario.focal_track_id
|
||||
|
||||
# obj id
|
||||
obj_keys = list(tracks.keys())
|
||||
md_scenario[SD.METADATA]["objects_of_interest"] = [obj_keys[idx] for idx, cat in enumerate(category) if cat == 2]
|
||||
|
||||
track_index = [obj_keys.index(scenario.focal_track_id)]
|
||||
track_id = [scenario.focal_track_id]
|
||||
track_difficulty = [0]
|
||||
track_obj_type = [tracks[id]["type"] for id in track_id]
|
||||
md_scenario[SD.METADATA]["tracks_to_predict"] = {
|
||||
id: {
|
||||
"track_index": track_index[count],
|
||||
"track_id": id,
|
||||
"difficulty": track_difficulty[count],
|
||||
"object_type": track_obj_type[count]
|
||||
}
|
||||
for count, id in enumerate(track_id)
|
||||
}
|
||||
# clean memory
|
||||
del scenario
|
||||
return md_scenario
|
||||
|
||||
|
||||
def preprocess_av2_scenarios(files, worker_index):
|
||||
"""
|
||||
Convert the waymo files into scenario_pb2. This happens in each worker.
|
||||
:param files: a list of file path
|
||||
:param worker_index, the index for the worker
|
||||
:return: a list of scenario_pb2
|
||||
"""
|
||||
|
||||
for scenario_path in tqdm(files, desc="Process av2 scenarios for worker {}".format(worker_index)):
|
||||
scenario_id = scenario_path.stem.split("_")[-1]
|
||||
static_map_path = (scenario_path.parents[0] / f"log_map_archive_{scenario_id}.json")
|
||||
scenario = scenario_serialization.load_argoverse_scenario_parquet(scenario_path)
|
||||
static_map = ArgoverseStaticMap.from_json(static_map_path)
|
||||
scenario.static_map = static_map
|
||||
yield scenario
|
||||
|
||||
# logger.info("Worker {}: Process {} waymo scenarios".format(worker_index, len(scenarios))) # return scenarios
|
||||
@@ -19,7 +19,7 @@ try:
|
||||
except ImportError as e:
|
||||
logger.info(e)
|
||||
|
||||
import scenarionet.converter.waymo.waymo_protos.scenario_pb2
|
||||
import scenarionet.converter.waymo.waymo_protos.scenario_pb2 as scenario_pb2
|
||||
|
||||
from metadrive.scenario import ScenarioDescription as SD
|
||||
from metadrive.type import MetaDriveType
|
||||
@@ -453,7 +453,7 @@ def preprocess_waymo_scenarios(files, worker_index):
|
||||
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
||||
continue
|
||||
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
|
||||
scenario = scenarionet.converter.waymo.waymo_protos.scenario_pb2.Scenario()
|
||||
scenario = scenario_pb2.Scenario()
|
||||
scenario.ParseFromString(data)
|
||||
# a trick for loging file name
|
||||
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file
|
||||
|
||||
Reference in New Issue
Block a user