From 36858549560138b782aac3a7cd7e03d788535383 Mon Sep 17 00:00:00 2001 From: QuanyiLi Date: Sun, 7 May 2023 13:52:43 +0100 Subject: [PATCH] test script --- .gitignore | 1 + scenarionet/builder/utils.py | 30 +++++++++++++----- scenarionet/converter/utils.py | 2 +- scenarionet/test/generate_test_dataset.py | 30 ++++++++++++++++++ scenarionet/test/test_combine_dataset.py | 38 ++++++++++++++++++++++- 5 files changed, 91 insertions(+), 10 deletions(-) create mode 100644 scenarionet/test/generate_test_dataset.py diff --git a/.gitignore b/.gitignore index 9216a92..621944c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ /dist/ /documentation/build/ dataset/* +**/combine/ diff --git a/scenarionet/builder/utils.py b/scenarionet/builder/utils.py index a05b9d3..d2e5702 100644 --- a/scenarionet/builder/utils.py +++ b/scenarionet/builder/utils.py @@ -3,21 +3,31 @@ import logging import os import os.path as osp import pickle +import shutil +import metadrive.scenario.utils as sd_utils from metadrive.scenario.scenario_description import ScenarioDescription logger = logging.getLogger(__name__) +def read_dataset_summary(dataset_path): + return sd_utils.read_dataset_summary(dataset_path) + + +def read_scenario(pkl_file_path): + return sd_utils.read_scenario_data(pkl_file_path) + + def try_generating_summary(file_folder): # Create a fake one files = os.listdir(file_folder) summary = {} for file in files: - file = file.replace(".pkl", "") - with open(osp.join(file_folder, file), "rb+") as f: - scenario = pickle.load(f) - summary[file] = copy.deepcopy(scenario[ScenarioDescription.METADATA]) + if file != ScenarioDescription.DATASET.SUMMARY_FILE and file != ScenarioDescription.DATASET.MAPPING_FILE: + with open(osp.join(file_folder, file), "rb+") as f: + scenario = pickle.load(f) + summary[file] = copy.deepcopy(scenario[ScenarioDescription.METADATA]) return summary @@ -30,7 +40,7 @@ def try_generating_mapping(file_folder): return mapping -def combine_multiple_dataset(output_path, force_overwrite=False, try_generate_missing_file=True, *dataset_paths): +def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False, try_generate_missing_file=True): """ Combine multiple datasets. Each dataset should have a dataset_summary.pkl :param output_path: The path to store the output dataset @@ -40,8 +50,12 @@ def combine_multiple_dataset(output_path, force_overwrite=False, try_generate_mi :return: """ output_abs_path = osp.abspath(output_path) - if os.path.exists(output_abs_path) and not force_overwrite: - raise FileExistsError("Output path already exists!") + if os.path.exists(output_abs_path): + if not force_overwrite: + raise FileExistsError("Output path already exists!") + else: + shutil.rmtree(output_abs_path) + os.mkdir(output_abs_path) summaries = {} mappings = {} @@ -71,7 +85,7 @@ def combine_multiple_dataset(output_path, force_overwrite=False, try_generate_mi if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE)): if try_generate_missing_file: - mapping = try_generating_mapping(abs_dir_path) + mapping = {k: "" for k in summary} else: raise FileNotFoundError("Can not find mapping file for dataset: {}".format(abs_dir_path)) else: diff --git a/scenarionet/converter/utils.py b/scenarionet/converter/utils.py index 1570827..7761a4b 100644 --- a/scenarionet/converter/utils.py +++ b/scenarionet/converter/utils.py @@ -106,7 +106,7 @@ def write_to_directory( # convert scenario sd_scenario = convert_func(scenario, dataset_version, **kwargs) scenario_id = sd_scenario[SD.ID] - export_file_name = "sd_{}_{}.pkl".format(dataset_name + "_" + dataset_version, scenario_id) + export_file_name = SD.get_export_file_name(dataset_name, dataset_version, scenario_id) # add agents summary summary_dict = {} diff --git a/scenarionet/test/generate_test_dataset.py b/scenarionet/test/generate_test_dataset.py new file mode 100644 index 0000000..a12dc14 --- /dev/null +++ b/scenarionet/test/generate_test_dataset.py @@ -0,0 +1,30 @@ +""" +This script aims to convert nuscenes scenarios to ScenarioDescription, so that we can load any nuscenes scenarios into +MetaDrive. +""" +import os.path + +from scenarionet import SCENARIONET_PACKAGE_PATH +from scenarionet.converter.nuscenes.utils import convert_nuscenes_scenario, get_nuscenes_scenarios +from scenarionet.converter.utils import write_to_directory + +if __name__ == "__main__": + # raise ValueError("Avoid generating ata") + dataset_name = "nuscenes" + output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "test", "test_dataset", dataset_name) + version = 'v1.0-mini' + force_overwrite = True + + dataroot = '/home/shady/data/nuscenes' + scenarios, nusc = get_nuscenes_scenarios(dataroot, version) + + for i in range(5): + write_to_directory( + convert_func=convert_nuscenes_scenario, + scenarios=scenarios[i * 2:i * 2 + 2], + output_path=output_path + "_{}".format(i), + dataset_version=version, + dataset_name=dataset_name, + force_overwrite=force_overwrite, + nuscenes=nusc + ) diff --git a/scenarionet/test/test_combine_dataset.py b/scenarionet/test/test_combine_dataset.py index 64d20fa..e1419d7 100644 --- a/scenarionet/test/test_combine_dataset.py +++ b/scenarionet/test/test_combine_dataset.py @@ -1,4 +1,40 @@ +import os +import os.path + +import tqdm +from metadrive.envs.scenario_env import ScenarioEnv +from metadrive.policy.replay_policy import ReplayEgoCarPolicy + +from scenarionet import SCENARIONET_PACKAGE_PATH +from scenarionet.builder.utils import combine_multiple_dataset, read_dataset_summary, read_scenario def test_combine_multiple_dataset(): - pass \ No newline at end of file + dataset_name = "nuscenes" + original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "test", "test_dataset", dataset_name) + dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)] + + output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "test", "combine") + combine_multiple_dataset(output_path, + *dataset_paths, + force_overwrite=True, + try_generate_missing_file=True) + dataset_paths.append(output_path) + for dataset_path in dataset_paths: + summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path) + for scenario_file in sorted_scenarios: + read_scenario(os.path.join(dataset_path, mapping[scenario_file], scenario_file)) + + env = ScenarioEnv({"agent_policy": ReplayEgoCarPolicy, + "num_scenarios": 10, + "horizon": 1000, + "data_directory": output_path}) + try: + for i in tqdm.tqdm(range(10), desc="Test env loading"): + env.reset(force_seed=i) + finally: + env.close() + + +if __name__ == '__main__': + test_combine_multiple_dataset()