add common utils
This commit is contained in:
@@ -6,21 +6,13 @@ import pickle
|
||||
import shutil
|
||||
from typing import Callable, List
|
||||
|
||||
import metadrive.scenario.utils as sd_utils
|
||||
import numpy as np
|
||||
from metadrive.scenario.scenario_description import ScenarioDescription
|
||||
|
||||
from scenarionet.common_utils import save_summary_anda_mapping
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def read_dataset_summary(dataset_path):
|
||||
return sd_utils.read_dataset_summary(dataset_path)
|
||||
|
||||
|
||||
def read_scenario(pkl_file_path):
|
||||
return sd_utils.read_scenario_data(pkl_file_path)
|
||||
|
||||
|
||||
def try_generating_summary(file_folder):
|
||||
# Create a fake one
|
||||
files = os.listdir(file_folder)
|
||||
@@ -108,20 +100,3 @@ def combine_multiple_dataset(
|
||||
return summaries, mappings
|
||||
|
||||
|
||||
def dict_recursive_remove_array_and_set(d):
|
||||
if isinstance(d, np.ndarray):
|
||||
return d.tolist()
|
||||
if isinstance(d, set):
|
||||
return tuple(d)
|
||||
if isinstance(d, dict):
|
||||
for k in d.keys():
|
||||
d[k] = dict_recursive_remove_array_and_set(d[k])
|
||||
return d
|
||||
|
||||
|
||||
def save_summary_anda_mapping(summary_file_path, mapping_file_path, summary, mapping):
|
||||
with open(summary_file_path, "wb") as file:
|
||||
pickle.dump(dict_recursive_remove_array_and_set(summary), file)
|
||||
with open(mapping_file_path, "wb") as file:
|
||||
pickle.dump(mapping, file)
|
||||
print("Dataset Summary and Mapping are saved at: {}".format(summary_file_path))
|
||||
|
||||
81
scenarionet/common_utils.py
Normal file
81
scenarionet/common_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from metadrive.scenario import utils as sd_utils
|
||||
|
||||
|
||||
def recursive_equal(data1, data2, need_assert=False):
|
||||
from metadrive.utils.config import Config
|
||||
if isinstance(data1, Config):
|
||||
data1 = data1.get_dict()
|
||||
if isinstance(data2, Config):
|
||||
data2 = data2.get_dict()
|
||||
|
||||
if isinstance(data1, np.ndarray):
|
||||
tmp = np.asarray(data2)
|
||||
return np.all(data1 == tmp)
|
||||
|
||||
if isinstance(data2, np.ndarray):
|
||||
tmp = np.asarray(data1)
|
||||
return np.all(tmp == data2)
|
||||
|
||||
if isinstance(data1, dict):
|
||||
is_ins = isinstance(data2, dict)
|
||||
key_right = set(data1.keys()) == set(data2.keys())
|
||||
if need_assert:
|
||||
assert is_ins and key_right, (data1.keys(), data2.keys())
|
||||
if not (is_ins and key_right):
|
||||
return False
|
||||
ret = []
|
||||
for k in data1:
|
||||
ret.append(recursive_equal(data1[k], data2[k], need_assert=need_assert))
|
||||
return all(ret)
|
||||
|
||||
elif isinstance(data1, (list, tuple)):
|
||||
len_right = len(data1) == len(data2)
|
||||
is_ins = isinstance(data2, (list, tuple))
|
||||
if need_assert:
|
||||
assert len_right and is_ins, (len(data1), len(data2), data1, data2)
|
||||
if not (is_ins and len_right):
|
||||
return False
|
||||
ret = []
|
||||
for i in range(len(data1)):
|
||||
ret.append(recursive_equal(data1[i], data2[i], need_assert=need_assert))
|
||||
return all(ret)
|
||||
elif isinstance(data1, np.ndarray):
|
||||
ret = np.isclose(data1, data2).all()
|
||||
if need_assert:
|
||||
assert ret, (type(data1), type(data2), data1, data2)
|
||||
return ret
|
||||
else:
|
||||
ret = data1 == data2
|
||||
if need_assert:
|
||||
assert ret, (type(data1), type(data2), data1, data2)
|
||||
return ret
|
||||
|
||||
|
||||
def dict_recursive_remove_array_and_set(d):
|
||||
if isinstance(d, np.ndarray):
|
||||
return d.tolist()
|
||||
if isinstance(d, set):
|
||||
return tuple(d)
|
||||
if isinstance(d, dict):
|
||||
for k in d.keys():
|
||||
d[k] = dict_recursive_remove_array_and_set(d[k])
|
||||
return d
|
||||
|
||||
|
||||
def save_summary_anda_mapping(summary_file_path, mapping_file_path, summary, mapping):
|
||||
with open(summary_file_path, "wb") as file:
|
||||
pickle.dump(dict_recursive_remove_array_and_set(summary), file)
|
||||
with open(mapping_file_path, "wb") as file:
|
||||
pickle.dump(mapping, file)
|
||||
print("Dataset Summary and Mapping are saved at: {}".format(summary_file_path))
|
||||
|
||||
|
||||
def read_dataset_summary(dataset_path):
|
||||
return sd_utils.read_dataset_summary(dataset_path)
|
||||
|
||||
|
||||
def read_scenario(pkl_file_path):
|
||||
return sd_utils.read_scenario_data(pkl_file_path)
|
||||
@@ -6,7 +6,7 @@ import math
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
from scenarionet.builder.utils import save_summary_anda_mapping
|
||||
from scenarionet.common_utils import save_summary_anda_mapping
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from metadrive.scenario import ScenarioDescription as SD
|
||||
|
||||
@@ -2,7 +2,8 @@ import os
|
||||
import os.path
|
||||
|
||||
from scenarionet import SCENARIONET_PACKAGE_PATH
|
||||
from scenarionet.builder.utils import combine_multiple_dataset, read_dataset_summary, read_scenario
|
||||
from scenarionet.builder.utils import combine_multiple_dataset
|
||||
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
||||
from scenarionet.verifier.utils import verify_loading_into_metadrive
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import os.path
|
||||
|
||||
from metadrive.scenario.utils import assert_scenario_equal
|
||||
from scenarionet import SCENARIONET_PACKAGE_PATH
|
||||
from scenarionet.builder.utils import combine_multiple_dataset, read_dataset_summary, read_scenario
|
||||
from scenarionet.builder.utils import combine_multiple_dataset
|
||||
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
||||
from scenarionet.verifier.error import ErrorFile
|
||||
from scenarionet.verifier.utils import verify_loading_into_metadrive, set_random_drop
|
||||
|
||||
|
||||
@@ -11,18 +13,27 @@ def test_combine_multiple_dataset():
|
||||
dataset_name = "nuscenes"
|
||||
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
||||
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
||||
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine")
|
||||
combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
|
||||
dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine")
|
||||
combine_multiple_dataset(dataset_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
|
||||
|
||||
dataset_paths.append(output_path)
|
||||
for dataset_path in dataset_paths:
|
||||
summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path)
|
||||
for scenario_file in sorted_scenarios:
|
||||
read_scenario(os.path.join(dataset_path, mapping[scenario_file], scenario_file))
|
||||
success, result = verify_loading_into_metadrive(
|
||||
success, logs = verify_loading_into_metadrive(
|
||||
dataset_path, result_save_dir="test_dataset", steps_to_run=1000, num_workers=4)
|
||||
assert success
|
||||
set_random_drop(False)
|
||||
# regenerate
|
||||
file_name = ErrorFile.get_error_file_name(dataset_path)
|
||||
error_file_path = os.path.join("test_dataset", file_name)
|
||||
|
||||
pass_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "passed_senarios")
|
||||
fail_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "failed_scenarios")
|
||||
pass_summary, pass_mapping = ErrorFile.generate_dataset(error_file_path, pass_dataset, broken_scenario=False)
|
||||
fail_summary, fail_mapping = ErrorFile.generate_dataset(error_file_path, fail_dataset, broken_scenario=True)
|
||||
|
||||
read_pass_summary, _, read_pass_mapping = read_dataset_summary(pass_dataset)
|
||||
read_fail_summary, _, read_fail_mapping, = read_dataset_summary(fail_dataset)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -5,8 +5,7 @@ from typing import List
|
||||
|
||||
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||
|
||||
from scenarionet.builder.utils import read_dataset_summary
|
||||
from scenarionet.builder.utils import save_summary_anda_mapping
|
||||
from scenarionet.common_utils import save_summary_anda_mapping, read_dataset_summary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,6 +34,10 @@ class ErrorFile:
|
||||
DATASET = "dataset_path"
|
||||
ERRORS = "errors"
|
||||
|
||||
@classmethod
|
||||
def get_error_file_name(cls, dataset_path):
|
||||
return "{}_{}.json".format(cls.PREFIX, os.path.basename(dataset_path))
|
||||
|
||||
@classmethod
|
||||
def dump(cls, save_dir, errors: List, dataset_path):
|
||||
"""
|
||||
@@ -43,9 +46,11 @@ class ErrorFile:
|
||||
:param errors: error list, containing a list of dict from ErrorDescription.make()
|
||||
:param dataset_path: dataset_path, the dir of dataset_summary.pkl
|
||||
"""
|
||||
file_name = "{}_{}.json".format(cls.PREFIX, os.path.basename(dataset_path))
|
||||
with open(os.path.join(save_dir, file_name), "w+") as f:
|
||||
file_name = cls.get_error_file_name(dataset_path)
|
||||
path = os.path.join(save_dir, file_name)
|
||||
with open(path, "w+") as f:
|
||||
json.dump({cls.DATASET: dataset_path, cls.ERRORS: errors}, f, indent=4)
|
||||
return path
|
||||
|
||||
@classmethod
|
||||
def generate_dataset(cls, error_file_path, new_dataset_path, force_overwrite=False, broken_scenario=False):
|
||||
|
||||
@@ -46,21 +46,20 @@ def verify_loading_into_metadrive(dataset_path, result_save_dir, steps_to_run=10
|
||||
# Run, workers and process result from worker
|
||||
with multiprocessing.Pool(num_workers) as p:
|
||||
all_result = list(p.imap(func, argument_list))
|
||||
result = all([i[0] for i in all_result])
|
||||
success = all([i[0] for i in all_result])
|
||||
errors = []
|
||||
for _, error in all_result:
|
||||
errors += error
|
||||
|
||||
# save result
|
||||
EF.dump(result_save_dir, errors, dataset_path)
|
||||
|
||||
# logging
|
||||
if result:
|
||||
if success:
|
||||
logger.info("All scenarios can be loaded successfully!")
|
||||
else:
|
||||
# save result
|
||||
path = EF.dump(result_save_dir, errors, dataset_path)
|
||||
logger.info(
|
||||
"Fail to load all scenarios, see log for more details! Number of failed scenarios: {}".format(len(errors)))
|
||||
return result, errors
|
||||
"Fail to load all scenarios. Number of failed scenarios: {}. "
|
||||
"See: {} more details! ".format(len(errors), path))
|
||||
return success, errors
|
||||
|
||||
|
||||
def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, steps_to_run, metadrive_config=None):
|
||||
@@ -86,7 +85,7 @@ def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, ste
|
||||
try:
|
||||
env.reset(force_seed=scenario_index)
|
||||
arrive = False
|
||||
if RANDOM_DROP and np.random.rand() < 0.5:
|
||||
if RANDOM_DROP and np.random.rand() < 0.8:
|
||||
raise ValueError("Random Drop")
|
||||
for _ in range(steps_to_run):
|
||||
o, r, d, info = env.step([0, 0])
|
||||
|
||||
Reference in New Issue
Block a user