add pg converter
This commit is contained in:
26
scenarionet/converter/pg/utils.py
Normal file
26
scenarionet/converter/pg/utils.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from metadrive.envs.metadrive_env import MetaDriveEnv
|
||||
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||
|
||||
|
||||
def convert_pg_scenario(scenario_index, version, env):
|
||||
"""
|
||||
Simulate to collect PG Scenarios
|
||||
:param scenario_index: the index to export
|
||||
:param version: place holder
|
||||
:param env: metadrive env instance
|
||||
"""
|
||||
policy = lambda x: [0, 1] # placeholder
|
||||
scenarios, done_info = env.export_scenarios(policy, scenario_index=[scenario_index], to_dict=False)
|
||||
scenario = scenarios[scenario_index]
|
||||
assert scenario[SD.VERSION] == version, "Data version mismatch"
|
||||
return scenario
|
||||
|
||||
|
||||
def get_pg_scenarios(num_scenarios, policy, start_seed=0):
|
||||
env = MetaDriveEnv(dict(start_seed=start_seed,
|
||||
num_scenarios=num_scenarios,
|
||||
traffic_density=0.2,
|
||||
agent_policy=policy,
|
||||
crash_vehicle_done=False,
|
||||
map=2))
|
||||
return [i for i in range(num_scenarios)], env
|
||||
@@ -100,7 +100,7 @@ def write_to_directory(
|
||||
for scenario in tqdm.tqdm(scenarios):
|
||||
# convert scenario
|
||||
sd_scenario = convert_func(scenario, dataset_version, **kwargs)
|
||||
scenario_id = sd_scenario[SD.METADATA][SD.ID]
|
||||
scenario_id = sd_scenario[SD.ID]
|
||||
export_file_name = "sd_{}_{}.pkl".format(dataset_name + "_" + dataset_version, scenario_id)
|
||||
|
||||
# add agents summary
|
||||
|
||||
@@ -16,8 +16,8 @@ if __name__ == "__main__":
|
||||
|
||||
data_root = os.path.join(os.getenv("NUPLAN_DATA_ROOT"), "nuplan-v1.1/splits/mini")
|
||||
map_root = os.getenv("NUPLAN_MAPS_ROOT")
|
||||
# scenarios = get_nuplan_scenarios(data_root, map_root, logs=["2021.07.16.20.45.29_veh-35_01095_01486"])
|
||||
scenarios = get_nuplan_scenarios(data_root, map_root)
|
||||
scenarios = get_nuplan_scenarios(data_root, map_root, logs=["2021.07.16.20.45.29_veh-35_01095_01486"])
|
||||
# scenarios = get_nuplan_scenarios(data_root, map_root)
|
||||
|
||||
write_to_directory(
|
||||
convert_func=convert_nuplan_scenario,
|
||||
|
||||
27
scenarionet/examples/convert_pg.py
Normal file
27
scenarionet/examples/convert_pg.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import os.path
|
||||
|
||||
import metadrive
|
||||
|
||||
from scenarionet import SCENARIONET_DATASET_PATH
|
||||
from scenarionet.converter.pg.utils import get_pg_scenarios, convert_pg_scenario
|
||||
from scenarionet.converter.utils import write_to_directory
|
||||
from metadrive.policy.idm_policy import IDMPolicy
|
||||
# from metadrive.policy.expert_policy import ExpertPolicy
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_name = "pg"
|
||||
output_path = os.path.join(SCENARIONET_DATASET_PATH, dataset_name)
|
||||
version = metadrive.constants.DATA_VERSION
|
||||
force_overwrite = True
|
||||
|
||||
scenario_indices, env = get_pg_scenarios(30, IDMPolicy)
|
||||
|
||||
write_to_directory(
|
||||
convert_func=convert_pg_scenario,
|
||||
scenarios=scenario_indices,
|
||||
output_path=output_path,
|
||||
dataset_version=version,
|
||||
dataset_name=dataset_name,
|
||||
force_overwrite=force_overwrite,
|
||||
env=env
|
||||
)
|
||||
Reference in New Issue
Block a user