diff --git a/.gitignore b/.gitignore index 621944c..9671ccf 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ /documentation/build/ dataset/* **/combine/ +**.json diff --git a/scenarionet/tests/_test_combine_dataset_local.py b/scenarionet/tests/_test_combine_dataset_local.py new file mode 100644 index 0000000..ed5b879 --- /dev/null +++ b/scenarionet/tests/_test_combine_dataset_local.py @@ -0,0 +1,21 @@ +import os + +from scenarionet import SCENARIONET_DATASET_PATH +from scenarionet.builder.utils import combine_multiple_dataset +from scenarionet.verifier.utils import verify_loading_into_metadrive + + +def _test_combine_dataset(): + dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")] + dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuplan")) + dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo")) + dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")) + + combine_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset") + combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True) + success, result = verify_loading_into_metadrive(combine_path) + assert success + + +if __name__ == '__main__': + _test_combine_dataset() diff --git a/scenarionet/tests/generate_test_dataset.py b/scenarionet/tests/script/generate_test_dataset.py similarity index 100% rename from scenarionet/tests/generate_test_dataset.py rename to scenarionet/tests/script/generate_test_dataset.py diff --git a/scenarionet/tests/script/run_env.py b/scenarionet/tests/script/run_env.py new file mode 100644 index 0000000..d13390c --- /dev/null +++ b/scenarionet/tests/script/run_env.py @@ -0,0 +1,60 @@ +import os + +from metadrive.envs.scenario_env import ScenarioEnv +from metadrive.policy.replay_policy import ReplayEgoCarPolicy +from metadrive.scenario.utils import get_number_of_scenarios + +from scenarionet import SCENARIONET_DATASET_PATH +from scenarionet.builder.utils import combine_multiple_dataset + +if __name__ == '__main__': + dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")] + dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuplan")) + dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo")) + dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")) + + combine_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset") + combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True) + + env = ScenarioEnv( + { + "use_render": True, + "agent_policy": ReplayEgoCarPolicy, + "manual_control": False, + "show_interface": True, + "debug": False, + "show_logo": False, + "show_fps": False, + "force_reuse_object_name": True, + "num_scenarios": get_number_of_scenarios(combine_path), + "horizon": 1000, + "no_static_vehicles": True, + "vehicle_config": dict( + show_navi_mark=False, + no_wheel_friction=True, + lidar=dict(num_lasers=120, distance=50, num_others=4), + lane_line_detector=dict(num_lasers=12, distance=50), + side_detector=dict(num_lasers=160, distance=50) + ), + "data_directory": combine_path, + } + ) + success = [] + env.reset(force_seed=91) + while True: + env.reset(force_seed=91) + for t in range(10000): + o, r, d, info = env.step([0, 0]) + assert env.observation_space.contains(o) + c_lane = env.vehicle.lane + long, lat, = c_lane.local_coordinates(env.vehicle.position) + if env.config["use_render"]: + env.render( + text={ + "seed": env.engine.global_seed + env.config["start_scenario_index"], + } + ) + + if d and info["arrive_dest"]: + print("seed:{}, success".format(env.engine.global_random_seed)) + break diff --git a/scenarionet/tests/test_combine_dataset.py b/scenarionet/tests/test_combine_dataset.py index 36bfc05..a9ef874 100644 --- a/scenarionet/tests/test_combine_dataset.py +++ b/scenarionet/tests/test_combine_dataset.py @@ -21,7 +21,8 @@ def test_combine_multiple_dataset(): summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path) for scenario_file in sorted_scenarios: read_scenario(os.path.join(dataset_path, mapping[scenario_file], scenario_file)) - verify_loading_into_metadrive(dataset_path) + success, result = verify_loading_into_metadrive(dataset_path, result_save_dir="./test_dataset") + assert success if __name__ == '__main__': diff --git a/scenarionet/verifier/utils.py b/scenarionet/verifier/utils.py index c0a6cec..35d1986 100644 --- a/scenarionet/verifier/utils.py +++ b/scenarionet/verifier/utils.py @@ -1,14 +1,20 @@ +import json +import logging import os +logger = logging.getLogger(__name__) import tqdm from metadrive.envs.scenario_env import ScenarioEnv from metadrive.policy.replay_policy import ReplayEgoCarPolicy from metadrive.scenario.utils import get_number_of_scenarios -def verify_loading_into_metadrive(dataset_path): +def verify_loading_into_metadrive(dataset_path, result_save_dir=None): scenario_num = get_number_of_scenarios(dataset_path) - + if result_save_dir is not None: + assert os.path.exists(result_save_dir) and os.path.isdir( + result_save_dir), "Argument result_save_dir must be an existing dir" + success = True env = ScenarioEnv( { "agent_policy": ReplayEgoCarPolicy, @@ -18,11 +24,20 @@ def verify_loading_into_metadrive(dataset_path): "data_directory": dataset_path, } ) + error_files = [] try: for i in tqdm.tqdm(range(scenario_num)): env.reset(force_seed=i) except Exception as e: file_name = env.engine.data_manager.summary_lookup[i] file_path = os.path.join(dataset_path, env.engine.data_manager.mapping[file_name], file_name) - raise ValueError("Scenario Error, seed: {}, file_path: {}. " - "\n Error message: {}".format(i, file_path, e)) + error_file = {"seed": i, "file_path": file_path, "error": e} + error_files.append(error_file) + logger.warning("\n Scenario Error, seed: {}, file_path: {}.\n Error message: {}".format(i, file_path, e)) + success=False + finally: + env.close() + if result_save_dir is not None: + with open(os.path.join(result_save_dir, "error_scenarios_{}.json".format(os.path.basename(dataset_path))), "w+") as f: + json.dump(error_files, f) + return success, error_files