Add support to AV2 (#48)

* add support to av2
---------

Co-authored-by: Alan-LanFeng <fenglan18@outook.com>
This commit is contained in:
Alan
2023-12-09 15:23:46 +01:00
committed by GitHub
parent d4709347f4
commit f2b21d709f
4 changed files with 404 additions and 2 deletions

View File

@@ -0,0 +1,77 @@
desc = "Build database from Argoverse scenarios"
if __name__ == '__main__':
import shutil
import argparse
import logging
import os
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_REPO_PATH
from scenarionet.converter.utils import write_to_directory
from scenarionet.converter.argoverse2.utils import convert_av2_scenario, get_av2_scenarios, preprocess_av2_scenarios
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description=desc)
parser.add_argument(
"--database_path",
"-d",
default=os.path.join(SCENARIONET_DATASET_PATH, "av2"),
help="A directory, the path to place the converted data"
)
parser.add_argument(
"--dataset_name", "-n", default="av2", help="Dataset name, will be used to generate scenario files"
)
parser.add_argument("--version", "-v", default='v2', help="version")
parser.add_argument("--overwrite", action="store_true", help="If the database_path exists, whether to overwrite it")
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
parser.add_argument(
"--raw_data_path",
default=os.path.join(SCENARIONET_REPO_PATH, "waymo_origin"),
help="The directory stores all waymo tfrecord"
)
parser.add_argument(
"--start_file_index",
default=0,
type=int,
help="Control how many files to use. We will list all files in the raw data folder "
"and select files[start_file_index: start_file_index+num_files]"
)
parser.add_argument(
"--num_files",
default=1000,
type=int,
help="Control how many files to use. We will list all files in the raw data folder "
"and select files[start_file_index: start_file_index+num_files]"
)
args = parser.parse_args()
overwrite = args.overwrite
dataset_name = args.dataset_name
output_path = args.database_path
version = args.version
save_path = output_path
if os.path.exists(output_path):
if not overwrite:
raise ValueError(
"Directory {} already exists! Abort. "
"\n Try setting overwrite=True or adding --overwrite".format(output_path)
)
else:
shutil.rmtree(output_path)
av2_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path)
scenarios = get_av2_scenarios(av2_data_directory, args.start_file_index, args.num_files)
write_to_directory(
convert_func=convert_av2_scenario,
scenarios=scenarios,
output_path=output_path,
dataset_version=version,
dataset_name=dataset_name,
overwrite=overwrite,
num_workers=args.num_workers,
preprocess=preprocess_av2_scenarios
)

View File

