get a unified func

This commit is contained in:
QuanyiLi
2023-05-06 16:45:06 +01:00
parent 5d536819c7
commit 2b9be248c4
5 changed files with 73 additions and 64 deletions

View File

@@ -1,8 +1,15 @@
import copy
import math
import ast
import inspect
import os
import pickle
import shutil
from collections import defaultdict
import numpy as np
from metadrive.scenario import ScenarioDescription as SD
import tqdm
from metadrive.scenario import ScenarioDescription as SD, ScenarioDescription
def nuplan_to_metadrive_vector(vector, nuplan_center=(0, 0)):
@@ -99,3 +106,49 @@ def get_number_summary(scenario):
number_summary_dict["dynamic_object_states_counter"] = dict(dynamic_object_states_counter)
return number_summary_dict
def contains_explicit_return(f):
return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f))))
def write_to_directory(convert_func, scenarios, output_path, version, dataset_name, force_overwrite=False, **kwargs):
if not contains_explicit_return(convert_func):
raise RuntimeError("The convert function should return a metadata dict")
save_path = copy.deepcopy(output_path)
output_path = output_path + "_tmp"
# meta recorder and data summary
if os.path.exists(output_path):
shutil.rmtree(output_path)
os.makedirs(output_path, exist_ok=False)
# make real save dir
delay_remove = None
if os.path.exists(save_path):
if force_overwrite:
delay_remove = save_path
else:
raise ValueError("Directory already exists! Abort")
summary_file = "dataset_summary.pkl"
metadata_recorder = {}
for scenario in tqdm.tqdm(scenarios):
sd_scenario = convert_func(scenario, **kwargs)
sd_scenario = sd_scenario.to_dict()
ScenarioDescription.sanity_check(sd_scenario, check_self_type=True)
export_file_name = "sd_{}_{}.pkl".format(dataset_name+"_" + version, scenario["token"])
p = os.path.join(output_path, export_file_name)
with open(p, "wb") as f:
pickle.dump(sd_scenario, f)
# rename and save
if delay_remove is not None:
shutil.rmtree(delay_remove)
os.rename(output_path, save_path)
summary_file = os.path.join(save_path, summary_file)
with open(summary_file, "wb") as file:
pickle.dump(dict_recursive_remove_array(metadata_recorder), file)
print("Summary is saved at: {}".format(summary_file))
assert delay_remove == save_path