Optimize waymo converter (#44)
* use generator for waymo * :wqadd preprocessor * use generator
This commit is contained in:
@@ -9,7 +9,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_REPO_PATH
|
from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_REPO_PATH
|
||||||
from scenarionet.converter.utils import write_to_directory
|
from scenarionet.converter.utils import write_to_directory
|
||||||
from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios
|
from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios, preprocess_waymo_scenarios
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -63,16 +63,15 @@ if __name__ == '__main__':
|
|||||||
shutil.rmtree(output_path)
|
shutil.rmtree(output_path)
|
||||||
|
|
||||||
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path)
|
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path)
|
||||||
scenarios = get_waymo_scenarios(
|
files = get_waymo_scenarios(waymo_data_directory, args.start_file_index, args.num_files)
|
||||||
waymo_data_directory, args.start_file_index, args.num_files, num_workers=8
|
|
||||||
) # do not use too much worker to read data
|
|
||||||
|
|
||||||
write_to_directory(
|
write_to_directory(
|
||||||
convert_func=convert_waymo_scenario,
|
convert_func=convert_waymo_scenario,
|
||||||
scenarios=scenarios,
|
scenarios=files,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
dataset_version=version,
|
dataset_version=version,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
overwrite=overwrite,
|
overwrite=overwrite,
|
||||||
num_workers=args.num_workers
|
num_workers=args.num_workers,
|
||||||
|
preprocess=preprocess_waymo_scenarios,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,6 +21,18 @@ from scenarionet.converter.pg.utils import convert_pg_scenario, make_env
|
|||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
def single_worker_preprocess(x, worker_index):
|
||||||
|
"""
|
||||||
|
All scenarios passed to write_to_directory_single_worker will be preprocessed. The input is expected to be a list.
|
||||||
|
The output should be a list too. The element in the second list will be processed by convertors. By default, you
|
||||||
|
don't need to provide this processor. We override it for waymo convertor to release the memory in time.
|
||||||
|
:param x: input
|
||||||
|
:param worker_index: worker_index, useful for logging
|
||||||
|
:return: input
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def nuplan_to_metadrive_vector(vector, nuplan_center=(0, 0)):
|
def nuplan_to_metadrive_vector(vector, nuplan_center=(0, 0)):
|
||||||
"All vec in nuplan should be centered in (0,0) to avoid numerical explosion"
|
"All vec in nuplan should be centered in (0,0) to avoid numerical explosion"
|
||||||
vector = np.array(vector)
|
vector = np.array(vector)
|
||||||
@@ -63,7 +75,15 @@ def contains_explicit_return(f):
|
|||||||
|
|
||||||
|
|
||||||
def write_to_directory(
|
def write_to_directory(
|
||||||
convert_func, scenarios, output_path, dataset_version, dataset_name, overwrite=False, num_workers=8, **kwargs
|
convert_func,
|
||||||
|
scenarios,
|
||||||
|
output_path,
|
||||||
|
dataset_version,
|
||||||
|
dataset_name,
|
||||||
|
overwrite=False,
|
||||||
|
num_workers=8,
|
||||||
|
preprocess=single_worker_preprocess,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
# make sure dir not exist
|
# make sure dir not exist
|
||||||
kwargs_for_workers = [{} for _ in range(num_workers)]
|
kwargs_for_workers = [{} for _ in range(num_workers)]
|
||||||
@@ -117,6 +137,7 @@ def write_to_directory(
|
|||||||
convert_func=convert_func,
|
convert_func=convert_func,
|
||||||
dataset_version=dataset_version,
|
dataset_version=dataset_version,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
|
preprocess=preprocess,
|
||||||
overwrite=overwrite
|
overwrite=overwrite
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -127,13 +148,16 @@ def write_to_directory(
|
|||||||
merge_database(save_path, *output_pathes, exist_ok=True, overwrite=False, try_generate_missing_file=False)
|
merge_database(save_path, *output_pathes, exist_ok=True, overwrite=False, try_generate_missing_file=False)
|
||||||
|
|
||||||
|
|
||||||
def writing_to_directory_wrapper(args, convert_func, dataset_version, dataset_name, overwrite=False):
|
def writing_to_directory_wrapper(
|
||||||
|
args, convert_func, dataset_version, dataset_name, overwrite=False, preprocess=single_worker_preprocess
|
||||||
|
):
|
||||||
return write_to_directory_single_worker(
|
return write_to_directory_single_worker(
|
||||||
convert_func=convert_func,
|
convert_func=convert_func,
|
||||||
scenarios=args[0],
|
scenarios=args[0],
|
||||||
output_path=args[3],
|
output_path=args[3],
|
||||||
dataset_version=dataset_version,
|
dataset_version=dataset_version,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
|
preprocess=preprocess,
|
||||||
overwrite=overwrite,
|
overwrite=overwrite,
|
||||||
worker_index=args[2],
|
worker_index=args[2],
|
||||||
**args[1]
|
**args[1]
|
||||||
@@ -149,6 +173,7 @@ def write_to_directory_single_worker(
|
|||||||
worker_index=0,
|
worker_index=0,
|
||||||
overwrite=False,
|
overwrite=False,
|
||||||
report_memory_freq=None,
|
report_memory_freq=None,
|
||||||
|
preprocess=single_worker_preprocess,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -161,6 +186,9 @@ def write_to_directory_single_worker(
|
|||||||
kwargs.pop("version")
|
kwargs.pop("version")
|
||||||
logger.info("the specified version in kwargs is replaced by argument: 'dataset_version'")
|
logger.info("the specified version in kwargs is replaced by argument: 'dataset_version'")
|
||||||
|
|
||||||
|
# preprocess
|
||||||
|
scenarios = preprocess(scenarios, worker_index)
|
||||||
|
|
||||||
save_path = copy.deepcopy(output_path)
|
save_path = copy.deepcopy(output_path)
|
||||||
output_path = output_path + "_tmp"
|
output_path = output_path + "_tmp"
|
||||||
# meta recorder and data summary
|
# meta recorder and data summary
|
||||||
|
|||||||
@@ -421,59 +421,40 @@ def convert_waymo_scenario(scenario, version):
|
|||||||
for count, id in enumerate(track_id)
|
for count, id in enumerate(track_id)
|
||||||
}
|
}
|
||||||
# clean memory
|
# clean memory
|
||||||
|
scenario.Clear()
|
||||||
del scenario
|
del scenario
|
||||||
scenario = None
|
scenario = None
|
||||||
return md_scenario
|
return md_scenario
|
||||||
|
|
||||||
|
|
||||||
def get_waymo_scenarios(waymo_data_directory, start_index, num, num_workers=8):
|
def get_waymo_scenarios(waymo_data_directory, start_index, num):
|
||||||
# parse raw data from input path to output path,
|
# parse raw data from input path to output path,
|
||||||
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
|
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
|
||||||
logger.info("\n Reading raw data")
|
logger.info("\nReading raw data")
|
||||||
file_list = os.listdir(waymo_data_directory)
|
file_list = os.listdir(waymo_data_directory)
|
||||||
assert len(file_list) >= start_index + num and start_index >= 0, \
|
assert len(file_list) >= start_index + num and start_index >= 0, \
|
||||||
"No sufficient files ({}) in raw_data_directory. need: {}, start: {}".format(len(file_list), num, start_index)
|
"No sufficient files ({}) in raw_data_directory. need: {}, start: {}".format(len(file_list), num, start_index)
|
||||||
file_list = file_list[start_index:start_index + num]
|
file_list = file_list[start_index:start_index + num]
|
||||||
num_files = len(file_list)
|
num_files = len(file_list)
|
||||||
if num_files < num_workers:
|
all_result = [os.path.join(waymo_data_directory, f) for f in file_list]
|
||||||
# single process
|
logger.info("\nFind {} waymo files".format(num_files))
|
||||||
logger.info("Use one worker, as num_scenario < num_workers:")
|
|
||||||
num_workers = 1
|
|
||||||
|
|
||||||
argument_list = []
|
|
||||||
num_files_each_worker = int(num_files // num_workers)
|
|
||||||
for i in range(num_workers):
|
|
||||||
if i == num_workers - 1:
|
|
||||||
end_idx = num_files
|
|
||||||
else:
|
|
||||||
end_idx = (i + 1) * num_files_each_worker
|
|
||||||
argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]])
|
|
||||||
|
|
||||||
# Run, workers and process result from worker
|
|
||||||
# with multiprocessing.Pool(num_workers) as p:
|
|
||||||
# all_result = list(p.imap(read_from_files, argument_list))
|
|
||||||
# Disable multiprocessing read
|
|
||||||
all_result = read_from_files([waymo_data_directory, file_list])
|
|
||||||
# ret = []
|
|
||||||
#
|
|
||||||
# # get result
|
|
||||||
# for r in all_result:
|
|
||||||
# if len(r) == 0:
|
|
||||||
# logger.warning("0 scenarios found")
|
|
||||||
# ret += r
|
|
||||||
logger.info("\n Find {} waymo scenarios from {} files".format(len(all_result), num_files))
|
|
||||||
return all_result
|
return all_result
|
||||||
|
|
||||||
|
|
||||||
def read_from_files(arg):
|
def preprocess_waymo_scenarios(files, worker_index):
|
||||||
|
"""
|
||||||
|
Convert the waymo files into scenario_pb2. This happens in each worker.
|
||||||
|
:param files: a list of file path
|
||||||
|
:param worker_index, the index for the worker
|
||||||
|
:return: a list of scenario_pb2
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
scenario_pb2
|
scenario_pb2
|
||||||
except NameError:
|
except NameError:
|
||||||
raise ImportError("Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0")
|
raise ImportError("Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0")
|
||||||
waymo_data_directory, file_list = arg[0], arg[1]
|
|
||||||
scenarios = []
|
for file in tqdm.tqdm(files, desc="Process Waymo scenarios for worker {}".format(worker_index)):
|
||||||
for file in tqdm.tqdm(file_list):
|
file_path = os.path.join(file)
|
||||||
file_path = os.path.join(waymo_data_directory, file)
|
|
||||||
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
||||||
continue
|
continue
|
||||||
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
|
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
|
||||||
@@ -481,5 +462,6 @@ def read_from_files(arg):
|
|||||||
scenario.ParseFromString(data)
|
scenario.ParseFromString(data)
|
||||||
# a trick for loging file name
|
# a trick for loging file name
|
||||||
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file
|
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file
|
||||||
scenarios.append(scenario)
|
yield scenario
|
||||||
return scenarios
|
# logger.info("Worker {}: Process {} waymo scenarios".format(worker_index, len(scenarios)))
|
||||||
|
# return scenarios
|
||||||
|
|||||||
Reference in New Issue
Block a user