@@ -0,0 +1,68 @@
import logging
from metadrive.type import MetaDriveType
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
from av2.datasets.motion_forecasting.data_schema import ObjectType
from av2.map.lane_segment import LaneType, LaneMarkType
except ImportError as e:
logger.warning("Can not import av2-devkit: {}".format(e))
def get_traffic_obj_type(av2_obj_type):
if av2_obj_type == ObjectType.VEHICLE or av2_obj_type == ObjectType.BUS:
return MetaDriveType.VEHICLE
# elif av2_obj_type == ObjectType.MOTORCYCLIST:
# return MetaDriveType.MOTORCYCLIST
elif av2_obj_type == ObjectType.PEDESTRIAN:
return MetaDriveType.PEDESTRIAN
elif av2_obj_type == ObjectType.CYCLIST:
return MetaDriveType.CYCLIST
# elif av2_obj_type == ObjectType.BUS:
# return MetaDriveType.BUS
# elif av2_obj_type == ObjectType.STATIC:
# return MetaDriveType.STATIC
# elif av2_obj_type == ObjectType.CONSTRUCTION:
# return MetaDriveType.CONSTRUCTION
# elif av2_obj_type == ObjectType.BACKGROUND:
# return MetaDriveType.BACKGROUND
# elif av2_obj_type == ObjectType.RIDERLESS_BICYCLE:
# return MetaDriveType.RIDERLESS_BICYCLE
# elif av2_obj_type == ObjectType.UNKNOWN:
# return MetaDriveType.UNKNOWN
else:
return MetaDriveType.OTHER
def get_lane_type(av2_lane_type):
if av2_lane_type == LaneType.VEHICLE or av2_lane_type == LaneType.BUS:
return MetaDriveType.LANE_SURFACE_STREET
elif av2_lane_type == LaneType.BIKE:
return MetaDriveType.LANE_BIKE_LANE
else:
raise ValueError("Unknown nuplan lane type: {}".format(av2_lane_type))
def get_lane_mark_type(av2_mark_type):
conversion_dict = {
LaneMarkType.DOUBLE_SOLID_YELLOW: "ROAD_LINE_SOLID_DOUBLE_YELLOW",
LaneMarkType.DOUBLE_SOLID_WHITE: "ROAD_LINE_SOLID_DOUBLE_WHITE",
LaneMarkType.SOLID_YELLOW: "ROAD_LINE_SOLID_SINGLE_YELLOW",
LaneMarkType.SOLID_WHITE: "ROAD_LINE_SOLID_SINGLE_WHITE",
LaneMarkType.DASHED_WHITE: "ROAD_LINE_BROKEN_SINGLE_WHITE",
LaneMarkType.DASHED_YELLOW: "ROAD_LINE_BROKEN_SINGLE_YELLOW",
LaneMarkType.DASH_SOLID_YELLOW: "ROAD_LINE_SOLID_DOUBLE_YELLOW",
LaneMarkType.DASH_SOLID_WHITE: "ROAD_LINE_SOLID_DOUBLE_WHITE",
LaneMarkType.DOUBLE_DASH_YELLOW: "ROAD_LINE_BROKEN_SINGLE_YELLOW",
LaneMarkType.DOUBLE_DASH_WHITE: "ROAD_LINE_BROKEN_SINGLE_WHITE",
LaneMarkType.SOLID_DASH_WHITE: "ROAD_LINE_BROKEN_SINGLE_WHITE",
LaneMarkType.SOLID_DASH_YELLOW: "ROAD_LINE_BROKEN_SINGLE_YELLOW",
LaneMarkType.SOLID_BLUE: "UNKNOWN_LINE",
LaneMarkType.NONE: "UNKNOWN_LINE",
LaneMarkType.UNKNOWN: "UNKNOWN_LINE"
}
return conversion_dict.get(av2_mark_type, "UNKNOWN_LINE")

View File

