add common utils
This commit is contained in:
@@ -6,21 +6,13 @@ import pickle
|
|||||||
import shutil
|
import shutil
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
import metadrive.scenario.utils as sd_utils
|
|
||||||
import numpy as np
|
|
||||||
from metadrive.scenario.scenario_description import ScenarioDescription
|
from metadrive.scenario.scenario_description import ScenarioDescription
|
||||||
|
|
||||||
|
from scenarionet.common_utils import save_summary_anda_mapping
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def read_dataset_summary(dataset_path):
|
|
||||||
return sd_utils.read_dataset_summary(dataset_path)
|
|
||||||
|
|
||||||
|
|
||||||
def read_scenario(pkl_file_path):
|
|
||||||
return sd_utils.read_scenario_data(pkl_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
def try_generating_summary(file_folder):
|
def try_generating_summary(file_folder):
|
||||||
# Create a fake one
|
# Create a fake one
|
||||||
files = os.listdir(file_folder)
|
files = os.listdir(file_folder)
|
||||||
@@ -108,20 +100,3 @@ def combine_multiple_dataset(
|
|||||||
return summaries, mappings
|
return summaries, mappings
|
||||||
|
|
||||||
|
|
||||||
def dict_recursive_remove_array_and_set(d):
|
|
||||||
if isinstance(d, np.ndarray):
|
|
||||||
return d.tolist()
|
|
||||||
if isinstance(d, set):
|
|
||||||
return tuple(d)
|
|
||||||
if isinstance(d, dict):
|
|
||||||
for k in d.keys():
|
|
||||||
d[k] = dict_recursive_remove_array_and_set(d[k])
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def save_summary_anda_mapping(summary_file_path, mapping_file_path, summary, mapping):
|
|
||||||
with open(summary_file_path, "wb") as file:
|
|
||||||
pickle.dump(dict_recursive_remove_array_and_set(summary), file)
|
|
||||||
with open(mapping_file_path, "wb") as file:
|
|
||||||
pickle.dump(mapping, file)
|
|
||||||
print("Dataset Summary and Mapping are saved at: {}".format(summary_file_path))
|
|
||||||
|
|||||||
81
scenarionet/common_utils.py
Normal file
81
scenarionet/common_utils.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from metadrive.scenario import utils as sd_utils
|
||||||
|
|
||||||
|
|
||||||
|
def recursive_equal(data1, data2, need_assert=False):
|
||||||
|
from metadrive.utils.config import Config
|
||||||
|
if isinstance(data1, Config):
|
||||||
|
data1 = data1.get_dict()
|
||||||
|
if isinstance(data2, Config):
|
||||||
|
data2 = data2.get_dict()
|
||||||
|
|
||||||
|
if isinstance(data1, np.ndarray):
|
||||||
|
tmp = np.asarray(data2)
|
||||||
|
return np.all(data1 == tmp)
|
||||||
|
|
||||||
|
if isinstance(data2, np.ndarray):
|
||||||
|
tmp = np.asarray(data1)
|
||||||
|
return np.all(tmp == data2)
|
||||||
|
|
||||||
|
if isinstance(data1, dict):
|
||||||
|
is_ins = isinstance(data2, dict)
|
||||||
|
key_right = set(data1.keys()) == set(data2.keys())
|
||||||
|
if need_assert:
|
||||||
|
assert is_ins and key_right, (data1.keys(), data2.keys())
|
||||||
|
if not (is_ins and key_right):
|
||||||
|
return False
|
||||||
|
ret = []
|
||||||
|
for k in data1:
|
||||||
|
ret.append(recursive_equal(data1[k], data2[k], need_assert=need_assert))
|
||||||
|
return all(ret)
|
||||||
|
|
||||||
|
elif isinstance(data1, (list, tuple)):
|
||||||
|
len_right = len(data1) == len(data2)
|
||||||
|
is_ins = isinstance(data2, (list, tuple))
|
||||||
|
if need_assert:
|
||||||
|
assert len_right and is_ins, (len(data1), len(data2), data1, data2)
|
||||||
|
if not (is_ins and len_right):
|
||||||
|
return False
|
||||||
|
ret = []
|
||||||
|
for i in range(len(data1)):
|
||||||
|
ret.append(recursive_equal(data1[i], data2[i], need_assert=need_assert))
|
||||||
|
return all(ret)
|
||||||
|
elif isinstance(data1, np.ndarray):
|
||||||
|
ret = np.isclose(data1, data2).all()
|
||||||
|
if need_assert:
|
||||||
|
assert ret, (type(data1), type(data2), data1, data2)
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
ret = data1 == data2
|
||||||
|
if need_assert:
|
||||||
|
assert ret, (type(data1), type(data2), data1, data2)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def dict_recursive_remove_array_and_set(d):
|
||||||
|
if isinstance(d, np.ndarray):
|
||||||
|
return d.tolist()
|
||||||
|
if isinstance(d, set):
|
||||||
|
return tuple(d)
|
||||||
|
if isinstance(d, dict):
|
||||||
|
for k in d.keys():
|
||||||
|
d[k] = dict_recursive_remove_array_and_set(d[k])
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def save_summary_anda_mapping(summary_file_path, mapping_file_path, summary, mapping):
|
||||||
|
with open(summary_file_path, "wb") as file:
|
||||||
|
pickle.dump(dict_recursive_remove_array_and_set(summary), file)
|
||||||
|
with open(mapping_file_path, "wb") as file:
|
||||||
|
pickle.dump(mapping, file)
|
||||||
|
print("Dataset Summary and Mapping are saved at: {}".format(summary_file_path))
|
||||||
|
|
||||||
|
|
||||||
|
def read_dataset_summary(dataset_path):
|
||||||
|
return sd_utils.read_dataset_summary(dataset_path)
|
||||||
|
|
||||||
|
|
||||||
|
def read_scenario(pkl_file_path):
|
||||||
|
return sd_utils.read_scenario_data(pkl_file_path)
|
||||||
@@ -6,7 +6,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import shutil
|
import shutil
|
||||||
from scenarionet.builder.utils import save_summary_anda_mapping
|
from scenarionet.common_utils import save_summary_anda_mapping
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tqdm
|
import tqdm
|
||||||
from metadrive.scenario import ScenarioDescription as SD
|
from metadrive.scenario import ScenarioDescription as SD
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ import os
|
|||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_PACKAGE_PATH
|
from scenarionet import SCENARIONET_PACKAGE_PATH
|
||||||
from scenarionet.builder.utils import combine_multiple_dataset, read_dataset_summary, read_scenario
|
from scenarionet.builder.utils import combine_multiple_dataset
|
||||||
|
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
||||||
from scenarionet.verifier.utils import verify_loading_into_metadrive
|
from scenarionet.verifier.utils import verify_loading_into_metadrive
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
|
from metadrive.scenario.utils import assert_scenario_equal
|
||||||
from scenarionet import SCENARIONET_PACKAGE_PATH
|
from scenarionet import SCENARIONET_PACKAGE_PATH
|
||||||
from scenarionet.builder.utils import combine_multiple_dataset, read_dataset_summary, read_scenario
|
from scenarionet.builder.utils import combine_multiple_dataset
|
||||||
|
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
||||||
|
from scenarionet.verifier.error import ErrorFile
|
||||||
from scenarionet.verifier.utils import verify_loading_into_metadrive, set_random_drop
|
from scenarionet.verifier.utils import verify_loading_into_metadrive, set_random_drop
|
||||||
|
|
||||||
|
|
||||||
@@ -11,18 +13,27 @@ def test_combine_multiple_dataset():
|
|||||||
dataset_name = "nuscenes"
|
dataset_name = "nuscenes"
|
||||||
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
||||||
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
||||||
output_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine")
|
dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "combine")
|
||||||
combine_multiple_dataset(output_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
|
combine_multiple_dataset(dataset_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
|
||||||
|
|
||||||
dataset_paths.append(output_path)
|
summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path)
|
||||||
for dataset_path in dataset_paths:
|
for scenario_file in sorted_scenarios:
|
||||||
summary, sorted_scenarios, mapping = read_dataset_summary(dataset_path)
|
read_scenario(os.path.join(dataset_path, mapping[scenario_file], scenario_file))
|
||||||
for scenario_file in sorted_scenarios:
|
success, logs = verify_loading_into_metadrive(
|
||||||
read_scenario(os.path.join(dataset_path, mapping[scenario_file], scenario_file))
|
dataset_path, result_save_dir="test_dataset", steps_to_run=1000, num_workers=4)
|
||||||
success, result = verify_loading_into_metadrive(
|
|
||||||
dataset_path, result_save_dir="test_dataset", steps_to_run=1000, num_workers=4)
|
|
||||||
assert success
|
|
||||||
set_random_drop(False)
|
set_random_drop(False)
|
||||||
|
# regenerate
|
||||||
|
file_name = ErrorFile.get_error_file_name(dataset_path)
|
||||||
|
error_file_path = os.path.join("test_dataset", file_name)
|
||||||
|
|
||||||
|
pass_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "passed_senarios")
|
||||||
|
fail_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "failed_scenarios")
|
||||||
|
pass_summary, pass_mapping = ErrorFile.generate_dataset(error_file_path, pass_dataset, broken_scenario=False)
|
||||||
|
fail_summary, fail_mapping = ErrorFile.generate_dataset(error_file_path, fail_dataset, broken_scenario=True)
|
||||||
|
|
||||||
|
read_pass_summary, _, read_pass_mapping = read_dataset_summary(pass_dataset)
|
||||||
|
read_fail_summary, _, read_fail_mapping, = read_dataset_summary(fail_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ from typing import List
|
|||||||
|
|
||||||
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||||
|
|
||||||
from scenarionet.builder.utils import read_dataset_summary
|
from scenarionet.common_utils import save_summary_anda_mapping, read_dataset_summary
|
||||||
from scenarionet.builder.utils import save_summary_anda_mapping
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -35,6 +34,10 @@ class ErrorFile:
|
|||||||
DATASET = "dataset_path"
|
DATASET = "dataset_path"
|
||||||
ERRORS = "errors"
|
ERRORS = "errors"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_error_file_name(cls, dataset_path):
|
||||||
|
return "{}_{}.json".format(cls.PREFIX, os.path.basename(dataset_path))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def dump(cls, save_dir, errors: List, dataset_path):
|
def dump(cls, save_dir, errors: List, dataset_path):
|
||||||
"""
|
"""
|
||||||
@@ -43,9 +46,11 @@ class ErrorFile:
|
|||||||
:param errors: error list, containing a list of dict from ErrorDescription.make()
|
:param errors: error list, containing a list of dict from ErrorDescription.make()
|
||||||
:param dataset_path: dataset_path, the dir of dataset_summary.pkl
|
:param dataset_path: dataset_path, the dir of dataset_summary.pkl
|
||||||
"""
|
"""
|
||||||
file_name = "{}_{}.json".format(cls.PREFIX, os.path.basename(dataset_path))
|
file_name = cls.get_error_file_name(dataset_path)
|
||||||
with open(os.path.join(save_dir, file_name), "w+") as f:
|
path = os.path.join(save_dir, file_name)
|
||||||
|
with open(path, "w+") as f:
|
||||||
json.dump({cls.DATASET: dataset_path, cls.ERRORS: errors}, f, indent=4)
|
json.dump({cls.DATASET: dataset_path, cls.ERRORS: errors}, f, indent=4)
|
||||||
|
return path
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_dataset(cls, error_file_path, new_dataset_path, force_overwrite=False, broken_scenario=False):
|
def generate_dataset(cls, error_file_path, new_dataset_path, force_overwrite=False, broken_scenario=False):
|
||||||
|
|||||||
@@ -46,21 +46,20 @@ def verify_loading_into_metadrive(dataset_path, result_save_dir, steps_to_run=10
|
|||||||
# Run, workers and process result from worker
|
# Run, workers and process result from worker
|
||||||
with multiprocessing.Pool(num_workers) as p:
|
with multiprocessing.Pool(num_workers) as p:
|
||||||
all_result = list(p.imap(func, argument_list))
|
all_result = list(p.imap(func, argument_list))
|
||||||
result = all([i[0] for i in all_result])
|
success = all([i[0] for i in all_result])
|
||||||
errors = []
|
errors = []
|
||||||
for _, error in all_result:
|
for _, error in all_result:
|
||||||
errors += error
|
errors += error
|
||||||
|
|
||||||
# save result
|
|
||||||
EF.dump(result_save_dir, errors, dataset_path)
|
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
if result:
|
if success:
|
||||||
logger.info("All scenarios can be loaded successfully!")
|
logger.info("All scenarios can be loaded successfully!")
|
||||||
else:
|
else:
|
||||||
|
# save result
|
||||||
|
path = EF.dump(result_save_dir, errors, dataset_path)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Fail to load all scenarios, see log for more details! Number of failed scenarios: {}".format(len(errors)))
|
"Fail to load all scenarios. Number of failed scenarios: {}. "
|
||||||
return result, errors
|
"See: {} more details! ".format(len(errors), path))
|
||||||
|
return success, errors
|
||||||
|
|
||||||
|
|
||||||
def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, steps_to_run, metadrive_config=None):
|
def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, steps_to_run, metadrive_config=None):
|
||||||
@@ -86,7 +85,7 @@ def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, ste
|
|||||||
try:
|
try:
|
||||||
env.reset(force_seed=scenario_index)
|
env.reset(force_seed=scenario_index)
|
||||||
arrive = False
|
arrive = False
|
||||||
if RANDOM_DROP and np.random.rand() < 0.5:
|
if RANDOM_DROP and np.random.rand() < 0.8:
|
||||||
raise ValueError("Random Drop")
|
raise ValueError("Random Drop")
|
||||||
for _ in range(steps_to_run):
|
for _ in range(steps_to_run):
|
||||||
o, r, d, info = env.step([0, 0])
|
o, r, d, info = env.step([0, 0])
|
||||||
|
|||||||
Reference in New Issue
Block a user