diff --git a/scenarionet/converter/pg/convert_pg.py b/scenarionet/converter/pg/convert_pg.py deleted file mode 100644 index e69de29..0000000 diff --git a/scenarionet/converter/pg/utils.py b/scenarionet/converter/pg/utils.py new file mode 100644 index 0000000..d5e66d0 --- /dev/null +++ b/scenarionet/converter/pg/utils.py @@ -0,0 +1,26 @@ +from metadrive.envs.metadrive_env import MetaDriveEnv +from metadrive.scenario.scenario_description import ScenarioDescription as SD + + +def convert_pg_scenario(scenario_index, version, env): + """ + Simulate to collect PG Scenarios + :param scenario_index: the index to export + :param version: place holder + :param env: metadrive env instance + """ + policy = lambda x: [0, 1] # placeholder + scenarios, done_info = env.export_scenarios(policy, scenario_index=[scenario_index], to_dict=False) + scenario = scenarios[scenario_index] + assert scenario[SD.VERSION] == version, "Data version mismatch" + return scenario + + +def get_pg_scenarios(num_scenarios, policy, start_seed=0): + env = MetaDriveEnv(dict(start_seed=start_seed, + num_scenarios=num_scenarios, + traffic_density=0.2, + agent_policy=policy, + crash_vehicle_done=False, + map=2)) + return [i for i in range(num_scenarios)], env diff --git a/scenarionet/converter/utils.py b/scenarionet/converter/utils.py index a4f7133..4c8ad96 100644 --- a/scenarionet/converter/utils.py +++ b/scenarionet/converter/utils.py @@ -100,7 +100,7 @@ def write_to_directory( for scenario in tqdm.tqdm(scenarios): # convert scenario sd_scenario = convert_func(scenario, dataset_version, **kwargs) - scenario_id = sd_scenario[SD.METADATA][SD.ID] + scenario_id = sd_scenario[SD.ID] export_file_name = "sd_{}_{}.pkl".format(dataset_name + "_" + dataset_version, scenario_id) # add agents summary diff --git a/scenarionet/examples/convert_nuplan.py b/scenarionet/examples/convert_nuplan.py index dfb1f28..4c1d6f2 100644 --- a/scenarionet/examples/convert_nuplan.py +++ b/scenarionet/examples/convert_nuplan.py @@ -16,8 +16,8 @@ if __name__ == "__main__": data_root = os.path.join(os.getenv("NUPLAN_DATA_ROOT"), "nuplan-v1.1/splits/mini") map_root = os.getenv("NUPLAN_MAPS_ROOT") - # scenarios = get_nuplan_scenarios(data_root, map_root, logs=["2021.07.16.20.45.29_veh-35_01095_01486"]) - scenarios = get_nuplan_scenarios(data_root, map_root) + scenarios = get_nuplan_scenarios(data_root, map_root, logs=["2021.07.16.20.45.29_veh-35_01095_01486"]) + # scenarios = get_nuplan_scenarios(data_root, map_root) write_to_directory( convert_func=convert_nuplan_scenario, diff --git a/scenarionet/examples/convert_pg.py b/scenarionet/examples/convert_pg.py new file mode 100644 index 0000000..072c0b4 --- /dev/null +++ b/scenarionet/examples/convert_pg.py @@ -0,0 +1,27 @@ +import os.path + +import metadrive + +from scenarionet import SCENARIONET_DATASET_PATH +from scenarionet.converter.pg.utils import get_pg_scenarios, convert_pg_scenario +from scenarionet.converter.utils import write_to_directory +from metadrive.policy.idm_policy import IDMPolicy +# from metadrive.policy.expert_policy import ExpertPolicy + +if __name__ == '__main__': + dataset_name = "pg" + output_path = os.path.join(SCENARIONET_DATASET_PATH, dataset_name) + version = metadrive.constants.DATA_VERSION + force_overwrite = True + + scenario_indices, env = get_pg_scenarios(30, IDMPolicy) + + write_to_directory( + convert_func=convert_pg_scenario, + scenarios=scenario_indices, + output_path=output_path, + dataset_version=version, + dataset_name=dataset_name, + force_overwrite=force_overwrite, + env=env + )