From 436cbd3ab356b51b38e3f7b0145958cbb61490f7 Mon Sep 17 00:00:00 2001 From: QuanyiLi Date: Mon, 8 May 2023 14:48:27 +0100 Subject: [PATCH] test waymo converter --- scenarionet/converter/utils.py | 25 +++++++------ scenarionet/converter/waymo/utils.py | 3 ++ .../tests/local_test/_test_convert_waymo.py | 36 +++++++++++++++++++ 3 files changed, 54 insertions(+), 10 deletions(-) create mode 100644 scenarionet/tests/local_test/_test_convert_waymo.py diff --git a/scenarionet/converter/utils.py b/scenarionet/converter/utils.py index ecdf0df..bf3c270 100644 --- a/scenarionet/converter/utils.py +++ b/scenarionet/converter/utils.py @@ -90,15 +90,15 @@ def write_to_directory(convert_func, end_idx = num_files else: end_idx = (i + 1) * num_files_each_worker - argument_list.append([scenarios[i * num_files_each_worker:end_idx], kwargs]) + output_path = os.path.join(dir, "{}_{}".format(basename, str(i))) + argument_list.append([scenarios[i * num_files_each_worker:end_idx], kwargs, i, output_path]) # prefill arguments func = partial(writing_to_directory_wrapper, - convert_func, - output_path, - dataset_version, - dataset_name, - force_overwrite) + convert_func=convert_func, + dataset_version=dataset_version, + dataset_name=dataset_name, + force_overwrite=force_overwrite) # Run, workers and process result from worker with multiprocessing.Pool(num_workers) as p: @@ -108,20 +108,25 @@ def write_to_directory(convert_func, def writing_to_directory_wrapper(args, convert_func, - output_path, dataset_version, dataset_name, force_overwrite=False): return write_to_directory_single_worker(convert_func=convert_func, scenarios=args[0], - output_path=output_path, + output_path=args[3], dataset_version=dataset_version, dataset_name=dataset_name, force_overwrite=force_overwrite, + worker_index=args[2], **args[1]) -def write_to_directory_single_worker(convert_func, scenarios, output_path, dataset_version, dataset_name, +def write_to_directory_single_worker(convert_func, + scenarios, + output_path, + dataset_version, + dataset_name, + worker_index=0, force_overwrite=False, **kwargs): """ Convert a batch of scenarios. @@ -157,7 +162,7 @@ def write_to_directory_single_worker(convert_func, scenarios, output_path, datas summary = {} mapping = {} - for scenario in tqdm.tqdm(scenarios): + for scenario in tqdm.tqdm(scenarios, desc="Worker Index: {}".format(worker_index)): # convert scenario sd_scenario = convert_func(scenario, dataset_version, **kwargs) scenario_id = sd_scenario[SD.ID] diff --git a/scenarionet/converter/waymo/utils.py b/scenarionet/converter/waymo/utils.py index 3220bb2..1b67530 100644 --- a/scenarionet/converter/waymo/utils.py +++ b/scenarionet/converter/waymo/utils.py @@ -448,7 +448,10 @@ def get_waymo_scenarios(waymo_data_directory, num_workers=8): # get result for r in all_result: + if len(r) == 0: + logger.warning("0 scenarios found") ret += r + logger.info("\n Find {} waymo scenarios from {} files".format(len(ret), num_files)) return ret diff --git a/scenarionet/tests/local_test/_test_convert_waymo.py b/scenarionet/tests/local_test/_test_convert_waymo.py new file mode 100644 index 0000000..228d9ed --- /dev/null +++ b/scenarionet/tests/local_test/_test_convert_waymo.py @@ -0,0 +1,36 @@ +import argparse +import logging +import os + +from scenarionet import SCENARIONET_DATASET_PATH +from scenarionet.converter.utils import write_to_directory +from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios + +logger = logging.getLogger(__name__) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_name", "-n", default="waymo", + help="Dataset name, will be used to generate scenario files") + parser.add_argument("--dataset_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "waymo"), + help="The path of the dataset") + parser.add_argument("--version", "-v", default='v1.2', help="version") + args = parser.parse_args() + + force_overwrite = True + dataset_name = args.dataset_name + output_path = args.dataset_path + version = args.version + + waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin") + scenarios = get_waymo_scenarios(waymo_data_directory, num_workers=3) + + write_to_directory( + convert_func=convert_waymo_scenario, + scenarios=scenarios, + output_path=output_path, + dataset_version=version, + dataset_name=dataset_name, + force_overwrite=force_overwrite, + num_workers=8 + )