This commit is contained in:
QuanyiLi
2023-05-06 16:59:17 +01:00
parent 2b9be248c4
commit 8eb2e07d50
7 changed files with 49 additions and 22 deletions

View File

@@ -1,7 +1,7 @@
import copy
import math
import ast
import copy
import inspect
import math
import os
import pickle
import shutil
@@ -112,7 +112,16 @@ def contains_explicit_return(f):
return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f))))
def write_to_directory(convert_func, scenarios, output_path, version, dataset_name, force_overwrite=False, **kwargs):
def write_to_directory(convert_func,
scenarios,
output_path,
dataset_version,
dataset_name,
force_overwrite=False,
**kwargs):
"""
Convert a batch of scenarios.
"""
if not contains_explicit_return(convert_func):
raise RuntimeError("The convert function should return a metadata dict")
@@ -135,10 +144,29 @@ def write_to_directory(convert_func, scenarios, output_path, version, dataset_na
metadata_recorder = {}
for scenario in tqdm.tqdm(scenarios):
sd_scenario = convert_func(scenario, **kwargs)
# convert scenario
sd_scenario, scenario_id = convert_func(scenario, **kwargs)
export_file_name = "sd_{}_{}.pkl".format(dataset_name + "_" + dataset_version, scenario_id)
# add agents summary
summary_dict = {}
ego_car_id = sd_scenario[SD.METADATA][SD.SDC_ID]
summary_dict[ego_car_id] = get_agent_summary(
state_dict=sd_scenario.get_sdc_track()["state"], id=ego_car_id, type=sd_scenario.get_sdc_track()["type"]
)
for track_id, track in sd_scenario[SD.TRACKS].items():
summary_dict[track_id] = get_agent_summary(state_dict=track["state"], id=track_id, type=track["type"])
sd_scenario[SD.METADATA]["object_summary"] = summary_dict
# count some objects occurrence
sd_scenario[SD.METADATA]["number_summary"] = get_number_summary(sd_scenario)
metadata_recorder[export_file_name] = copy.deepcopy(sd_scenario[SD.METADATA])
# sanity check
sd_scenario = sd_scenario.to_dict()
ScenarioDescription.sanity_check(sd_scenario, check_self_type=True)
export_file_name = "sd_{}_{}.pkl".format(dataset_name+"_" + version, scenario["token"])
# dump
p = os.path.join(output_path, export_file_name)
with open(p, "wb") as f:
pickle.dump(sd_scenario, f)