diff --git a/scenarionet/builder/utils.py b/scenarionet/builder/utils.py new file mode 100644 index 0000000..a05b9d3 --- /dev/null +++ b/scenarionet/builder/utils.py @@ -0,0 +1,87 @@ +import copy +import logging +import os +import os.path as osp +import pickle + +from metadrive.scenario.scenario_description import ScenarioDescription + +logger = logging.getLogger(__name__) + + +def try_generating_summary(file_folder): + # Create a fake one + files = os.listdir(file_folder) + summary = {} + for file in files: + file = file.replace(".pkl", "") + with open(osp.join(file_folder, file), "rb+") as f: + scenario = pickle.load(f) + summary[file] = copy.deepcopy(scenario[ScenarioDescription.METADATA]) + return summary + + +def try_generating_mapping(file_folder): + # Create a fake one + files = os.listdir(file_folder) + mapping = {} + for file in files: + mapping[file] = "" + return mapping + + +def combine_multiple_dataset(output_path, force_overwrite=False, try_generate_missing_file=True, *dataset_paths): + """ + Combine multiple datasets. Each dataset should have a dataset_summary.pkl + :param output_path: The path to store the output dataset + :param force_overwrite: If True, overwrite the output_path even if it exists + :param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them + :param dataset_paths: Path of each dataset + :return: + """ + output_abs_path = osp.abspath(output_path) + if os.path.exists(output_abs_path) and not force_overwrite: + raise FileExistsError("Output path already exists!") + + summaries = {} + mappings = {} + + # collect + for dataset_path in dataset_paths: + abs_dir_path = osp.abspath(dataset_path) + # summary + assert osp.exists(abs_dir_path), "Wrong dataset path. Can not find dataset at: {}".format(abs_dir_path) + if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE)): + if try_generate_missing_file: + # TODO add test for 1. number dataset 2. missing summary dataset 3. missing mapping dataset + summary = try_generating_summary(abs_dir_path) + else: + raise FileNotFoundError("Can not find summary file for dataset: {}".format(abs_dir_path)) + else: + with open(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE), "rb+") as f: + summary = pickle.load(f) + intersect = set(summaries.keys()).intersection(set(summary.keys())) + if len(intersect) > 0: + existing = [] + for v in list(intersect): + existing.append(mappings[v]) + logging.warning("Repeat scenarios: {} in : {}. Existing: {}".format(intersect, abs_dir_path, existing)) + + summaries.update(summary) + + if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE)): + if try_generate_missing_file: + mapping = try_generating_mapping(abs_dir_path) + else: + raise FileNotFoundError("Can not find mapping file for dataset: {}".format(abs_dir_path)) + else: + with open(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE), "rb+") as f: + mapping = pickle.load(f) + new_mapping = {k: os.path.relpath(abs_dir_path, output_abs_path) for k, v in mapping.items()} + mappings.update(new_mapping) + + with open(osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE), "wb+") as f: + pickle.dump(summaries, f) + + with open(osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE), "wb+") as f: + pickle.dump(mappings, f) diff --git a/scenarionet/converter/utils.py b/scenarionet/converter/utils.py index 4c8ad96..1570827 100644 --- a/scenarionet/converter/utils.py +++ b/scenarionet/converter/utils.py @@ -67,7 +67,7 @@ def contains_explicit_return(f): def write_to_directory( - convert_func, scenarios, output_path, dataset_version, dataset_name, force_overwrite=False, **kwargs + convert_func, scenarios, output_path, dataset_version, dataset_name, force_overwrite=False, **kwargs ): """ Convert a batch of scenarios. @@ -94,9 +94,14 @@ def write_to_directory( else: raise ValueError("Directory already exists! Abort") - summary_file = "dataset_summary.pkl" + summary_file = SD.DATASET.SUMMARY_FILE + mapping_file = SD.DATASET.MAPPING_FILE - metadata_recorder = {} + summary_file_path = os.path.join(output_path, summary_file) + mapping_file_path = os.path.join(output_path, mapping_file) + + summary = {} + mapping = {} for scenario in tqdm.tqdm(scenarios): # convert scenario sd_scenario = convert_func(scenario, dataset_version, **kwargs) @@ -115,7 +120,10 @@ def write_to_directory( # count some objects occurrence sd_scenario[SD.METADATA][SD.SUMMARY.NUMBER_SUMMARY] = SD.get_number_summary(sd_scenario) - metadata_recorder[export_file_name] = copy.deepcopy(sd_scenario[SD.METADATA]) + + # update summary/mapping dicy + summary[export_file_name] = copy.deepcopy(sd_scenario[SD.METADATA]) + mapping[export_file_name] = "" # in the same dir # sanity check sd_scenario = sd_scenario.to_dict() @@ -127,10 +135,11 @@ def write_to_directory( pickle.dump(sd_scenario, f) # store summary file, which is human-readable - summary_file = os.path.join(output_path, summary_file) - with open(summary_file, "wb") as file: - pickle.dump(dict_recursive_remove_array_and_set(metadata_recorder), file) - print("Summary is saved at: {}".format(summary_file)) + with open(summary_file_path, "wb") as file: + pickle.dump(dict_recursive_remove_array_and_set(summary), file) + with open(mapping_file_path, "wb") as file: + pickle.dump(mapping, file) + print("Dataset Summary and Mapping are saved at: {}".format(summary_file_path)) # rename and save if delay_remove is not None: diff --git a/scenarionet/examples/run_scenario.py b/scenarionet/examples/run_scenario.py new file mode 100644 index 0000000..8d5f97f --- /dev/null +++ b/scenarionet/examples/run_scenario.py @@ -0,0 +1,50 @@ +import os + +from metadrive.envs.scenario_env import ScenarioEnv +from metadrive.policy.replay_policy import ReplayEgoCarPolicy + +from scenarionet import SCENARIONET_DATASET_PATH + +if __name__ == '__main__': + dataset_path = os.path.join(SCENARIONET_DATASET_PATH, "nuscenes_no_mapping_test") + + env = ScenarioEnv( + { + "use_render": True, + "agent_policy": ReplayEgoCarPolicy, + "manual_control": False, + "show_interface": True, + "show_logo": False, + "show_fps": False, + "num_scenarios": 10, + "horizon": 1000, + "no_static_vehicles": True, + "vehicle_config": dict( + show_navi_mark=False, + no_wheel_friction=True, + lidar=dict(num_lasers=120, distance=50, num_others=4), + lane_line_detector=dict(num_lasers=12, distance=50), + side_detector=dict(num_lasers=160, distance=50) + ), + "data_directory": dataset_path, + } + ) + success = [] + env.reset(force_seed=0) + while True: + env.reset(force_seed=env.current_seed + 1) + for t in range(10000): + o, r, d, info = env.step([0, 0]) + assert env.observation_space.contains(o) + c_lane = env.vehicle.lane + long, lat, = c_lane.local_coordinates(env.vehicle.position) + # if env.config["use_render"]: + env.render( + text={ + "seed": env.engine.global_seed + env.config["start_scenario_index"], + } + ) + + if d and info["arrive_dest"]: + print("seed:{}, success".format(env.engine.global_random_seed)) + break diff --git a/scenarionet/test/test_combine_dataset.py b/scenarionet/test/test_combine_dataset.py new file mode 100644 index 0000000..64d20fa --- /dev/null +++ b/scenarionet/test/test_combine_dataset.py @@ -0,0 +1,4 @@ + + +def test_combine_multiple_dataset(): + pass \ No newline at end of file