This commit is contained in:
QuanyiLi
2023-05-08 13:28:21 +01:00
parent 50d338da17
commit 720fc61d72

View File

@@ -0,0 +1,70 @@
import copy
import os
import os.path
from metadrive.scenario.scenario_description import ScenarioDescription as SD
from scenarionet import SCENARIONET_DATASET_PATH
from scenarionet import SCENARIONET_PACKAGE_PATH
from scenarionet.builder.utils import combine_multiple_dataset
from scenarionet.common_utils import read_dataset_summary, read_scenario
from scenarionet.common_utils import recursive_equal
from scenarionet.verifier.error import ErrorFile
from scenarionet.verifier.utils import set_random_drop
from scenarionet.verifier.utils import verify_loading_into_metadrive
def test_combine_multiple_dataset():
set_random_drop(True)
dataset_paths = [
os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"),
os.path.join(SCENARIONET_DATASET_PATH, "nuplan"),
os.path.join(SCENARIONET_DATASET_PATH, "waymo"),
os.path.join(SCENARIONET_DATASET_PATH, "pg")
]
dataset_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset")
combine_multiple_dataset(dataset_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path)
for scenario_file in sorted_scenarios:
read_scenario(dataset_path, mapping, scenario_file)
success, logs = verify_loading_into_metadrive(
dataset_path, result_save_dir="test_dataset", steps_to_run=1000, num_workers=8)
set_random_drop(False)
# get error file
file_name = ErrorFile.get_error_file_name(dataset_path)
error_file_path = os.path.join("test_dataset", file_name)
# regenerate
pass_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "passed_senarios")
fail_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "failed_scenarios")
pass_summary, pass_mapping = ErrorFile.generate_dataset(error_file_path, pass_dataset, force_overwrite=True,
broken_scenario=False)
fail_summary, fail_mapping = ErrorFile.generate_dataset(error_file_path, fail_dataset, force_overwrite=True,
broken_scenario=True)
# assert
read_pass_summary, _, read_pass_mapping = read_dataset_summary(pass_dataset)
assert recursive_equal(read_pass_summary, pass_summary)
assert recursive_equal(read_pass_mapping, pass_mapping)
read_fail_summary, _, read_fail_mapping, = read_dataset_summary(fail_dataset)
assert recursive_equal(read_fail_mapping, fail_mapping)
assert recursive_equal(read_fail_summary, fail_summary)
# assert pass+fail = origin
all_summaries = copy.deepcopy(read_pass_summary)
all_summaries.update(fail_summary)
assert recursive_equal(all_summaries, summary)
# test read
for scenario in read_pass_summary:
sd = read_scenario(pass_dataset, read_pass_mapping, scenario)
SD.sanity_check(sd)
for scenario in read_fail_summary:
sd = read_scenario(fail_dataset, read_fail_mapping, scenario)
SD.sanity_check(sd)
if __name__ == '__main__':
test_combine_multiple_dataset()