Files
scenarionet/scenarionet/converter/utils.py
QuanyiLi 8eb2e07d50 rename
2023-05-06 17:02:28 +01:00

183 lines
6.2 KiB
Python

import ast
import copy
import inspect
import math
import os
import pickle
import shutil
from collections import defaultdict
import numpy as np
import tqdm
from metadrive.scenario import ScenarioDescription as SD, ScenarioDescription
def nuplan_to_metadrive_vector(vector, nuplan_center=(0, 0)):
"All vec in nuplan should be centered in (0,0) to avoid numerical explosion"
vector = np.array(vector)
vector -= np.asarray(nuplan_center)
return vector
def compute_angular_velocity(initial_heading, final_heading, dt):
"""
Calculate the angular velocity between two headings given in radians.
Parameters:
initial_heading (float): The initial heading in radians.
final_heading (float): The final heading in radians.
dt (float): The time interval between the two headings in seconds.
Returns:
float: The angular velocity in radians per second.
"""
# Calculate the difference in headings
delta_heading = final_heading - initial_heading
# Adjust the delta_heading to be in the range (-π, π]
delta_heading = (delta_heading + math.pi) % (2 * math.pi) - math.pi
# Compute the angular velocity
angular_vel = delta_heading / dt
return angular_vel
def dict_recursive_remove_array(d):
if isinstance(d, np.ndarray):
return d.tolist()
if isinstance(d, dict):
for k in d.keys():
d[k] = dict_recursive_remove_array(d[k])
return d
def mph_to_kmh(speed_in_mph: float):
speed_in_kmh = speed_in_mph * 1.609344
return speed_in_kmh
def get_agent_summary(state_dict, id, type):
track = state_dict["position"]
valid_track = track[state_dict["valid"], :2]
distance = float(sum(np.linalg.norm(valid_track[i] - valid_track[i + 1]) for i in range(valid_track.shape[0] - 1)))
valid_length = int(sum(state_dict["valid"]))
continuous_valid_length = 0
for v in state_dict["valid"]:
if v:
continuous_valid_length += 1
if continuous_valid_length > 0 and not v:
break
return {
"type": type,
"object_id": str(id),
"track_length": int(len(track)),
"distance": float(distance),
"valid_length": int(valid_length),
"continuous_valid_length": int(continuous_valid_length)
}
def get_number_summary(scenario):
number_summary_dict = {}
number_summary_dict["object"] = len(scenario[SD.TRACKS])
number_summary_dict["dynamic_object_states"] = len(scenario[SD.DYNAMIC_MAP_STATES])
number_summary_dict["map_features"] = len(scenario[SD.MAP_FEATURES])
number_summary_dict["object_types"] = set(v["type"] for v in scenario[SD.TRACKS].values())
object_types_counter = defaultdict(int)
for v in scenario[SD.TRACKS].values():
object_types_counter[v["type"]] += 1
number_summary_dict["object_types_counter"] = dict(object_types_counter)
# Number of different dynamic object states
dynamic_object_states_types = set()
dynamic_object_states_counter = defaultdict(int)
for v in scenario[SD.DYNAMIC_MAP_STATES].values():
for step_state in v["state"]["object_state"]:
if step_state is None:
continue
dynamic_object_states_types.add(step_state)
dynamic_object_states_counter[step_state] += 1
number_summary_dict["dynamic_object_states_types"] = dynamic_object_states_types
number_summary_dict["dynamic_object_states_counter"] = dict(dynamic_object_states_counter)
return number_summary_dict
def contains_explicit_return(f):
return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f))))
def write_to_directory(convert_func,
scenarios,
output_path,
dataset_version,
dataset_name,
force_overwrite=False,
**kwargs):
"""
Convert a batch of scenarios.
"""
if not contains_explicit_return(convert_func):
raise RuntimeError("The convert function should return a metadata dict")
save_path = copy.deepcopy(output_path)
output_path = output_path + "_tmp"
# meta recorder and data summary
if os.path.exists(output_path):
shutil.rmtree(output_path)
os.makedirs(output_path, exist_ok=False)
# make real save dir
delay_remove = None
if os.path.exists(save_path):
if force_overwrite:
delay_remove = save_path
else:
raise ValueError("Directory already exists! Abort")
summary_file = "dataset_summary.pkl"
metadata_recorder = {}
for scenario in tqdm.tqdm(scenarios):
# convert scenario
sd_scenario, scenario_id = convert_func(scenario, **kwargs)
export_file_name = "sd_{}_{}.pkl".format(dataset_name + "_" + dataset_version, scenario_id)
# add agents summary
summary_dict = {}
ego_car_id = sd_scenario[SD.METADATA][SD.SDC_ID]
summary_dict[ego_car_id] = get_agent_summary(
state_dict=sd_scenario.get_sdc_track()["state"], id=ego_car_id, type=sd_scenario.get_sdc_track()["type"]
)
for track_id, track in sd_scenario[SD.TRACKS].items():
summary_dict[track_id] = get_agent_summary(state_dict=track["state"], id=track_id, type=track["type"])
sd_scenario[SD.METADATA]["object_summary"] = summary_dict
# count some objects occurrence
sd_scenario[SD.METADATA]["number_summary"] = get_number_summary(sd_scenario)
metadata_recorder[export_file_name] = copy.deepcopy(sd_scenario[SD.METADATA])
# sanity check
sd_scenario = sd_scenario.to_dict()
ScenarioDescription.sanity_check(sd_scenario, check_self_type=True)
# dump
p = os.path.join(output_path, export_file_name)
with open(p, "wb") as f:
pickle.dump(sd_scenario, f)
# rename and save
if delay_remove is not None:
shutil.rmtree(delay_remove)
os.rename(output_path, save_path)
summary_file = os.path.join(save_path, summary_file)
with open(summary_file, "wb") as file:
pickle.dump(dict_recursive_remove_array(metadata_recorder), file)
print("Summary is saved at: {}".format(summary_file))
assert delay_remove == save_path