Optimize waymo converter (#44)

* use generator for waymo

* :wqadd preprocessor

* use generator
This commit is contained in:
Quanyi Li
2023-11-05 16:07:07 +00:00
committed by GitHub
parent 3dd161188c
commit 8bc1d88f06
3 changed files with 53 additions and 44 deletions

View File

@@ -9,7 +9,7 @@ if __name__ == '__main__':
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_REPO_PATH from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_REPO_PATH
from scenarionet.converter.utils import write_to_directory from scenarionet.converter.utils import write_to_directory
from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios, preprocess_waymo_scenarios
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -63,16 +63,15 @@ if __name__ == '__main__':
shutil.rmtree(output_path) shutil.rmtree(output_path)
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path) waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path)
scenarios = get_waymo_scenarios( files = get_waymo_scenarios(waymo_data_directory, args.start_file_index, args.num_files)
waymo_data_directory, args.start_file_index, args.num_files, num_workers=8
) # do not use too much worker to read data
write_to_directory( write_to_directory(
convert_func=convert_waymo_scenario, convert_func=convert_waymo_scenario,
scenarios=scenarios, scenarios=files,
output_path=output_path, output_path=output_path,
dataset_version=version, dataset_version=version,
dataset_name=dataset_name, dataset_name=dataset_name,
overwrite=overwrite, overwrite=overwrite,
num_workers=args.num_workers num_workers=args.num_workers,
preprocess=preprocess_waymo_scenarios,
) )

View File

