diff --git a/README.md b/README.md index 2bf9633..e11979d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,82 @@ # ScenarioNet + ScenarioNet: Scalable Traffic Scenario Management System for Autonomous Driving + +## Installation + +``` +git clone git@github.com:metadriverse/scenarionet.git +cd scenarionet +pip install -e . +``` + +## Usage + +We provide some explanation and demo for all scripts here. +**You are encouraged to try them on your own, add ```-h``` or ```--help``` argument to know more details about these scripts.** + +### Convert + +**Waymo**: the following script can convert Waymo tfrecord to Metadrive scenario description and +store them at directory ./waymo + +``` +python -m scenarionet.convert_waymo -d waymo --raw_data_path /path/to/tfrecords --num_workers=16 +``` + +**nuPlan**: the following script will convert nuPlan split containing .db files to Metadrive scenario description and +store them at directory ./nuplan + +``` +python -m scenarionet.convert_nuplan -d nuplan -raw_data_path /path/to/dir/containing/.db files --num_workers=16 +``` + +**nuScenes**: as nuScenes split can be read by specifying version like v1.0-mini and v1.0-training, the following script +will convert all scenarios in that split + +``` +python -m scenarionet.convert_nuscenes -d nuscenes --version v1.0-mini --num_workers=16 +``` + +**PG**: the following script can generate 10000 scenarios stored at directory ./pg + +``` +python -m scenarionet.scripts.convert_pg -d pg --num_workers=16 --num_scenarios=10000 +``` + +### Merge & move +For merging two or more database, use +``` +python -m scenarionet.merge_database -d /destination/path --from_databases /database/path1 /database/path2 ... +``` +As a database contains a path mapping, one should move database folder with the following script instead of ```cp``` +command +``` +python -m scenarionet.move_database --to /destination/path --from /source/path +``` + +### Verify +The following scripts will check whether all scenarios exist or can be loaded into simulator. +The missing or broken scenarios will be recorded and stored into the error file. Otherwise, no error file will be +generated. +With teh error file, one can build a new database excluding or including the broken or missing scenarios. +**Existence check** +``` +python -m scenarionet.verify_completeness -d /database/to/check --result_save_dir /where/to/save/test/result +``` +**Runnable check** +``` +python -m scenarionet.verify_simulation -d /database/to/check --result_save_dir /where/to/save/test/result +``` +**Generating new database** +``` +python -m scenarionet.generate_from_error_file -d /where/to/create/the/new/database --file /error/file/path --broken +``` + +### visualization + +Visualizing the simulated scenario +``` +python -m scenarionet.run_simulation -d /path/to/database --render --scenario_index +``` + diff --git a/scenarionet/builder/utils.py b/scenarionet/builder/utils.py index 0a67d58..9d002fe 100644 --- a/scenarionet/builder/utils.py +++ b/scenarionet/builder/utils.py @@ -1,4 +1,3 @@ -import pkg_resources # for suppress warning import copy import logging import os @@ -6,8 +5,8 @@ import os.path as osp import pickle import shutil from typing import Callable, List -import tqdm +import tqdm from metadrive.scenario.scenario_description import ScenarioDescription from scenarionet.common_utils import save_summary_anda_mapping @@ -27,7 +26,7 @@ def try_generating_summary(file_folder): return summary -def combine_dataset( +def merge_database( output_path, *dataset_paths, exist_ok=False, @@ -109,3 +108,27 @@ def combine_dataset( save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings) return summaries, mappings + + +def move_database( + from_path, + to_path, + exist_ok=False, + overwrite=False, +): + if not os.path.exists(from_path): + raise FileNotFoundError("Can not find dataset: {}".format(from_path)) + if os.path.exists(to_path): + assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a dataset" + assert not os.path.samefile(from_path, to_path), "to_directory is the same as from_directory. Abort!" + merge_database( + to_path, + from_path, + exist_ok=exist_ok, + overwrite=overwrite, + try_generate_missing_file=True, + ) + files = os.listdir(from_path) + if ScenarioDescription.DATASET.MAPPING_FILE in files and ScenarioDescription.DATASET.SUMMARY_FILE in files and len( + files) == 2: + shutil.rmtree(from_path) diff --git a/scenarionet/convert_nuplan.py b/scenarionet/convert_nuplan.py index 20007cf..ac4dea6 100644 --- a/scenarionet/convert_nuplan.py +++ b/scenarionet/convert_nuplan.py @@ -8,16 +8,16 @@ from scenarionet.converter.utils import write_to_directory if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--dataset_name", "-n", default="nuplan", help="Dataset name, will be used to generate scenario files" - ) - parser.add_argument( - "--dataset_path", + "--database_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "nuplan"), help="A directory, the path to place the data" ) + parser.add_argument( + "--dataset_name", "-n", default="nuplan", help="Dataset name, will be used to generate scenario files" + ) parser.add_argument("--version", "-v", default='v1.1', help="version of the raw data") - parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, whether to overwrite it") + parser.add_argument("--overwrite", action="store_true", help="If the database_path exists, whether to overwrite it") parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use") parser.add_argument( "--raw_data_path", @@ -30,7 +30,7 @@ if __name__ == '__main__': overwrite = args.overwrite dataset_name = args.dataset_name - output_path = args.dataset_path + output_path = args.database_path version = args.version data_root = args.raw_data_path diff --git a/scenarionet/convert_nuscenes.py b/scenarionet/convert_nuscenes.py index 33a7a5f..eae1da4 100644 --- a/scenarionet/convert_nuscenes.py +++ b/scenarionet/convert_nuscenes.py @@ -8,27 +8,27 @@ from scenarionet.converter.utils import write_to_directory if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--dataset_name", "-n", default="nuscenes", help="Dataset name, will be used to generate scenario files" - ) - parser.add_argument( - "--dataset_path", + "--database_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"), help="directory, The path to place the data" ) + parser.add_argument( + "--dataset_name", "-n", default="nuscenes", help="Dataset name, will be used to generate scenario files" + ) parser.add_argument( "--version", "-v", default='v1.0-mini', help="version of nuscenes data, scenario of this version will be converted " ) - parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, whether to overwrite it") + parser.add_argument("--overwrite", action="store_true", help="If the database_path exists, whether to overwrite it") parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use") args = parser.parse_args() overwrite = args.overwrite dataset_name = args.dataset_name - output_path = args.dataset_path + output_path = args.database_path version = args.version dataroot = '/home/shady/data/nuscenes' diff --git a/scenarionet/convert_pg.py b/scenarionet/convert_pg.py index ba79509..edf26f5 100644 --- a/scenarionet/convert_pg.py +++ b/scenarionet/convert_pg.py @@ -11,16 +11,16 @@ from scenarionet.converter.utils import write_to_directory if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--dataset_name", "-n", default="pg", help="Dataset name, will be used to generate scenario files" - ) - parser.add_argument( - "--dataset_path", + "--database_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "pg"), help="directory, The path to place the data" ) + parser.add_argument( + "--dataset_name", "-n", default="pg", help="Dataset name, will be used to generate scenario files" + ) parser.add_argument("--version", "-v", default=metadrive.constants.DATA_VERSION, help="version") - parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, whether to overwrite it") + parser.add_argument("--overwrite", action="store_true", help="If the database_path exists, whether to overwrite it") parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use") parser.add_argument("--num_scenarios", type=int, default=64, help="how many scenarios to generate (default: 30)") parser.add_argument("--start_index", type=int, default=0, help="which index to start") @@ -28,7 +28,7 @@ if __name__ == '__main__': overwrite = args.overwrite dataset_name = args.dataset_name - output_path = args.dataset_path + output_path = args.database_path version = args.version scenario_indices = get_pg_scenarios(args.start_index, args.num_scenarios) diff --git a/scenarionet/convert_waymo.py b/scenarionet/convert_waymo.py index 1711fe8..99c589d 100644 --- a/scenarionet/convert_waymo.py +++ b/scenarionet/convert_waymo.py @@ -12,16 +12,16 @@ logger = logging.getLogger(__name__) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--dataset_name", "-n", default="waymo", help="Dataset name, will be used to generate scenario files" - ) - parser.add_argument( - "--dataset_path", + "--database_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "waymo"), help="A directory, the path to place the converted data" ) + parser.add_argument( + "--dataset_name", "-n", default="waymo", help="Dataset name, will be used to generate scenario files" + ) parser.add_argument("--version", "-v", default='v1.2', help="version") - parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, whether to overwrite it") + parser.add_argument("--overwrite", action="store_true", help="If the database_path exists, whether to overwrite it") parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use") parser.add_argument( "--raw_data_path", @@ -32,7 +32,7 @@ if __name__ == '__main__': overwrite = args.overwrite dataset_name = args.dataset_name - output_path = args.dataset_path + output_path = args.database_path version = args.version waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path) diff --git a/scenarionet/converter/utils.py b/scenarionet/converter/utils.py index d2b3640..106455b 100644 --- a/scenarionet/converter/utils.py +++ b/scenarionet/converter/utils.py @@ -16,7 +16,7 @@ from metadrive.envs.metadrive_env import MetaDriveEnv from metadrive.policy.idm_policy import IDMPolicy from metadrive.scenario import ScenarioDescription as SD -from scenarionet.builder.utils import combine_dataset +from scenarionet.builder.utils import merge_database from scenarionet.common_utils import save_summary_anda_mapping from scenarionet.converter.pg.utils import convert_pg_scenario @@ -126,7 +126,7 @@ def write_to_directory( with multiprocessing.Pool(num_workers, maxtasksperchild=10) as p: ret = list(p.imap(func, argument_list)) # call ret to block the process - combine_dataset(save_path, *output_pathes, exist_ok=True, overwrite=False, try_generate_missing_file=False) + merge_database(save_path, *output_pathes, exist_ok=True, overwrite=False, try_generate_missing_file=False) def writing_to_directory_wrapper(args, convert_func, dataset_version, dataset_name, overwrite=False): diff --git a/scenarionet/generate_from_error_file.py b/scenarionet/generate_from_error_file.py index cb9e096..5897f24 100644 --- a/scenarionet/generate_from_error_file.py +++ b/scenarionet/generate_from_error_file.py @@ -5,9 +5,9 @@ from scenarionet.verifier.error import ErrorFile if __name__ == '__main__': parser = argparse.ArgumentParser() + parser.add_argument("--database_path", "-d", required=True, help="The path of the newly generated dataset") parser.add_argument("--file", "-f", required=True, help="The path of the error file, should be xyz.json") - parser.add_argument("--dataset_path", "-d", required=True, help="The path of the newly generated dataset") - parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it") + parser.add_argument("--overwrite", action="store_true", help="If the database_path exists, overwrite it") parser.add_argument( "--broken", action="store_true", @@ -15,4 +15,4 @@ if __name__ == '__main__': "If turn on this flog, it will generate dataset containing only broken scenarios." ) args = parser.parse_args() - ErrorFile.generate_dataset(args.file, args.dataset_path, args.overwrite, args.broken) + ErrorFile.generate_dataset(args.file, args.database_path, args.overwrite, args.broken) diff --git a/scenarionet/combine_dataset.py b/scenarionet/merge_database.py similarity index 92% rename from scenarionet/combine_dataset.py rename to scenarionet/merge_database.py index 01fa48d..be23d5c 100644 --- a/scenarionet/combine_dataset.py +++ b/scenarionet/merge_database.py @@ -1,12 +1,13 @@ import pkg_resources # for suppress warning import argparse from scenarionet.builder.filters import ScenarioFilter -from scenarionet.builder.utils import combine_dataset +from scenarionet.builder.utils import merge_database if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--dataset_path", + "--database_path", + "-d", required=True, help="The name of the new combined dataset. " "It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. " @@ -43,8 +44,8 @@ if __name__ == '__main__': filters = [ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=target, condition="greater")] if len(args.from_datasets) != 0: - combine_dataset( - args.dataset_path, + merge_database( + args.database_path, *args.from_datasets, exist_ok=args.exist_ok, overwrite=args.overwrite, diff --git a/scenarionet/move_database.py b/scenarionet/move_database.py new file mode 100644 index 0000000..d32ca78 --- /dev/null +++ b/scenarionet/move_database.py @@ -0,0 +1,35 @@ +import argparse + +from scenarionet.builder.utils import move_database + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--from', required=True, help="Which dataset to move.") + parser.add_argument( + "--to", + required=True, + help="The name of the new dataset. " + "It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. " + "If exists_ok=True, those two .pkl files will be stored in an existing directory and turn " + "that directory into a dataset." + ) + parser.add_argument( + "--exist_ok", + action="store_true", + help="Still allow to write, if the to_folder exists already. " + "This write will only create two .pkl files and this directory will become a dataset." + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, " + "whether to overwrite both files" + ) + args = parser.parse_args() + from_path = args.__getattr__("from") + move_database( + from_path, + args.to, + exist_ok=args.exist_ok, + overwrite=args.overwrite, + ) diff --git a/scenarionet/run_simulation.py b/scenarionet/run_simulation.py index a460b22..65bb13f 100644 --- a/scenarionet/run_simulation.py +++ b/scenarionet/run_simulation.py @@ -7,13 +7,13 @@ from metadrive.scenario.utils import get_number_of_scenarios if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("--dataset_path", "-d", required=True, help="The path of the dataset") + parser.add_argument("--database_path", "-d", required=True, help="The path of the dataset") parser.add_argument("--render", action="store_true", help="Enable 3D rendering") parser.add_argument("--scenario_index", default=None, type=int, help="Specifying a scenario to run") args = parser.parse_args() - dataset_path = os.path.abspath(args.dataset_path) - num_scenario = get_number_of_scenarios(dataset_path) + database_path = os.path.abspath(args.database_path) + num_scenario = get_number_of_scenarios(database_path) if args.scenario_index is not None: assert args.scenario_index < num_scenario, \ "The specified scenario index exceeds the scenario range: {}!".format(num_scenario) @@ -35,7 +35,7 @@ if __name__ == '__main__': lane_line_detector=dict(num_lasers=12, distance=50), side_detector=dict(num_lasers=160, distance=50) ), - "data_directory": dataset_path, + "data_directory": database_path, } ) for seed in range(num_scenario if args.scenario_index is not None else 1000000): diff --git a/scenarionet/tests/local_test/_test_combine_dataset_local.py b/scenarionet/tests/local_test/_test_combine_dataset_local.py index b35dfc3..778e250 100644 --- a/scenarionet/tests/local_test/_test_combine_dataset_local.py +++ b/scenarionet/tests/local_test/_test_combine_dataset_local.py @@ -3,7 +3,7 @@ import os from metadrive.scenario.scenario_description import ScenarioDescription as SD from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_PACKAGE_PATH, TMP_PATH -from scenarionet.builder.utils import combine_dataset +from scenarionet.builder.utils import merge_database from scenarionet.common_utils import read_dataset_summary, read_scenario @@ -17,7 +17,7 @@ def _test_combine_dataset(): ] combine_path = os.path.join(TMP_PATH, "combine") - combine_dataset(combine_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) + merge_database(combine_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) summary, _, mapping = read_dataset_summary(combine_path) for scenario in summary: sd = read_scenario(combine_path, mapping, scenario) diff --git a/scenarionet/tests/local_test/_test_filter_local.py b/scenarionet/tests/local_test/_test_filter_local.py index 2af2662..f3524e2 100644 --- a/scenarionet/tests/local_test/_test_filter_local.py +++ b/scenarionet/tests/local_test/_test_filter_local.py @@ -5,7 +5,7 @@ from metadrive.type import MetaDriveType from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_PACKAGE_PATH, TMP_PATH from scenarionet.builder.filters import ScenarioFilter -from scenarionet.builder.utils import combine_dataset +from scenarionet.builder.utils import merge_database def test_filter_dataset(): @@ -23,7 +23,7 @@ def test_filter_dataset(): # nuscenes data has no light # light_condition = ScenarioFilter.make(ScenarioFilter.has_traffic_light) sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=30, condition="greater") - summary, mapping = combine_dataset( + summary, mapping = merge_database( output_path, *dataset_paths, exist_ok=True, @@ -39,7 +39,7 @@ def test_filter_dataset(): ScenarioFilter.object_number, number_threshold=50, object_type=MetaDriveType.PEDESTRIAN, condition="greater" ) - summary, mapping = combine_dataset( + summary, mapping = merge_database( output_path, *dataset_paths, exist_ok=True, @@ -53,7 +53,7 @@ def test_filter_dataset(): traffic_light = ScenarioFilter.make(ScenarioFilter.has_traffic_light) - summary, mapping = combine_dataset( + summary, mapping = merge_database( output_path, *dataset_paths, exist_ok=True, diff --git a/scenarionet/tests/local_test/_test_generate_from_error_file.py b/scenarionet/tests/local_test/_test_generate_from_error_file.py index 45b67a3..9a1f0aa 100644 --- a/scenarionet/tests/local_test/_test_generate_from_error_file.py +++ b/scenarionet/tests/local_test/_test_generate_from_error_file.py @@ -6,12 +6,12 @@ from metadrive.scenario.scenario_description import ScenarioDescription as SD from scenarionet import SCENARIONET_DATASET_PATH from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH -from scenarionet.builder.utils import combine_dataset +from scenarionet.builder.utils import merge_database from scenarionet.common_utils import read_dataset_summary, read_scenario from scenarionet.common_utils import recursive_equal from scenarionet.verifier.error import ErrorFile from scenarionet.verifier.utils import set_random_drop -from scenarionet.verifier.utils import verify_dataset +from scenarionet.verifier.utils import verify_database def test_generate_from_error(): @@ -25,12 +25,12 @@ def test_generate_from_error(): ] dataset_path = os.path.join(TMP_PATH, "combine") - combine_dataset(dataset_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) + merge_database(dataset_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path) for scenario_file in sorted_scenarios: read_scenario(dataset_path, mapping, scenario_file) - success, logs = verify_dataset( + success, logs = verify_database( dataset_path, result_save_dir="../test_dataset", steps_to_run=1000, num_workers=16, overwrite=True ) set_random_drop(False) diff --git a/scenarionet/tests/local_test/combine_verify_generate.sh b/scenarionet/tests/local_test/combine_verify_generate.sh index 997b37e..529a621 100644 --- a/scenarionet/tests/local_test/combine_verify_generate.sh +++ b/scenarionet/tests/local_test/combine_verify_generate.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -python ../../combine_dataset.py --overwrite --exist_ok --dataset_path ../tmp/test_combine_dataset --from_datasets ../../../dataset/waymo ../../../dataset/pg ../../../dataset/nuscenes ../../../dataset/nuplan --overwrite -python ../../verify_simulation.py --overwrite --dataset_path ../tmp/test_combine_dataset --result_save_dir ../tmp/test_combine_dataset --random_drop --num_workers=16 -python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --dataset_path ../tmp/verify_pass -python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --dataset_path ../tmp/verify_fail --broken \ No newline at end of file +python ../../merge_database.py --overwrite --exist_ok --database_path ../tmp/test_combine_dataset --from_datasets ../../../dataset/waymo ../../../dataset/pg ../../../dataset/nuscenes ../../../dataset/nuplan --overwrite +python ../../verify_simulation.py --overwrite --database_path ../tmp/test_combine_dataset --result_save_dir ../tmp/test_combine_dataset --random_drop --num_workers=16 +python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --database_path ../tmp/verify_pass +python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --database_path ../tmp/verify_fail --broken \ No newline at end of file diff --git a/scenarionet/tests/local_test/convert_pg_large.sh b/scenarionet/tests/local_test/convert_pg_large.sh index 3f4c038..8e65e6b 100644 --- a/scenarionet/tests/local_test/convert_pg_large.sh +++ b/scenarionet/tests/local_test/convert_pg_large.sh @@ -32,8 +32,8 @@ done # combine the datasets if [ "$overwrite" = true ]; then - python -m scenarionet.scripts.combine_dataset --dataset_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --overwrite --exist_ok + python -m scenarionet.scripts.merge_database --database_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --overwrite --exist_ok else - python -m scenarionet.scripts.combine_dataset --dataset_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --exist_ok + python -m scenarionet.scripts.merge_database --database_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --exist_ok fi diff --git a/scenarionet/tests/script/run_env.py b/scenarionet/tests/script/run_env.py index 53153f2..682a145 100644 --- a/scenarionet/tests/script/run_env.py +++ b/scenarionet/tests/script/run_env.py @@ -5,7 +5,7 @@ from metadrive.policy.replay_policy import ReplayEgoCarPolicy from metadrive.scenario.utils import get_number_of_scenarios from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_PACKAGE_PATH, TMP_PATH -from scenarionet.builder.utils import combine_dataset +from scenarionet.builder.utils import merge_database if __name__ == '__main__': dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")] @@ -14,7 +14,7 @@ if __name__ == '__main__': dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg")) combine_path = os.path.join(TMP_PATH, "combine") - combine_dataset(combine_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) + merge_database(combine_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) env = ScenarioEnv( { diff --git a/scenarionet/tests/test_combine_dataset.py b/scenarionet/tests/test_combine_dataset.py index a14f2b4..1f3b1a9 100644 --- a/scenarionet/tests/test_combine_dataset.py +++ b/scenarionet/tests/test_combine_dataset.py @@ -2,9 +2,9 @@ import os import os.path from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH -from scenarionet.builder.utils import combine_dataset +from scenarionet.builder.utils import merge_database from scenarionet.common_utils import read_dataset_summary, read_scenario -from scenarionet.verifier.utils import verify_dataset +from scenarionet.verifier.utils import verify_database def test_combine_multiple_dataset(): @@ -14,13 +14,13 @@ def test_combine_multiple_dataset(): dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)] output_path = os.path.join(TMP_PATH, "combine") - combine_dataset(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) + merge_database(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) dataset_paths.append(output_path) for dataset_path in dataset_paths: summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path) for scenario_file in sorted_scenarios: read_scenario(dataset_path, mapping, scenario_file) - success, result = verify_dataset( + success, result = verify_database( dataset_path, result_save_dir=test_dataset_path, steps_to_run=1000, num_workers=4, overwrite=True ) assert success diff --git a/scenarionet/tests/test_filter.py b/scenarionet/tests/test_filter.py index 1f153b0..6adbad4 100644 --- a/scenarionet/tests/test_filter.py +++ b/scenarionet/tests/test_filter.py @@ -5,7 +5,7 @@ from metadrive.type import MetaDriveType from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH from scenarionet.builder.filters import ScenarioFilter -from scenarionet.builder.utils import combine_dataset +from scenarionet.builder.utils import merge_database def test_filter_dataset(): @@ -20,7 +20,7 @@ def test_filter_dataset(): # light_condition = ScenarioFilter.make(ScenarioFilter.has_traffic_light) sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=30, condition="smaller") answer = ['sd_nuscenes_v1.0-mini_scene-0553.pkl', '0.pkl', 'sd_nuscenes_v1.0-mini_scene-1100.pkl'] - summary, mapping = combine_dataset( + summary, mapping = merge_database( output_path, *dataset_paths, exist_ok=True, @@ -38,7 +38,7 @@ def test_filter_dataset(): assert in_, summary.keys() sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=5, condition="greater") - summary, mapping = combine_dataset( + summary, mapping = merge_database( output_path, *dataset_paths, exist_ok=True, @@ -55,7 +55,7 @@ def test_filter_dataset(): ) answer = ['sd_nuscenes_v1.0-mini_scene-0061.pkl', 'sd_nuscenes_v1.0-mini_scene-1094.pkl'] - summary, mapping = combine_dataset( + summary, mapping = merge_database( output_path, *dataset_paths, exist_ok=True, @@ -69,7 +69,7 @@ def test_filter_dataset(): num_condition = ScenarioFilter.make(ScenarioFilter.object_number, number_threshold=50, condition="greater") - summary, mapping = combine_dataset( + summary, mapping = merge_database( output_path, *dataset_paths, exist_ok=True, diff --git a/scenarionet/tests/test_generate_from_error_file.py b/scenarionet/tests/test_generate_from_error_file.py index c3f9dd3..16aa668 100644 --- a/scenarionet/tests/test_generate_from_error_file.py +++ b/scenarionet/tests/test_generate_from_error_file.py @@ -5,11 +5,11 @@ import os.path from metadrive.scenario.scenario_description import ScenarioDescription as SD from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH -from scenarionet.builder.utils import combine_dataset +from scenarionet.builder.utils import merge_database from scenarionet.common_utils import read_dataset_summary, read_scenario from scenarionet.common_utils import recursive_equal from scenarionet.verifier.error import ErrorFile -from scenarionet.verifier.utils import verify_dataset, set_random_drop +from scenarionet.verifier.utils import verify_database, set_random_drop def test_generate_from_error(): @@ -18,12 +18,12 @@ def test_generate_from_error(): original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name) dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)] dataset_path = os.path.join(TMP_PATH, "combine") - combine_dataset(dataset_path, *dataset_paths, exist_ok=True, try_generate_missing_file=True, overwrite=True) + merge_database(dataset_path, *dataset_paths, exist_ok=True, try_generate_missing_file=True, overwrite=True) summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path) for scenario_file in sorted_scenarios: read_scenario(dataset_path, mapping, scenario_file) - success, logs = verify_dataset( + success, logs = verify_database( dataset_path, result_save_dir=TMP_PATH, steps_to_run=1000, num_workers=3, overwrite=True ) set_random_drop(False) diff --git a/scenarionet/tests/test_move.py b/scenarionet/tests/test_move.py new file mode 100644 index 0000000..f6c3d44 --- /dev/null +++ b/scenarionet/tests/test_move.py @@ -0,0 +1,54 @@ +import os +import os.path + +import pytest + +from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH +from scenarionet.builder.utils import move_database, merge_database +from scenarionet.common_utils import read_dataset_summary, read_scenario +from scenarionet.verifier.utils import verify_database + + +@pytest.mark.order("first") +def test_move_database(): + dataset_name = "nuscenes" + original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name) + dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)] + moved_path = [] + output_path = os.path.join(TMP_PATH, "move_combine") + # move + for k, from_path in enumerate(dataset_paths): + to = os.path.join(TMP_PATH, str(k)) + move_database(from_path, to) + moved_path.append(to) + assert os.path.exists(from_path) + merge_database(output_path, *moved_path, exist_ok=True, overwrite=True, try_generate_missing_file=True) + # verify + summary, sorted_scenarios, mapping = read_dataset_summary(output_path) + for scenario_file in sorted_scenarios: + read_scenario(output_path, mapping, scenario_file) + success, result = verify_database( + output_path, result_save_dir=output_path, steps_to_run=0, num_workers=4, overwrite=True + ) + assert success + + # move 2 + new_move_pathes = [] + for k, from_path in enumerate(moved_path): + new_p = os.path.join(TMP_PATH, str(k) + str(k)) + new_move_pathes.append(new_p) + move_database(from_path, new_p, exist_ok=True, overwrite=True) + assert not os.path.exists(from_path) + merge_database(output_path, *new_move_pathes, exist_ok=True, overwrite=True, try_generate_missing_file=True) + # verify + summary, sorted_scenarios, mapping = read_dataset_summary(output_path) + for scenario_file in sorted_scenarios: + read_scenario(output_path, mapping, scenario_file) + success, result = verify_database( + output_path, result_save_dir=output_path, steps_to_run=0, num_workers=4, overwrite=True + ) + assert success + + +if __name__ == '__main__': + test_move_database() diff --git a/scenarionet/tests/test_verify_completeness.py b/scenarionet/tests/test_verify_completeness.py index 070b795..24f2f30 100644 --- a/scenarionet/tests/test_verify_completeness.py +++ b/scenarionet/tests/test_verify_completeness.py @@ -2,9 +2,9 @@ import os import os.path from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH -from scenarionet.builder.utils import combine_dataset +from scenarionet.builder.utils import merge_database from scenarionet.common_utils import read_dataset_summary, read_scenario -from scenarionet.verifier.utils import verify_dataset, set_random_drop +from scenarionet.verifier.utils import verify_database, set_random_drop def test_verify_completeness(): @@ -13,19 +13,19 @@ def test_verify_completeness(): dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)] output_path = os.path.join(TMP_PATH, "combine") - combine_dataset(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) + merge_database(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) dataset_path = output_path summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path) for scenario_file in sorted_scenarios: read_scenario(dataset_path, mapping, scenario_file) set_random_drop(True) - success, result = verify_dataset( + success, result = verify_database( dataset_path, result_save_dir=TMP_PATH, steps_to_run=0, num_workers=4, overwrite=True ) assert not success set_random_drop(False) - success, result = verify_dataset( + success, result = verify_database( dataset_path, result_save_dir=TMP_PATH, steps_to_run=0, num_workers=4, overwrite=True ) assert success diff --git a/scenarionet/verifier/utils.py b/scenarionet/verifier/utils.py index 1ff6d47..6a2c7d8 100644 --- a/scenarionet/verifier/utils.py +++ b/scenarionet/verifier/utils.py @@ -25,7 +25,7 @@ def set_random_drop(drop): RANDOM_DROP = drop -def verify_dataset(dataset_path, result_save_dir, overwrite=False, num_workers=8, steps_to_run=1000): +def verify_database(dataset_path, result_save_dir, overwrite=False, num_workers=8, steps_to_run=1000): global RANDOM_DROP assert os.path.isdir(result_save_dir), "result_save_dir must be a dir, get {}".format(result_save_dir) os.makedirs(result_save_dir, exist_ok=True) diff --git a/scenarionet/verify_completeness.py b/scenarionet/verify_completeness.py index eb6b19b..13606cf 100644 --- a/scenarionet/verify_completeness.py +++ b/scenarionet/verify_completeness.py @@ -1,11 +1,11 @@ import argparse -from scenarionet.verifier.utils import verify_dataset, set_random_drop +from scenarionet.verifier.utils import verify_database, set_random_drop if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--dataset_path", "-d", required=True, help="Dataset path, a directory containing summary.pkl and mapping.pkl" + "--database_path", "-d", required=True, help="Dataset path, a directory containing summary.pkl and mapping.pkl" ) parser.add_argument("--result_save_dir", default="./", help="Where to save the error file") parser.add_argument( @@ -18,6 +18,10 @@ if __name__ == '__main__': parser.add_argument("--random_drop", action="store_true", help="Randomly make some scenarios fail. for test only!") args = parser.parse_args() set_random_drop(args.random_drop) - verify_dataset( - args.dataset_path, args.result_save_dir, overwrite=args.overwrite, num_workers=args.num_workers, steps_to_run=0 + verify_database( + args.database_path, + args.result_save_dir, + overwrite=args.overwrite, + num_workers=args.num_workers, + steps_to_run=0 ) diff --git a/scenarionet/verify_simulation.py b/scenarionet/verify_simulation.py index 6c279fe..de1e294 100644 --- a/scenarionet/verify_simulation.py +++ b/scenarionet/verify_simulation.py @@ -1,11 +1,11 @@ import pkg_resources # for suppress warning import argparse -from scenarionet.verifier.utils import verify_dataset, set_random_drop +from scenarionet.verifier.utils import verify_database, set_random_drop if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--dataset_path", "-d", required=True, help="Dataset path, a directory containing summary.pkl and mapping.pkl" + "--database_path", "-d", required=True, help="Dataset path, a directory containing summary.pkl and mapping.pkl" ) parser.add_argument("--result_save_dir", default="./", help="Where to save the error file") parser.add_argument( @@ -18,4 +18,4 @@ if __name__ == '__main__': parser.add_argument("--random_drop", action="store_true", help="Randomly make some scenarios fail. for test only!") args = parser.parse_args() set_random_drop(args.random_drop) - verify_dataset(args.dataset_path, args.result_save_dir, overwrite=args.overwrite, num_workers=args.num_workers) + verify_database(args.database_path, args.result_save_dir, overwrite=args.overwrite, num_workers=args.num_workers)