small dataset test!

This commit is contained in:
QuanyiLi
2023-05-07 23:01:01 +01:00
parent 25bb9764c8
commit 5c1b2e053b
3 changed files with 18 additions and 14 deletions

View File

@@ -8,10 +8,10 @@ from scenarionet import SCENARIONET_DATASET_PATH
from scenarionet.builder.utils import combine_multiple_dataset
if __name__ == '__main__':
dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")]
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuplan"))
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo"))
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg"))
dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"),
os.path.join(SCENARIONET_DATASET_PATH, "nuplan"),
os.path.join(SCENARIONET_DATASET_PATH, "waymo"),
os.path.join(SCENARIONET_DATASET_PATH, "pg")]
combine_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset")
combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
@@ -38,9 +38,9 @@ if __name__ == '__main__':
}
)
success = []
env.reset(force_seed=0)
env.reset(force_seed=2)
while True:
env.reset(force_seed=env.current_seed + 1)
env.reset(force_seed=2)
for t in range(10000):
o, r, d, info = env.step([0, 0])
assert env.observation_space.contains(o)
@@ -53,6 +53,7 @@ if __name__ == '__main__':
}
)
if d and info["arrive_dest"]:
if d:
if info["arrive_dest"]:
print("seed:{}, success".format(env.engine.global_random_seed))
break

View File

@@ -6,10 +6,10 @@ from scenarionet.verifier.utils import verify_loading_into_metadrive
def _test_combine_dataset():
dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes")]
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "nuplan"))
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "waymo"))
dataset_paths.append(os.path.join(SCENARIONET_DATASET_PATH, "pg"))
dataset_paths = [os.path.join(SCENARIONET_DATASET_PATH, "nuscenes"),
os.path.join(SCENARIONET_DATASET_PATH, "nuplan"),
os.path.join(SCENARIONET_DATASET_PATH, "waymo"),
os.path.join(SCENARIONET_DATASET_PATH, "pg")]
combine_path = os.path.join(SCENARIONET_DATASET_PATH, "combined_dataset")
combine_multiple_dataset(combine_path, *dataset_paths, force_overwrite=True, try_generate_missing_file=True)
@@ -17,5 +17,6 @@ def _test_combine_dataset():
success, result = verify_loading_into_metadrive(combine_path, "verify_results", steps_to_run=250)
assert success
if __name__ == '__main__':
_test_combine_dataset()

View File

@@ -70,10 +70,12 @@ def loading_into_metadrive(start_scenario_index, num_scenario, dataset_path, ste
desc="Scenarios: {}-{}".format(start_scenario_index,
start_scenario_index + num_scenario)):
env.reset(force_seed=scenario_index)
arrive = False
for _ in range(steps_to_run):
o, r, d, info = env.step([0, 0])
if d:
assert info["arrive_dest"], "Can not arrive destination"
if d and info["arrive_dest"]:
arrive = True
assert arrive, "Can not arrive destination"
except Exception as e:
file_name = env.engine.data_manager.summary_lookup[scenario_index]
file_path = os.path.join(dataset_path, env.engine.data_manager.mapping[file_name], file_name)