Optimize waymo converter (#44)
* use generator for waymo * :wqadd preprocessor * use generator
This commit is contained in:
@@ -9,7 +9,7 @@ if __name__ == '__main__':
|
||||
|
||||
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_REPO_PATH
|
||||
from scenarionet.converter.utils import write_to_directory
|
||||
from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios
|
||||
from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios, preprocess_waymo_scenarios
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -63,16 +63,15 @@ if __name__ == '__main__':
|
||||
shutil.rmtree(output_path)
|
||||
|
||||
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path)
|
||||
scenarios = get_waymo_scenarios(
|
||||
waymo_data_directory, args.start_file_index, args.num_files, num_workers=8
|
||||
) # do not use too much worker to read data
|
||||
files = get_waymo_scenarios(waymo_data_directory, args.start_file_index, args.num_files)
|
||||
|
||||
write_to_directory(
|
||||
convert_func=convert_waymo_scenario,
|
||||
scenarios=scenarios,
|
||||
scenarios=files,
|
||||
output_path=output_path,
|
||||
dataset_version=version,
|
||||
dataset_name=dataset_name,
|
||||
overwrite=overwrite,
|
||||
num_workers=args.num_workers
|
||||
num_workers=args.num_workers,
|
||||
preprocess=preprocess_waymo_scenarios,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,18 @@ from scenarionet.converter.pg.utils import convert_pg_scenario, make_env
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def single_worker_preprocess(x, worker_index):
|
||||
"""
|
||||
All scenarios passed to write_to_directory_single_worker will be preprocessed. The input is expected to be a list.
|
||||
The output should be a list too. The element in the second list will be processed by convertors. By default, you
|
||||
don't need to provide this processor. We override it for waymo convertor to release the memory in time.
|
||||
:param x: input
|
||||
:param worker_index: worker_index, useful for logging
|
||||
:return: input
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
def nuplan_to_metadrive_vector(vector, nuplan_center=(0, 0)):
|
||||
"All vec in nuplan should be centered in (0,0) to avoid numerical explosion"
|
||||
vector = np.array(vector)
|
||||
@@ -63,7 +75,15 @@ def contains_explicit_return(f):
|
||||
|
||||
|
||||
def write_to_directory(
|
||||
convert_func, scenarios, output_path, dataset_version, dataset_name, overwrite=False, num_workers=8, **kwargs
|
||||
convert_func,
|
||||
scenarios,
|
||||
output_path,
|
||||
dataset_version,
|
||||
dataset_name,
|
||||
overwrite=False,
|
||||
num_workers=8,
|
||||
preprocess=single_worker_preprocess,
|
||||
**kwargs
|
||||
):
|
||||
# make sure dir not exist
|
||||
kwargs_for_workers = [{} for _ in range(num_workers)]
|
||||
@@ -117,6 +137,7 @@ def write_to_directory(
|
||||
convert_func=convert_func,
|
||||
dataset_version=dataset_version,
|
||||
dataset_name=dataset_name,
|
||||
preprocess=preprocess,
|
||||
overwrite=overwrite
|
||||
)
|
||||
|
||||
@@ -127,13 +148,16 @@ def write_to_directory(
|
||||
merge_database(save_path, *output_pathes, exist_ok=True, overwrite=False, try_generate_missing_file=False)
|
||||
|
||||
|
||||
def writing_to_directory_wrapper(args, convert_func, dataset_version, dataset_name, overwrite=False):
|
||||
def writing_to_directory_wrapper(
|
||||
args, convert_func, dataset_version, dataset_name, overwrite=False, preprocess=single_worker_preprocess
|
||||
):
|
||||
return write_to_directory_single_worker(
|
||||
convert_func=convert_func,
|
||||
scenarios=args[0],
|
||||
output_path=args[3],
|
||||
dataset_version=dataset_version,
|
||||
dataset_name=dataset_name,
|
||||
preprocess=preprocess,
|
||||
overwrite=overwrite,
|
||||
worker_index=args[2],
|
||||
**args[1]
|
||||
@@ -149,6 +173,7 @@ def write_to_directory_single_worker(
|
||||
worker_index=0,
|
||||
overwrite=False,
|
||||
report_memory_freq=None,
|
||||
preprocess=single_worker_preprocess,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -161,6 +186,9 @@ def write_to_directory_single_worker(
|
||||
kwargs.pop("version")
|
||||
logger.info("the specified version in kwargs is replaced by argument: 'dataset_version'")
|
||||
|
||||
# preprocess
|
||||
scenarios = preprocess(scenarios, worker_index)
|
||||
|
||||
save_path = copy.deepcopy(output_path)
|
||||
output_path = output_path + "_tmp"
|
||||
# meta recorder and data summary
|
||||
|
||||
@@ -421,12 +421,13 @@ def convert_waymo_scenario(scenario, version):
|
||||
for count, id in enumerate(track_id)
|
||||
}
|
||||
# clean memory
|
||||
scenario.Clear()
|
||||
del scenario
|
||||
scenario = None
|
||||
return md_scenario
|
||||
|
||||
|
||||
def get_waymo_scenarios(waymo_data_directory, start_index, num, num_workers=8):
|
||||
def get_waymo_scenarios(waymo_data_directory, start_index, num):
|
||||
# parse raw data from input path to output path,
|
||||
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
|
||||
logger.info("\nReading raw data")
|
||||
@@ -435,45 +436,25 @@ def get_waymo_scenarios(waymo_data_directory, start_index, num, num_workers=8):
|
||||
"No sufficient files ({}) in raw_data_directory. need: {}, start: {}".format(len(file_list), num, start_index)
|
||||
file_list = file_list[start_index:start_index + num]
|
||||
num_files = len(file_list)
|
||||
if num_files < num_workers:
|
||||
# single process
|
||||
logger.info("Use one worker, as num_scenario < num_workers:")
|
||||
num_workers = 1
|
||||
|
||||
argument_list = []
|
||||
num_files_each_worker = int(num_files // num_workers)
|
||||
for i in range(num_workers):
|
||||
if i == num_workers - 1:
|
||||
end_idx = num_files
|
||||
else:
|
||||
end_idx = (i + 1) * num_files_each_worker
|
||||
argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]])
|
||||
|
||||
# Run, workers and process result from worker
|
||||
# with multiprocessing.Pool(num_workers) as p:
|
||||
# all_result = list(p.imap(read_from_files, argument_list))
|
||||
# Disable multiprocessing read
|
||||
all_result = read_from_files([waymo_data_directory, file_list])
|
||||
# ret = []
|
||||
#
|
||||
# # get result
|
||||
# for r in all_result:
|
||||
# if len(r) == 0:
|
||||
# logger.warning("0 scenarios found")
|
||||
# ret += r
|
||||
logger.info("\n Find {} waymo scenarios from {} files".format(len(all_result), num_files))
|
||||
all_result = [os.path.join(waymo_data_directory, f) for f in file_list]
|
||||
logger.info("\nFind {} waymo files".format(num_files))
|
||||
return all_result
|
||||
|
||||
|
||||
def read_from_files(arg):
|
||||
def preprocess_waymo_scenarios(files, worker_index):
|
||||
"""
|
||||
Convert the waymo files into scenario_pb2. This happens in each worker.
|
||||
:param files: a list of file path
|
||||
:param worker_index, the index for the worker
|
||||
:return: a list of scenario_pb2
|
||||
"""
|
||||
try:
|
||||
scenario_pb2
|
||||
except NameError:
|
||||
raise ImportError("Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0")
|
||||
waymo_data_directory, file_list = arg[0], arg[1]
|
||||
scenarios = []
|
||||
for file in tqdm.tqdm(file_list):
|
||||
file_path = os.path.join(waymo_data_directory, file)
|
||||
|
||||
for file in tqdm.tqdm(files, desc="Process Waymo scenarios for worker {}".format(worker_index)):
|
||||
file_path = os.path.join(file)
|
||||
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
||||
continue
|
||||
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
|
||||
@@ -481,5 +462,6 @@ def read_from_files(arg):
|
||||
scenario.ParseFromString(data)
|
||||
# a trick for loging file name
|
||||
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file
|
||||
scenarios.append(scenario)
|
||||
return scenarios
|
||||
yield scenario
|
||||
# logger.info("Worker {}: Process {} waymo scenarios".format(worker_index, len(scenarios)))
|
||||
# return scenarios
|
||||
|
||||
Reference in New Issue
Block a user