filter function

This commit is contained in:
QuanyiLi
2023-05-07 19:06:23 +01:00
parent 26013c6fd6
commit b7fe44d472
9 changed files with 119 additions and 69 deletions

View File

@@ -1,29 +0,0 @@
import numpy as np
def validate_sdc_track(sdc_state):
"""
This function filters the scenario based on SDC information.
Rule 1: Filter out if the trajectory length < 10
Rule 2: Filter out if the whole trajectory last < 5s, assuming sampling frequency = 10Hz.
"""
valid_array = sdc_state["valid"]
sdc_trajectory = sdc_state["position"][valid_array, :2]
sdc_track_length = [
np.linalg.norm(sdc_trajectory[i] - sdc_trajectory[i + 1]) for i in range(sdc_trajectory.shape[0] - 1)
]
sdc_track_length = sum(sdc_track_length)
# Rule 1
if sdc_track_length < 10:
return False
print("sdc_track_length: ", sdc_track_length)
# Rule 2
if valid_array.sum() < 50:
return False
return True

View File

@@ -0,0 +1,60 @@
from functools import partial
from metadrive.scenario.scenario_description import ScenarioDescription as SD
class ScenarioFilter:
GREATER = "greater"
SMALLER = "smaller"
@staticmethod
def sdc_moving_dist(metadata, target_dist, condition=GREATER):
"""
This function filters the scenario based on SDC information.
"""
assert condition in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type"
sdc_info = metadata[SD.SUMMARY.OBJECT_SUMMARY][metadata[SD.SDC_ID]]
moving_dist = sdc_info[SD.SUMMARY.MOVING_DIST]
if moving_dist > target_dist and condition == ScenarioFilter.GREATER:
return True
if moving_dist < target_dist and condition == ScenarioFilter.SMALLER:
return True
return False
@staticmethod
def object_number(metadata, number_threshold, object_type=None, condition=SMALLER):
"""
Return True if the scenario satisfying the object number condition
:param metadata: metadata in each scenario
:param number_threshold: number of objects threshold
:param object_type: MetaDriveType.VEHICLE or other object type. If none, calculate number for all object types
:param condition: SMALLER or GREATER
:return: boolean
"""
assert condition in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type"
if object_type is not None:
num = metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_OBJECTS_EACH_TYPE].get(object_type, 0)
else:
num = metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_OBJECTS]
if num > number_threshold and condition == ScenarioFilter.GREATER:
return True
if num < number_threshold and condition == ScenarioFilter.SMALLER:
return True
return False
@staticmethod
def has_traffic_light(metadata):
return metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_TRAFFIC_LIGHTS] > 0
@staticmethod
def make(func, **kwargs):
"""
A wrapper for partial() for filling some parameters
:param func: func in this class
:param kwargs: kwargs for filter
:return: func taking only metadat as input
"""
assert "metadata" not in kwargs, "You should only fill conditions, metadata will be fill automatically"
if "condition" in kwargs:
assert kwargs["condition"] in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type"
return partial(func, **kwargs)

View File

@@ -1,33 +0,0 @@
import pickle
if __name__ == '__main__':
with open("waymo120/0408_output_final/dataset_summary.pkl", "rb") as f:
summary_dict = pickle.load(f)
new_summary = {}
for obj_id, summary in summary_dict.items():
if summary["number_summary"]["dynamic_object_states"] == 0:
continue
if summary["object_summary"]["sdc"]["distance"] < 80 or \
summary["object_summary"]["sdc"]["continuous_valid_length"] < 50:
continue
if len(summary["number_summary"]["object_types"]) < 3:
continue
if summary["number_summary"]["object"] < 80:
continue
new_summary[obj_id] = summary
if len(new_summary) >= 3:
break
file_path = AssetLoader.file_path("../converter/waymo", "dataset_summary.pkl", return_raw_style=False)
with open(file_path, "wb") as f:
pickle.dump(new_summary, f)
print(new_summary.keys())

View File

