get a unified func
This commit is contained in:
@@ -1,8 +1,15 @@
|
||||
import copy
|
||||
import math
|
||||
import ast
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from metadrive.scenario import ScenarioDescription as SD
|
||||
import tqdm
|
||||
from metadrive.scenario import ScenarioDescription as SD, ScenarioDescription
|
||||
|
||||
|
||||
def nuplan_to_metadrive_vector(vector, nuplan_center=(0, 0)):
|
||||
@@ -99,3 +106,49 @@ def get_number_summary(scenario):
|
||||
number_summary_dict["dynamic_object_states_counter"] = dict(dynamic_object_states_counter)
|
||||
|
||||
return number_summary_dict
|
||||
|
||||
|
||||
def contains_explicit_return(f):
|
||||
return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f))))
|
||||
|
||||
|
||||
def write_to_directory(convert_func, scenarios, output_path, version, dataset_name, force_overwrite=False, **kwargs):
|
||||
if not contains_explicit_return(convert_func):
|
||||
raise RuntimeError("The convert function should return a metadata dict")
|
||||
|
||||
save_path = copy.deepcopy(output_path)
|
||||
output_path = output_path + "_tmp"
|
||||
# meta recorder and data summary
|
||||
if os.path.exists(output_path):
|
||||
shutil.rmtree(output_path)
|
||||
os.makedirs(output_path, exist_ok=False)
|
||||
|
||||
# make real save dir
|
||||
delay_remove = None
|
||||
if os.path.exists(save_path):
|
||||
if force_overwrite:
|
||||
delay_remove = save_path
|
||||
else:
|
||||
raise ValueError("Directory already exists! Abort")
|
||||
|
||||
summary_file = "dataset_summary.pkl"
|
||||
|
||||
metadata_recorder = {}
|
||||
for scenario in tqdm.tqdm(scenarios):
|
||||
sd_scenario = convert_func(scenario, **kwargs)
|
||||
sd_scenario = sd_scenario.to_dict()
|
||||
ScenarioDescription.sanity_check(sd_scenario, check_self_type=True)
|
||||
export_file_name = "sd_{}_{}.pkl".format(dataset_name+"_" + version, scenario["token"])
|
||||
p = os.path.join(output_path, export_file_name)
|
||||
with open(p, "wb") as f:
|
||||
pickle.dump(sd_scenario, f)
|
||||
|
||||
# rename and save
|
||||
if delay_remove is not None:
|
||||
shutil.rmtree(delay_remove)
|
||||
os.rename(output_path, save_path)
|
||||
summary_file = os.path.join(save_path, summary_file)
|
||||
with open(summary_file, "wb") as file:
|
||||
pickle.dump(dict_recursive_remove_array(metadata_recorder), file)
|
||||
print("Summary is saved at: {}".format(summary_file))
|
||||
assert delay_remove == save_path
|
||||
|
||||
Reference in New Issue
Block a user