filter function
This commit is contained in:
@@ -1,29 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def validate_sdc_track(sdc_state):
|
|
||||||
"""
|
|
||||||
This function filters the scenario based on SDC information.
|
|
||||||
|
|
||||||
Rule 1: Filter out if the trajectory length < 10
|
|
||||||
|
|
||||||
Rule 2: Filter out if the whole trajectory last < 5s, assuming sampling frequency = 10Hz.
|
|
||||||
"""
|
|
||||||
valid_array = sdc_state["valid"]
|
|
||||||
sdc_trajectory = sdc_state["position"][valid_array, :2]
|
|
||||||
sdc_track_length = [
|
|
||||||
np.linalg.norm(sdc_trajectory[i] - sdc_trajectory[i + 1]) for i in range(sdc_trajectory.shape[0] - 1)
|
|
||||||
]
|
|
||||||
sdc_track_length = sum(sdc_track_length)
|
|
||||||
|
|
||||||
# Rule 1
|
|
||||||
if sdc_track_length < 10:
|
|
||||||
return False
|
|
||||||
|
|
||||||
print("sdc_track_length: ", sdc_track_length)
|
|
||||||
|
|
||||||
# Rule 2
|
|
||||||
if valid_array.sum() < 50:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
60
scenarionet/builder/filters.py
Normal file
60
scenarionet/builder/filters.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||||
|
|
||||||
|
|
||||||
|
class ScenarioFilter:
|
||||||
|
GREATER = "greater"
|
||||||
|
SMALLER = "smaller"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sdc_moving_dist(metadata, target_dist, condition=GREATER):
|
||||||
|
"""
|
||||||
|
This function filters the scenario based on SDC information.
|
||||||
|
"""
|
||||||
|
assert condition in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type"
|
||||||
|
sdc_info = metadata[SD.SUMMARY.OBJECT_SUMMARY][metadata[SD.SDC_ID]]
|
||||||
|
moving_dist = sdc_info[SD.SUMMARY.MOVING_DIST]
|
||||||
|
if moving_dist > target_dist and condition == ScenarioFilter.GREATER:
|
||||||
|
return True
|
||||||
|
if moving_dist < target_dist and condition == ScenarioFilter.SMALLER:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def object_number(metadata, number_threshold, object_type=None, condition=SMALLER):
|
||||||
|
"""
|
||||||
|
Return True if the scenario satisfying the object number condition
|
||||||
|
:param metadata: metadata in each scenario
|
||||||
|
:param number_threshold: number of objects threshold
|
||||||
|
:param object_type: MetaDriveType.VEHICLE or other object type. If none, calculate number for all object types
|
||||||
|
:param condition: SMALLER or GREATER
|
||||||
|
:return: boolean
|
||||||
|
"""
|
||||||
|
assert condition in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type"
|
||||||
|
if object_type is not None:
|
||||||
|
num = metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_OBJECTS_EACH_TYPE].get(object_type, 0)
|
||||||
|
else:
|
||||||
|
num = metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_OBJECTS]
|
||||||
|
if num > number_threshold and condition == ScenarioFilter.GREATER:
|
||||||
|
return True
|
||||||
|
if num < number_threshold and condition == ScenarioFilter.SMALLER:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def has_traffic_light(metadata):
|
||||||
|
return metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_TRAFFIC_LIGHTS] > 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make(func, **kwargs):
|
||||||
|
"""
|
||||||
|
A wrapper for partial() for filling some parameters
|
||||||
|
:param func: func in this class
|
||||||
|
:param kwargs: kwargs for filter
|
||||||
|
:return: func taking only metadat as input
|
||||||
|
"""
|
||||||
|
assert "metadata" not in kwargs, "You should only fill conditions, metadata will be fill automatically"
|
||||||
|
if "condition" in kwargs:
|
||||||
|
assert kwargs["condition"] in [ScenarioFilter.GREATER, ScenarioFilter.SMALLER], "Wrong condition type"
|
||||||
|
return partial(func, **kwargs)
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
import pickle
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
with open("waymo120/0408_output_final/dataset_summary.pkl", "rb") as f:
|
|
||||||
summary_dict = pickle.load(f)
|
|
||||||
|
|
||||||
new_summary = {}
|
|
||||||
for obj_id, summary in summary_dict.items():
|
|
||||||
|
|
||||||
if summary["number_summary"]["dynamic_object_states"] == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if summary["object_summary"]["sdc"]["distance"] < 80 or \
|
|
||||||
summary["object_summary"]["sdc"]["continuous_valid_length"] < 50:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if len(summary["number_summary"]["object_types"]) < 3:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if summary["number_summary"]["object"] < 80:
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_summary[obj_id] = summary
|
|
||||||
|
|
||||||
if len(new_summary) >= 3:
|
|
||||||
break
|
|
||||||
|
|
||||||
file_path = AssetLoader.file_path("../converter/waymo", "dataset_summary.pkl", return_raw_style=False)
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
pickle.dump(new_summary, f)
|
|
||||||
|
|
||||||
print(new_summary.keys())
|
|
||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import pickle
|
import pickle
|
||||||
import shutil
|
import shutil
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
import metadrive.scenario.utils as sd_utils
|
import metadrive.scenario.utils as sd_utils
|
||||||
from metadrive.scenario.scenario_description import ScenarioDescription
|
from metadrive.scenario.scenario_description import ScenarioDescription
|
||||||
@@ -40,15 +41,20 @@ def try_generating_mapping(file_folder):
|
|||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False, try_generate_missing_file=True):
|
def combine_multiple_dataset(output_path, *dataset_paths,
|
||||||
|
force_overwrite=False,
|
||||||
|
try_generate_missing_file=True,
|
||||||
|
filters: List[Callable] = None):
|
||||||
"""
|
"""
|
||||||
Combine multiple datasets. Each dataset should have a dataset_summary.pkl
|
Combine multiple datasets. Each dataset should have a dataset_summary.pkl
|
||||||
:param output_path: The path to store the output dataset
|
:param output_path: The path to store the output dataset
|
||||||
:param force_overwrite: If True, overwrite the output_path even if it exists
|
:param force_overwrite: If True, overwrite the output_path even if it exists
|
||||||
:param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them
|
:param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them
|
||||||
:param dataset_paths: Path of each dataset
|
:param dataset_paths: Path of each dataset
|
||||||
|
:param filters: a set of filters to choose which scenario to be selected and added into this combined dataset
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
filters = filters or []
|
||||||
output_abs_path = osp.abspath(output_path)
|
output_abs_path = osp.abspath(output_path)
|
||||||
if os.path.exists(output_abs_path):
|
if os.path.exists(output_abs_path):
|
||||||
if not force_overwrite:
|
if not force_overwrite:
|
||||||
@@ -80,9 +86,9 @@ def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False,
|
|||||||
for v in list(intersect):
|
for v in list(intersect):
|
||||||
existing.append(mappings[v])
|
existing.append(mappings[v])
|
||||||
logging.warning("Repeat scenarios: {} in : {}. Existing: {}".format(intersect, abs_dir_path, existing))
|
logging.warning("Repeat scenarios: {} in : {}. Existing: {}".format(intersect, abs_dir_path, existing))
|
||||||
|
|
||||||
summaries.update(summary)
|
summaries.update(summary)
|
||||||
|
|
||||||
|
# mapping
|
||||||
if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE)):
|
if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE)):
|
||||||
if try_generate_missing_file:
|
if try_generate_missing_file:
|
||||||
mapping = {k: "" for k in summary}
|
mapping = {k: "" for k in summary}
|
||||||
@@ -94,8 +100,19 @@ def combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=False,
|
|||||||
new_mapping = {k: os.path.relpath(abs_dir_path, output_abs_path) for k, v in mapping.items()}
|
new_mapping = {k: os.path.relpath(abs_dir_path, output_abs_path) for k, v in mapping.items()}
|
||||||
mappings.update(new_mapping)
|
mappings.update(new_mapping)
|
||||||
|
|
||||||
|
# apply filter stage
|
||||||
|
file_to_pop = []
|
||||||
|
for file_name, metadata, in summaries.items():
|
||||||
|
if not all([fil(metadata) for fil in filters]):
|
||||||
|
file_to_pop.append(file_name)
|
||||||
|
for file in file_to_pop:
|
||||||
|
summaries.pop(file)
|
||||||
|
mappings.pop(file)
|
||||||
|
|
||||||
with open(osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE), "wb+") as f:
|
with open(osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE), "wb+") as f:
|
||||||
pickle.dump(summaries, f)
|
pickle.dump(summaries, f)
|
||||||
|
|
||||||
with open(osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE), "wb+") as f:
|
with open(osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE), "wb+") as f:
|
||||||
pickle.dump(mappings, f)
|
pickle.dump(mappings, f)
|
||||||
|
|
||||||
|
return summaries, mappings
|
||||||
|
|||||||
@@ -146,3 +146,5 @@ def write_to_directory(
|
|||||||
assert delay_remove == save_path
|
assert delay_remove == save_path
|
||||||
shutil.rmtree(delay_remove)
|
shutil.rmtree(delay_remove)
|
||||||
os.rename(output_path, save_path)
|
os.rename(output_path, save_path)
|
||||||
|
|
||||||
|
return summary, mapping
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from scenarionet.converter.nuscenes.utils import convert_nuscenes_scenario, get_
|
|||||||
from scenarionet.converter.utils import write_to_directory
|
from scenarionet.converter.utils import write_to_directory
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
raise ValueError("Avoid overwriting existing ata")
|
|
||||||
dataset_name = "nuscenes"
|
dataset_name = "nuscenes"
|
||||||
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
||||||
version = 'v1.0-mini'
|
version = 'v1.0-mini'
|
||||||
@@ -26,5 +25,4 @@ if __name__ == "__main__":
|
|||||||
dataset_version=version,
|
dataset_version=version,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
force_overwrite=force_overwrite,
|
force_overwrite=force_overwrite,
|
||||||
nuscenes=nusc
|
nuscenes=nusc)
|
||||||
)
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from scenarionet.verifier.utils import verify_loading_into_metadrive
|
|||||||
|
|
||||||
def test_combine_multiple_dataset():
|
def test_combine_multiple_dataset():
|
||||||
dataset_name = "nuscenes"
|
dataset_name = "nuscenes"
|
||||||
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "_test_dataset", dataset_name)
|
||||||
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
||||||
|
|
||||||
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine")
|
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine")
|
||||||
@@ -22,7 +22,7 @@ def test_combine_multiple_dataset():
|
|||||||
for scenario_file in sorted_scenarios:
|
for scenario_file in sorted_scenarios:
|
||||||
read_scenario(os.path.join(dataset_path, mapping[scenario_file], scenario_file))
|
read_scenario(os.path.join(dataset_path, mapping[scenario_file], scenario_file))
|
||||||
success, result = verify_loading_into_metadrive(dataset_path,
|
success, result = verify_loading_into_metadrive(dataset_path,
|
||||||
result_save_dir="./test_dataset",
|
result_save_dir="_test_dataset",
|
||||||
steps_to_run=300)
|
steps_to_run=300)
|
||||||
assert success
|
assert success
|
||||||
|
|
||||||
|
|||||||
35
scenarionet/tests/test_filter.py
Normal file
35
scenarionet/tests/test_filter.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
import os.path
|
||||||
|
from scenarionet.builder.filters import ScenarioFilter
|
||||||
|
from scenarionet import SCENARIONET_PACKAGE_PATH
|
||||||
|
from scenarionet.builder.utils import combine_multiple_dataset, read_dataset_summary, read_scenario
|
||||||
|
from scenarionet.verifier.utils import verify_loading_into_metadrive
|
||||||
|
from metadrive.type import MetaDriveType
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_dataset():
|
||||||
|
dataset_name = "nuscenes"
|
||||||
|
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "_test_dataset", dataset_name)
|
||||||
|
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
||||||
|
|
||||||
|
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine")
|
||||||
|
|
||||||
|
num_condition = ScenarioFilter.make(ScenarioFilter.object_number,
|
||||||
|
number_threshold=6,
|
||||||
|
object_type=MetaDriveType.PEDESTRIAN,
|
||||||
|
condition="greater")
|
||||||
|
# nuscenes data has no light
|
||||||
|
# light_condition = ScenarioFilter.make(ScenarioFilter.has_traffic_light)
|
||||||
|
sdc_driving_condition = ScenarioFilter.make(ScenarioFilter.sdc_moving_dist,
|
||||||
|
target_dist=2,
|
||||||
|
condition="smaller")
|
||||||
|
|
||||||
|
summary, mapping = combine_multiple_dataset(output_path,
|
||||||
|
*dataset_paths,
|
||||||
|
force_overwrite=True,
|
||||||
|
try_generate_missing_file=True,
|
||||||
|
filters=[num_condition, sdc_driving_condition])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_filter_dataset()
|
||||||
Reference in New Issue
Block a user