From 8bc1d88f069d9ddad511bfbb2ba806cb59df1949 Mon Sep 17 00:00:00 2001 From: Quanyi Li Date: Sun, 5 Nov 2023 16:07:07 +0000 Subject: [PATCH] Optimize waymo converter (#44) * use generator for waymo * :wqadd preprocessor * use generator --- scenarionet/convert_waymo.py | 11 +++--- scenarionet/converter/utils.py | 32 +++++++++++++++-- scenarionet/converter/waymo/utils.py | 54 ++++++++++------------------ 3 files changed, 53 insertions(+), 44 deletions(-) diff --git a/scenarionet/convert_waymo.py b/scenarionet/convert_waymo.py index 294f08b..8b4037c 100644 --- a/scenarionet/convert_waymo.py +++ b/scenarionet/convert_waymo.py @@ -9,7 +9,7 @@ if __name__ == '__main__': from scenarionet import SCENARIONET_DATASET_PATH, SCENARIONET_REPO_PATH from scenarionet.converter.utils import write_to_directory - from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios + from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios, preprocess_waymo_scenarios logger = logging.getLogger(__name__) @@ -63,16 +63,15 @@ if __name__ == '__main__': shutil.rmtree(output_path) waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path) - scenarios = get_waymo_scenarios( - waymo_data_directory, args.start_file_index, args.num_files, num_workers=8 - ) # do not use too much worker to read data + files = get_waymo_scenarios(waymo_data_directory, args.start_file_index, args.num_files) write_to_directory( convert_func=convert_waymo_scenario, - scenarios=scenarios, + scenarios=files, output_path=output_path, dataset_version=version, dataset_name=dataset_name, overwrite=overwrite, - num_workers=args.num_workers + num_workers=args.num_workers, + preprocess=preprocess_waymo_scenarios, ) diff --git a/scenarionet/converter/utils.py b/scenarionet/converter/utils.py index bc031f8..6e8d88b 100644 --- a/scenarionet/converter/utils.py +++ b/scenarionet/converter/utils.py @@ -21,6 +21,18 @@ from scenarionet.converter.pg.utils import convert_pg_scenario, make_env logger = logging.getLogger(__file__) +def single_worker_preprocess(x, worker_index): + """ + All scenarios passed to write_to_directory_single_worker will be preprocessed. The input is expected to be a list. + The output should be a list too. The element in the second list will be processed by convertors. By default, you + don't need to provide this processor. We override it for waymo convertor to release the memory in time. + :param x: input + :param worker_index: worker_index, useful for logging + :return: input + """ + return x + + def nuplan_to_metadrive_vector(vector, nuplan_center=(0, 0)): "All vec in nuplan should be centered in (0,0) to avoid numerical explosion" vector = np.array(vector) @@ -63,7 +75,15 @@ def contains_explicit_return(f): def write_to_directory( - convert_func, scenarios, output_path, dataset_version, dataset_name, overwrite=False, num_workers=8, **kwargs + convert_func, + scenarios, + output_path, + dataset_version, + dataset_name, + overwrite=False, + num_workers=8, + preprocess=single_worker_preprocess, + **kwargs ): # make sure dir not exist kwargs_for_workers = [{} for _ in range(num_workers)] @@ -117,6 +137,7 @@ def write_to_directory( convert_func=convert_func, dataset_version=dataset_version, dataset_name=dataset_name, + preprocess=preprocess, overwrite=overwrite ) @@ -127,13 +148,16 @@ def write_to_directory( merge_database(save_path, *output_pathes, exist_ok=True, overwrite=False, try_generate_missing_file=False) -def writing_to_directory_wrapper(args, convert_func, dataset_version, dataset_name, overwrite=False): +def writing_to_directory_wrapper( + args, convert_func, dataset_version, dataset_name, overwrite=False, preprocess=single_worker_preprocess +): return write_to_directory_single_worker( convert_func=convert_func, scenarios=args[0], output_path=args[3], dataset_version=dataset_version, dataset_name=dataset_name, + preprocess=preprocess, overwrite=overwrite, worker_index=args[2], **args[1] @@ -149,6 +173,7 @@ def write_to_directory_single_worker( worker_index=0, overwrite=False, report_memory_freq=None, + preprocess=single_worker_preprocess, **kwargs ): """ @@ -161,6 +186,9 @@ def write_to_directory_single_worker( kwargs.pop("version") logger.info("the specified version in kwargs is replaced by argument: 'dataset_version'") + # preprocess + scenarios = preprocess(scenarios, worker_index) + save_path = copy.deepcopy(output_path) output_path = output_path + "_tmp" # meta recorder and data summary diff --git a/scenarionet/converter/waymo/utils.py b/scenarionet/converter/waymo/utils.py index 8d06ef5..a5c0813 100644 --- a/scenarionet/converter/waymo/utils.py +++ b/scenarionet/converter/waymo/utils.py @@ -421,59 +421,40 @@ def convert_waymo_scenario(scenario, version): for count, id in enumerate(track_id) } # clean memory + scenario.Clear() del scenario scenario = None return md_scenario -def get_waymo_scenarios(waymo_data_directory, start_index, num, num_workers=8): +def get_waymo_scenarios(waymo_data_directory, start_index, num): # parse raw data from input path to output path, # there is 1000 raw data in google cloud, each of them produce about 500 pkl file - logger.info("\n Reading raw data") + logger.info("\nReading raw data") file_list = os.listdir(waymo_data_directory) assert len(file_list) >= start_index + num and start_index >= 0, \ "No sufficient files ({}) in raw_data_directory. need: {}, start: {}".format(len(file_list), num, start_index) file_list = file_list[start_index:start_index + num] num_files = len(file_list) - if num_files < num_workers: - # single process - logger.info("Use one worker, as num_scenario < num_workers:") - num_workers = 1 - - argument_list = [] - num_files_each_worker = int(num_files // num_workers) - for i in range(num_workers): - if i == num_workers - 1: - end_idx = num_files - else: - end_idx = (i + 1) * num_files_each_worker - argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]]) - - # Run, workers and process result from worker - # with multiprocessing.Pool(num_workers) as p: - # all_result = list(p.imap(read_from_files, argument_list)) - # Disable multiprocessing read - all_result = read_from_files([waymo_data_directory, file_list]) - # ret = [] - # - # # get result - # for r in all_result: - # if len(r) == 0: - # logger.warning("0 scenarios found") - # ret += r - logger.info("\n Find {} waymo scenarios from {} files".format(len(all_result), num_files)) + all_result = [os.path.join(waymo_data_directory, f) for f in file_list] + logger.info("\nFind {} waymo files".format(num_files)) return all_result -def read_from_files(arg): +def preprocess_waymo_scenarios(files, worker_index): + """ + Convert the waymo files into scenario_pb2. This happens in each worker. + :param files: a list of file path + :param worker_index, the index for the worker + :return: a list of scenario_pb2 + """ try: scenario_pb2 except NameError: raise ImportError("Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0") - waymo_data_directory, file_list = arg[0], arg[1] - scenarios = [] - for file in tqdm.tqdm(file_list): - file_path = os.path.join(waymo_data_directory, file) + + for file in tqdm.tqdm(files, desc="Process Waymo scenarios for worker {}".format(worker_index)): + file_path = os.path.join(file) if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)): continue for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator(): @@ -481,5 +462,6 @@ def read_from_files(arg): scenario.ParseFromString(data) # a trick for loging file name scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file - scenarios.append(scenario) - return scenarios + yield scenario + # logger.info("Worker {}: Process {} waymo scenarios".format(worker_index, len(scenarios))) + # return scenarios