filter function
This commit is contained in:
@@ -1,29 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def validate_sdc_track(sdc_state):
|
||||
"""
|
||||
This function filters the scenario based on SDC information.
|
||||
|
||||
Rule 1: Filter out if the trajectory length < 10
|
||||
|
||||
Rule 2: Filter out if the whole trajectory last < 5s, assuming sampling frequency = 10Hz.
|
||||
"""
|
||||
valid_array = sdc_state["valid"]
|
||||
sdc_trajectory = sdc_state["position"][valid_array, :2]
|
||||
sdc_track_length = [
|
||||
np.linalg.norm(sdc_trajectory[i] - sdc_trajectory[i + 1]) for i in range(sdc_trajectory.shape[0] - 1)
|
||||
]
|
||||
sdc_track_length = sum(sdc_track_length)
|
||||
|
||||
# Rule 1
|
||||
if sdc_track_length < 10:
|
||||
return False
|
||||
|
||||
print("sdc_track_length: ", sdc_track_length)
|
||||
|
||||
# Rule 2
|
||||
if valid_array.sum() < 50:
|
||||
return False
|
||||
|
||||
return True
|
||||
60
scenarionet/builder/filters.py
Normal file
60
scenarionet/builder/filters.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from functools import partial
|
||||
|
||||
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||
|
||||
|
||||
class ScenarioFilter:
|
||||
GREATER = "greater"
|
||||
SMALLER = "smaller"
|
||||
|
||||
@staticmethod
|
||||
def sdc_moving_dist(metadata, target_dist, condition=GREATER):
|
||||
"""
|
||||
This function filters the scenario based on SDC information.
|
||||
"""
|
||||
assert condition in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type"
|
||||
sdc_info = metadata[SD.SUMMARY.OBJECT_SUMMARY][metadata[SD.SDC_ID]]
|
||||
moving_dist = sdc_info[SD.SUMMARY.MOVING_DIST]
|
||||
if moving_dist > target_dist and condition == ScenarioFilter.GREATER:
|
||||
return True
|
||||
if moving_dist < target_dist and condition == ScenarioFilter.SMALLER:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def object_number(metadata, number_threshold, object_type=None, condition=SMALLER):
|
||||
"""
|
||||
Return True if the scenario satisfying the object number condition
|
||||
:param metadata: metadata in each scenario
|
||||
:param number_threshold: number of objects threshold
|
||||
:param object_type: MetaDriveType.VEHICLE or other object type. If none, calculate number for all object types
|
||||
:param condition: SMALLER or GREATER
|
||||
:return: boolean
|
||||
"""
|
||||
assert condition in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type"
|
||||
if object_type is not None:
|
||||
num = metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_OBJECTS_EACH_TYPE].get(object_type, 0)
|
||||
else:
|
||||
num = metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_OBJECTS]
|
||||
if num > number_threshold and condition == ScenarioFilter.GREATER:
|
||||
return True
|
||||
if num < number_threshold and condition == ScenarioFilter.SMALLER:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def has_traffic_light(metadata):
|
||||
return metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_TRAFFIC_LIGHTS] > 0
|
||||
|
||||
@staticmethod
|
||||
def make(func, **kwargs):
|
||||
"""
|
||||
A wrapper for partial() for filling some parameters
|
||||
:param func: func in this class
|
||||
:param kwargs: kwargs for filter
|
||||
:return: func taking only metadat as input
|
||||
"""
|
||||
assert "metadata" not in kwargs, "You should only fill conditions, metadata will be fill automatically"
|
||||
if "condition" in kwargs:
|
||||
assert kwargs["condition"] in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type"
|
||||
return partial(func, **kwargs)
|
||||
@@ -1,33 +0,0 @@
|
||||
import pickle
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
with open("waymo120/0408_output_final/dataset_summary.pkl", "rb") as f:
|
||||
summary_dict = pickle.load(f)
|
||||
|
||||
new_summary = {}
|
||||
for obj_id, summary in summary_dict.items():
|
||||
|
||||
if summary["number_summary"]["dynamic_object_states"] == 0:
|
||||
continue
|
||||
|
||||
if summary["object_summary"]["sdc"]["distance"] < 80 or \
|
||||
summary["object_summary"]["sdc"]["continuous_valid_length"] < 50:
|
||||
continue
|
||||
|
||||
if len(summary["number_summary"]["object_types"]) < 3:
|
||||
continue
|
||||
|
||||
if summary["number_summary"]["object"] < 80:
|
||||
continue
|
||||
|
||||
new_summary[obj_id] = summary
|
||||
|
||||
if len(new_summary) >= 3:
|
||||
break
|
||||
|
||||
file_path = AssetLoader.file_path("../converter/waymo", "dataset_summary.pkl", return_raw_style=False)
|
||||
with open(file_path, "wb") as f:
|
||||
pickle.dump(new_summary, f)
|
||||
|
||||
print(new_summary.keys())
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import shutil
|
||||
from typing import Callable, List
|
||||
|
||||
import metadrive.scenario.utils as sd_utils
|
||||
from metadrive.scenario.scenario_description import ScenarioDescription
|
||||
@@ -40,15 +41,20 @@ def try_generating_mapping(file_folder):
|
||||
return mapping
|
||||
|
||||
|
||||
def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False, try_generate_missing_file=True):
|
||||
def combine_multiple_dataset(output_path, *dataset_paths,
|
||||
force_overwrite=False,
|
||||
try_generate_missing_file=True,
|
||||
filters: List[Callable] = None):
|
||||
"""
|
||||
Combine multiple datasets. Each dataset should have a dataset_summary.pkl
|
||||
:param output_path: The path to store the output dataset
|
||||
:param force_overwrite: If True, overwrite the output_path even if it exists
|
||||
:param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them
|
||||
:param dataset_paths: Path of each dataset
|
||||
:param filters: a set of filters to choose which scenario to be selected and added into this combined dataset
|
||||
:return:
|
||||
"""
|
||||
filters = filters or []
|
||||
output_abs_path = osp.abspath(output_path)
|
||||
if os.path.exists(output_abs_path):
|
||||
if not force_overwrite:
|
||||
@@ -80,9 +86,9 @@ def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False,
|
||||
for v in list(intersect):
|
||||
existing.append(mappings[v])
|
||||
logging.warning("Repeat scenarios: {} in : {}. Existing: {}".format(intersect, abs_dir_path, existing))
|
||||
|
||||
summaries.update(summary)
|
||||
|
||||
# mapping
|
||||
if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE)):
|
||||
if try_generate_missing_file:
|
||||
mapping = {k: "" for k in summary}
|
||||
@@ -94,8 +100,19 @@ def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False,
|
||||
new_mapping = {k: os.path.relpath(abs_dir_path, output_abs_path) for k, v in mapping.items()}
|
||||
mappings.update(new_mapping)
|
||||
|
||||
# apply filter stage
|
||||
file_to_pop = []
|
||||
for file_name, metadata, in summaries.items():
|
||||
if not all([fil(metadata) for fil in filters]):
|
||||
file_to_pop.append(file_name)
|
||||
for file in file_to_pop:
|
||||
summaries.pop(file)
|
||||
mappings.pop(file)
|
||||
|
||||
with open(osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE), "wb+") as f:
|
||||
pickle.dump(summaries, f)
|
||||
|
||||
with open(osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE), "wb+") as f:
|
||||
pickle.dump(mappings, f)
|
||||
|
||||
return summaries, mappings
|
||||
|
||||
@@ -146,3 +146,5 @@ def write_to_directory(
|
||||
assert delay_remove == save_path
|
||||
shutil.rmtree(delay_remove)
|
||||
os.rename(output_path, save_path)
|
||||
|
||||
return summary, mapping
|
||||
|
||||
@@ -9,7 +9,6 @@ from scenarionet.converter.nuscenes.utils import convert_nuscenes_scenario, get_
|
||||
from scenarionet.converter.utils import write_to_directory
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise ValueError("Avoid overwriting existing ata")
|
||||
dataset_name = "nuscenes"
|
||||
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
||||
version = 'v1.0-mini'
|
||||
@@ -26,5 +25,4 @@ if __name__ == "__main__":
|
||||
dataset_version=version,
|
||||
dataset_name=dataset_name,
|
||||
force_overwrite=force_overwrite,
|
||||
nuscenes=nusc
|
||||
)
|
||||
nuscenes=nusc)
|
||||
|
||||
@@ -8,7 +8,7 @@ from scenarionet.verifier.utils import verify_loading_into_metadrive
|
||||
|
||||
def test_combine_multiple_dataset():
|
||||
dataset_name = "nuscenes"
|
||||
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
||||
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "_test_dataset", dataset_name)
|
||||
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
||||
|
||||
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine")
|
||||
@@ -22,7 +22,7 @@ def test_combine_multiple_dataset():
|
||||
for scenario_file in sorted_scenarios:
|
||||
read_scenario(os.path.join(dataset_path, mapping[scenario_file], scenario_file))
|
||||
success, result = verify_loading_into_metadrive(dataset_path,
|
||||
result_save_dir="./test_dataset",
|
||||
result_save_dir="_test_dataset",
|
||||
steps_to_run=300)
|
||||
assert success
|
||||
|
||||
|
||||
35
scenarionet/tests/test_filter.py
Normal file
35
scenarionet/tests/test_filter.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import os
|
||||
import os.path
|
||||
from scenarionet.builder.filters import ScenarioFilter
|
||||
from scenarionet import SCENARIONET_PACKAGE_PATH
|
||||
from scenarionet.builder.utils import combine_multiple_dataset, read_dataset_summary, read_scenario
|
||||
from scenarionet.verifier.utils import verify_loading_into_metadrive
|
||||
from metadrive.type import MetaDriveType
|
||||
|
||||
|
||||
def test_filter_dataset():
|
||||
dataset_name = "nuscenes"
|
||||
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "_test_dataset", dataset_name)
|
||||
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
||||
|
||||
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine")
|
||||
|
||||
num_condition = ScenarioFilter.make(ScenarioFilter.object_number,
|
||||
number_threshold=6,
|
||||
object_type=MetaDriveType.PEDESTRIAN,
|
||||
condition="greater")
|
||||
# nuscenes data has no light
|
||||
# light_condition = ScenarioFilter.make(ScenarioFilter.has_traffic_light)
|
||||
sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist,
|
||||
target_dist=2,
|
||||
condition="smaller")
|
||||
|
||||
summary, mapping = combine_multiple_dataset(output_path,
|
||||
*dataset_paths,
|
||||
force_overwrite=True,
|
||||
try_generate_missing_file=True,
|
||||
filters=[num_condition, sdc_driving_condition])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_filter_dataset()
|
||||
Reference in New Issue
Block a user