This commit is contained in:
QuanyiLi
2023-05-06 22:10:58 +01:00
parent b101f6fdea
commit 49537ffc1e
10 changed files with 61 additions and 33 deletions

7
.style.yapf Normal file
View File

@@ -0,0 +1,7 @@
[style]
based_on_style=pep8
dedent_closing_brackets=True
split_before_first_argument=True
allow_split_before_dict_value=False
join_multiple_lines=False
column_limit=120

20
format.sh Executable file
View File

@@ -0,0 +1,20 @@
#!/usr/bin/env bash
# Usage: at the root dir >> bash scripts/format.sh
# Check yapf version.
ver=$(yapf --version)
if ! echo $ver | grep -q 0.30.0; then
echo "Wrong YAPF version installed: 0.30.0 is required, not $ver. Please install via `pip install yapf==0.30.0`"
exit 1
fi
yapf --in-place --recursive -p --verbose --style .style.yapf scenarionet/
if [[ "$1" == '--test' ]]; then # Only for CI usage, user should not use --test flag.
if ! git diff --quiet &>/dev/null; then
echo '*** You have not formatted your code! Please run [bash format.sh] at root directory before commit! Thanks! ***'
exit 1
else
echo "Code style test passed!"
fi
fi

View File

@@ -1,6 +1,5 @@
import pickle import pickle
if __name__ == '__main__': if __name__ == '__main__':
with open("waymo120/0408_output_final/dataset_summary.pkl", "rb") as f: with open("waymo120/0408_output_final/dataset_summary.pkl", "rb") as f:

View File

@@ -356,9 +356,9 @@ def extract_traffic(scenario: NuPlanScenario, center):
type=MetaDriveType.UNSET, type=MetaDriveType.UNSET,
state=dict( state=dict(
position=np.zeros(shape=(episode_len, 3)), position=np.zeros(shape=(episode_len, 3)),
heading=np.zeros(shape=(episode_len,)), heading=np.zeros(shape=(episode_len, )),
velocity=np.zeros(shape=(episode_len, 2)), velocity=np.zeros(shape=(episode_len, 2)),
valid=np.zeros(shape=(episode_len,)), valid=np.zeros(shape=(episode_len, )),
length=np.zeros(shape=(episode_len, 1)), length=np.zeros(shape=(episode_len, 1)),
width=np.zeros(shape=(episode_len, 1)), width=np.zeros(shape=(episode_len, 1)),
height=np.zeros(shape=(episode_len, 1)) height=np.zeros(shape=(episode_len, 1))

View File

@@ -128,9 +128,9 @@ def get_tracks_from_frames(nuscenes: NuScenes, scene_info, frames, num_to_interp
type=MetaDriveType.UNSET, type=MetaDriveType.UNSET,
state=dict( state=dict(
position=np.zeros(shape=(episode_len, 3)), position=np.zeros(shape=(episode_len, 3)),
heading=np.zeros(shape=(episode_len,)), heading=np.zeros(shape=(episode_len, )),
velocity=np.zeros(shape=(episode_len, 2)), velocity=np.zeros(shape=(episode_len, 2)),
valid=np.zeros(shape=(episode_len,)), valid=np.zeros(shape=(episode_len, )),
length=np.zeros(shape=(episode_len, 1)), length=np.zeros(shape=(episode_len, 1)),
width=np.zeros(shape=(episode_len, 1)), width=np.zeros(shape=(episode_len, 1)),
height=np.zeros(shape=(episode_len, 1)) height=np.zeros(shape=(episode_len, 1))
@@ -183,7 +183,7 @@ def get_tracks_from_frames(nuscenes: NuScenes, scene_info, frames, num_to_interp
interpolate_tracks[id]["metadata"]["track_length"] = new_episode_len interpolate_tracks[id]["metadata"]["track_length"] = new_episode_len
# valid first # valid first
new_valid = np.zeros(shape=(new_episode_len,)) new_valid = np.zeros(shape=(new_episode_len, ))
if track["state"]["valid"][0]: if track["state"]["valid"][0]:
new_valid[0] = 1 new_valid[0] = 1
for k, valid in enumerate(track["state"]["valid"][1:], start=1): for k, valid in enumerate(track["state"]["valid"][1:], start=1):

View File

@@ -66,13 +66,9 @@ def contains_explicit_return(f):
return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f)))) return any(isinstance(node, ast.Return) for node in ast.walk(ast.parse(inspect.getsource(f))))
def write_to_directory(convert_func, def write_to_directory(
scenarios, convert_func, scenarios, output_path, dataset_version, dataset_name, force_overwrite=False, **kwargs
output_path, ):
dataset_version,
dataset_name,
force_overwrite=False,
**kwargs):
""" """
Convert a batch of scenarios. Convert a batch of scenarios.
""" """

View File

@@ -19,10 +19,11 @@ if __name__ == "__main__":
# scenarios = get_nuplan_scenarios(data_root, map_root, logs=["2021.07.16.20.45.29_veh-35_01095_01486"]) # scenarios = get_nuplan_scenarios(data_root, map_root, logs=["2021.07.16.20.45.29_veh-35_01095_01486"])
scenarios = get_nuplan_scenarios(data_root, map_root) scenarios = get_nuplan_scenarios(data_root, map_root)
write_to_directory(convert_func=convert_nuplan_scenario, write_to_directory(
scenarios=scenarios, convert_func=convert_nuplan_scenario,
output_path=output_path, scenarios=scenarios,
dataset_version=version, output_path=output_path,
dataset_name=dataset_name, dataset_version=version,
force_overwrite=force_overwrite, dataset_name=dataset_name,
) force_overwrite=force_overwrite,
)

View File

@@ -17,10 +17,12 @@ if __name__ == "__main__":
dataroot = '/home/shady/data/nuscenes' dataroot = '/home/shady/data/nuscenes'
scenarios, nusc = get_nuscenes_scenarios(dataroot, version) scenarios, nusc = get_nuscenes_scenarios(dataroot, version)
write_to_directory(convert_func=convert_nuscenes_scenario, write_to_directory(
scenarios=scenarios, convert_func=convert_nuscenes_scenario,
output_path=output_path, scenarios=scenarios,
dataset_version=version, output_path=output_path,
dataset_name=dataset_name, dataset_version=version,
force_overwrite=force_overwrite, dataset_name=dataset_name,
nuscenes=nusc) force_overwrite=force_overwrite,
nuscenes=nusc
)

View File

@@ -17,9 +17,11 @@ if __name__ == '__main__':
waymo_data_direction = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin") waymo_data_direction = os.path.join(SCENARIONET_DATASET_PATH, "waymo_origin")
scenarios = get_waymo_scenarios(waymo_data_direction) scenarios = get_waymo_scenarios(waymo_data_direction)
write_to_directory(convert_func=convert_waymo_scenario, write_to_directory(
scenarios=scenarios, convert_func=convert_waymo_scenario,
output_path=output_path, scenarios=scenarios,
dataset_version=version, output_path=output_path,
dataset_name=dataset_name, dataset_version=version,
force_overwrite=force_overwrite) dataset_name=dataset_name,
force_overwrite=force_overwrite
)

View File

@@ -36,6 +36,7 @@ install_requires = [
"tqdm", "tqdm",
"metadrive-simulator", "metadrive-simulator",
"geopandas", "geopandas",
"yapf==0.30.0",
"shapely" "shapely"
] ]