small dataset test!

This commit is contained in:
QuanyiLi
2023-05-07 23:01:01 +01:00
parent 25bb9764c8
commit 5c1b2e053b
3 changed files with 18 additions and 14 deletions

View File

@@ -8,10 +8,10 @@ from scenarionet import SCENARIONET_DATASET_PATH
from scenarionet.builder.utils import combine_multiple_dataset from scenarionet.builder.utils import combine_multiple_dataset
if __name__ == '__main__': if __name__ == '__main__':
dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")] dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"),
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuplan")) os.path.join(SCENARIONET_DATASET_PATH, "nuplan"),
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo")) os.path.join(SCENARIONET_DATASET_PATH, "waymo"),
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg")) os.path.join(SCENARIONET_DATASET_PATH, "pg")]
combine_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset") combine_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset")
combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True) combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
@@ -38,9 +38,9 @@ if __name__ == '__main__':
} }
) )
success = [] success = []
env.reset(force_seed=0) env.reset(force_seed=2)
while True: while True:
env.reset(force_seed=env.current_seed + 1) env.reset(force_seed=2)
for t in range(10000): for t in range(10000):
o, r, d, info = env.step([0, 0]) o, r, d, info = env.step([0, 0])
assert env.observation_space.contains(o) assert env.observation_space.contains(o)
@@ -53,6 +53,7 @@ if __name__ == '__main__':
} }
) )
if d and info["arrive_dest"]: if d:
print("seed:{}, success".format(env.engine.global_random_seed)) if info["arrive_dest"]:
print("seed:{}, success".format(env.engine.global_random_seed))
break break

View File

@@ -6,10 +6,10 @@ from scenarionet.verifier.utils import verify_loading_into_metadrive
def _test_combine_dataset(): def _test_combine_dataset():
dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")] dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"),
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuplan")) os.path.join(SCENARIONET_DATASET_PATH, "nuplan"),
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo")) os.path.join(SCENARIONET_DATASET_PATH, "waymo"),
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg")) os.path.join(SCENARIONET_DATASET_PATH, "pg")]
combine_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset") combine_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset")
combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True) combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
@@ -17,5 +17,6 @@ def _test_combine_dataset():
success, result = verify_loading_into_metadrive(combine_path, "verify_results", steps_to_run=250) success, result = verify_loading_into_metadrive(combine_path, "verify_results", steps_to_run=250)
assert success assert success
if __name__ == '__main__': if __name__ == '__main__':
_test_combine_dataset() _test_combine_dataset()

View File

@@ -70,10 +70,12 @@ def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, ste
desc="Scenarios: {}-{}".format(start_scenario_index, desc="Scenarios: {}-{}".format(start_scenario_index,
start_scenario_index + num_scenario)): start_scenario_index + num_scenario)):
env.reset(force_seed=scenario_index) env.reset(force_seed=scenario_index)
arrive = False
for _ in range(steps_to_run): for _ in range(steps_to_run):
o, r, d, info = env.step([0, 0]) o, r, d, info = env.step([0, 0])
if d: if d and info["arrive_dest"]:
assert info["arrive_dest"], "Can not arrive destination" arrive = True
assert arrive, "Can not arrive destination"
except Exception as e: except Exception as e:
file_name = env.engine.data_manager.summary_lookup[scenario_index] file_name = env.engine.data_manager.summary_lookup[scenario_index]
file_path = os.path.join(dataset_path, env.engine.data_manager.mapping[file_name], file_name) file_path = os.path.join(dataset_path, env.engine.data_manager.mapping[file_name], file_name)