multi processing get files

This commit is contained in:
QuanyiLi
2023-05-08 13:53:57 +01:00
parent 9483b8fc56
commit 97ae6addf6
2 changed files with 33 additions and 5 deletions

View File

@@ -1,4 +1,5 @@
import logging
import multiprocessing
import os
import pickle
@@ -421,14 +422,41 @@ def convert_waymo_scenario(scenario, version):
return md_scenario
def get_waymo_scenarios(waymo_data_direction):
def get_waymo_scenarios(waymo_data_directory, num_workers=8):
# parse raw data from input path to output path,
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
file_list = os.listdir(waymo_data_direction)
file_list = os.listdir(waymo_data_directory)
num_files = len(file_list)
if num_files < num_workers:
# single process
logger.info("Use one worker, as num_scenario < num_workers:")
num_workers = 1
argument_list = []
num_files_each_worker = int(num_files // num_workers)
for i in range(num_workers):
if i == num_workers - 1:
end_idx = num_files
else:
end_idx = (i + 1) * num_files_each_worker
argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]])
# Run, workers and process result from worker
with multiprocessing.Pool(num_workers) as p:
all_result = list(p.imap(read_from_files, argument_list))
ret = []
# get result
for r in all_result:
ret += r
return ret
def read_from_files(arg):
waymo_data_directory, file_list = arg[0], arg[1]
scenarios = []
for file_count, file in enumerate(file_list):
file_path = os.path.join(waymo_data_direction, file)
file_path = os.path.join(waymo_data_directory, file)
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
continue
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():

View File

@@ -23,8 +23,8 @@ if __name__ == '__main__':
output_path = args.dataset_path
version = args.version
waymo_data_direction = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin")
scenarios = get_waymo_scenarios(waymo_data_direction)
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin")
scenarios = get_waymo_scenarios(waymo_data_directory)
write_to_directory(
convert_func=convert_waymo_scenario,