test waymo converter
This commit is contained in:
@@ -90,15 +90,15 @@ def write_to_directory(convert_func,
|
||||
end_idx = num_files
|
||||
else:
|
||||
end_idx = (i + 1) * num_files_each_worker
|
||||
argument_list.append([scenarios[i * num_files_each_worker:end_idx], kwargs])
|
||||
output_path = os.path.join(dir, "{}_{}".format(basename, str(i)))
|
||||
argument_list.append([scenarios[i * num_files_each_worker:end_idx], kwargs, i, output_path])
|
||||
|
||||
# prefill arguments
|
||||
func = partial(writing_to_directory_wrapper,
|
||||
convert_func,
|
||||
output_path,
|
||||
dataset_version,
|
||||
dataset_name,
|
||||
force_overwrite)
|
||||
convert_func=convert_func,
|
||||
dataset_version=dataset_version,
|
||||
dataset_name=dataset_name,
|
||||
force_overwrite=force_overwrite)
|
||||
|
||||
# Run, workers and process result from worker
|
||||
with multiprocessing.Pool(num_workers) as p:
|
||||
@@ -108,20 +108,25 @@ def write_to_directory(convert_func,
|
||||
|
||||
def writing_to_directory_wrapper(args,
|
||||
convert_func,
|
||||
output_path,
|
||||
dataset_version,
|
||||
dataset_name,
|
||||
force_overwrite=False):
|
||||
return write_to_directory_single_worker(convert_func=convert_func,
|
||||
scenarios=args[0],
|
||||
output_path=output_path,
|
||||
output_path=args[3],
|
||||
dataset_version=dataset_version,
|
||||
dataset_name=dataset_name,
|
||||
force_overwrite=force_overwrite,
|
||||
worker_index=args[2],
|
||||
**args[1])
|
||||
|
||||
|
||||
def write_to_directory_single_worker(convert_func, scenarios, output_path, dataset_version, dataset_name,
|
||||
def write_to_directory_single_worker(convert_func,
|
||||
scenarios,
|
||||
output_path,
|
||||
dataset_version,
|
||||
dataset_name,
|
||||
worker_index=0,
|
||||
force_overwrite=False, **kwargs):
|
||||
"""
|
||||
Convert a batch of scenarios.
|
||||
@@ -157,7 +162,7 @@ def write_to_directory_single_worker(convert_func, scenarios, output_path, datas
|
||||
|
||||
summary = {}
|
||||
mapping = {}
|
||||
for scenario in tqdm.tqdm(scenarios):
|
||||
for scenario in tqdm.tqdm(scenarios, desc="Worker Index: {}".format(worker_index)):
|
||||
# convert scenario
|
||||
sd_scenario = convert_func(scenario, dataset_version, **kwargs)
|
||||
scenario_id = sd_scenario[SD.ID]
|
||||
|
||||
@@ -448,7 +448,10 @@ def get_waymo_scenarios(waymo_data_directory, num_workers=8):
|
||||
|
||||
# get result
|
||||
for r in all_result:
|
||||
if len(r) == 0:
|
||||
logger.warning("0 scenarios found")
|
||||
ret += r
|
||||
logger.info("\n Find {} waymo scenarios from {} files".format(len(ret), num_files))
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
36
scenarionet/tests/local_test/_test_convert_waymo.py
Normal file
36
scenarionet/tests/local_test/_test_convert_waymo.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from scenarionet import SCENARIONET_DATASET_PATH
|
||||
from scenarionet.converter.utils import write_to_directory
|
||||
from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset_name", "-n", default="waymo",
|
||||
help="Dataset name, will be used to generate scenario files")
|
||||
parser.add_argument("--dataset_path", "-d", default=os.path.join(SCENARIONET_DATASET_PATH, "waymo"),
|
||||
help="The path of the dataset")
|
||||
parser.add_argument("--version", "-v", default='v1.2', help="version")
|
||||
args = parser.parse_args()
|
||||
|
||||
force_overwrite = True
|
||||
dataset_name = args.dataset_name
|
||||
output_path = args.dataset_path
|
||||
version = args.version
|
||||
|
||||
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin")
|
||||
scenarios = get_waymo_scenarios(waymo_data_directory, num_workers=3)
|
||||
|
||||
write_to_directory(
|
||||
convert_func=convert_waymo_scenario,
|
||||
scenarios=scenarios,
|
||||
output_path=output_path,
|
||||
dataset_version=version,
|
||||
dataset_name=dataset_name,
|
||||
force_overwrite=force_overwrite,
|
||||
num_workers=8
|
||||
)
|
||||
Reference in New Issue
Block a user