From 7b44e99d88408c931d61169144290a124fe8261a Mon Sep 17 00:00:00 2001 From: QuanyiLi Date: Mon, 8 May 2023 14:14:15 +0100 Subject: [PATCH] multi-process writing --- scenarionet/converter/utils.py | 69 ++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 3 deletions(-) diff --git a/scenarionet/converter/utils.py b/scenarionet/converter/utils.py index ec3997f..ecdf0df 100644 --- a/scenarionet/converter/utils.py +++ b/scenarionet/converter/utils.py @@ -3,9 +3,11 @@ import copy import inspect import logging import math +import multiprocessing import os import pickle import shutil +from functools import partial import numpy as np import tqdm @@ -57,9 +59,70 @@ def contains_explicit_return(f): return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f)))) -def write_to_directory( - convert_func, scenarios, output_path, dataset_version, dataset_name, force_overwrite=False, **kwargs -): +def write_to_directory(convert_func, + scenarios, + output_path, + dataset_version, + dataset_name, + force_overwrite=False, + num_workers=8, + **kwargs): + # make sure dir not exist + basename = os.path.basename(output_path) + dir = os.path.dirname(output_path) + for i in range(num_workers): + output_path = os.path.join(dir, "{}_{}".format(basename, str(i))) + if os.path.exists(output_path): + if not force_overwrite: + raise ValueError("Directory {} already exists! Abort. " + "\n Try setting force_overwrite=True or adding --overwrite".format(output_path)) + # get arguments for workers + num_files = len(scenarios) + if num_files < num_workers: + # single process + logger.info("Use one worker, as num_scenario < num_workers:") + num_workers = 1 + + argument_list = [] + num_files_each_worker = int(num_files // num_workers) + for i in range(num_workers): + if i == num_workers - 1: + end_idx = num_files + else: + end_idx = (i + 1) * num_files_each_worker + argument_list.append([scenarios[i * num_files_each_worker:end_idx], kwargs]) + + # prefill arguments + func = partial(writing_to_directory_wrapper, + convert_func, + output_path, + dataset_version, + dataset_name, + force_overwrite) + + # Run, workers and process result from worker + with multiprocessing.Pool(num_workers) as p: + all_result = list(p.imap(func, argument_list)) + return all_result + + +def writing_to_directory_wrapper(args, + convert_func, + output_path, + dataset_version, + dataset_name, + force_overwrite=False): + return write_to_directory_single_worker(convert_func=convert_func, + scenarios=args[0], + output_path=output_path, + dataset_version=dataset_version, + dataset_name=dataset_name, + force_overwrite=force_overwrite, + **args[1]) + + +def write_to_directory_single_worker(convert_func, scenarios, output_path, dataset_version, dataset_name, + force_overwrite=False, **kwargs): """ Convert a batch of scenarios. """