Files
scenarionet/scenarionet/builder/utils.py
Quanyi Li 9acdcaf321 Make video (#5)
* generate accident scene

* construction PG

* no object

* accident prob

* capture script

* update nuscenes toolds

* make video

* format

* fix test

* update readme

* update readme

* format

* format
2023-06-17 17:19:40 +01:00

203 lines
8.5 KiB
Python

import copy
from random import sample
from metadrive.scenario.utils import read_dataset_summary
import logging
import os
import os.path as osp
import pickle
import shutil
from typing import Callable, List
import tqdm
from metadrive.scenario.scenario_description import ScenarioDescription
from scenarionet.common_utils import save_summary_anda_mapping
logger = logging.getLogger(__name__)
def try_generating_summary(file_folder):
# Create a fake one
files = os.listdir(file_folder)
summary = {}
for file in files:
if ScenarioDescription.is_scenario_file(file):
with open(osp.join(file_folder, file), "rb+") as f:
scenario = pickle.load(f)
summary[file] = copy.deepcopy(scenario[ScenarioDescription.METADATA])
return summary
def merge_database(
output_path,
*dataset_paths,
exist_ok=False,
overwrite=False,
try_generate_missing_file=True,
filters: List[Callable] = None,
save=True,
):
"""
Combine multiple datasets. Each database should have a dataset_summary.pkl
:param output_path: The path to store the output database
:param exist_ok: If True, though the output_path already exist, still write into it
:param overwrite: If True, overwrite existing dataset_summary.pkl and mapping.pkl. Otherwise, raise error
:param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them
:param dataset_paths: Path of each database
:param filters: a set of filters to choose which scenario to be selected and added into this combined database
:param save: save to output path, immediately
:return: summary, mapping
"""
filters = filters or []
output_abs_path = osp.abspath(output_path)
os.makedirs(output_abs_path, exist_ok=exist_ok)
summary_file = osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE)
mapping_file = osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE)
for file in [summary_file, mapping_file]:
if os.path.exists(file):
if overwrite:
os.remove(file)
else:
raise FileExistsError("{} already exists at: {}!".format(file, output_abs_path))
summaries = {}
mappings = {}
# collect
for dataset_path in tqdm.tqdm(dataset_paths, desc="Merge Data"):
abs_dir_path = osp.abspath(dataset_path)
# summary
assert osp.exists(abs_dir_path), "Wrong database path. Can not find database at: {}".format(abs_dir_path)
if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE)):
if try_generate_missing_file:
summary = try_generating_summary(abs_dir_path)
else:
raise FileNotFoundError("Can not find summary file for database: {}".format(abs_dir_path))
else:
with open(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE), "rb+") as f:
summary = pickle.load(f)
intersect = set(summaries.keys()).intersection(set(summary.keys()))
if len(intersect) > 0:
existing = []
for v in list(intersect):
existing.append(mappings[v])
logging.warning("Repeat scenarios: {} in : {}. Existing: {}".format(intersect, abs_dir_path, existing))
summaries.update(summary)
# mapping
if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE)):
if try_generate_missing_file:
mapping = {k: "" for k in summary}
else:
raise FileNotFoundError("Can not find mapping file for database: {}".format(abs_dir_path))
else:
with open(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE), "rb+") as f:
mapping = pickle.load(f)
new_mapping = {}
for file, rel_path in mapping.items():
# mapping to real file path
new_mapping[file] = os.path.relpath(osp.join(abs_dir_path, rel_path), output_abs_path)
mappings.update(new_mapping)
# apply filter stage
file_to_pop = []
for file_name in tqdm.tqdm(summaries.keys(), desc="Filter Scenarios"):
metadata = summaries[file_name]
if not all([fil(metadata, os.path.join(output_abs_path, mappings[file_name], file_name)) for fil in filters]):
file_to_pop.append(file_name)
for file in file_to_pop:
summaries.pop(file)
mappings.pop(file)
if save:
save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings)
return summaries, mappings
def copy_database(
from_path, to_path, exist_ok=False, overwrite=False, copy_raw_data=False, remove_source=False, force_move=False
):
if not os.path.exists(from_path):
raise FileNotFoundError("Can not find database: {}".format(from_path))
if os.path.exists(to_path):
assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a database"
assert not os.path.samefile(from_path, to_path), "to_directory is the same as from_directory. Abort!"
files = os.listdir(from_path)
if not force_move and (ScenarioDescription.DATASET.MAPPING_FILE in files
and ScenarioDescription.DATASET.SUMMARY_FILE in files and len(files) > 2):
raise RuntimeError(
"The source database is not allowed to move! "
"This will break the relationship between this database and other database built on it."
"If it is ok for you, use 'mv' to move it manually "
)
summaries, mappings = merge_database(
to_path, from_path, exist_ok=exist_ok, overwrite=overwrite, try_generate_missing_file=True, save=False
)
summary_file = osp.join(to_path, ScenarioDescription.DATASET.SUMMARY_FILE)
mapping_file = osp.join(to_path, ScenarioDescription.DATASET.MAPPING_FILE)
if copy_raw_data:
logger.info("Copy raw data...")
for scenario_file in tqdm.tqdm(mappings.keys()):
rel_path = mappings[scenario_file]
shutil.copyfile(os.path.join(to_path, rel_path, scenario_file), os.path.join(to_path, scenario_file))
mappings = {key: "./" for key in summaries.keys()}
save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings)
if remove_source and ScenarioDescription.DATASET.MAPPING_FILE in files and \
ScenarioDescription.DATASET.SUMMARY_FILE in files and len(files) == 2:
shutil.rmtree(from_path)
def split_database(
from_path,
to_path,
start_index,
num_scenarios,
exist_ok=False,
overwrite=False,
random=False,
):
if not os.path.exists(from_path):
raise FileNotFoundError("Can not find database: {}".format(from_path))
if os.path.exists(to_path):
assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a database"
assert not os.path.samefile(from_path, to_path), "to_directory is the same as from_directory. Abort!"
overwrite = overwrite,
output_abs_path = osp.abspath(to_path)
os.makedirs(output_abs_path, exist_ok=exist_ok)
summary_file = osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE)
mapping_file = osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE)
for file in [summary_file, mapping_file]:
if os.path.exists(file):
if overwrite:
os.remove(file)
else:
raise FileExistsError("{} already exists at: {}!".format(file, output_abs_path))
# collect
abs_dir_path = osp.abspath(from_path)
# summary
assert osp.exists(abs_dir_path), "Wrong database path. Can not find database at: {}".format(abs_dir_path)
summaries, lookup, mappings = read_dataset_summary(from_path)
assert start_index >= 0 and start_index + num_scenarios <= len(
lookup
), "No enough scenarios in source dataset: total {}, start_index: {}, need: {}".format(
len(lookup), start_index, num_scenarios
)
if random:
selected = sample(lookup[start_index:], k=num_scenarios)
else:
selected = lookup[start_index:start_index + num_scenarios]
selected_summary = {}
selected_mapping = {}
for scenario in selected:
selected_summary[scenario] = summaries[scenario]
selected_mapping[scenario] = os.path.relpath(osp.join(abs_dir_path, mappings[scenario]), output_abs_path)
save_summary_anda_mapping(summary_file, mapping_file, selected_summary, selected_mapping)
return summaries, mappings