diff --git a/scenarionet/converter/waymo/utils.py b/scenarionet/converter/waymo/utils.py index 96b9fff..3220bb2 100644 --- a/scenarionet/converter/waymo/utils.py +++ b/scenarionet/converter/waymo/utils.py @@ -1,4 +1,5 @@ import logging +import multiprocessing import os import pickle @@ -421,14 +422,41 @@ def convert_waymo_scenario(scenario, version): return md_scenario -def get_waymo_scenarios(waymo_data_direction): +def get_waymo_scenarios(waymo_data_directory, num_workers=8): # parse raw data from input path to output path, # there is 1000 raw data in google cloud, each of them produce about 500 pkl file - file_list = os.listdir(waymo_data_direction) + file_list = os.listdir(waymo_data_directory) + num_files = len(file_list) + if num_files < num_workers: + # single process + logger.info("Use one worker, as num_scenario < num_workers:") + num_workers = 1 + argument_list = [] + num_files_each_worker = int(num_files // num_workers) + for i in range(num_workers): + if i == num_workers - 1: + end_idx = num_files + else: + end_idx = (i + 1) * num_files_each_worker + argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]]) + + # Run, workers and process result from worker + with multiprocessing.Pool(num_workers) as p: + all_result = list(p.imap(read_from_files, argument_list)) + ret = [] + + # get result + for r in all_result: + ret += r + return ret + + +def read_from_files(arg): + waymo_data_directory, file_list = arg[0], arg[1] scenarios = [] for file_count, file in enumerate(file_list): - file_path = os.path.join(waymo_data_direction, file) + file_path = os.path.join(waymo_data_directory, file) if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)): continue for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator(): diff --git a/scenarionet/scripts/convert_waymo.py b/scenarionet/scripts/convert_waymo.py index ad0f4d6..d1c0d73 100644 --- a/scenarionet/scripts/convert_waymo.py +++ b/scenarionet/scripts/convert_waymo.py @@ -23,8 +23,8 @@ if __name__ == '__main__': output_path = args.dataset_path version = args.version - waymo_data_direction = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin") - scenarios = get_waymo_scenarios(waymo_data_direction) + waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin") + scenarios = get_waymo_scenarios(waymo_data_directory) write_to_directory( convert_func=convert_waymo_scenario,