install md for github test
This commit is contained in:
6
.github/workflows/main.yml
vendored
6
.github/workflows/main.yml
vendored
@@ -39,6 +39,12 @@ jobs:
|
|||||||
pip install pytest
|
pip install pytest
|
||||||
pip install pytest-cov
|
pip install pytest-cov
|
||||||
pip install ray
|
pip install ray
|
||||||
|
|
||||||
|
git clone git@github.com:metadriverse/metadrive.git
|
||||||
|
cd metadrive
|
||||||
|
pip install -e .
|
||||||
|
cd ../
|
||||||
|
|
||||||
cd scenarionet/
|
cd scenarionet/
|
||||||
pytest --cov=./ --cov-config=.coveragerc --cov-report=xml -sv tests
|
pytest --cov=./ --cov-config=.coveragerc --cov-report=xml -sv tests
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,7 @@ def try_generating_summary(file_folder):
|
|||||||
|
|
||||||
|
|
||||||
def combine_multiple_dataset(
|
def combine_multiple_dataset(
|
||||||
output_path, *dataset_paths, force_overwrite=False, try_generate_missing_file=True,
|
output_path, *dataset_paths, force_overwrite=False, try_generate_missing_file=True, filters: List[Callable] = None
|
||||||
filters: List[Callable] = None
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Combine multiple datasets. Each dataset should have a dataset_summary.pkl
|
Combine multiple datasets. Each dataset should have a dataset_summary.pkl
|
||||||
|
|||||||
@@ -71,8 +71,10 @@ def save_summary_anda_mapping(summary_file_path, mapping_file_path, summary, map
|
|||||||
pickle.dump(dict_recursive_remove_array_and_set(summary), file)
|
pickle.dump(dict_recursive_remove_array_and_set(summary), file)
|
||||||
with open(mapping_file_path, "wb") as file:
|
with open(mapping_file_path, "wb") as file:
|
||||||
pickle.dump(mapping, file)
|
pickle.dump(mapping, file)
|
||||||
print("\n ================ Dataset Summary and Mapping are saved at: {} "
|
print(
|
||||||
"================ \n".format(summary_file_path))
|
"\n ================ Dataset Summary and Mapping are saved at: {} "
|
||||||
|
"================ \n".format(summary_file_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def read_dataset_summary(dataset_path):
|
def read_dataset_summary(dataset_path):
|
||||||
|
|||||||
@@ -60,20 +60,24 @@ def contains_explicit_return(f):
|
|||||||
return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f))))
|
return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f))))
|
||||||
|
|
||||||
|
|
||||||
def write_to_directory(convert_func,
|
def write_to_directory(
|
||||||
scenarios,
|
convert_func,
|
||||||
output_path,
|
scenarios,
|
||||||
dataset_version,
|
output_path,
|
||||||
dataset_name,
|
dataset_version,
|
||||||
force_overwrite=False,
|
dataset_name,
|
||||||
num_workers=8,
|
force_overwrite=False,
|
||||||
**kwargs):
|
num_workers=8,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
# make sure dir not exist
|
# make sure dir not exist
|
||||||
save_path = copy.deepcopy(output_path)
|
save_path = copy.deepcopy(output_path)
|
||||||
if os.path.exists(output_path):
|
if os.path.exists(output_path):
|
||||||
if not force_overwrite:
|
if not force_overwrite:
|
||||||
raise ValueError("Directory {} already exists! Abort. "
|
raise ValueError(
|
||||||
"\n Try setting force_overwrite=True or adding --overwrite".format(output_path))
|
"Directory {} already exists! Abort. "
|
||||||
|
"\n Try setting force_overwrite=True or adding --overwrite".format(output_path)
|
||||||
|
)
|
||||||
|
|
||||||
basename = os.path.basename(output_path)
|
basename = os.path.basename(output_path)
|
||||||
dir = os.path.dirname(output_path)
|
dir = os.path.dirname(output_path)
|
||||||
@@ -81,8 +85,10 @@ def write_to_directory(convert_func,
|
|||||||
output_path = os.path.join(dir, "{}_{}".format(basename, str(i)))
|
output_path = os.path.join(dir, "{}_{}".format(basename, str(i)))
|
||||||
if os.path.exists(output_path):
|
if os.path.exists(output_path):
|
||||||
if not force_overwrite:
|
if not force_overwrite:
|
||||||
raise ValueError("Directory {} already exists! Abort. "
|
raise ValueError(
|
||||||
"\n Try setting force_overwrite=True or adding --overwrite".format(output_path))
|
"Directory {} already exists! Abort. "
|
||||||
|
"\n Try setting force_overwrite=True or adding --overwrite".format(output_path)
|
||||||
|
)
|
||||||
# get arguments for workers
|
# get arguments for workers
|
||||||
num_files = len(scenarios)
|
num_files = len(scenarios)
|
||||||
if num_files < num_workers:
|
if num_files < num_workers:
|
||||||
@@ -103,42 +109,46 @@ def write_to_directory(convert_func,
|
|||||||
argument_list.append([scenarios[i * num_files_each_worker:end_idx], kwargs, i, output_path])
|
argument_list.append([scenarios[i * num_files_each_worker:end_idx], kwargs, i, output_path])
|
||||||
|
|
||||||
# prefill arguments
|
# prefill arguments
|
||||||
func = partial(writing_to_directory_wrapper,
|
func = partial(
|
||||||
convert_func=convert_func,
|
writing_to_directory_wrapper,
|
||||||
dataset_version=dataset_version,
|
convert_func=convert_func,
|
||||||
dataset_name=dataset_name,
|
dataset_version=dataset_version,
|
||||||
force_overwrite=force_overwrite)
|
dataset_name=dataset_name,
|
||||||
|
force_overwrite=force_overwrite
|
||||||
|
)
|
||||||
|
|
||||||
# Run, workers and process result from worker
|
# Run, workers and process result from worker
|
||||||
with multiprocessing.Pool(num_workers) as p:
|
with multiprocessing.Pool(num_workers) as p:
|
||||||
all_result = list(p.imap(func, argument_list))
|
all_result = list(p.imap(func, argument_list))
|
||||||
combine_multiple_dataset(save_path, *output_pathes, force_overwrite=force_overwrite,
|
combine_multiple_dataset(
|
||||||
try_generate_missing_file=False)
|
save_path, *output_pathes, force_overwrite=force_overwrite, try_generate_missing_file=False
|
||||||
|
)
|
||||||
return all_result
|
return all_result
|
||||||
|
|
||||||
|
|
||||||
def writing_to_directory_wrapper(args,
|
def writing_to_directory_wrapper(args, convert_func, dataset_version, dataset_name, force_overwrite=False):
|
||||||
convert_func,
|
return write_to_directory_single_worker(
|
||||||
dataset_version,
|
convert_func=convert_func,
|
||||||
dataset_name,
|
scenarios=args[0],
|
||||||
force_overwrite=False):
|
output_path=args[3],
|
||||||
return write_to_directory_single_worker(convert_func=convert_func,
|
dataset_version=dataset_version,
|
||||||
scenarios=args[0],
|
dataset_name=dataset_name,
|
||||||
output_path=args[3],
|
force_overwrite=force_overwrite,
|
||||||
dataset_version=dataset_version,
|
worker_index=args[2],
|
||||||
dataset_name=dataset_name,
|
**args[1]
|
||||||
force_overwrite=force_overwrite,
|
)
|
||||||
worker_index=args[2],
|
|
||||||
**args[1])
|
|
||||||
|
|
||||||
|
|
||||||
def write_to_directory_single_worker(convert_func,
|
def write_to_directory_single_worker(
|
||||||
scenarios,
|
convert_func,
|
||||||
output_path,
|
scenarios,
|
||||||
dataset_version,
|
output_path,
|
||||||
dataset_name,
|
dataset_version,
|
||||||
worker_index=0,
|
dataset_name,
|
||||||
force_overwrite=False, **kwargs):
|
worker_index=0,
|
||||||
|
force_overwrite=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Convert a batch of scenarios.
|
Convert a batch of scenarios.
|
||||||
"""
|
"""
|
||||||
@@ -162,8 +172,10 @@ def write_to_directory_single_worker(convert_func,
|
|||||||
if force_overwrite:
|
if force_overwrite:
|
||||||
delay_remove = save_path
|
delay_remove = save_path
|
||||||
else:
|
else:
|
||||||
raise ValueError("Directory already exists! Abort."
|
raise ValueError(
|
||||||
"\n Try setting force_overwrite=True or using --overwrite")
|
"Directory already exists! Abort."
|
||||||
|
"\n Try setting force_overwrite=True or using --overwrite"
|
||||||
|
)
|
||||||
|
|
||||||
summary_file = SD.DATASET.SUMMARY_FILE
|
summary_file = SD.DATASET.SUMMARY_FILE
|
||||||
mapping_file = SD.DATASET.MAPPING_FILE
|
mapping_file = SD.DATASET.MAPPING_FILE
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if len(args.from_datasets) != 0:
|
if len(args.from_datasets) != 0:
|
||||||
combine_multiple_dataset(args.to,
|
combine_multiple_dataset(
|
||||||
*args.from_datasets,
|
args.to, *args.from_datasets, force_overwrite=args.overwrite, try_generate_missing_file=True
|
||||||
force_overwrite=args.overwrite,
|
)
|
||||||
try_generate_missing_file=True)
|
|
||||||
|
|||||||
@@ -7,10 +7,15 @@ from scenarionet.converter.utils import write_to_directory
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--dataset_name", "-n", default="nuplan",
|
parser.add_argument(
|
||||||
help="Dataset name, will be used to generate scenario files")
|
"--dataset_name", "-n", default="nuplan", help="Dataset name, will be used to generate scenario files"
|
||||||
parser.add_argument("--dataset_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "nuplan"),
|
)
|
||||||
help="The path of the dataset")
|
parser.add_argument(
|
||||||
|
"--dataset_path",
|
||||||
|
"-d",
|
||||||
|
default=os.path.join(SCENARIONET_DATASET_PATH, "nuplan"),
|
||||||
|
help="The path of the dataset"
|
||||||
|
)
|
||||||
parser.add_argument("--version", "-v", default='v1.1', help="version")
|
parser.add_argument("--version", "-v", default='v1.1', help="version")
|
||||||
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
||||||
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||||
|
|||||||
@@ -7,11 +7,16 @@ from scenarionet.converter.utils import write_to_directory
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--dataset_name", "-n", default="nuscenes",
|
parser.add_argument(
|
||||||
help="Dataset name, will be used to generate scenario files")
|
"--dataset_name", "-n", default="nuscenes", help="Dataset name, will be used to generate scenario files"
|
||||||
parser.add_argument("--dataset_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"),
|
)
|
||||||
help="The path of the dataset")
|
parser.add_argument(
|
||||||
parser.add_argument("--version", "-v", default='v1.0-mini', help="version")
|
"--dataset_path",
|
||||||
|
"-d",
|
||||||
|
default=os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"),
|
||||||
|
help="The path of the dataset"
|
||||||
|
)
|
||||||
|
parser.add_argument("--version", "-v", default='v1.0-mini', help="version")
|
||||||
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
||||||
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@@ -10,11 +10,13 @@ from scenarionet.converter.utils import write_to_directory
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--dataset_name", "-n", default="pg",
|
parser.add_argument(
|
||||||
help="Dataset name, will be used to generate scenario files")
|
"--dataset_name", "-n", default="pg", help="Dataset name, will be used to generate scenario files"
|
||||||
parser.add_argument("--dataset_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "pg"),
|
)
|
||||||
help="The path of the dataset")
|
parser.add_argument(
|
||||||
parser.add_argument("--version", "-v", default=metadrive.constants.DATA_VERSION, help="version")
|
"--dataset_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "pg"), help="The path of the dataset"
|
||||||
|
)
|
||||||
|
parser.add_argument("--version", "-v", default=metadrive.constants.DATA_VERSION, help="version")
|
||||||
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
||||||
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@@ -10,11 +10,13 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--dataset_name", "-n", default="waymo",
|
parser.add_argument(
|
||||||
help="Dataset name, will be used to generate scenario files")
|
"--dataset_name", "-n", default="waymo", help="Dataset name, will be used to generate scenario files"
|
||||||
parser.add_argument("--dataset_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "waymo"),
|
)
|
||||||
help="The path of the dataset")
|
parser.add_argument(
|
||||||
parser.add_argument("--version", "-v", default='v1.2', help="version")
|
"--dataset_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "waymo"), help="The path of the dataset"
|
||||||
|
)
|
||||||
|
parser.add_argument("--version", "-v", default='v1.2', help="version")
|
||||||
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
parser.add_argument("--overwrite", action="store_true", help="If the dataset_path exists, overwrite it")
|
||||||
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
parser.add_argument("--num_workers", type=int, default=8, help="number of workers to use")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ def test_generate_from_error():
|
|||||||
for scenario_file in sorted_scenarios:
|
for scenario_file in sorted_scenarios:
|
||||||
read_scenario(dataset_path, mapping, scenario_file)
|
read_scenario(dataset_path, mapping, scenario_file)
|
||||||
success, logs = verify_loading_into_metadrive(
|
success, logs = verify_loading_into_metadrive(
|
||||||
dataset_path, result_save_dir="../test_dataset", steps_to_run=1000, num_workers=16)
|
dataset_path, result_save_dir="../test_dataset", steps_to_run=1000, num_workers=16
|
||||||
|
)
|
||||||
set_random_drop(False)
|
set_random_drop(False)
|
||||||
# get error file
|
# get error file
|
||||||
file_name = ErrorFile.get_error_file_name(dataset_path)
|
file_name = ErrorFile.get_error_file_name(dataset_path)
|
||||||
@@ -39,10 +40,12 @@ def test_generate_from_error():
|
|||||||
# regenerate
|
# regenerate
|
||||||
pass_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "passed_scenarios")
|
pass_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "passed_scenarios")
|
||||||
fail_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "failed_scenarios")
|
fail_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "failed_scenarios")
|
||||||
pass_summary, pass_mapping = ErrorFile.generate_dataset(error_file_path, pass_dataset, force_overwrite=True,
|
pass_summary, pass_mapping = ErrorFile.generate_dataset(
|
||||||
broken_scenario=False)
|
error_file_path, pass_dataset, force_overwrite=True, broken_scenario=False
|
||||||
fail_summary, fail_mapping = ErrorFile.generate_dataset(error_file_path, fail_dataset, force_overwrite=True,
|
)
|
||||||
broken_scenario=True)
|
fail_summary, fail_mapping = ErrorFile.generate_dataset(
|
||||||
|
error_file_path, fail_dataset, force_overwrite=True, broken_scenario=True
|
||||||
|
)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
read_pass_summary, _, read_pass_mapping = read_dataset_summary(pass_dataset)
|
read_pass_summary, _, read_pass_mapping = read_dataset_summary(pass_dataset)
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ def test_generate_from_error():
|
|||||||
for scenario_file in sorted_scenarios:
|
for scenario_file in sorted_scenarios:
|
||||||
read_scenario(dataset_path, mapping, scenario_file)
|
read_scenario(dataset_path, mapping, scenario_file)
|
||||||
success, logs = verify_loading_into_metadrive(
|
success, logs = verify_loading_into_metadrive(
|
||||||
dataset_path, result_save_dir="test_dataset", steps_to_run=1000, num_workers=3)
|
dataset_path, result_save_dir="test_dataset", steps_to_run=1000, num_workers=3
|
||||||
|
)
|
||||||
set_random_drop(False)
|
set_random_drop(False)
|
||||||
# get error file
|
# get error file
|
||||||
file_name = ErrorFile.get_error_file_name(dataset_path)
|
file_name = ErrorFile.get_error_file_name(dataset_path)
|
||||||
@@ -32,10 +33,12 @@ def test_generate_from_error():
|
|||||||
# regenerate
|
# regenerate
|
||||||
pass_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "passed_senarios")
|
pass_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "passed_senarios")
|
||||||
fail_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "failed_scenarios")
|
fail_dataset = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "tmp", "failed_scenarios")
|
||||||
pass_summary, pass_mapping = ErrorFile.generate_dataset(error_file_path, pass_dataset, force_overwrite=True,
|
pass_summary, pass_mapping = ErrorFile.generate_dataset(
|
||||||
broken_scenario=False)
|
error_file_path, pass_dataset, force_overwrite=True, broken_scenario=False
|
||||||
fail_summary, fail_mapping = ErrorFile.generate_dataset(error_file_path, fail_dataset, force_overwrite=True,
|
)
|
||||||
broken_scenario=True)
|
fail_summary, fail_mapping = ErrorFile.generate_dataset(
|
||||||
|
error_file_path, fail_dataset, force_overwrite=True, broken_scenario=True
|
||||||
|
)
|
||||||
|
|
||||||
# assert
|
# assert
|
||||||
read_pass_summary, _, read_pass_mapping = read_dataset_summary(pass_dataset)
|
read_pass_summary, _, read_pass_mapping = read_dataset_summary(pass_dataset)
|
||||||
|
|||||||
@@ -24,10 +24,7 @@ class ErrorDescription:
|
|||||||
"\n Scenario Error, "
|
"\n Scenario Error, "
|
||||||
"scenario_index: {}, file_path: {}.\n Error message: {}".format(scenario_index, file_path, str(error))
|
"scenario_index: {}, file_path: {}.\n Error message: {}".format(scenario_index, file_path, str(error))
|
||||||
)
|
)
|
||||||
return {cls.INDEX: scenario_index,
|
return {cls.INDEX: scenario_index, cls.PATH: file_path, cls.FILE_NAME: file_name, cls.ERROR: str(error)}
|
||||||
cls.PATH: file_path,
|
|
||||||
cls.FILE_NAME: file_name,
|
|
||||||
cls.ERROR: str(error)}
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorFile:
|
class ErrorFile:
|
||||||
@@ -69,8 +66,10 @@ class ErrorFile:
|
|||||||
if force_overwrite:
|
if force_overwrite:
|
||||||
shutil.rmtree(new_dataset_path)
|
shutil.rmtree(new_dataset_path)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Directory: {} already exists! "
|
raise ValueError(
|
||||||
"Set force_overwrite=True to overwrite".format(new_dataset_path))
|
"Directory: {} already exists! "
|
||||||
|
"Set force_overwrite=True to overwrite".format(new_dataset_path)
|
||||||
|
)
|
||||||
os.makedirs(new_dataset_path, exist_ok=False)
|
os.makedirs(new_dataset_path, exist_ok=False)
|
||||||
|
|
||||||
with open(error_file_path, "r+") as f:
|
with open(error_file_path, "r+") as f:
|
||||||
|
|||||||
@@ -59,7 +59,8 @@ def verify_loading_into_metadrive(dataset_path, result_save_dir, steps_to_run=10
|
|||||||
path = EF.dump(result_save_dir, errors, dataset_path)
|
path = EF.dump(result_save_dir, errors, dataset_path)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Fail to load all scenarios. Number of failed scenarios: {}. "
|
"Fail to load all scenarios. Number of failed scenarios: {}. "
|
||||||
"See: {} more details! ".format(len(errors), path))
|
"See: {} more details! ".format(len(errors), path)
|
||||||
|
)
|
||||||
return success, errors
|
return success, errors
|
||||||
|
|
||||||
|
|
||||||
@@ -67,17 +68,21 @@ def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, ste
|
|||||||
global RANDOM_DROP
|
global RANDOM_DROP
|
||||||
logger.info(
|
logger.info(
|
||||||
"================ Begin Scenario Loading Verification for scenario {}-{} ================ \n".format(
|
"================ Begin Scenario Loading Verification for scenario {}-{} ================ \n".format(
|
||||||
start_scenario_index, num_scenario + start_scenario_index))
|
start_scenario_index, num_scenario + start_scenario_index
|
||||||
|
)
|
||||||
|
)
|
||||||
success = True
|
success = True
|
||||||
metadrive_config = metadrive_config or {}
|
metadrive_config = metadrive_config or {}
|
||||||
metadrive_config.update({
|
metadrive_config.update(
|
||||||
"agent_policy": ReplayEgoCarPolicy,
|
{
|
||||||
"num_scenarios": num_scenario,
|
"agent_policy": ReplayEgoCarPolicy,
|
||||||
"horizon": 1000,
|
"num_scenarios": num_scenario,
|
||||||
"start_scenario_index": start_scenario_index,
|
"horizon": 1000,
|
||||||
"no_static_vehicles": False,
|
"start_scenario_index": start_scenario_index,
|
||||||
"data_directory": dataset_path,
|
"no_static_vehicles": False,
|
||||||
})
|
"data_directory": dataset_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
env = ScenarioEnv(metadrive_config)
|
env = ScenarioEnv(metadrive_config)
|
||||||
logging.disable(logging.INFO)
|
logging.disable(logging.INFO)
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
|
|||||||
Reference in New Issue
Block a user