multi processing get files
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
@@ -421,14 +422,41 @@ def convert_waymo_scenario(scenario, version):
|
|||||||
return md_scenario
|
return md_scenario
|
||||||
|
|
||||||
|
|
||||||
def get_waymo_scenarios(waymo_data_direction):
|
def get_waymo_scenarios(waymo_data_directory, num_workers=8):
|
||||||
# parse raw data from input path to output path,
|
# parse raw data from input path to output path,
|
||||||
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
|
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
|
||||||
file_list = os.listdir(waymo_data_direction)
|
file_list = os.listdir(waymo_data_directory)
|
||||||
|
num_files = len(file_list)
|
||||||
|
if num_files < num_workers:
|
||||||
|
# single process
|
||||||
|
logger.info("Use one worker, as num_scenario < num_workers:")
|
||||||
|
num_workers = 1
|
||||||
|
|
||||||
|
argument_list = []
|
||||||
|
num_files_each_worker = int(num_files // num_workers)
|
||||||
|
for i in range(num_workers):
|
||||||
|
if i == num_workers - 1:
|
||||||
|
end_idx = num_files
|
||||||
|
else:
|
||||||
|
end_idx = (i + 1) * num_files_each_worker
|
||||||
|
argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]])
|
||||||
|
|
||||||
|
# Run, workers and process result from worker
|
||||||
|
with multiprocessing.Pool(num_workers) as p:
|
||||||
|
all_result = list(p.imap(read_from_files, argument_list))
|
||||||
|
ret = []
|
||||||
|
|
||||||
|
# get result
|
||||||
|
for r in all_result:
|
||||||
|
ret += r
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def read_from_files(arg):
|
||||||
|
waymo_data_directory, file_list = arg[0], arg[1]
|
||||||
scenarios = []
|
scenarios = []
|
||||||
for file_count, file in enumerate(file_list):
|
for file_count, file in enumerate(file_list):
|
||||||
file_path = os.path.join(waymo_data_direction, file)
|
file_path = os.path.join(waymo_data_directory, file)
|
||||||
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
||||||
continue
|
continue
|
||||||
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
|
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
|
||||||
|
|||||||
@@ -23,8 +23,8 @@ if __name__ == '__main__':
|
|||||||
output_path = args.dataset_path
|
output_path = args.dataset_path
|
||||||
version = args.version
|
version = args.version
|
||||||
|
|
||||||
waymo_data_direction = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin")
|
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin")
|
||||||
scenarios = get_waymo_scenarios(waymo_data_direction)
|
scenarios = get_waymo_scenarios(waymo_data_directory)
|
||||||
|
|
||||||
write_to_directory(
|
write_to_directory(
|
||||||
convert_func=convert_waymo_scenario,
|
convert_func=convert_waymo_scenario,
|
||||||
|
|||||||
Reference in New Issue
Block a user