Optimize waymo converter (#44)

* use generator for waymo

* :wqadd preprocessor

* use generator
This commit is contained in:
Quanyi Li
2023-11-05 16:07:07 +00:00
committed by GitHub
parent 3dd161188c
commit 8bc1d88f06
3 changed files with 53 additions and 44 deletions

View File

@@ -9,7 +9,7 @@ if __name__ == '__main__':
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_REPO_PATH
from scenarionet.converter.utils import write_to_directory
from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios
from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios, preprocess_waymo_scenarios
logger = logging.getLogger(__name__)
@@ -63,16 +63,15 @@ if __name__ == '__main__':
shutil.rmtree(output_path)
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path)
scenarios = get_waymo_scenarios(
waymo_data_directory, args.start_file_index, args.num_files, num_workers=8
) # do not use too much worker to read data
files = get_waymo_scenarios(waymo_data_directory, args.start_file_index, args.num_files)
write_to_directory(
convert_func=convert_waymo_scenario,
scenarios=scenarios,
scenarios=files,
output_path=output_path,
dataset_version=version,
dataset_name=dataset_name,
overwrite=overwrite,
num_workers=args.num_workers
num_workers=args.num_workers,
preprocess=preprocess_waymo_scenarios,
)

View File

@@ -21,6 +21,18 @@ from scenarionet.converter.pg.utils import convert_pg_scenario, make_env
logger = logging.getLogger(__file__)
def single_worker_preprocess(x, worker_index):
"""
All scenarios passed to write_to_directory_single_worker will be preprocessed. The input is expected to be a list.
The output should be a list too. The element in the second list will be processed by convertors. By default, you
don't need to provide this processor. We override it for waymo convertor to release the memory in time.
:param x: input
:param worker_index: worker_index, useful for logging
:return: input
"""
return x
def nuplan_to_metadrive_vector(vector, nuplan_center=(0, 0)):
"All vec in nuplan should be centered in (0,0) to avoid numerical explosion"
vector = np.array(vector)
@@ -63,7 +75,15 @@ def contains_explicit_return(f):
def write_to_directory(
convert_func, scenarios, output_path, dataset_version, dataset_name, overwrite=False, num_workers=8, **kwargs
convert_func,
scenarios,
output_path,
dataset_version,
dataset_name,
overwrite=False,
num_workers=8,
preprocess=single_worker_preprocess,
**kwargs
):
# make sure dir not exist
kwargs_for_workers = [{} for _ in range(num_workers)]
@@ -117,6 +137,7 @@ def write_to_directory(
convert_func=convert_func,
dataset_version=dataset_version,
dataset_name=dataset_name,
preprocess=preprocess,
overwrite=overwrite
)
@@ -127,13 +148,16 @@ def write_to_directory(
merge_database(save_path, *output_pathes, exist_ok=True, overwrite=False, try_generate_missing_file=False)
def writing_to_directory_wrapper(args, convert_func, dataset_version, dataset_name, overwrite=False):
def writing_to_directory_wrapper(
args, convert_func, dataset_version, dataset_name, overwrite=False, preprocess=single_worker_preprocess
):
return write_to_directory_single_worker(
convert_func=convert_func,
scenarios=args[0],
output_path=args[3],
dataset_version=dataset_version,
dataset_name=dataset_name,
preprocess=preprocess,
overwrite=overwrite,
worker_index=args[2],
**args[1]
@@ -149,6 +173,7 @@ def write_to_directory_single_worker(
worker_index=0,
overwrite=False,
report_memory_freq=None,
preprocess=single_worker_preprocess,
**kwargs
):
"""
@@ -161,6 +186,9 @@ def write_to_directory_single_worker(
kwargs.pop("version")
logger.info("the specified version in kwargs is replaced by argument: 'dataset_version'")
# preprocess
scenarios = preprocess(scenarios, worker_index)
save_path = copy.deepcopy(output_path)
output_path = output_path + "_tmp"
# meta recorder and data summary

View File

@@ -421,12 +421,13 @@ def convert_waymo_scenario(scenario, version):
for count, id in enumerate(track_id)
}
# clean memory
scenario.Clear()
del scenario
scenario = None
return md_scenario
def get_waymo_scenarios(waymo_data_directory, start_index, num, num_workers=8):
def get_waymo_scenarios(waymo_data_directory, start_index, num):
# parse raw data from input path to output path,
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
logger.info("\nReading raw data")
@@ -435,45 +436,25 @@ def get_waymo_scenarios(waymo_data_directory, start_index, num, num_workers=8):
"No sufficient files ({}) in raw_data_directory. need: {}, start: {}".format(len(file_list), num, start_index)
file_list = file_list[start_index:start_index + num]
num_files = len(file_list)
if num_files < num_workers:
# single process
logger.info("Use one worker, as num_scenario < num_workers:")
num_workers = 1
argument_list = []
num_files_each_worker = int(num_files // num_workers)
for i in range(num_workers):
if i == num_workers - 1:
end_idx = num_files
else:
end_idx = (i + 1) * num_files_each_worker
argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]])
# Run, workers and process result from worker
# with multiprocessing.Pool(num_workers) as p:
# all_result = list(p.imap(read_from_files, argument_list))
# Disable multiprocessing read
all_result = read_from_files([waymo_data_directory, file_list])
# ret = []
#
# # get result
# for r in all_result:
# if len(r) == 0:
# logger.warning("0 scenarios found")
# ret += r
logger.info("\n Find {} waymo scenarios from {} files".format(len(all_result), num_files))
all_result = [os.path.join(waymo_data_directory, f) for f in file_list]
logger.info("\nFind {} waymo files".format(num_files))
return all_result
def read_from_files(arg):
def preprocess_waymo_scenarios(files, worker_index):
"""
Convert the waymo files into scenario_pb2. This happens in each worker.
:param files: a list of file path
:param worker_index, the index for the worker
:return: a list of scenario_pb2
"""
try:
scenario_pb2
except NameError:
raise ImportError("Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0")
waymo_data_directory, file_list = arg[0], arg[1]
scenarios = []
for file in tqdm.tqdm(file_list):
file_path = os.path.join(waymo_data_directory, file)
for file in tqdm.tqdm(files, desc="Process Waymo scenarios for worker {}".format(worker_index)):
file_path = os.path.join(file)
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
continue
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
@@ -481,5 +462,6 @@ def read_from_files(arg):
scenario.ParseFromString(data)
# a trick for loging file name
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file
scenarios.append(scenario)
return scenarios
yield scenario
# logger.info("Worker {}: Process {} waymo scenarios".format(worker_index, len(scenarios)))
# return scenarios