@@ -21,6 +21,18 @@ from scenarionet.converter.pg.utils import convert_pg_scenario, make_env
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
def single_worker_preprocess(x, worker_index):
"""
All scenarios passed to write_to_directory_single_worker will be preprocessed. The input is expected to be a list.
The output should be a list too. The element in the second list will be processed by convertors. By default, you
don't need to provide this processor. We override it for waymo convertor to release the memory in time.
:param x: input
:param worker_index: worker_index, useful for logging
:return: input
"""
return x
def nuplan_to_metadrive_vector(vector, nuplan_center=(0, 0)): def nuplan_to_metadrive_vector(vector, nuplan_center=(0, 0)):
"All vec in nuplan should be centered in (0,0) to avoid numerical explosion" "All vec in nuplan should be centered in (0,0) to avoid numerical explosion"
vector = np.array(vector) vector = np.array(vector)
@@ -63,7 +75,15 @@ def contains_explicit_return(f):
def write_to_directory( def write_to_directory(
convert_func, scenarios, output_path, dataset_version, dataset_name, overwrite=False, num_workers=8, **kwargs convert_func,
scenarios,
output_path,
dataset_version,
dataset_name,
overwrite=False,
num_workers=8,
preprocess=single_worker_preprocess,
**kwargs
): ):
# make sure dir not exist # make sure dir not exist
kwargs_for_workers = [{} for _ in range(num_workers)] kwargs_for_workers = [{} for _ in range(num_workers)]
@@ -117,6 +137,7 @@ def write_to_directory(
convert_func=convert_func, convert_func=convert_func,
dataset_version=dataset_version, dataset_version=dataset_version,
dataset_name=dataset_name, dataset_name=dataset_name,
preprocess=preprocess,
overwrite=overwrite overwrite=overwrite
) )
@@ -127,13 +148,16 @@ def write_to_directory(
merge_database(save_path, *output_pathes, exist_ok=True, overwrite=False, try_generate_missing_file=False) merge_database(save_path, *output_pathes, exist_ok=True, overwrite=False, try_generate_missing_file=False)
def writing_to_directory_wrapper(args, convert_func, dataset_version, dataset_name, overwrite=False): def writing_to_directory_wrapper(
args, convert_func, dataset_version, dataset_name, overwrite=False, preprocess=single_worker_preprocess
):
return write_to_directory_single_worker( return write_to_directory_single_worker(
convert_func=convert_func, convert_func=convert_func,
scenarios=args[0], scenarios=args[0],
output_path=args[3], output_path=args[3],
dataset_version=dataset_version, dataset_version=dataset_version,
dataset_name=dataset_name, dataset_name=dataset_name,
preprocess=preprocess,
overwrite=overwrite, overwrite=overwrite,
worker_index=args[2], worker_index=args[2],
**args[1] **args[1]
@@ -149,6 +173,7 @@ def write_to_directory_single_worker(
worker_index=0, worker_index=0,
overwrite=False, overwrite=False,
report_memory_freq=None, report_memory_freq=None,
preprocess=single_worker_preprocess,
**kwargs **kwargs
): ):
""" """
@@ -161,6 +186,9 @@ def write_to_directory_single_worker(
kwargs.pop("version") kwargs.pop("version")
logger.info("the specified version in kwargs is replaced by argument: 'dataset_version'") logger.info("the specified version in kwargs is replaced by argument: 'dataset_version'")
# preprocess
scenarios = preprocess(scenarios, worker_index)
save_path = copy.deepcopy(output_path) save_path = copy.deepcopy(output_path)
output_path = output_path + "_tmp" output_path = output_path + "_tmp"
# meta recorder and data summary # meta recorder and data summary

View File

@@ -421,59 +421,40 @@ def convert_waymo_scenario(scenario, version):
for count, id in enumerate(track_id) for count, id in enumerate(track_id)
} }
# clean memory # clean memory
scenario.Clear()
del scenario del scenario
scenario = None scenario = None
return md_scenario return md_scenario
def get_waymo_scenarios(waymo_data_directory, start_index, num, num_workers=8): def get_waymo_scenarios(waymo_data_directory, start_index, num):
# parse raw data from input path to output path, # parse raw data from input path to output path,
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file # there is 1000 raw data in google cloud, each of them produce about 500 pkl file
logger.info("\n Reading raw data") logger.info("\nReading raw data")
file_list = os.listdir(waymo_data_directory) file_list = os.listdir(waymo_data_directory)
assert len(file_list) >= start_index + num and start_index >= 0, \ assert len(file_list) >= start_index + num and start_index >= 0, \
"No sufficient files ({}) in raw_data_directory. need: {}, start: {}".format(len(file_list), num, start_index) "No sufficient files ({}) in raw_data_directory. need: {}, start: {}".format(len(file_list), num, start_index)
file_list = file_list[start_index:start_index + num] file_list = file_list[start_index:start_index + num]
num_files = len(file_list) num_files = len(file_list)
if num_files < num_workers: all_result = [os.path.join(waymo_data_directory, f) for f in file_list]
# single process logger.info("\nFind {} waymo files".format(num_files))
logger.info("Use one worker, as num_scenario < num_workers:")
num_workers = 1
argument_list = []
num_files_each_worker = int(num_files // num_workers)
for i in range(num_workers):
if i == num_workers - 1:
end_idx = num_files
else:
end_idx = (i + 1) * num_files_each_worker
argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]])
# Run, workers and process result from worker
# with multiprocessing.Pool(num_workers) as p:
# all_result = list(p.imap(read_from_files, argument_list))
# Disable multiprocessing read
all_result = read_from_files([waymo_data_directory, file_list])
# ret = []
#
# # get result
# for r in all_result:
# if len(r) == 0:
# logger.warning("0 scenarios found")
# ret += r
logger.info("\n Find {} waymo scenarios from {} files".format(len(all_result), num_files))
return all_result return all_result
def read_from_files(arg): def preprocess_waymo_scenarios(files, worker_index):
"""
Convert the waymo files into scenario_pb2. This happens in each worker.
:param files: a list of file path
:param worker_index, the index for the worker
:return: a list of scenario_pb2
"""
try: try:
scenario_pb2 scenario_pb2
except NameError: except NameError:
raise ImportError("Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0") raise ImportError("Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0")
waymo_data_directory, file_list = arg[0], arg[1]
scenarios = [] for file in tqdm.tqdm(files, desc="Process Waymo scenarios for worker {}".format(worker_index)):
for file in tqdm.tqdm(file_list): file_path = os.path.join(file)
file_path = os.path.join(waymo_data_directory, file)
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)): if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
continue continue
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator(): for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
@@ -481,5 +462,6 @@ def read_from_files(arg):
scenario.ParseFromString(data) scenario.ParseFromString(data)
# a trick for loging file name # a trick for loging file name
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file
scenarios.append(scenario) yield scenario
return scenarios # logger.info("Worker {}: Process {} waymo scenarios".format(worker_index, len(scenarios)))
# return scenarios