@@ -0,0 +1,257 @@
import logging
import tqdm
from scenarionet.converter.utils import mph_to_kmh
logger = logging.getLogger(__name__)
import numpy as np
from pathlib import Path
from tqdm import tqdm
from metadrive.scenario import ScenarioDescription as SD
from metadrive.type import MetaDriveType
from scenarionet.converter.argoverse2.type import get_traffic_obj_type, get_lane_type, get_lane_mark_type
from av2.datasets.motion_forecasting import scenario_serialization
from av2.map.map_api import ArgoverseStaticMap
from typing import Final
_ESTIMATED_VEHICLE_LENGTH_M: Final[float] = 4.0
_ESTIMATED_VEHICLE_WIDTH_M: Final[float] = 2.0
_ESTIMATED_CYCLIST_LENGTH_M: Final[float] = 2.0
_ESTIMATED_CYCLIST_WIDTH_M: Final[float] = 0.7
_ESTIMATED_PEDESTRIAN_LENGTH_M: Final[float] = 0.5
_ESTIMATED_PEDESTRIAN_WIDTH_M: Final[float] = 0.5
_ESTIMATED_BUS_LENGTH_M: Final[float] = 12.0
_ESTIMATED_BUS_WIDTH_M: Final[float] = 2.5
_HIGHWAY_SPEED_LIMIT_MPH: Final[float] = 85.0
def extract_tracks(tracks, sdc_idx, track_length):
ret = dict()
def _object_state_template(object_id):
return dict(type=None, state=dict(# Never add extra dim if the value is scalar.
position=np.zeros([track_length, 3], dtype=np.float32), length=np.zeros([track_length], dtype=np.float32),
width=np.zeros([track_length], dtype=np.float32), height=np.zeros([track_length], dtype=np.float32),
heading=np.zeros([track_length], dtype=np.float32), velocity=np.zeros([track_length, 2], dtype=np.float32),
valid=np.zeros([track_length], dtype=bool), ),
metadata=dict(track_length=track_length, type=None, object_id=object_id, dataset="av2"))
track_category = []
for obj in tracks:
object_id = obj.track_id
track_category.append(obj.category.value)
obj_state = _object_state_template(object_id)
# Transform it to Waymo type string
obj_state["type"] = get_traffic_obj_type(obj.object_type)
if obj_state["type"] == MetaDriveType.VEHICLE:
length = _ESTIMATED_VEHICLE_LENGTH_M
width = _ESTIMATED_VEHICLE_WIDTH_M
elif obj_state["type"] == MetaDriveType.PEDESTRIAN:
length = _ESTIMATED_PEDESTRIAN_LENGTH_M
width = _ESTIMATED_PEDESTRIAN_WIDTH_M
elif obj_state["type"] == MetaDriveType.CYCLIST:
length = _ESTIMATED_CYCLIST_LENGTH_M
width = _ESTIMATED_CYCLIST_WIDTH_M
# elif obj_state["type"] == MetaDriveType.BUS:
# length = _ESTIMATED_BUS_LENGTH_M
# width = _ESTIMATED_BUS_WIDTH_M
else:
length = 1
width = 1
for _, state in enumerate(obj.object_states):
step_count = state.timestep
obj_state["state"]["position"][step_count][0] = state.position[0]
obj_state["state"]["position"][step_count][1] = state.position[1]
obj_state["state"]["position"][step_count][2] = 0
# l = [state.length for state in obj.states]
# w = [state.width for state in obj.states]
# h = [state.height for state in obj.states]
# obj_state["state"]["size"] = np.stack([l, w, h], 1).astype("float32")
obj_state["state"]["length"][step_count] = length
obj_state["state"]["width"][step_count] = width
obj_state["state"]["height"][step_count] = 1
# heading = [state.heading for state in obj.states]
obj_state["state"]["heading"][step_count] = state.heading
obj_state["state"]["velocity"][step_count][0] = state.velocity[0]
obj_state["state"]["velocity"][step_count][1] = state.velocity[1]
obj_state["state"]["valid"][step_count] = True
obj_state["metadata"]["type"] = obj_state["type"]
ret[object_id] = obj_state
return ret, track_category
def extract_lane_mark(lane_mark):
line = dict()
line["type"] = get_lane_mark_type(lane_mark.mark_type)
line["polyline"] = lane_mark.polyline.astype(np.float32)
return line
def extract_map_features(map_features):
# with open(
# "/Users/fenglan/Desktop/vita-group/code/mdsn/scenarionet/data_sample/waymo_converted_0/sd_waymo_v1.2_7e8422433c66cc13.pkl",
# 'rb') as f:
# waymo_sample = pickle.load(f)
ret = {}
vector_lane_segments = map_features.get_scenario_lane_segments()
vector_drivable_areas = map_features.get_scenario_vector_drivable_areas()
ped_crossings = map_features.get_scenario_ped_crossings()
ids = map_features.get_scenario_lane_segment_ids()
max_id = max(ids)
for seg in vector_lane_segments:
center = {}
lane_id = str(seg.id)
left_id = str(seg.id + max_id + 1)
right_id = str(seg.id + max_id + 2)
left_marking = extract_lane_mark(seg.left_lane_marking)
right_marking = extract_lane_mark(seg.right_lane_marking)
ret[left_id] = left_marking
ret[right_id] = right_marking
center["speed_limit_mph"] = _HIGHWAY_SPEED_LIMIT_MPH
center["speed_limit_kmh"] = mph_to_kmh(_HIGHWAY_SPEED_LIMIT_MPH)
center["type"] = get_lane_type(seg.lane_type)
polyline = map_features.get_lane_segment_centerline(seg.id)
center["polyline"] = polyline.astype(np.float32)
center["interpolating"] = True
center["entry_lanes"] = [str(id) for id in seg.predecessors]
center["exit_lanes"] = [str(id) for id in seg.successors]
center["left_boundaries"] = []
center["right_boundaries"] = []
center["left_neighbor"] = []
center["right_neighbor"] = []
center['width'] = np.zeros([len(polyline), 2], dtype=np.float32)
ret[lane_id] = center
# for edge in vector_drivable_areas:
# bound = dict()
# bound["type"] = MetaDriveType.BOUNDARY_LINE
# bound["polyline"] = edge.xyz.astype(np.float32)
# ret[str(edge.id)] = bound
for cross in ped_crossings:
bound = dict()
bound["type"] = MetaDriveType.CROSSWALK
bound["polygon"] = cross.polygon.astype(np.float32)
ret[str(cross.id)] = bound
return ret
def get_av2_scenarios(av2_data_directory, start_index, num):
# parse raw data from input path to output path,
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
logger.info("\nReading raw data")
all_scenario_files = sorted(Path(av2_data_directory).rglob("*.parquet"))
return all_scenario_files
def convert_av2_scenario(scenario, version):
md_scenario = SD()
md_scenario[SD.ID] = scenario.scenario_id
md_scenario[SD.VERSION] = version
# Please note that SDC track index is not identical to sdc_id.
# sdc_id is a unique indicator to a track, while sdc_track_index is only the index of the sdc track
# in the tracks datastructure.
track_length = scenario.timestamps_ns.shape[0]
tracks, category = extract_tracks(scenario.tracks, scenario.focal_track_id, track_length)
md_scenario[SD.LENGTH] = track_length
md_scenario[SD.TRACKS] = tracks
md_scenario[SD.DYNAMIC_MAP_STATES] = {}
map_features = extract_map_features(scenario.static_map)
md_scenario[SD.MAP_FEATURES] = map_features
# compute_width(md_scenario[SD.MAP_FEATURES])
md_scenario[SD.METADATA] = {}
md_scenario[SD.METADATA][SD.ID] = md_scenario[SD.ID]
md_scenario[SD.METADATA][SD.COORDINATE] = MetaDriveType.COORDINATE_WAYMO
md_scenario[SD.METADATA][SD.TIMESTEP] = np.array(list(range(track_length))) / 10
md_scenario[SD.METADATA][SD.METADRIVE_PROCESSED] = False
md_scenario[SD.METADATA][SD.SDC_ID] = 'AV'
md_scenario[SD.METADATA]["dataset"] = "av2"
md_scenario[SD.METADATA]["scenario_id"] = scenario.scenario_id
md_scenario[SD.METADATA]["source_file"] = scenario.scenario_id
md_scenario[SD.METADATA]["track_length"] = track_length
# === Waymo specific data. Storing them here ===
md_scenario[SD.METADATA]["current_time_index"] = 49
md_scenario[SD.METADATA]["sdc_track_index"] = scenario.focal_track_id
# obj id
obj_keys = list(tracks.keys())
md_scenario[SD.METADATA]["objects_of_interest"] = [obj_keys[idx] for idx, cat in enumerate(category) if cat == 2]
track_index = [obj_keys.index(scenario.focal_track_id)]
track_id = [scenario.focal_track_id]
track_difficulty = [0]
track_obj_type = [tracks[id]["type"] for id in track_id]
md_scenario[SD.METADATA]["tracks_to_predict"] = {
id: {
"track_index": track_index[count],
"track_id": id,
"difficulty": track_difficulty[count],
"object_type": track_obj_type[count]
}
for count, id in enumerate(track_id)
}
# clean memory
del scenario
return md_scenario
def preprocess_av2_scenarios(files, worker_index):
"""
Convert the waymo files into scenario_pb2. This happens in each worker.
:param files: a list of file path
:param worker_index, the index for the worker
:return: a list of scenario_pb2
"""
for scenario_path in tqdm(files, desc="Process av2 scenarios for worker {}".format(worker_index)):
scenario_id = scenario_path.stem.split("_")[-1]
static_map_path = (scenario_path.parents[0] / f"log_map_archive_{scenario_id}.json")
scenario = scenario_serialization.load_argoverse_scenario_parquet(scenario_path)
static_map = ArgoverseStaticMap.from_json(static_map_path)
scenario.static_map = static_map
yield scenario
# logger.info("Worker {}: Process {} waymo scenarios".format(worker_index, len(scenarios))) # return scenarios

View File

@@ -19,7 +19,7 @@ try:
except ImportError as e:
logger.info(e)
import scenarionet.converter.waymo.waymo_protos.scenario_pb2
import scenarionet.converter.waymo.waymo_protos.scenario_pb2 as scenario_pb2
from metadrive.scenario import ScenarioDescription as SD
from metadrive.type import MetaDriveType
@@ -453,7 +453,7 @@ def preprocess_waymo_scenarios(files, worker_index):
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
continue
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
scenario = scenarionet.converter.waymo.waymo_protos.scenario_pb2.Scenario()
scenario = scenario_pb2.Scenario()
scenario.ParseFromString(data)
# a trick for loging file name
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file