From 720fc61d724bd2684a6ea43ba70552833754d7ab Mon Sep 17 00:00:00 2001 From: QuanyiLi Date: Mon, 8 May 2023 13:28:21 +0100 Subject: [PATCH] add test --- .../_test_generate_from_error_file.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 scenarionet/tests/local_test/_test_generate_from_error_file.py diff --git a/scenarionet/tests/local_test/_test_generate_from_error_file.py b/scenarionet/tests/local_test/_test_generate_from_error_file.py new file mode 100644 index 0000000..e556535 --- /dev/null +++ b/scenarionet/tests/local_test/_test_generate_from_error_file.py @@ -0,0 +1,70 @@ +import copy +import os +import os.path + +from metadrive.scenario.scenario_description import ScenarioDescription as SD + +from scenarionet import SCENARIONET_DATASET_PATH +from scenarionet import SCENARIONET_PACKAGE_PATH +from scenarionet.builder.utils import combine_multiple_dataset +from scenarionet.common_utils import read_dataset_summary, read_scenario +from scenarionet.common_utils import recursive_equal +from scenarionet.verifier.error import ErrorFile +from scenarionet.verifier.utils import set_random_drop +from scenarionet.verifier.utils import verify_loading_into_metadrive + + +def test_combine_multiple_dataset(): + set_random_drop(True) + + dataset_paths = [ + os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"), + os.path.join(SCENARIONET_DATASET_PATH, "nuplan"), + os.path.join(SCENARIONET_DATASET_PATH, "waymo"), + os.path.join(SCENARIONET_DATASET_PATH, "pg") + ] + + dataset_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset") + combine_multiple_dataset(dataset_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True) + + summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path) + for scenario_file in sorted_scenarios: + read_scenario(dataset_path, mapping, scenario_file) + success, logs = verify_loading_into_metadrive( + dataset_path, result_save_dir="test_dataset", steps_to_run=1000, num_workers=8) + set_random_drop(False) + # get error file + file_name = ErrorFile.get_error_file_name(dataset_path) + error_file_path = os.path.join("test_dataset", file_name) + # regenerate + pass_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "passed_senarios") + fail_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "failed_scenarios") + pass_summary, pass_mapping = ErrorFile.generate_dataset(error_file_path, pass_dataset, force_overwrite=True, + broken_scenario=False) + fail_summary, fail_mapping = ErrorFile.generate_dataset(error_file_path, fail_dataset, force_overwrite=True, + broken_scenario=True) + + # assert + read_pass_summary, _, read_pass_mapping = read_dataset_summary(pass_dataset) + assert recursive_equal(read_pass_summary, pass_summary) + assert recursive_equal(read_pass_mapping, pass_mapping) + read_fail_summary, _, read_fail_mapping, = read_dataset_summary(fail_dataset) + assert recursive_equal(read_fail_mapping, fail_mapping) + assert recursive_equal(read_fail_summary, fail_summary) + + # assert pass+fail = origin + all_summaries = copy.deepcopy(read_pass_summary) + all_summaries.update(fail_summary) + assert recursive_equal(all_summaries, summary) + + # test read + for scenario in read_pass_summary: + sd = read_scenario(pass_dataset, read_pass_mapping, scenario) + SD.sanity_check(sd) + for scenario in read_fail_summary: + sd = read_scenario(fail_dataset, read_fail_mapping, scenario) + SD.sanity_check(sd) + + +if __name__ == '__main__': + test_combine_multiple_dataset()