From 5c1b2e053bd28cfd0f20978fdbd117f1b7110c67 Mon Sep 17 00:00:00 2001 From: QuanyiLi Date: Sun, 7 May 2023 23:01:01 +0100 Subject: [PATCH] small dataset test! --- scenarionet/examples/combine_dataset_and_run.py | 17 +++++++++-------- .../local_test/_test_combine_dataset_local.py | 9 +++++---- scenarionet/verifier/utils.py | 6 ++++-- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/scenarionet/examples/combine_dataset_and_run.py b/scenarionet/examples/combine_dataset_and_run.py index cfe9076..61c69e8 100644 --- a/scenarionet/examples/combine_dataset_and_run.py +++ b/scenarionet/examples/combine_dataset_and_run.py @@ -8,10 +8,10 @@ from scenarionet import SCENARIONET_DATASET_PATH from scenarionet.builder.utils import combine_multiple_dataset if __name__ == '__main__': - dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")] - dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuplan")) - dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo")) - dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg")) + dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"), + os.path.join(SCENARIONET_DATASET_PATH, "nuplan"), + os.path.join(SCENARIONET_DATASET_PATH, "waymo"), + os.path.join(SCENARIONET_DATASET_PATH, "pg")] combine_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset") combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True) @@ -38,9 +38,9 @@ if __name__ == '__main__': } ) success = [] - env.reset(force_seed=0) + env.reset(force_seed=2) while True: - env.reset(force_seed=env.current_seed + 1) + env.reset(force_seed=2) for t in range(10000): o, r, d, info = env.step([0, 0]) assert env.observation_space.contains(o) @@ -53,6 +53,7 @@ if __name__ == '__main__': } ) - if d and info["arrive_dest"]: - print("seed:{}, success".format(env.engine.global_random_seed)) + if d: + if info["arrive_dest"]: + print("seed:{}, success".format(env.engine.global_random_seed)) break diff --git a/scenarionet/tests/local_test/_test_combine_dataset_local.py b/scenarionet/tests/local_test/_test_combine_dataset_local.py index 1c741c5..9b5621f 100644 --- a/scenarionet/tests/local_test/_test_combine_dataset_local.py +++ b/scenarionet/tests/local_test/_test_combine_dataset_local.py @@ -6,10 +6,10 @@ from scenarionet.verifier.utils import verify_loading_into_metadrive def _test_combine_dataset(): - dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")] - dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuplan")) - dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo")) - dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg")) + dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"), + os.path.join(SCENARIONET_DATASET_PATH, "nuplan"), + os.path.join(SCENARIONET_DATASET_PATH, "waymo"), + os.path.join(SCENARIONET_DATASET_PATH, "pg")] combine_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset") combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True) @@ -17,5 +17,6 @@ def _test_combine_dataset(): success, result = verify_loading_into_metadrive(combine_path, "verify_results", steps_to_run=250) assert success + if __name__ == '__main__': _test_combine_dataset() diff --git a/scenarionet/verifier/utils.py b/scenarionet/verifier/utils.py index ee653e5..9aa4d9a 100644 --- a/scenarionet/verifier/utils.py +++ b/scenarionet/verifier/utils.py @@ -70,10 +70,12 @@ def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, ste desc="Scenarios: {}-{}".format(start_scenario_index, start_scenario_index + num_scenario)): env.reset(force_seed=scenario_index) + arrive = False for _ in range(steps_to_run): o, r, d, info = env.step([0, 0]) - if d: - assert info["arrive_dest"], "Can not arrive destination" + if d and info["arrive_dest"]: + arrive = True + assert arrive, "Can not arrive destination" except Exception as e: file_name = env.engine.data_manager.summary_lookup[scenario_index] file_path = os.path.join(dataset_path, env.engine.data_manager.mapping[file_name], file_name)