This commit is contained in:
QuanyiLi
2023-05-06 22:10:58 +01:00
parent b101f6fdea
commit 49537ffc1e
10 changed files with 61 additions and 33 deletions

View File

@@ -1,6 +1,5 @@
import pickle
if __name__ == '__main__':
with open("waymo120/0408_output_final/dataset_summary.pkl", "rb") as f:

View File

@@ -356,9 +356,9 @@ def extract_traffic(scenario: NuPlanScenario, center):
type=MetaDriveType.UNSET,
state=dict(
position=np.zeros(shape=(episode_len, 3)),
heading=np.zeros(shape=(episode_len,)),
heading=np.zeros(shape=(episode_len, )),
velocity=np.zeros(shape=(episode_len, 2)),
valid=np.zeros(shape=(episode_len,)),
valid=np.zeros(shape=(episode_len, )),
length=np.zeros(shape=(episode_len, 1)),
width=np.zeros(shape=(episode_len, 1)),
height=np.zeros(shape=(episode_len, 1))

View File

@@ -128,9 +128,9 @@ def get_tracks_from_frames(nuscenes: NuScenes, scene_info, frames, num_to_interp
type=MetaDriveType.UNSET,
state=dict(
position=np.zeros(shape=(episode_len, 3)),
heading=np.zeros(shape=(episode_len,)),
heading=np.zeros(shape=(episode_len, )),
velocity=np.zeros(shape=(episode_len, 2)),
valid=np.zeros(shape=(episode_len,)),
valid=np.zeros(shape=(episode_len, )),
length=np.zeros(shape=(episode_len, 1)),
width=np.zeros(shape=(episode_len, 1)),
height=np.zeros(shape=(episode_len, 1))
@@ -183,7 +183,7 @@ def get_tracks_from_frames(nuscenes: NuScenes, scene_info, frames, num_to_interp
interpolate_tracks[id]["metadata"]["track_length"] = new_episode_len
# valid first
new_valid = np.zeros(shape=(new_episode_len,))
new_valid = np.zeros(shape=(new_episode_len, ))
if track["state"]["valid"][0]:
new_valid[0] = 1
for k, valid in enumerate(track["state"]["valid"][1:], start=1):

View File

@@ -66,13 +66,9 @@ def contains_explicit_return(f):
return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f))))
def write_to_directory(convert_func,
scenarios,
output_path,
dataset_version,
dataset_name,
force_overwrite=False,
**kwargs):
def write_to_directory(
convert_func, scenarios, output_path, dataset_version, dataset_name, force_overwrite=False, **kwargs
):
"""
Convert a batch of scenarios.
"""

View File

@@ -19,10 +19,11 @@ if __name__ == "__main__":
# scenarios = get_nuplan_scenarios(data_root, map_root, logs=["2021.07.16.20.45.29_veh-35_01095_01486"])
scenarios = get_nuplan_scenarios(data_root, map_root)
write_to_directory(convert_func=convert_nuplan_scenario,
scenarios=scenarios,
output_path=output_path,
dataset_version=version,
dataset_name=dataset_name,
force_overwrite=force_overwrite,
)
write_to_directory(
convert_func=convert_nuplan_scenario,
scenarios=scenarios,
output_path=output_path,
dataset_version=version,
dataset_name=dataset_name,
force_overwrite=force_overwrite,
)

View File

@@ -17,10 +17,12 @@ if __name__ == "__main__":
dataroot = '/home/shady/data/nuscenes'
scenarios, nusc = get_nuscenes_scenarios(dataroot, version)
write_to_directory(convert_func=convert_nuscenes_scenario,
scenarios=scenarios,
output_path=output_path,
dataset_version=version,
dataset_name=dataset_name,
force_overwrite=force_overwrite,
nuscenes=nusc)
write_to_directory(
convert_func=convert_nuscenes_scenario,
scenarios=scenarios,
output_path=output_path,
dataset_version=version,
dataset_name=dataset_name,
force_overwrite=force_overwrite,
nuscenes=nusc
)

View File

@@ -17,9 +17,11 @@ if __name__ == '__main__':
waymo_data_direction = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin")
scenarios = get_waymo_scenarios(waymo_data_direction)
write_to_directory(convert_func=convert_waymo_scenario,
scenarios=scenarios,
output_path=output_path,
dataset_version=version,
dataset_name=dataset_name,
force_overwrite=force_overwrite)
write_to_directory(
convert_func=convert_waymo_scenario,
scenarios=scenarios,
output_path=output_path,
dataset_version=version,
dataset_name=dataset_name,
force_overwrite=force_overwrite
)