@@ -4,6 +4,7 @@ import os
import os.path as osp
import pickle
import shutil
from typing import Callable, List
import metadrive.scenario.utils as sd_utils
from metadrive.scenario.scenario_description import ScenarioDescription
@@ -40,15 +41,20 @@ def try_generating_mapping(file_folder):
return mapping
def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False, try_generate_missing_file=True):
def combine_multiple_dataset(output_path, *dataset_paths,
force_overwrite=False,
try_generate_missing_file=True,
filters: List[Callable] = None):
"""
Combine multiple datasets. Each dataset should have a dataset_summary.pkl
:param output_path: The path to store the output dataset
:param force_overwrite: If True, overwrite the output_path even if it exists
:param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them
:param dataset_paths: Path of each dataset
:param filters: a set of filters to choose which scenario to be selected and added into this combined dataset
:return:
"""
filters = filters or []
output_abs_path = osp.abspath(output_path)
if os.path.exists(output_abs_path):
if not force_overwrite:
@@ -80,9 +86,9 @@ def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False,
for v in list(intersect):
existing.append(mappings[v])
logging.warning("Repeat scenarios: {} in : {}. Existing: {}".format(intersect, abs_dir_path, existing))
summaries.update(summary)
# mapping
if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE)):
if try_generate_missing_file:
mapping = {k: "" for k in summary}
@@ -94,8 +100,19 @@ def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False,
new_mapping = {k: os.path.relpath(abs_dir_path, output_abs_path) for k, v in mapping.items()}
mappings.update(new_mapping)
# apply filter stage
file_to_pop = []
for file_name, metadata, in summaries.items():
if not all([fil(metadata) for fil in filters]):
file_to_pop.append(file_name)
for file in file_to_pop:
summaries.pop(file)
mappings.pop(file)
with open(osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE), "wb+") as f:
pickle.dump(summaries, f)
with open(osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE), "wb+") as f:
pickle.dump(mappings, f)
return summaries, mappings

View File

@@ -146,3 +146,5 @@ def write_to_directory(
assert delay_remove == save_path
shutil.rmtree(delay_remove)
os.rename(output_path, save_path)
return summary, mapping

View File

@@ -9,7 +9,6 @@ from scenarionet.converter.nuscenes.utils import convert_nuscenes_scenario, get_
from scenarionet.converter.utils import write_to_directory
if __name__ == "__main__":
raise ValueError("Avoid overwriting existing ata")
dataset_name = "nuscenes"
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
version = 'v1.0-mini'
@@ -26,5 +25,4 @@ if __name__ == "__main__":
dataset_version=version,
dataset_name=dataset_name,
force_overwrite=force_overwrite,
nuscenes=nusc
)
nuscenes=nusc)

View File

@@ -8,7 +8,7 @@ from scenarionet.verifier.utils import verify_loading_into_metadrive
def test_combine_multiple_dataset():
dataset_name = "nuscenes"
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "_test_dataset", dataset_name)
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine")
@@ -22,7 +22,7 @@ def test_combine_multiple_dataset():
for scenario_file in sorted_scenarios:
read_scenario(os.path.join(dataset_path, mapping[scenario_file], scenario_file))
success, result = verify_loading_into_metadrive(dataset_path,
result_save_dir="./test_dataset",
result_save_dir="_test_dataset",
steps_to_run=300)
assert success

View File

@@ -0,0 +1,35 @@
import os
import os.path
from scenarionet.builder.filters import ScenarioFilter
from scenarionet import SCENARIONET_PACKAGE_PATH
from scenarionet.builder.utils import combine_multiple_dataset, read_dataset_summary, read_scenario
from scenarionet.verifier.utils import verify_loading_into_metadrive
from metadrive.type import MetaDriveType
def test_filter_dataset():
dataset_name = "nuscenes"
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "_test_dataset", dataset_name)
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine")
num_condition = ScenarioFilter.make(ScenarioFilter.object_number,
number_threshold=6,
object_type=MetaDriveType.PEDESTRIAN,
condition="greater")
# nuscenes data has no light
# light_condition = ScenarioFilter.make(ScenarioFilter.has_traffic_light)
sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist,
target_dist=2,
condition="smaller")
summary, mapping = combine_multiple_dataset(output_path,
*dataset_paths,
force_overwrite=True,
try_generate_missing_file=True,
filters=[num_condition, sdc_driving_condition])
if __name__ == '__main__':
test_filter_dataset()