multi processing get files
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
|
||||
@@ -421,14 +422,41 @@ def convert_waymo_scenario(scenario, version):
|
||||
return md_scenario
|
||||
|
||||
|
||||
def get_waymo_scenarios(waymo_data_direction):
|
||||
def get_waymo_scenarios(waymo_data_directory, num_workers=8):
|
||||
# parse raw data from input path to output path,
|
||||
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
|
||||
file_list = os.listdir(waymo_data_direction)
|
||||
file_list = os.listdir(waymo_data_directory)
|
||||
num_files = len(file_list)
|
||||
if num_files < num_workers:
|
||||
# single process
|
||||
logger.info("Use one worker, as num_scenario < num_workers:")
|
||||
num_workers = 1
|
||||
|
||||
argument_list = []
|
||||
num_files_each_worker = int(num_files // num_workers)
|
||||
for i in range(num_workers):
|
||||
if i == num_workers - 1:
|
||||
end_idx = num_files
|
||||
else:
|
||||
end_idx = (i + 1) * num_files_each_worker
|
||||
argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]])
|
||||
|
||||
# Run, workers and process result from worker
|
||||
with multiprocessing.Pool(num_workers) as p:
|
||||
all_result = list(p.imap(read_from_files, argument_list))
|
||||
ret = []
|
||||
|
||||
# get result
|
||||
for r in all_result:
|
||||
ret += r
|
||||
return ret
|
||||
|
||||
|
||||
def read_from_files(arg):
|
||||
waymo_data_directory, file_list = arg[0], arg[1]
|
||||
scenarios = []
|
||||
for file_count, file in enumerate(file_list):
|
||||
file_path = os.path.join(waymo_data_direction, file)
|
||||
file_path = os.path.join(waymo_data_directory, file)
|
||||
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
||||
continue
|
||||
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
|
||||
|
||||
@@ -23,8 +23,8 @@ if __name__ == '__main__':
|
||||
output_path = args.dataset_path
|
||||
version = args.version
|
||||
|
||||
waymo_data_direction = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin")
|
||||
scenarios = get_waymo_scenarios(waymo_data_direction)
|
||||
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin")
|
||||
scenarios = get_waymo_scenarios(waymo_data_directory)
|
||||
|
||||
write_to_directory(
|
||||
convert_func=convert_waymo_scenario,
|
||||
|
||||
Reference in New Issue
Block a user