diff --git a/scenarionet/builder/conditions.py b/scenarionet/builder/conditions.py deleted file mode 100644 index 333cb95..0000000 --- a/scenarionet/builder/conditions.py +++ /dev/null @@ -1,29 +0,0 @@ -import numpy as np - - -def validate_sdc_track(sdc_state): - """ - This function filters the scenario based on SDC information. - - Rule 1: Filter out if the trajectory length < 10 - - Rule 2: Filter out if the whole trajectory last < 5s, assuming sampling frequency = 10Hz. - """ - valid_array = sdc_state["valid"] - sdc_trajectory = sdc_state["position"][valid_array, :2] - sdc_track_length = [ - np.linalg.norm(sdc_trajectory[i] - sdc_trajectory[i + 1]) for i in range(sdc_trajectory.shape[0] - 1) - ] - sdc_track_length = sum(sdc_track_length) - - # Rule 1 - if sdc_track_length < 10: - return False - - print("sdc_track_length: ", sdc_track_length) - - # Rule 2 - if valid_array.sum() < 50: - return False - - return True diff --git a/scenarionet/builder/filters.py b/scenarionet/builder/filters.py new file mode 100644 index 0000000..a292c78 --- /dev/null +++ b/scenarionet/builder/filters.py @@ -0,0 +1,60 @@ +from functools import partial + +from metadrive.scenario.scenario_description import ScenarioDescription as SD + + +class ScenarioFilter: + GREATER = "greater" + SMALLER = "smaller" + + @staticmethod + def sdc_moving_dist(metadata, target_dist, condition=GREATER): + """ + This function filters the scenario based on SDC information. + """ + assert condition in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type" + sdc_info = metadata[SD.SUMMARY.OBJECT_SUMMARY][metadata[SD.SDC_ID]] + moving_dist = sdc_info[SD.SUMMARY.MOVING_DIST] + if moving_dist > target_dist and condition == ScenarioFilter.GREATER: + return True + if moving_dist < target_dist and condition == ScenarioFilter.SMALLER: + return True + return False + + @staticmethod + def object_number(metadata, number_threshold, object_type=None, condition=SMALLER): + """ + Return True if the scenario satisfying the object number condition + :param metadata: metadata in each scenario + :param number_threshold: number of objects threshold + :param object_type: MetaDriveType.VEHICLE or other object type. If none, calculate number for all object types + :param condition: SMALLER or GREATER + :return: boolean + """ + assert condition in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type" + if object_type is not None: + num = metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_OBJECTS_EACH_TYPE].get(object_type, 0) + else: + num = metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_OBJECTS] + if num > number_threshold and condition == ScenarioFilter.GREATER: + return True + if num < number_threshold and condition == ScenarioFilter.SMALLER: + return True + return False + + @staticmethod + def has_traffic_light(metadata): + return metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_TRAFFIC_LIGHTS] > 0 + + @staticmethod + def make(func, **kwargs): + """ + A wrapper for partial() for filling some parameters + :param func: func in this class + :param kwargs: kwargs for filter + :return: func taking only metadat as input + """ + assert "metadata" not in kwargs, "You should only fill conditions, metadata will be fill automatically" + if "condition" in kwargs: + assert kwargs["condition"] in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type" + return partial(func, **kwargs) diff --git a/scenarionet/builder/select_demo_pickles.py b/scenarionet/builder/select_demo_pickles.py deleted file mode 100644 index 77b6323..0000000 --- a/scenarionet/builder/select_demo_pickles.py +++ /dev/null @@ -1,33 +0,0 @@ -import pickle - -if __name__ == '__main__': - - with open("waymo120/0408_output_final/dataset_summary.pkl", "rb") as f: - summary_dict = pickle.load(f) - - new_summary = {} - for obj_id, summary in summary_dict.items(): - - if summary["number_summary"]["dynamic_object_states"] == 0: - continue - - if summary["object_summary"]["sdc"]["distance"] < 80 or \ - summary["object_summary"]["sdc"]["continuous_valid_length"] < 50: - continue - - if len(summary["number_summary"]["object_types"]) < 3: - continue - - if summary["number_summary"]["object"] < 80: - continue - - new_summary[obj_id] = summary - - if len(new_summary) >= 3: - break - - file_path = AssetLoader.file_path("../converter/waymo", "dataset_summary.pkl", return_raw_style=False) - with open(file_path, "wb") as f: - pickle.dump(new_summary, f) - - print(new_summary.keys()) diff --git a/scenarionet/builder/utils.py b/scenarionet/builder/utils.py index d2e5702..9a38095 100644 --- a/scenarionet/builder/utils.py +++ b/scenarionet/builder/utils.py @@ -4,6 +4,7 @@ import os import os.path as osp import pickle import shutil +from typing import Callable, List import metadrive.scenario.utils as sd_utils from metadrive.scenario.scenario_description import ScenarioDescription @@ -40,15 +41,20 @@ def try_generating_mapping(file_folder): return mapping -def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False, try_generate_missing_file=True): +def combine_multiple_dataset(output_path, *dataset_paths, + force_overwrite=False, + try_generate_missing_file=True, + filters: List[Callable] = None): """ Combine multiple datasets. Each dataset should have a dataset_summary.pkl :param output_path: The path to store the output dataset :param force_overwrite: If True, overwrite the output_path even if it exists :param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them :param dataset_paths: Path of each dataset + :param filters: a set of filters to choose which scenario to be selected and added into this combined dataset :return: """ + filters = filters or [] output_abs_path = osp.abspath(output_path) if os.path.exists(output_abs_path): if not force_overwrite: @@ -80,9 +86,9 @@ def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False, for v in list(intersect): existing.append(mappings[v]) logging.warning("Repeat scenarios: {} in : {}. Existing: {}".format(intersect, abs_dir_path, existing)) - summaries.update(summary) + # mapping if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE)): if try_generate_missing_file: mapping = {k: "" for k in summary} @@ -94,8 +100,19 @@ def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False, new_mapping = {k: os.path.relpath(abs_dir_path, output_abs_path) for k, v in mapping.items()} mappings.update(new_mapping) + # apply filter stage + file_to_pop = [] + for file_name, metadata, in summaries.items(): + if not all([fil(metadata) for fil in filters]): + file_to_pop.append(file_name) + for file in file_to_pop: + summaries.pop(file) + mappings.pop(file) + with open(osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE), "wb+") as f: pickle.dump(summaries, f) with open(osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE), "wb+") as f: pickle.dump(mappings, f) + + return summaries, mappings diff --git a/scenarionet/converter/utils.py b/scenarionet/converter/utils.py index 7761a4b..b3a30b0 100644 --- a/scenarionet/converter/utils.py +++ b/scenarionet/converter/utils.py @@ -146,3 +146,5 @@ def write_to_directory( assert delay_remove == save_path shutil.rmtree(delay_remove) os.rename(output_path, save_path) + + return summary, mapping diff --git a/scenarionet/examples/run_combine_scenarios.py b/scenarionet/examples/combine_dataset_and_run.py similarity index 100% rename from scenarionet/examples/run_combine_scenarios.py rename to scenarionet/examples/combine_dataset_and_run.py diff --git a/scenarionet/tests/script/generate_test_dataset.py b/scenarionet/tests/script/generate_test_dataset.py index 5624e8d..71e50cb 100644 --- a/scenarionet/tests/script/generate_test_dataset.py +++ b/scenarionet/tests/script/generate_test_dataset.py @@ -9,7 +9,6 @@ from scenarionet.converter.nuscenes.utils import convert_nuscenes_scenario, get_ from scenarionet.converter.utils import write_to_directory if __name__ == "__main__": - raise ValueError("Avoid overwriting existing ata") dataset_name = "nuscenes" output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name) version = 'v1.0-mini' @@ -26,5 +25,4 @@ if __name__ == "__main__": dataset_version=version, dataset_name=dataset_name, force_overwrite=force_overwrite, - nuscenes=nusc - ) + nuscenes=nusc) diff --git a/scenarionet/tests/test_combine_dataset.py b/scenarionet/tests/test_combine_dataset.py index aef28c5..25d2157 100644 --- a/scenarionet/tests/test_combine_dataset.py +++ b/scenarionet/tests/test_combine_dataset.py @@ -8,7 +8,7 @@ from scenarionet.verifier.utils import verify_loading_into_metadrive def test_combine_multiple_dataset(): dataset_name = "nuscenes" - original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name) + original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "_test_dataset", dataset_name) dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)] output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine") @@ -22,7 +22,7 @@ def test_combine_multiple_dataset(): for scenario_file in sorted_scenarios: read_scenario(os.path.join(dataset_path, mapping[scenario_file], scenario_file)) success, result = verify_loading_into_metadrive(dataset_path, - result_save_dir="./test_dataset", + result_save_dir="_test_dataset", steps_to_run=300) assert success diff --git a/scenarionet/tests/test_filter.py b/scenarionet/tests/test_filter.py new file mode 100644 index 0000000..96ba438 --- /dev/null +++ b/scenarionet/tests/test_filter.py @@ -0,0 +1,35 @@ +import os +import os.path +from scenarionet.builder.filters import ScenarioFilter +from scenarionet import SCENARIONET_PACKAGE_PATH +from scenarionet.builder.utils import combine_multiple_dataset, read_dataset_summary, read_scenario +from scenarionet.verifier.utils import verify_loading_into_metadrive +from metadrive.type import MetaDriveType + + +def test_filter_dataset(): + dataset_name = "nuscenes" + original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "_test_dataset", dataset_name) + dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)] + + output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine") + + num_condition = ScenarioFilter.make(ScenarioFilter.object_number, + number_threshold=6, + object_type=MetaDriveType.PEDESTRIAN, + condition="greater") + # nuscenes data has no light + # light_condition = ScenarioFilter.make(ScenarioFilter.has_traffic_light) + sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, + target_dist=2, + condition="smaller") + + summary, mapping = combine_multiple_dataset(output_path, + *dataset_paths, + force_overwrite=True, + try_generate_missing_file=True, + filters=[num_condition, sdc_driving_condition]) + + +if __name__ == '__main__': + test_filter_dataset()