diff --git a/scenarionet/examples/run_scenario.py b/scenarionet/examples/run_combine_scenarios.py similarity index 68% rename from scenarionet/examples/run_scenario.py rename to scenarionet/examples/run_combine_scenarios.py index 8d5f97f..75d5039 100644 --- a/scenarionet/examples/run_scenario.py +++ b/scenarionet/examples/run_combine_scenarios.py @@ -2,11 +2,19 @@ import os from metadrive.envs.scenario_env import ScenarioEnv from metadrive.policy.replay_policy import ReplayEgoCarPolicy +from metadrive.scenario.utils import get_number_of_scenarios from scenarionet import SCENARIONET_DATASET_PATH +from scenarionet.builder.utils import combine_multiple_dataset if __name__ == '__main__': - dataset_path = os.path.join(SCENARIONET_DATASET_PATH, "nuscenes_no_mapping_test") + dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")] + dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuplan")) + dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo")) + dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")) + + combine_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset") + combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True) env = ScenarioEnv( { @@ -16,7 +24,7 @@ if __name__ == '__main__': "show_interface": True, "show_logo": False, "show_fps": False, - "num_scenarios": 10, + "num_scenarios": get_number_of_scenarios(combine_path), "horizon": 1000, "no_static_vehicles": True, "vehicle_config": dict( @@ -26,7 +34,7 @@ if __name__ == '__main__': lane_line_detector=dict(num_lasers=12, distance=50), side_detector=dict(num_lasers=160, distance=50) ), - "data_directory": dataset_path, + "data_directory": combine_path, } ) success = [] diff --git a/scenarionet/tests/generate_test_dataset.py b/scenarionet/tests/generate_test_dataset.py index a12dc14..5624e8d 100644 --- a/scenarionet/tests/generate_test_dataset.py +++ b/scenarionet/tests/generate_test_dataset.py @@ -9,9 +9,9 @@ from scenarionet.converter.nuscenes.utils import convert_nuscenes_scenario, get_ from scenarionet.converter.utils import write_to_directory if __name__ == "__main__": - # raise ValueError("Avoid generating ata") + raise ValueError("Avoid overwriting existing ata") dataset_name = "nuscenes" - output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "test", "test_dataset", dataset_name) + output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name) version = 'v1.0-mini' force_overwrite = True diff --git a/scenarionet/tests/test_combine_dataset.py b/scenarionet/tests/test_combine_dataset.py index e1419d7..36bfc05 100644 --- a/scenarionet/tests/test_combine_dataset.py +++ b/scenarionet/tests/test_combine_dataset.py @@ -1,20 +1,17 @@ import os import os.path -import tqdm -from metadrive.envs.scenario_env import ScenarioEnv -from metadrive.policy.replay_policy import ReplayEgoCarPolicy - from scenarionet import SCENARIONET_PACKAGE_PATH from scenarionet.builder.utils import combine_multiple_dataset, read_dataset_summary, read_scenario +from scenarionet.verifier.utils import verify_loading_into_metadrive def test_combine_multiple_dataset(): dataset_name = "nuscenes" - original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "test", "test_dataset", dataset_name) + original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name) dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)] - output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "test", "combine") + output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine") combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=True, @@ -24,16 +21,7 @@ def test_combine_multiple_dataset(): summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path) for scenario_file in sorted_scenarios: read_scenario(os.path.join(dataset_path, mapping[scenario_file], scenario_file)) - - env = ScenarioEnv({"agent_policy": ReplayEgoCarPolicy, - "num_scenarios": 10, - "horizon": 1000, - "data_directory": output_path}) - try: - for i in tqdm.tqdm(range(10), desc="Test env loading"): - env.reset(force_seed=i) - finally: - env.close() + verify_loading_into_metadrive(dataset_path) if __name__ == '__main__': diff --git a/scenarionet/verifier/__init__.py b/scenarionet/verifier/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scenarionet/verifier/utils.py b/scenarionet/verifier/utils.py new file mode 100644 index 0000000..c0a6cec --- /dev/null +++ b/scenarionet/verifier/utils.py @@ -0,0 +1,28 @@ +import os + +import tqdm +from metadrive.envs.scenario_env import ScenarioEnv +from metadrive.policy.replay_policy import ReplayEgoCarPolicy +from metadrive.scenario.utils import get_number_of_scenarios + + +def verify_loading_into_metadrive(dataset_path): + scenario_num = get_number_of_scenarios(dataset_path) + + env = ScenarioEnv( + { + "agent_policy": ReplayEgoCarPolicy, + "num_scenarios": scenario_num, + "horizon": 1000, + "no_static_vehicles": False, + "data_directory": dataset_path, + } + ) + try: + for i in tqdm.tqdm(range(scenario_num)): + env.reset(force_seed=i) + except Exception as e: + file_name = env.engine.data_manager.summary_lookup[i] + file_path = os.path.join(dataset_path, env.engine.data_manager.mapping[file_name], file_name) + raise ValueError("Scenario Error, seed: {}, file_path: {}. " + "\n Error message: {}".format(i, file_path, e))