Add come updates for Neurips paper (#4)
* scenarionet training * wandb * train utils * fix callback * run PPO * use pg test * save path * use torch * add dependency * update ignore * update training * large model * use curriculum training * add time to exp name * storage_path * restore * update training * use my key * add log message * check seed * restore callback * restore call bacl * add log message * add logging message * restore ray1.4 * length 500 * ray 100 * wandb * use tf * more levels * add callback * 10 worker * show level * no env horizon * callback result level * more call back * add diffuculty * add mroen stat * mroe stat * show levels * add callback * new * ep len 600 * fix setup * fix stepup * fix to 3.8 * update setup * parallel worker! * new exp * add callback * lateral dist * pg dataset * evaluate * modify config * align config * train single RL * update training script * 100w eval * less eval to reveal * 2000 env eval * new trianing * eval 1000 * update eval * more workers * more worker * 20 worker * dataset to database * split tool! * split dataset * try fix * train 003 * fix mapping * fix test * add waymo tqdm * utils * fix bug * fix bug * waymo * int type * 8 worker read * disable * read file * add log message * check existence * dist 0 * int * check num * suprass warning * add filter API * filter * store map false * new * ablation * filter * fix * update filyter * reanme to from * random select * add overlapping checj * fix * new training sceheme * new reward * add waymo train script * waymo different config * copy raw data * fix bug * add tqdm * update readme * waymo * pg * max lateral dist 3 * pg * crash_done instead of penalty * no crash done * gpu * update eval script * steering range penalty * evaluate * finish pg * update setup * fix bug * test * fix * add on line * train nuplan * generate sensor * udpate training * static obj * multi worker eval * filx bug * use ray for testing * eval! * filter senario * id filter * fox bug * dist = 2 * filter * eval * eval ret * ok * update training pg * test before use * store data=False * collect figures * capture pic --------- Co-authored-by: Quanyi Li <quanyi@bolei-gpu02.cs.ucla.edu>
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -19,3 +19,7 @@ dataset/*
|
||||
**/passed_scenarios/
|
||||
**/waymo_origin
|
||||
/dataset/
|
||||
/scenarionet_training/wandb/*.pkl
|
||||
**/TEST/
|
||||
**/experiment/
|
||||
**/wandb/
|
||||
|
||||
29
README.md
29
README.md
@@ -13,7 +13,8 @@ pip install -e .
|
||||
## Usage
|
||||
|
||||
We provide some explanation and demo for all scripts here.
|
||||
**You are encouraged to try them on your own, add ```-h``` or ```--help``` argument to know more details about these scripts.**
|
||||
**You are encouraged to try them on your own, add ```-h``` or ```--help``` argument to know more details about these
|
||||
scripts.**
|
||||
|
||||
### Convert
|
||||
|
||||
@@ -45,31 +46,42 @@ python -m scenarionet.scripts.convert_pg -d pg --num_workers=16 --num_scenarios=
|
||||
```
|
||||
|
||||
### Merge & move
|
||||
|
||||
For merging two or more database, use
|
||||
|
||||
```
|
||||
python -m scenarionet.merge_database -d /destination/path --from_databases /database1 /2 ...
|
||||
python -m scenarionet.merge_database -d /destination/path --from /database1 /2 ...
|
||||
```
|
||||
|
||||
As a database contains a path mapping, one should move database folder with the following script instead of ```cp```
|
||||
command
|
||||
command.
|
||||
Using ```--copy_raw_data``` will copy the raw scenario file into target directory and cancel the virtual mapping.
|
||||
|
||||
```
|
||||
python -m scenarionet.move_database --to /destination/path --from /source/path
|
||||
python -m scenarionet.copy_database --to /destination/path --from /source/path
|
||||
```
|
||||
|
||||
### Verify
|
||||
|
||||
The following scripts will check whether all scenarios exist or can be loaded into simulator.
|
||||
The missing or broken scenarios will be recorded and stored into the error file. Otherwise, no error file will be
|
||||
generated.
|
||||
The missing or broken scenarios will be recorded and stored into the error file. Otherwise, no error file will be
|
||||
generated.
|
||||
With teh error file, one can build a new database excluding or including the broken or missing scenarios.
|
||||
|
||||
**Existence check**
|
||||
|
||||
```
|
||||
python -m scenarionet.verify_existence -d /database/to/check --error_file_path /error/file/path
|
||||
python -m scenarionet.check_existence -d /database/to/check --error_file_path /error/file/path
|
||||
```
|
||||
|
||||
**Runnable check**
|
||||
|
||||
```
|
||||
python -m scenarionet.verify_simulation -d /database/to/check --error_file_path /error/file/path
|
||||
python -m scenarionet.check_simulation -d /database/to/check --error_file_path /error/file/path
|
||||
```
|
||||
|
||||
**Generating new database**
|
||||
|
||||
```
|
||||
python -m scenarionet.generate_from_error_file -d /new/database/path --file /error/file/path
|
||||
```
|
||||
@@ -77,6 +89,7 @@ python -m scenarionet.generate_from_error_file -d /new/database/path --file /err
|
||||
### visualization
|
||||
|
||||
Visualizing the simulated scenario
|
||||
|
||||
```
|
||||
python -m scenarionet.run_simulation -d /path/to/database --render --scenario_index
|
||||
```
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||
from metadrive.scenario.utils import read_scenario_data
|
||||
|
||||
|
||||
class ScenarioFilter:
|
||||
@@ -8,7 +10,7 @@ class ScenarioFilter:
|
||||
SMALLER = "smaller"
|
||||
|
||||
@staticmethod
|
||||
def sdc_moving_dist(metadata, target_dist, condition=GREATER):
|
||||
def sdc_moving_dist(metadata, file_path, target_dist, condition=GREATER):
|
||||
"""
|
||||
This function filters the scenario based on SDC information.
|
||||
"""
|
||||
@@ -22,10 +24,11 @@ class ScenarioFilter:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def object_number(metadata, number_threshold, object_type=None, condition=SMALLER):
|
||||
def object_number(metadata, file_path, number_threshold, object_type=None, condition=SMALLER):
|
||||
"""
|
||||
Return True if the scenario satisfying the object number condition
|
||||
:param metadata: metadata in each scenario
|
||||
:param file_path: where to find this scenario
|
||||
:param number_threshold: number of objects threshold
|
||||
:param object_type: MetaDriveType.VEHICLE or other object type. If none, calculate number for all object types
|
||||
:param condition: SMALLER or GREATER
|
||||
@@ -43,9 +46,33 @@ class ScenarioFilter:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def has_traffic_light(metadata):
|
||||
def has_traffic_light(metadata, file_path):
|
||||
return metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_TRAFFIC_LIGHTS] > 0
|
||||
|
||||
@staticmethod
|
||||
def no_traffic_light(metadata, file_path):
|
||||
return metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_TRAFFIC_LIGHTS] == 0
|
||||
|
||||
@staticmethod
|
||||
def no_overpass(metadata, file_path):
|
||||
"""
|
||||
We need read the map data to do overpass filter
|
||||
"""
|
||||
max_height_diff = 5
|
||||
if SD.SUMMARY.MAP_HEIGHT_DIFF in metadata:
|
||||
return metadata[SD.SUMMARY.MAP_HEIGHT_DIFF] < max_height_diff
|
||||
else:
|
||||
# calculate online
|
||||
map_features = read_scenario_data(file_path)[SD.MAP_FEATURES]
|
||||
return abs(SD.map_height_diff(map_features, target=max_height_diff)) < max_height_diff
|
||||
|
||||
@staticmethod
|
||||
def id_filter(metadata, file_path, ids):
|
||||
for id in ids:
|
||||
if metadata["id"] in id:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def make(func, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import copy
|
||||
from random import sample
|
||||
from metadrive.scenario.utils import read_dataset_summary
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
@@ -27,21 +29,23 @@ def try_generating_summary(file_folder):
|
||||
|
||||
|
||||
def merge_database(
|
||||
output_path,
|
||||
*dataset_paths,
|
||||
exist_ok=False,
|
||||
overwrite=False,
|
||||
try_generate_missing_file=True,
|
||||
filters: List[Callable] = None
|
||||
output_path,
|
||||
*dataset_paths,
|
||||
exist_ok=False,
|
||||
overwrite=False,
|
||||
try_generate_missing_file=True,
|
||||
filters: List[Callable] = None,
|
||||
save=True,
|
||||
):
|
||||
"""
|
||||
Combine multiple datasets. Each dataset should have a dataset_summary.pkl
|
||||
:param output_path: The path to store the output dataset
|
||||
Combine multiple datasets. Each database should have a dataset_summary.pkl
|
||||
:param output_path: The path to store the output database
|
||||
:param exist_ok: If True, though the output_path already exist, still write into it
|
||||
:param overwrite: If True, overwrite existing dataset_summary.pkl and mapping.pkl. Otherwise, raise error
|
||||
:param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them
|
||||
:param dataset_paths: Path of each dataset
|
||||
:param filters: a set of filters to choose which scenario to be selected and added into this combined dataset
|
||||
:param dataset_paths: Path of each database
|
||||
:param filters: a set of filters to choose which scenario to be selected and added into this combined database
|
||||
:param save: save to output path, immediately
|
||||
:return: summary, mapping
|
||||
"""
|
||||
filters = filters or []
|
||||
@@ -60,15 +64,15 @@ def merge_database(
|
||||
mappings = {}
|
||||
|
||||
# collect
|
||||
for dataset_path in tqdm.tqdm(dataset_paths):
|
||||
for dataset_path in tqdm.tqdm(dataset_paths, desc="Merge Data"):
|
||||
abs_dir_path = osp.abspath(dataset_path)
|
||||
# summary
|
||||
assert osp.exists(abs_dir_path), "Wrong dataset path. Can not find dataset at: {}".format(abs_dir_path)
|
||||
assert osp.exists(abs_dir_path), "Wrong database path. Can not find database at: {}".format(abs_dir_path)
|
||||
if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE)):
|
||||
if try_generate_missing_file:
|
||||
summary = try_generating_summary(abs_dir_path)
|
||||
else:
|
||||
raise FileNotFoundError("Can not find summary file for dataset: {}".format(abs_dir_path))
|
||||
raise FileNotFoundError("Can not find summary file for database: {}".format(abs_dir_path))
|
||||
else:
|
||||
with open(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE), "rb+") as f:
|
||||
summary = pickle.load(f)
|
||||
@@ -85,7 +89,7 @@ def merge_database(
|
||||
if try_generate_missing_file:
|
||||
mapping = {k: "" for k in summary}
|
||||
else:
|
||||
raise FileNotFoundError("Can not find mapping file for dataset: {}".format(abs_dir_path))
|
||||
raise FileNotFoundError("Can not find mapping file for database: {}".format(abs_dir_path))
|
||||
else:
|
||||
with open(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE), "rb+") as f:
|
||||
mapping = pickle.load(f)
|
||||
@@ -98,37 +102,108 @@ def merge_database(
|
||||
|
||||
# apply filter stage
|
||||
file_to_pop = []
|
||||
for file_name, metadata, in summaries.items():
|
||||
if not all([fil(metadata) for fil in filters]):
|
||||
for file_name in tqdm.tqdm(summaries.keys(), desc="Filter Scenarios"):
|
||||
metadata = summaries[file_name]
|
||||
if not all([fil(metadata, os.path.join(output_abs_path, mappings[file_name], file_name)) for fil in filters]):
|
||||
file_to_pop.append(file_name)
|
||||
for file in file_to_pop:
|
||||
summaries.pop(file)
|
||||
mappings.pop(file)
|
||||
|
||||
save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings)
|
||||
if save:
|
||||
save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings)
|
||||
|
||||
return summaries, mappings
|
||||
|
||||
|
||||
def move_database(
|
||||
from_path,
|
||||
to_path,
|
||||
exist_ok=False,
|
||||
overwrite=False,
|
||||
def copy_database(
|
||||
from_path,
|
||||
to_path,
|
||||
exist_ok=False,
|
||||
overwrite=False,
|
||||
copy_raw_data=False,
|
||||
remove_source=False
|
||||
):
|
||||
if not os.path.exists(from_path):
|
||||
raise FileNotFoundError("Can not find dataset: {}".format(from_path))
|
||||
raise FileNotFoundError("Can not find database: {}".format(from_path))
|
||||
if os.path.exists(to_path):
|
||||
assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a dataset"
|
||||
assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a database"
|
||||
assert not os.path.samefile(from_path, to_path), "to_directory is the same as from_directory. Abort!"
|
||||
merge_database(
|
||||
files = os.listdir(from_path)
|
||||
if ScenarioDescription.DATASET.MAPPING_FILE in files and ScenarioDescription.DATASET.SUMMARY_FILE in files and len(
|
||||
files) > 2:
|
||||
raise RuntimeError("The source database is not allowed to move! "
|
||||
"This will break the relationship between this database and other database built on it."
|
||||
"If it is ok for you, use 'mv' to move it manually ")
|
||||
|
||||
summaries, mappings = merge_database(
|
||||
to_path,
|
||||
from_path,
|
||||
exist_ok=exist_ok,
|
||||
overwrite=overwrite,
|
||||
try_generate_missing_file=True,
|
||||
save=False
|
||||
)
|
||||
files = os.listdir(from_path)
|
||||
if ScenarioDescription.DATASET.MAPPING_FILE in files and ScenarioDescription.DATASET.SUMMARY_FILE in files and len(
|
||||
files) == 2:
|
||||
summary_file = osp.join(to_path, ScenarioDescription.DATASET.SUMMARY_FILE)
|
||||
mapping_file = osp.join(to_path, ScenarioDescription.DATASET.MAPPING_FILE)
|
||||
|
||||
if copy_raw_data:
|
||||
logger.info("Copy raw data...")
|
||||
for scenario_file in tqdm.tqdm(mappings.keys()):
|
||||
rel_path = mappings[scenario_file]
|
||||
shutil.copyfile(os.path.join(to_path, rel_path, scenario_file), os.path.join(to_path, scenario_file))
|
||||
mappings = {key: "./" for key in summaries.keys()}
|
||||
save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings)
|
||||
|
||||
if remove_source and ScenarioDescription.DATASET.MAPPING_FILE in files and \
|
||||
ScenarioDescription.DATASET.SUMMARY_FILE in files and len(files) == 2:
|
||||
shutil.rmtree(from_path)
|
||||
|
||||
|
||||
def split_database(
|
||||
from_path,
|
||||
to_path,
|
||||
start_index,
|
||||
num_scenarios,
|
||||
exist_ok=False,
|
||||
overwrite=False,
|
||||
random=False,
|
||||
):
|
||||
if not os.path.exists(from_path):
|
||||
raise FileNotFoundError("Can not find database: {}".format(from_path))
|
||||
if os.path.exists(to_path):
|
||||
assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a database"
|
||||
assert not os.path.samefile(from_path, to_path), "to_directory is the same as from_directory. Abort!"
|
||||
overwrite = overwrite,
|
||||
output_abs_path = osp.abspath(to_path)
|
||||
os.makedirs(output_abs_path, exist_ok=exist_ok)
|
||||
summary_file = osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE)
|
||||
mapping_file = osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE)
|
||||
for file in [summary_file, mapping_file]:
|
||||
if os.path.exists(file):
|
||||
if overwrite:
|
||||
os.remove(file)
|
||||
else:
|
||||
raise FileExistsError("{} already exists at: {}!".format(file, output_abs_path))
|
||||
|
||||
# collect
|
||||
abs_dir_path = osp.abspath(from_path)
|
||||
# summary
|
||||
assert osp.exists(abs_dir_path), "Wrong database path. Can not find database at: {}".format(abs_dir_path)
|
||||
summaries, lookup, mappings = read_dataset_summary(from_path)
|
||||
assert start_index >= 0 and start_index + num_scenarios <= len(
|
||||
lookup), "No enough scenarios in source dataset: total {}, start_index: {}, need: {}".format(len(lookup),
|
||||
start_index,
|
||||
num_scenarios)
|
||||
if random:
|
||||
selected = sample(lookup[start_index:], k=num_scenarios)
|
||||
else:
|
||||
selected = lookup[start_index: start_index + num_scenarios]
|
||||
selected_summary = {}
|
||||
selected_mapping = {}
|
||||
for scenario in selected:
|
||||
selected_summary[scenario] = summaries[scenario]
|
||||
selected_mapping[scenario] = os.path.relpath(osp.join(abs_dir_path, mappings[scenario]), output_abs_path)
|
||||
|
||||
save_summary_anda_mapping(summary_file, mapping_file, selected_summary, selected_mapping)
|
||||
|
||||
return summaries, mappings
|
||||
|
||||
22
scenarionet/check_overlap.py
Normal file
22
scenarionet/check_overlap.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Check If any overlap between two database
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from scenarionet.common_utils import read_dataset_summary
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--database_1', type=str, required=True, help="The path of the first database")
|
||||
parser.add_argument('--database_2', type=str, required=True, help="The path of the second database")
|
||||
args = parser.parse_args()
|
||||
|
||||
summary_1, _, _ = read_dataset_summary(args.database_1)
|
||||
summary_2, _, _ = read_dataset_summary(args.database_2)
|
||||
|
||||
intersection = set(summary_1.keys()).intersection(set(summary_2.keys()))
|
||||
if len(intersection) == 0:
|
||||
print("No overlapping in two database!")
|
||||
else:
|
||||
print("Find overlapped scenarios: {}".format(intersection))
|
||||
@@ -1,4 +1,5 @@
|
||||
import pkg_resources # for suppress warning
|
||||
import shutil
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
@@ -28,6 +29,20 @@ if __name__ == '__main__':
|
||||
default=os.path.join(SCENARIONET_REPO_PATH, "waymo_origin"),
|
||||
help="The directory stores all waymo tfrecord"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start_file_index",
|
||||
default=0,
|
||||
type=int,
|
||||
help="Control how many files to use. We will list all files in the raw data folder "
|
||||
"and select files[start_file_index: start_file_index+num_files]"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_files",
|
||||
default=1000,
|
||||
type=int,
|
||||
help="Control how many files to use. We will list all files in the raw data folder "
|
||||
"and select files[start_file_index: start_file_index+num_files]"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
overwrite = args.overwrite
|
||||
@@ -35,8 +50,19 @@ if __name__ == '__main__':
|
||||
output_path = args.database_path
|
||||
version = args.version
|
||||
|
||||
save_path = output_path
|
||||
if os.path.exists(output_path):
|
||||
if not overwrite:
|
||||
raise ValueError(
|
||||
"Directory {} already exists! Abort. "
|
||||
"\n Try setting overwrite=True or adding --overwrite".format(output_path)
|
||||
)
|
||||
else:
|
||||
shutil.rmtree(output_path)
|
||||
|
||||
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path)
|
||||
scenarios = get_waymo_scenarios(waymo_data_directory)
|
||||
scenarios = get_waymo_scenarios(waymo_data_directory, args.start_file_index, args.num_files,
|
||||
num_workers=8) # do not use too much worker to read data
|
||||
|
||||
write_to_directory(
|
||||
convert_func=convert_waymo_scenario,
|
||||
|
||||
@@ -357,9 +357,9 @@ def extract_traffic(scenario: NuPlanScenario, center):
|
||||
type=MetaDriveType.UNSET,
|
||||
state=dict(
|
||||
position=np.zeros(shape=(episode_len, 3)),
|
||||
heading=np.zeros(shape=(episode_len, )),
|
||||
heading=np.zeros(shape=(episode_len,)),
|
||||
velocity=np.zeros(shape=(episode_len, 2)),
|
||||
valid=np.zeros(shape=(episode_len, )),
|
||||
valid=np.zeros(shape=(episode_len,)),
|
||||
length=np.zeros(shape=(episode_len, 1)),
|
||||
width=np.zeros(shape=(episode_len, 1)),
|
||||
height=np.zeros(shape=(episode_len, 1))
|
||||
@@ -490,3 +490,84 @@ example_scenario_types = "[behind_pedestrian_on_pickup_dropoff, \
|
||||
starting_unprotected_cross_turn, \
|
||||
starting_protected_noncross_turn, \
|
||||
on_pickup_dropoff]"
|
||||
|
||||
# - accelerating_at_crosswalk
|
||||
# - accelerating_at_stop_sign
|
||||
# - accelerating_at_stop_sign_no_crosswalk
|
||||
# - accelerating_at_traffic_light
|
||||
# - accelerating_at_traffic_light_with_lead
|
||||
# - accelerating_at_traffic_light_without_lead
|
||||
# - behind_bike
|
||||
# - behind_long_vehicle
|
||||
# - behind_pedestrian_on_driveable
|
||||
# - behind_pedestrian_on_pickup_dropoff
|
||||
# - changing_lane
|
||||
# - changing_lane_to_left
|
||||
# - changing_lane_to_right
|
||||
# - changing_lane_with_lead
|
||||
# - changing_lane_with_trail
|
||||
# - crossed_by_bike
|
||||
# - crossed_by_vehicle
|
||||
# - following_lane_with_lead
|
||||
# - following_lane_with_slow_lead
|
||||
# - following_lane_without_lead
|
||||
# - high_lateral_acceleration
|
||||
# - high_magnitude_jerk
|
||||
# - high_magnitude_speed
|
||||
# - low_magnitude_speed
|
||||
# - medium_magnitude_speed
|
||||
# - near_barrier_on_driveable
|
||||
# - near_construction_zone_sign
|
||||
# - near_high_speed_vehicle
|
||||
# - near_long_vehicle
|
||||
# - near_multiple_bikes
|
||||
# - near_multiple_pedestrians
|
||||
# - near_multiple_vehicles
|
||||
# - near_pedestrian_at_pickup_dropoff
|
||||
# - near_pedestrian_on_crosswalk
|
||||
# - near_pedestrian_on_crosswalk_with_ego
|
||||
# - near_trafficcone_on_driveable
|
||||
# - on_all_way_stop_intersection
|
||||
# - on_carpark
|
||||
# - on_intersection
|
||||
# - on_pickup_dropoff
|
||||
# - on_stopline_crosswalk
|
||||
# - on_stopline_stop_sign
|
||||
# - on_stopline_traffic_light
|
||||
# - on_traffic_light_intersection
|
||||
# - starting_high_speed_turn
|
||||
# - starting_left_turn
|
||||
# - starting_low_speed_turn
|
||||
# - starting_protected_cross_turn
|
||||
# - starting_protected_noncross_turn
|
||||
# - starting_right_turn
|
||||
# - starting_straight_stop_sign_intersection_traversal
|
||||
# - starting_straight_traffic_light_intersection_traversal
|
||||
# - starting_u_turn
|
||||
# - starting_unprotected_cross_turn
|
||||
# - starting_unprotected_noncross_turn
|
||||
# - stationary
|
||||
# - stationary_at_crosswalk
|
||||
# - stationary_at_traffic_light_with_lead
|
||||
# - stationary_at_traffic_light_without_lead
|
||||
# - stationary_in_traffic
|
||||
# - stopping_at_crosswalk
|
||||
# - stopping_at_stop_sign_no_crosswalk
|
||||
# - stopping_at_stop_sign_with_lead
|
||||
# - stopping_at_stop_sign_without_lead
|
||||
# - stopping_at_traffic_light_with_lead
|
||||
# - stopping_at_traffic_light_without_lead
|
||||
# - stopping_with_lead
|
||||
# - traversing_crosswalk
|
||||
# - traversing_intersection
|
||||
# - traversing_narrow_lane
|
||||
# - traversing_pickup_dropoff
|
||||
# - traversing_traffic_light_intersection
|
||||
# - waiting_for_pedestrian_to_cross
|
||||
#
|
||||
|
||||
all_scenario_types = "[near_pedestrian_on_crosswalk_with_ego," \
|
||||
"near_trafficcone_on_driveable, " \
|
||||
"following_lane_with_lead, " \
|
||||
"following_lane_with_slow_lead, " \
|
||||
"following_lane_without_lead]"
|
||||
|
||||
@@ -3,6 +3,8 @@ import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import tqdm
|
||||
|
||||
from scenarionet.converter.utils import mph_to_kmh
|
||||
from scenarionet.converter.waymo.type import WaymoLaneType, WaymoAgentType, WaymoRoadLineType, WaymoRoadEdgeType
|
||||
|
||||
@@ -418,13 +420,20 @@ def convert_waymo_scenario(scenario, version):
|
||||
}
|
||||
for count, id in enumerate(track_id)
|
||||
}
|
||||
# clean memory
|
||||
del scenario
|
||||
scenario = None
|
||||
return md_scenario
|
||||
|
||||
|
||||
def get_waymo_scenarios(waymo_data_directory, num_workers=8):
|
||||
def get_waymo_scenarios(waymo_data_directory, start_index, num, num_workers=8):
|
||||
# parse raw data from input path to output path,
|
||||
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
|
||||
logger.info("\n Reading raw data")
|
||||
file_list = os.listdir(waymo_data_directory)
|
||||
assert len(file_list) >= start_index + num and start_index >= 0, \
|
||||
"No sufficient files ({}) in raw_data_directory. need: {}, start: {}".format(len(file_list), num, start_index)
|
||||
file_list = file_list[start_index: start_index + num]
|
||||
num_files = len(file_list)
|
||||
if num_files < num_workers:
|
||||
# single process
|
||||
@@ -441,23 +450,29 @@ def get_waymo_scenarios(waymo_data_directory, num_workers=8):
|
||||
argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]])
|
||||
|
||||
# Run, workers and process result from worker
|
||||
with multiprocessing.Pool(num_workers) as p:
|
||||
all_result = list(p.imap(read_from_files, argument_list))
|
||||
ret = []
|
||||
|
||||
# get result
|
||||
for r in all_result:
|
||||
if len(r) == 0:
|
||||
logger.warning("0 scenarios found")
|
||||
ret += r
|
||||
logger.info("\n Find {} waymo scenarios from {} files".format(len(ret), num_files))
|
||||
return ret
|
||||
# with multiprocessing.Pool(num_workers) as p:
|
||||
# all_result = list(p.imap(read_from_files, argument_list))
|
||||
# Disable multiprocessing read
|
||||
all_result = read_from_files([waymo_data_directory, file_list])
|
||||
# ret = []
|
||||
#
|
||||
# # get result
|
||||
# for r in all_result:
|
||||
# if len(r) == 0:
|
||||
# logger.warning("0 scenarios found")
|
||||
# ret += r
|
||||
logger.info("\n Find {} waymo scenarios from {} files".format(len(all_result), num_files))
|
||||
return all_result
|
||||
|
||||
|
||||
def read_from_files(arg):
|
||||
try:
|
||||
scenario_pb2
|
||||
except NameError:
|
||||
raise ImportError("Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0==1.5.0")
|
||||
waymo_data_directory, file_list = arg[0], arg[1]
|
||||
scenarios = []
|
||||
for file_count, file in enumerate(file_list):
|
||||
for file in tqdm.tqdm(file_list):
|
||||
file_path = os.path.join(waymo_data_directory, file)
|
||||
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
|
||||
continue
|
||||
|
||||
47
scenarionet/copy_database.py
Normal file
47
scenarionet/copy_database.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import argparse
|
||||
|
||||
from scenarionet.builder.utils import copy_database
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--from', required=True, help="Which database to move.")
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
required=True,
|
||||
help="The name of the new database. "
|
||||
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
|
||||
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
|
||||
"that directory into a database."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove_source",
|
||||
action="store_true",
|
||||
help="Remove the `from_database` if set this flag"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--copy_raw_data",
|
||||
action="store_true",
|
||||
help="Instead of creating virtual file mapping, copy raw scenario.pkl file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exist_ok",
|
||||
action="store_true",
|
||||
help="Still allow to write, if the to_folder exists already. "
|
||||
"This write will only create two .pkl files and this directory will become a database."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, "
|
||||
"whether to overwrite both files"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
from_path = args.__getattribute__("from")
|
||||
copy_database(
|
||||
from_path,
|
||||
args.to,
|
||||
exist_ok=args.exist_ok,
|
||||
overwrite=args.overwrite,
|
||||
copy_raw_data=args.copy_raw_data,
|
||||
remove_source=args.remove_source
|
||||
)
|
||||
111
scenarionet/filter_database.py
Normal file
111
scenarionet/filter_database.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import argparse
|
||||
|
||||
from scenarionet.builder.filters import ScenarioFilter
|
||||
from scenarionet.builder.utils import merge_database
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--database_path",
|
||||
"-d",
|
||||
required=True,
|
||||
help="The name of the new database. "
|
||||
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
|
||||
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
|
||||
"that directory into a database."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--from',
|
||||
required=True,
|
||||
type=str,
|
||||
help="Which dataset to filter. It takes one directory path as input"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exist_ok",
|
||||
action="store_true",
|
||||
help="Still allow to write, if the dir exists already. "
|
||||
"This write will only create two .pkl files and this directory will become a database."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, "
|
||||
"whether to overwrite both files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--moving_dist",
|
||||
action="store_true",
|
||||
help="add this flag to select cases with SDC moving dist > sdc_moving_dist_min"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sdc_moving_dist_min",
|
||||
default=10,
|
||||
type=float,
|
||||
help="Selecting case with sdc_moving_dist > this value. "
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_object",
|
||||
action="store_true",
|
||||
help="add this flag to select cases with object_num < max_num_object"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_num_object",
|
||||
default=30,
|
||||
type=float,
|
||||
help="case will be selected if num_obj < this argument"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_overpass",
|
||||
action="store_true",
|
||||
help="Scenarios with overpass WON'T be selected"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_traffic_light",
|
||||
action="store_true",
|
||||
help="Scenarios with traffic light WON'T be selected"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--id_filter",
|
||||
action="store_true",
|
||||
help="Scenarios with indicated name will NOT be selected"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exclude_ids",
|
||||
nargs='+',
|
||||
default=[],
|
||||
help="Scenarios with indicated name will NOT be selected"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
target = args.sdc_moving_dist_min
|
||||
obj_threshold = args.max_num_object
|
||||
from_path = args.__getattribute__("from")
|
||||
|
||||
filters = []
|
||||
if args.no_overpass:
|
||||
filters.append(ScenarioFilter.make(ScenarioFilter.no_overpass))
|
||||
if args.num_object:
|
||||
filters.append(ScenarioFilter.make(ScenarioFilter.object_number, number_threshold=obj_threshold))
|
||||
if args.moving_dist:
|
||||
filters.append(ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=target, condition="greater"))
|
||||
if args.no_traffic_light:
|
||||
filters.append(ScenarioFilter.make(ScenarioFilter.no_traffic_light))
|
||||
if args.id_filter:
|
||||
filters.append(ScenarioFilter.make(ScenarioFilter.id_filter, ids=args.exclude_ids))
|
||||
|
||||
if len(filters) == 0:
|
||||
raise ValueError("No filters are applied. Abort.")
|
||||
|
||||
merge_database(
|
||||
args.database_path,
|
||||
from_path,
|
||||
exist_ok=args.exist_ok,
|
||||
overwrite=args.overwrite,
|
||||
try_generate_missing_file=True,
|
||||
filters=filters
|
||||
)
|
||||
@@ -5,14 +5,14 @@ from scenarionet.verifier.error import ErrorFile
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--database_path", "-d", required=True, help="The path of the newly generated dataset")
|
||||
parser.add_argument("--database_path", "-d", required=True, help="The path of the newly generated database")
|
||||
parser.add_argument("--file", "-f", required=True, help="The path of the error file, should be xyz.json")
|
||||
parser.add_argument("--overwrite", action="store_true", help="If the database_path exists, overwrite it")
|
||||
parser.add_argument(
|
||||
"--broken",
|
||||
action="store_true",
|
||||
help="By default, only successful scenarios will be picked to build the new dataset. "
|
||||
"If turn on this flog, it will generate dataset containing only broken scenarios."
|
||||
help="By default, only successful scenarios will be picked to build the new database. "
|
||||
"If turn on this flog, it will generate database containing only broken scenarios."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ErrorFile.generate_dataset(args.file, args.database_path, args.overwrite, args.broken)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import pkg_resources # for suppress warning
|
||||
import argparse
|
||||
|
||||
from scenarionet.builder.filters import ScenarioFilter
|
||||
from scenarionet.builder.utils import merge_database
|
||||
|
||||
@@ -9,13 +9,13 @@ if __name__ == '__main__':
|
||||
"--database_path",
|
||||
"-d",
|
||||
required=True,
|
||||
help="The name of the new combined dataset. "
|
||||
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
|
||||
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
|
||||
"that directory into a dataset."
|
||||
help="The name of the new combined database. "
|
||||
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
|
||||
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
|
||||
"that directory into a database."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--from_datasets',
|
||||
'--from',
|
||||
required=True,
|
||||
nargs='+',
|
||||
default=[],
|
||||
@@ -25,32 +25,38 @@ if __name__ == '__main__':
|
||||
"--exist_ok",
|
||||
action="store_true",
|
||||
help="Still allow to write, if the dir exists already. "
|
||||
"This write will only create two .pkl files and this directory will become a dataset."
|
||||
"This write will only create two .pkl files and this directory will become a database."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, "
|
||||
"whether to overwrite both files"
|
||||
"whether to overwrite both files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter_moving_dist",
|
||||
action="store_true",
|
||||
help="add this flag to select cases with SDC moving dist > sdc_moving_dist_min"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sdc_moving_dist_min",
|
||||
default=20,
|
||||
default=5,
|
||||
type=float,
|
||||
help="Selecting case with sdc_moving_dist > this value. "
|
||||
"We will add more filter conditions in the future."
|
||||
"We will add more filter conditions in the future."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
target = args.sdc_moving_dist_min
|
||||
filters = [ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=target, condition="greater")]
|
||||
|
||||
if len(args.from_datasets) != 0:
|
||||
source = args.__getattribute__("from")
|
||||
if len(source) != 0:
|
||||
merge_database(
|
||||
args.database_path,
|
||||
*args.from_datasets,
|
||||
*source,
|
||||
exist_ok=args.exist_ok,
|
||||
overwrite=args.overwrite,
|
||||
try_generate_missing_file=True,
|
||||
filters=filters
|
||||
filters=filters if args.filter_moving_dist else []
|
||||
)
|
||||
else:
|
||||
raise ValueError("No source dataset are provided. Abort.")
|
||||
raise ValueError("No source database are provided. Abort.")
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import argparse
|
||||
|
||||
from scenarionet.builder.utils import move_database
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--from', required=True, help="Which dataset to move.")
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
required=True,
|
||||
help="The name of the new dataset. "
|
||||
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
|
||||
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
|
||||
"that directory into a dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exist_ok",
|
||||
action="store_true",
|
||||
help="Still allow to write, if the to_folder exists already. "
|
||||
"This write will only create two .pkl files and this directory will become a dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, "
|
||||
"whether to overwrite both files"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
from_path = args.__getattr__("from")
|
||||
move_database(
|
||||
from_path,
|
||||
args.to,
|
||||
exist_ok=args.exist_ok,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
18
scenarionet/num_scenarios.py
Normal file
18
scenarionet/num_scenarios.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import pkg_resources # for suppress warning
|
||||
import argparse
|
||||
import logging
|
||||
from scenarionet.common_utils import read_dataset_summary
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--database_path",
|
||||
"-d",
|
||||
required=True,
|
||||
help="Database to check number of scenarios"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
summary, _, _, = read_dataset_summary(args.database_path)
|
||||
logger.info("Number of scenarios: {}".format(len(summary)))
|
||||
@@ -7,7 +7,7 @@ from metadrive.scenario.utils import get_number_of_scenarios
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--database_path", "-d", required=True, help="The path of the dataset")
|
||||
parser.add_argument("--database_path", "-d", required=True, help="The path of the database")
|
||||
parser.add_argument("--render", action="store_true", help="Enable 3D rendering")
|
||||
parser.add_argument("--scenario_index", default=None, type=int, help="Specifying a scenario to run")
|
||||
args = parser.parse_args()
|
||||
@@ -38,13 +38,13 @@ if __name__ == '__main__':
|
||||
"data_directory": database_path,
|
||||
}
|
||||
)
|
||||
for seed in range(num_scenario if args.scenario_index is not None else 1000000):
|
||||
env.reset(force_seed=seed if args.scenario_index is None else args.scenario_index)
|
||||
for index in range(num_scenario if args.scenario_index is not None else 1000000):
|
||||
env.reset(force_seed=index if args.scenario_index is None else args.scenario_index)
|
||||
for t in range(10000):
|
||||
o, r, d, info = env.step([0, 0])
|
||||
if env.config["use_render"]:
|
||||
env.render(text={
|
||||
"seed": env.engine.global_seed + env.config["start_scenario_index"],
|
||||
"scenario index": env.engine.global_seed + env.config["start_scenario_index"],
|
||||
})
|
||||
|
||||
if d and info["arrive_dest"]:
|
||||
|
||||
47
scenarionet/split_database.py
Normal file
47
scenarionet/split_database.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
This script is for extracting a subset of data from an existing database
|
||||
"""
|
||||
import pkg_resources # for suppress warning
|
||||
import argparse
|
||||
|
||||
from scenarionet.builder.utils import split_database
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--from', required=True, help="Which database to extract data from.")
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
required=True,
|
||||
help="The name of the new database. "
|
||||
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
|
||||
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
|
||||
"that directory into a database."
|
||||
)
|
||||
parser.add_argument("--num_scenarios", type=int, default=64, help="how many scenarios to extract (default: 30)")
|
||||
parser.add_argument("--start_index", type=int, default=0, help="which index to start")
|
||||
parser.add_argument("--random", action="store_true", help="If set to true, it will choose scenarios randomly "
|
||||
"from all_scenarios[start_index:]. "
|
||||
"Otherwise, the scenarios will be selected sequentially")
|
||||
parser.add_argument(
|
||||
"--exist_ok",
|
||||
action="store_true",
|
||||
help="Still allow to write, if the to_folder exists already. "
|
||||
"This write will only create two .pkl files and this directory will become a database."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, "
|
||||
"whether to overwrite both files"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
from_path = args.__getattribute__("from")
|
||||
split_database(
|
||||
from_path,
|
||||
args.to,
|
||||
args.start_index,
|
||||
args.num_scenarios,
|
||||
exist_ok=args.exist_ok,
|
||||
overwrite=args.overwrite,
|
||||
random=args.random
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
python ../../merge_database.py --overwrite --exist_ok --database_path ../tmp/test_combine_dataset --from_datasets ../../../dataset/waymo ../../../dataset/pg ../../../dataset/nuscenes ../../../dataset/nuplan --overwrite
|
||||
python ../../verify_simulation.py --overwrite --database_path ../tmp/test_combine_dataset --error_file_path ../tmp/test_combine_dataset --random_drop --num_workers=16
|
||||
python ../../merge_database.py --overwrite --exist_ok --database_path ../tmp/test_combine_dataset --from ../../../dataset/waymo ../../../dataset/pg ../../../dataset/nuscenes ../../../dataset/nuplan --overwrite
|
||||
python ../../check_simulation.py --overwrite --database_path ../tmp/test_combine_dataset --error_file_path ../tmp/test_combine_dataset --random_drop --num_workers=16
|
||||
python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --database_path ../tmp/verify_pass
|
||||
python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --database_path ../tmp/verify_fail --broken
|
||||
@@ -32,8 +32,8 @@ done
|
||||
|
||||
# combine the datasets
|
||||
if [ "$overwrite" = true ]; then
|
||||
python -m scenarionet.scripts.merge_database --database_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --overwrite --exist_ok
|
||||
python -m scenarionet.scripts.merge_database --database_path $dataset_path --from $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --overwrite --exist_ok
|
||||
else
|
||||
python -m scenarionet.scripts.merge_database --database_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --exist_ok
|
||||
python -m scenarionet.scripts.merge_database --database_path $dataset_path --from $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --exist_ok
|
||||
fi
|
||||
|
||||
|
||||
119
scenarionet/tests/script/capture_pg.py
Normal file
119
scenarionet/tests/script/capture_pg.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import pygame
|
||||
from metadrive.envs.metadrive_env import MetaDriveEnv
|
||||
from metadrive.utils import setup_logger
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logger(True)
|
||||
env = MetaDriveEnv(
|
||||
{
|
||||
"num_scenarios": 1,
|
||||
"traffic_density": 0.15,
|
||||
"traffic_mode": "hybrid",
|
||||
"start_seed": 74,
|
||||
# "_disable_detector_mask":True,
|
||||
# "debug_physics_world": True,
|
||||
# "debug": True,
|
||||
# "global_light": False,
|
||||
# "debug_static_world": True,
|
||||
"show_interface": False,
|
||||
"cull_scene": False,
|
||||
"random_spawn_lane_index": False,
|
||||
"random_lane_width": False,
|
||||
# "image_observation": True,
|
||||
# "controller": "joystick",
|
||||
# "show_coordinates": True,
|
||||
"random_agent_model": False,
|
||||
"manual_control": True,
|
||||
"use_render": True,
|
||||
"accident_prob": 1,
|
||||
"decision_repeat": 5,
|
||||
"interface_panel": [],
|
||||
"need_inverse_traffic": False,
|
||||
"rgb_clip": True,
|
||||
"map": 2,
|
||||
# "agent_policy": ExpertPolicy,
|
||||
"random_traffic": False,
|
||||
# "random_lane_width": True,
|
||||
"driving_reward": 1.0,
|
||||
# "pstats": True,
|
||||
"force_destroy": False,
|
||||
# "show_skybox": False,
|
||||
"show_fps": False,
|
||||
"render_pipeline": True,
|
||||
# "camera_dist": 8,
|
||||
"window_size": (1600, 900),
|
||||
"camera_dist": 9,
|
||||
# "camera_pitch": 30,
|
||||
# "camera_height": 1,
|
||||
# "camera_smooth": False,
|
||||
# "camera_height": -1,
|
||||
"vehicle_config": {
|
||||
"enable_reverse": False,
|
||||
# "vehicle_model": "xl",
|
||||
# "rgb_camera": (1024, 1024),
|
||||
# "spawn_velocity": [8.728615581032535, -0.24411703918728195],
|
||||
"spawn_velocity_car_frame": True,
|
||||
# "image_source": "depth_camera",
|
||||
# "random_color": True
|
||||
# "show_lidar": True,
|
||||
"spawn_lane_index": None,
|
||||
# "destination":"2R1_3_",
|
||||
# "show_side_detector": True,
|
||||
# "show_lane_line_detector": True,
|
||||
# "side_detector": dict(num_lasers=2, distance=50),
|
||||
# "lane_line_detector": dict(num_lasers=2, distance=50),
|
||||
# "show_line_to_navi_mark": True,
|
||||
"show_navi_mark": False,
|
||||
# "show_dest_mark": True
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
o = env.reset()
|
||||
|
||||
|
||||
def capture():
|
||||
env.capture()
|
||||
ret = env.render(mode="topdown", screen_size=(1600, 900), film_size=(2000, 2000), track_target_vehicle=True)
|
||||
pygame.image.save(ret, "top_down_{}.png".format(env.current_seed))
|
||||
|
||||
env.engine.accept("c", capture)
|
||||
# env.main_camera.set_follow_lane(True)
|
||||
# env.vehicle.get_camera("rgb_camera").save_image(env.vehicle)
|
||||
# for line in env.engine.coordinate_line:
|
||||
# line.reparentTo(env.vehicle.origin)
|
||||
# env.vehicle.set_velocity([5, 0], in_local_frame=True)
|
||||
for s in range(1, 100000):
|
||||
# env.vehicle.set_velocity([1, 0], in_local_frame=True)
|
||||
o, r, d, info = env.step([0, 0])
|
||||
|
||||
# env.vehicle.set_pitch(-np.pi/4)
|
||||
# [0.09231533, 0.491018, 0.47076905, 0.7691619, 0.5, 0.5, 1.0, 0.0, 0.48037243, 0.8904728, 0.81229943, 0.7317231, 1.0, 0.85320455, 0.9747932, 0.65675277, 0.0, 0.5, 0.5]
|
||||
# else:
|
||||
# if s % 100 == 0:
|
||||
# env.close()
|
||||
# env.reset()
|
||||
# info["fuel"] = env.vehicle.energy_consumption
|
||||
# env.render(
|
||||
# text={
|
||||
# # "heading_diff": env.vehicle.heading_diff(env.vehicle.lane),
|
||||
# # "lane_width": env.vehicle.lane.width,
|
||||
# # "lane_index": env.vehicle.lane_index,
|
||||
# # "lateral": env.vehicle.lane.local_coordinates(env.vehicle.position),
|
||||
# "current_seed": env.current_seed
|
||||
# }
|
||||
# )
|
||||
# if d:
|
||||
# env.reset()
|
||||
# # assert env.observation_space.contains(o)
|
||||
# if (s + 1) % 100 == 0:
|
||||
# # print(
|
||||
# "Finish {}/10000 simulation steps. Time elapse: {:.4f}. Average FPS: {:.4f}".format(
|
||||
# s + 1,f
|
||||
# time.time() - start, (s + 1) / (time.time() - start)
|
||||
# )
|
||||
# )
|
||||
# if d:
|
||||
# # # env.close()
|
||||
# # # print(len(env.engine._spawned_objects))
|
||||
# env.reset()
|
||||
17
scenarionet/tests/script/compare_data.py
Normal file
17
scenarionet/tests/script/compare_data.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from metadrive.scenario.scenario_description import ScenarioDescription as SD
|
||||
from metadrive.scenario.utils import read_scenario_data, read_dataset_summary, assert_scenario_equal
|
||||
from scenarionet.common_utils import read_scenario
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_1 = "D:\\code\\scenarionet\\dataset\pg_2000"
|
||||
data_2 = "C:\\Users\\x1\\Desktop\\remote"
|
||||
summary_1, lookup_1, mapping_1 = read_dataset_summary(data_1)
|
||||
summary_2, lookup_2, mapping_2 = read_dataset_summary(data_2)
|
||||
# assert lookup_1[-10:] == lookup_2
|
||||
scenarios_1 = {}
|
||||
scenarios_2 = {}
|
||||
|
||||
for i in range(9):
|
||||
scenarios_1[str(i)] = read_scenario(data_1, mapping_1, lookup_1[-9+i])
|
||||
scenarios_2[str(i)] = read_scenario(data_2, mapping_2, lookup_2[i])
|
||||
# assert_scenario_equal(scenarios_1, scenarios_2, check_self_type=False, only_compare_sdc=True)
|
||||
97
scenarionet/tests/script/generate_sensor.py
Normal file
97
scenarionet/tests/script/generate_sensor.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import time
|
||||
import pygame
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
|
||||
|
||||
NuScenesEnv = ScenarioEnv
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = NuScenesEnv(
|
||||
{
|
||||
"use_render": True,
|
||||
"agent_policy": ReplayEgoCarPolicy,
|
||||
"show_interface": False,
|
||||
# "need_lane_localization": False,
|
||||
"show_logo": False,
|
||||
"no_traffic": False,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": False,
|
||||
"show_fps": False,
|
||||
# "debug": True,
|
||||
# "render_pipeline": True,
|
||||
"daytime": "11:01",
|
||||
"window_size": (1600, 900),
|
||||
"camera_dist": 0.8,
|
||||
"camera_height": 1.5,
|
||||
"camera_pitch": None,
|
||||
"camera_fov": 60,
|
||||
"start_scenario_index": 0,
|
||||
"num_scenarios": 10,
|
||||
# "force_reuse_object_name": True,
|
||||
# "data_directory": "/home/shady/Downloads/test_processed",
|
||||
"horizon": 1000,
|
||||
# "no_static_vehicles": True,
|
||||
# "show_policy_mark": True,
|
||||
# "show_coordinates": True,
|
||||
# "force_destroy": True,
|
||||
# "default_vehicle_in_traffic": True,
|
||||
"vehicle_config": dict(
|
||||
# light=True,
|
||||
# random_color=True,
|
||||
show_navi_mark=False,
|
||||
use_special_color=False,
|
||||
image_source="depth_camera",
|
||||
rgb_camera=(1600, 900),
|
||||
depth_camera=(1600, 900, True),
|
||||
# no_wheel_friction=True,
|
||||
lidar=dict(num_lasers=120, distance=50),
|
||||
lane_line_detector=dict(num_lasers=0, distance=50),
|
||||
side_detector=dict(num_lasers=12, distance=50)
|
||||
),
|
||||
"data_directory": AssetLoader.file_path("nuscenes", return_raw_style=False),
|
||||
"image_observation": True,
|
||||
}
|
||||
)
|
||||
|
||||
# 0,1,3,4,5,6
|
||||
|
||||
success = []
|
||||
reset_num = 0
|
||||
start = time.time()
|
||||
reset_used_time = 0
|
||||
s = 0
|
||||
while True:
|
||||
# for i in range(10):
|
||||
start_reset = time.time()
|
||||
env.reset(force_seed=0)
|
||||
|
||||
reset_used_time += time.time() - start_reset
|
||||
reset_num += 1
|
||||
for t in range(10000):
|
||||
if t==30:
|
||||
# env.capture("camera_deluxe.jpg")
|
||||
# ret = env.render(mode="topdown", screen_size=(1600, 900), film_size=(5000, 5000), track_target_vehicle=True)
|
||||
# pygame.image.save(ret, "top_down.png")
|
||||
env.vehicle.get_camera("depth_camera").save_image(env.vehicle, "camera.jpg")
|
||||
o, r, d, info = env.step([1, 0.88])
|
||||
assert env.observation_space.contains(o)
|
||||
s += 1
|
||||
# if env.config["use_render"]:
|
||||
# env.render(text={"seed": env.current_seed,
|
||||
# # "num_map": info["num_stored_maps"],
|
||||
# "data_coverage": info["data_coverage"],
|
||||
# "reward": r,
|
||||
# "heading_r": info["step_reward_heading"],
|
||||
# "lateral_r": info["step_reward_lateral"],
|
||||
# "smooth_action_r": info["step_reward_action_smooth"]})
|
||||
if d:
|
||||
print(
|
||||
"Time elapse: {:.4f}. Average FPS: {:.4f}, AVG_Reset_time: {:.4f}".format(
|
||||
time.time() - start, s / (time.time() - start - reset_used_time),
|
||||
reset_used_time / reset_num
|
||||
)
|
||||
)
|
||||
print("seed:{}, success".format(env.engine.global_random_seed))
|
||||
print(list(env.engine.curriculum_manager.recent_success.dict.values()))
|
||||
break
|
||||
104
scenarionet/tests/script/replay_origin.py
Normal file
104
scenarionet/tests/script/replay_origin.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import time
|
||||
|
||||
import pygame
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
|
||||
|
||||
NuScenesEnv = ScenarioEnv
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = NuScenesEnv(
|
||||
{
|
||||
"use_render": True,
|
||||
"agent_policy": ReplayEgoCarPolicy,
|
||||
"show_interface": False,
|
||||
"image_observation": False,
|
||||
"show_logo": False,
|
||||
"no_traffic": False,
|
||||
"drivable_region_extension": 15,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": False,
|
||||
"show_fps": False,
|
||||
# "debug": True,
|
||||
"render_pipeline": True,
|
||||
"daytime": "19:30",
|
||||
"window_size": (1600, 900),
|
||||
"camera_dist": 9,
|
||||
# "camera_height": 1.5,
|
||||
# "camera_pitch": None,
|
||||
# "camera_fov": 60,
|
||||
"start_scenario_index": 0,
|
||||
"num_scenarios": 4,
|
||||
# "force_reuse_object_name": True,
|
||||
# "data_directory": "/home/shady/Downloads/test_processed",
|
||||
"horizon": 1000,
|
||||
# "no_static_vehicles": True,
|
||||
# "show_policy_mark": True,
|
||||
# "show_coordinates": True,
|
||||
# "force_destroy": True,
|
||||
# "default_vehicle_in_traffic": True,
|
||||
"vehicle_config": dict(
|
||||
# light=True,
|
||||
# random_color=True,
|
||||
show_navi_mark=False,
|
||||
use_special_color=False,
|
||||
image_source="depth_camera",
|
||||
# rgb_camera=(1600, 900),
|
||||
# depth_camera=(1600, 900, True),
|
||||
# no_wheel_friction=True,
|
||||
lidar=dict(num_lasers=120, distance=50),
|
||||
lane_line_detector=dict(num_lasers=0, distance=50),
|
||||
side_detector=dict(num_lasers=12, distance=50)
|
||||
),
|
||||
"data_directory": "D:\\code\\scenarionet\\scenarionet\\tests\\script\\waymo_scenes_adv"
|
||||
}
|
||||
)
|
||||
|
||||
# 0,1,3,4,5,6
|
||||
|
||||
success = []
|
||||
reset_num = 0
|
||||
start = time.time()
|
||||
reset_used_time = 0
|
||||
s = 0
|
||||
|
||||
env.reset()
|
||||
|
||||
|
||||
def capture():
|
||||
env.capture()
|
||||
ret = env.render(mode="topdown", screen_size=(1600, 900), film_size=(7000, 7000), track_target_vehicle=True)
|
||||
pygame.image.save(ret, "top_down_{}.png".format(env.current_seed))
|
||||
|
||||
|
||||
env.engine.accept("c", capture)
|
||||
|
||||
while True:
|
||||
# for i in range(10):
|
||||
start_reset = time.time()
|
||||
env.reset()
|
||||
|
||||
reset_used_time += time.time() - start_reset
|
||||
reset_num += 1
|
||||
for t in range(10000):
|
||||
o, r, d, info = env.step([1, 0.88])
|
||||
assert env.observation_space.contains(o)
|
||||
s += 1
|
||||
# if env.config["use_render"]:
|
||||
# env.render(text={"seed": env.current_seed,
|
||||
# # "num_map": info["num_stored_maps"],
|
||||
# "data_coverage": info["data_coverage"],
|
||||
# "reward": r,
|
||||
# "heading_r": info["step_reward_heading"],
|
||||
# "lateral_r": info["step_reward_lateral"],
|
||||
# "smooth_action_r": info["step_reward_action_smooth"]})
|
||||
if d:
|
||||
print(
|
||||
"Time elapse: {:.4f}. Average FPS: {:.4f}, AVG_Reset_time: {:.4f}".format(
|
||||
time.time() - start, s / (time.time() - start - reset_used_time),
|
||||
reset_used_time / reset_num
|
||||
)
|
||||
)
|
||||
print("seed:{}, success".format(env.engine.global_random_seed))
|
||||
print(list(env.engine.curriculum_manager.recent_success.dict.values()))
|
||||
break
|
||||
@@ -50,9 +50,9 @@ if __name__ == '__main__':
|
||||
long, lat, = c_lane.local_coordinates(env.vehicle.position)
|
||||
if env.config["use_render"]:
|
||||
env.render(text={
|
||||
"seed": env.engine.global_seed + env.config["start_scenario_index"],
|
||||
"scenario index": env.engine.global_seed + env.config["start_scenario_index"],
|
||||
})
|
||||
|
||||
if d and info["arrive_dest"]:
|
||||
print("seed:{}, success".format(env.engine.global_random_seed))
|
||||
print("scenario index:{}, success".format(env.engine.global_random_seed))
|
||||
break
|
||||
|
||||
Binary file not shown.
37
scenarionet/tests/test_filter_overpass.py
Normal file
37
scenarionet/tests/test_filter_overpass.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
import os.path
|
||||
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
|
||||
from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH
|
||||
from scenarionet.builder.filters import ScenarioFilter
|
||||
from scenarionet.builder.utils import merge_database
|
||||
|
||||
|
||||
def test_filter_overpass():
|
||||
overpass_1 = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", "overpass")
|
||||
overpass_in_md = AssetLoader.file_path("waymo", return_raw_style=False)
|
||||
dataset_paths = [overpass_1, overpass_in_md]
|
||||
|
||||
output_path = os.path.join(TMP_PATH, "combine")
|
||||
merge_database(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True)
|
||||
|
||||
# filter
|
||||
filters = []
|
||||
filters.append(ScenarioFilter.make(ScenarioFilter.no_overpass))
|
||||
|
||||
summaries, _ = merge_database(
|
||||
output_path,
|
||||
*dataset_paths,
|
||||
exist_ok=True,
|
||||
overwrite=True,
|
||||
try_generate_missing_file=True,
|
||||
filters=filters
|
||||
)
|
||||
assert len(summaries) == 3
|
||||
for scenario in summaries:
|
||||
assert scenario != "sd_waymo_v1.2_eb4b91b10ca94ff2.pkl"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_filter_overpass()
|
||||
@@ -4,13 +4,13 @@ import os.path
|
||||
import pytest
|
||||
|
||||
from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH
|
||||
from scenarionet.builder.utils import move_database, merge_database
|
||||
from scenarionet.builder.utils import copy_database, merge_database
|
||||
from scenarionet.common_utils import read_dataset_summary, read_scenario
|
||||
from scenarionet.verifier.utils import verify_database
|
||||
|
||||
|
||||
@pytest.mark.order("first")
|
||||
def test_move_database():
|
||||
def test_copy_database():
|
||||
dataset_name = "nuscenes"
|
||||
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
||||
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
|
||||
@@ -19,7 +19,7 @@ def test_move_database():
|
||||
# move
|
||||
for k, from_path in enumerate(dataset_paths):
|
||||
to = os.path.join(TMP_PATH, str(k))
|
||||
move_database(from_path, to)
|
||||
copy_database(from_path, to)
|
||||
moved_path.append(to)
|
||||
assert os.path.exists(from_path)
|
||||
merge_database(output_path, *moved_path, exist_ok=True, overwrite=True, try_generate_missing_file=True)
|
||||
@@ -37,7 +37,7 @@ def test_move_database():
|
||||
for k, from_path in enumerate(moved_path):
|
||||
new_p = os.path.join(TMP_PATH, str(k) + str(k))
|
||||
new_move_pathes.append(new_p)
|
||||
move_database(from_path, new_p, exist_ok=True, overwrite=True)
|
||||
copy_database(from_path, new_p, exist_ok=True, overwrite=True)
|
||||
assert not os.path.exists(from_path)
|
||||
merge_database(output_path, *new_move_pathes, exist_ok=True, overwrite=True, try_generate_missing_file=True)
|
||||
# verify
|
||||
@@ -51,4 +51,4 @@ def test_move_database():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_move_database()
|
||||
test_copy_database()
|
||||
|
||||
37
scenarionet/tests/test_split_dataset.py
Normal file
37
scenarionet/tests/test_split_dataset.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
import os.path
|
||||
|
||||
from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH
|
||||
from scenarionet.builder.utils import merge_database, split_database
|
||||
from scenarionet.common_utils import read_dataset_summary
|
||||
|
||||
|
||||
def test_split_dataset():
|
||||
dataset_name = "nuscenes"
|
||||
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
|
||||
test_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset")
|
||||
dataset_paths = [original_dataset_path + "_{}".format(i) for i in [0, 1, 3, 4]]
|
||||
|
||||
output_path = os.path.join(TMP_PATH, "combine")
|
||||
merge_database(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True)
|
||||
|
||||
# split
|
||||
from_path = output_path
|
||||
to_path = os.path.join(TMP_PATH, "split", "split")
|
||||
summary_1, lookup_1, mapping_1 = read_dataset_summary(from_path)
|
||||
|
||||
split_database(from_path, to_path, start_index=3, random=True, num_scenarios=4, overwrite=True, exist_ok=True)
|
||||
summary_2, lookup_2, mapping_2 = read_dataset_summary(to_path)
|
||||
assert len(summary_2) == 4
|
||||
for scenario in summary_2:
|
||||
assert scenario not in lookup_1[:3]
|
||||
|
||||
split_database(from_path, to_path, start_index=3, num_scenarios=4, overwrite=True, exist_ok=True)
|
||||
summary_2, lookup_2, mapping_2 = read_dataset_summary(to_path)
|
||||
assert lookup_1[3:7] == lookup_2
|
||||
for k in range(4):
|
||||
assert summary_1[lookup_2[k]] == summary_2[lookup_2[k]]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_split_dataset()
|
||||
@@ -53,12 +53,12 @@ class ErrorFile:
|
||||
@classmethod
|
||||
def generate_dataset(cls, error_file_path, new_dataset_path, overwrite=False, broken_scenario=False):
|
||||
"""
|
||||
Generate a new dataset containing all broken scenarios or all good scenarios
|
||||
Generate a new database containing all broken scenarios or all good scenarios
|
||||
:param error_file_path: error file path
|
||||
:param new_dataset_path: a directory where you want to store your data
|
||||
:param overwrite: if new_dataset_path exists, whether to overwrite
|
||||
:param broken_scenario: generate broken scenarios. You can generate such a broken scenarios for debugging
|
||||
:return: dataset summary, dataset mapping
|
||||
:return: database summary, database mapping
|
||||
"""
|
||||
new_dataset_path = os.path.abspath(new_dataset_path)
|
||||
if os.path.exists(new_dataset_path):
|
||||
|
||||
@@ -108,6 +108,7 @@ def loading_into_metadrive(
|
||||
"agent_policy": ReplayEgoCarPolicy,
|
||||
"num_scenarios": num_scenario,
|
||||
"horizon": 1000,
|
||||
"store_map": False,
|
||||
"start_scenario_index": start_scenario_index,
|
||||
"no_static_vehicles": False,
|
||||
"data_directory": dataset_path,
|
||||
|
||||
0
scenarionet_training/__init__.py
Normal file
0
scenarionet_training/__init__.py
Normal file
0
scenarionet_training/scripts/__init__.py
Normal file
0
scenarionet_training/scripts/__init__.py
Normal file
23
scenarionet_training/scripts/evaluate_nuplan.py
Normal file
23
scenarionet_training/scripts/evaluate_nuplan.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from scenarionet_training.scripts.train_nuplan import config
|
||||
from scenarionet_training.train_utils.utils import eval_ckpt
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 27 29 30 37 39
|
||||
ckpt_path = "C:\\Users\\x1\\Desktop\\checkpoint_510\\checkpoint-510"
|
||||
scenario_data_path = "D:\\scenarionet_testset\\nuplan_test\\nuplan_test_w_raw"
|
||||
num_scenarios = 2000
|
||||
start_scenario_index = 0
|
||||
horizon = 600
|
||||
render = False
|
||||
explore = True # PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance
|
||||
log_interval = 10
|
||||
|
||||
eval_ckpt(config,
|
||||
ckpt_path,
|
||||
scenario_data_path,
|
||||
num_scenarios,
|
||||
start_scenario_index,
|
||||
horizon,
|
||||
render,
|
||||
explore,
|
||||
log_interval)
|
||||
27
scenarionet_training/scripts/evaluate_pg.py
Normal file
27
scenarionet_training/scripts/evaluate_pg.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import os.path
|
||||
|
||||
from scenarionet import SCENARIONET_DATASET_PATH
|
||||
from scenarionet_training.scripts.train_pg import config
|
||||
from scenarionet_training.train_utils.utils import eval_ckpt
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Merge all evaluate script
|
||||
# 10/15/20/26/30/31/32
|
||||
ckpt_path = "C:\\Users\\x1\\Desktop\\checkpoint_330\\checkpoint-330"
|
||||
scenario_data_path = os.path.join(SCENARIONET_DATASET_PATH, "pg_2000")
|
||||
num_scenarios = 2000
|
||||
start_scenario_index = 0
|
||||
horizon = 600
|
||||
render = False
|
||||
explore = True # PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance
|
||||
log_interval = 2
|
||||
|
||||
eval_ckpt(config,
|
||||
ckpt_path,
|
||||
scenario_data_path,
|
||||
num_scenarios,
|
||||
start_scenario_index,
|
||||
horizon,
|
||||
render,
|
||||
explore,
|
||||
log_interval)
|
||||
22
scenarionet_training/scripts/evaluate_waymo.py
Normal file
22
scenarionet_training/scripts/evaluate_waymo.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from scenarionet_training.scripts.train_waymo import config
|
||||
from scenarionet_training.train_utils.utils import eval_ckpt
|
||||
|
||||
if __name__ == '__main__':
|
||||
ckpt_path = "C:\\Users\\x1\\Desktop\\checkpoint_170\\checkpoint-170"
|
||||
scenario_data_path = "D:\\scenarionet_testset\\waymo_test_raw_data"
|
||||
num_scenarios = 2000
|
||||
start_scenario_index = 0
|
||||
horizon = 600
|
||||
render = True
|
||||
explore = True # PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance
|
||||
log_interval = 2
|
||||
|
||||
eval_ckpt(config,
|
||||
ckpt_path,
|
||||
scenario_data_path,
|
||||
num_scenarios,
|
||||
start_scenario_index,
|
||||
horizon,
|
||||
render,
|
||||
explore,
|
||||
log_interval)
|
||||
80
scenarionet_training/scripts/local_test.py
Normal file
80
scenarionet_training/scripts/local_test.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import os.path
|
||||
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
|
||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
||||
|
||||
if __name__ == '__main__':
|
||||
env = ScenarioEnv
|
||||
args = get_train_parser().parse_args()
|
||||
exp_name = get_exp_name(args)
|
||||
stop = int(100_000_000)
|
||||
|
||||
config = dict(
|
||||
env=env,
|
||||
env_config=dict(
|
||||
# scenario
|
||||
start_scenario_index=0,
|
||||
num_scenarios=32,
|
||||
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg"),
|
||||
sequential_seed=True,
|
||||
|
||||
# traffic & light
|
||||
reactive_traffic=False,
|
||||
no_static_vehicles=True,
|
||||
no_light=True,
|
||||
static_traffic_object=True,
|
||||
|
||||
# curriculum training
|
||||
curriculum_level=4,
|
||||
target_success_rate=0.8,
|
||||
|
||||
# training
|
||||
horizon=None,
|
||||
use_lateral_reward=True,
|
||||
),
|
||||
|
||||
# # ===== Evaluation =====
|
||||
evaluation_interval=2,
|
||||
evaluation_num_episodes=32,
|
||||
evaluation_config=dict(env_config=dict(start_scenario_index=32,
|
||||
num_scenarios=32,
|
||||
sequential_seed=True,
|
||||
curriculum_level=1, # turn off
|
||||
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg"))),
|
||||
evaluation_num_workers=2,
|
||||
metrics_smoothing_episodes=10,
|
||||
|
||||
# ===== Training =====
|
||||
model=dict(fcnet_hiddens=[512, 256, 128]),
|
||||
horizon=600,
|
||||
num_sgd_iter=20,
|
||||
lr=5e-5,
|
||||
rollout_fragment_length=500,
|
||||
sgd_minibatch_size=100,
|
||||
train_batch_size=4000,
|
||||
num_gpus=0.5 if args.num_gpus != 0 else 0,
|
||||
num_cpus_per_worker=0.4,
|
||||
num_cpus_for_driver=1,
|
||||
num_workers=2,
|
||||
framework="tf"
|
||||
)
|
||||
|
||||
train(
|
||||
MultiWorkerPPO,
|
||||
exp_name=exp_name,
|
||||
save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"),
|
||||
keep_checkpoints_num=5,
|
||||
stop=stop,
|
||||
config=config,
|
||||
num_gpus=args.num_gpus,
|
||||
# num_seeds=args.num_seeds,
|
||||
num_seeds=1,
|
||||
# test_mode=args.test,
|
||||
# local_mode=True,
|
||||
# TODO remove this when we release our code
|
||||
# wandb_key_file="~/wandb_api_key_file.txt",
|
||||
wandb_project="scenarionet",
|
||||
)
|
||||
74
scenarionet_training/scripts/multi_worker_eval.py
Normal file
74
scenarionet_training/scripts/multi_worker_eval.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import argparse
|
||||
import pickle
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scenarionet_training.scripts.train_nuplan import config
|
||||
from scenarionet_training.train_utils.callbacks import DrivingCallbacks
|
||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||
from scenarionet_training.train_utils.utils import initialize_ray
|
||||
|
||||
|
||||
class NumpyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, np.int32):
|
||||
return int(obj)
|
||||
elif isinstance(obj, np.int64):
|
||||
return int(obj)
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--start_index", type=int, default=0)
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--database_path", type=str, required=True)
|
||||
parser.add_argument("--id", type=str, default="")
|
||||
parser.add_argument("--num_scenarios", type=int, default=5000)
|
||||
parser.add_argument("--num_workers", type=int, default=10)
|
||||
parser.add_argument("--horizon", type=int, default=600)
|
||||
parser.add_argument("--allowed_more_steps", type=int, default=50)
|
||||
parser.add_argument("--max_lateral_dist", type=int, default=2.5)
|
||||
parser.add_argument("--overwrite", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
file = "eval_{}_{}_{}".format(args.id, os.path.basename(args.ckpt_path), os.path.basename(args.database_path))
|
||||
if os.path.exists(file) and not args.overwrite:
|
||||
raise FileExistsError("Please remove {} or set --overwrite".format(file))
|
||||
initialize_ray(test_mode=True, num_gpus=1)
|
||||
|
||||
config["callbacks"] = DrivingCallbacks
|
||||
config["evaluation_num_workers"] = args.num_workers
|
||||
config["evaluation_num_episodes"] = args.num_scenarios
|
||||
config["metrics_smoothing_episodes"] = args.num_scenarios
|
||||
config["custom_eval_function"] = None
|
||||
config["num_workers"] = 0
|
||||
config["evaluation_config"]["env_config"].update(dict(
|
||||
start_scenario_index=args.start_index,
|
||||
num_scenarios=args.num_scenarios,
|
||||
sequential_seed=True,
|
||||
store_map=False,
|
||||
store_data=False,
|
||||
allowed_more_steps=args.allowed_more_steps,
|
||||
# no_map=True,
|
||||
max_lateral_dist=args.max_lateral_dist,
|
||||
curriculum_level=1, # disable curriculum
|
||||
target_success_rate=1,
|
||||
horizon=args.horizon,
|
||||
episodes_to_evaluate_curriculum=args.num_scenarios,
|
||||
data_directory=args.database_path,
|
||||
use_render=False))
|
||||
|
||||
trainer = MultiWorkerPPO(config)
|
||||
trainer.restore(args.ckpt_path)
|
||||
|
||||
ret = trainer._evaluate()["evaluation"]
|
||||
with open(file + ".json", "w") as f:
|
||||
json.dump(ret, f, cls=NumpyEncoder)
|
||||
|
||||
with open(file + ".pkl", "wb+") as f:
|
||||
pickle.dump(ret, f)
|
||||
96
scenarionet_training/scripts/train_nuplan.py
Normal file
96
scenarionet_training/scripts/train_nuplan.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os.path
|
||||
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
||||
|
||||
config = dict(
|
||||
env=ScenarioEnv,
|
||||
env_config=dict(
|
||||
# scenario
|
||||
start_scenario_index=0,
|
||||
num_scenarios=40000,
|
||||
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "nuplan_train"),
|
||||
sequential_seed=True,
|
||||
|
||||
# curriculum training
|
||||
curriculum_level=100,
|
||||
target_success_rate=0.8, # or 0.7
|
||||
# episodes_to_evaluate_curriculum=400, # default=num_scenarios/curriculum_level
|
||||
|
||||
# traffic & light
|
||||
reactive_traffic=True,
|
||||
no_static_vehicles=True,
|
||||
no_light=True,
|
||||
static_traffic_object=True,
|
||||
|
||||
# training scheme
|
||||
horizon=None,
|
||||
driving_reward=4,
|
||||
steering_range_penalty=1.0,
|
||||
heading_penalty=2,
|
||||
lateral_penalty=2.0,
|
||||
no_negative_reward=True,
|
||||
on_lane_line_penalty=0,
|
||||
crash_vehicle_penalty=2,
|
||||
crash_human_penalty=2,
|
||||
crash_object_penalty=0.5,
|
||||
# out_of_road_penalty=2,
|
||||
max_lateral_dist=2,
|
||||
# crash_vehicle_done=True,
|
||||
|
||||
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
||||
|
||||
),
|
||||
|
||||
# ===== Evaluation =====
|
||||
evaluation_interval=15,
|
||||
evaluation_num_episodes=1000,
|
||||
# TODO (LQY), this is a sample from testset do eval on all scenarios after training!
|
||||
evaluation_config=dict(env_config=dict(start_scenario_index=0,
|
||||
num_scenarios=1000,
|
||||
sequential_seed=True,
|
||||
curriculum_level=1, # turn off
|
||||
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "nuplan_test"))),
|
||||
evaluation_num_workers=10,
|
||||
metrics_smoothing_episodes=10,
|
||||
|
||||
# ===== Training =====
|
||||
model=dict(fcnet_hiddens=[512, 256, 128]),
|
||||
horizon=600,
|
||||
num_sgd_iter=20,
|
||||
lr=1e-4,
|
||||
rollout_fragment_length=500,
|
||||
sgd_minibatch_size=200,
|
||||
train_batch_size=50000,
|
||||
num_gpus=0.5,
|
||||
num_cpus_per_worker=0.3,
|
||||
num_cpus_for_driver=1,
|
||||
num_workers=20,
|
||||
framework="tf"
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# PG data is generated with seeds 10,000 to 60,000
|
||||
args = get_train_parser().parse_args()
|
||||
exp_name = get_exp_name(args)
|
||||
stop = int(100_000_000)
|
||||
config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0
|
||||
|
||||
train(
|
||||
MultiWorkerPPO,
|
||||
exp_name=exp_name,
|
||||
save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"),
|
||||
keep_checkpoints_num=5,
|
||||
stop=stop,
|
||||
config=config,
|
||||
num_gpus=args.num_gpus,
|
||||
# num_seeds=args.num_seeds,
|
||||
num_seeds=5,
|
||||
test_mode=args.test,
|
||||
# local_mode=True,
|
||||
# TODO remove this when we release our code
|
||||
# wandb_key_file="~/wandb_api_key_file.txt",
|
||||
wandb_project="scenarionet",
|
||||
)
|
||||
95
scenarionet_training/scripts/train_pg.py
Normal file
95
scenarionet_training/scripts/train_pg.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import os.path
|
||||
from ray.tune import grid_search
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
|
||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
||||
|
||||
config = dict(
|
||||
env=ScenarioEnv,
|
||||
env_config=dict(
|
||||
# scenario
|
||||
start_scenario_index=0,
|
||||
num_scenarios=40000,
|
||||
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg_train"),
|
||||
sequential_seed=True,
|
||||
|
||||
# curriculum training
|
||||
curriculum_level=100,
|
||||
target_success_rate=0.8,
|
||||
# episodes_to_evaluate_curriculum=400, # default=num_scenarios/curriculum_level
|
||||
|
||||
# traffic & light
|
||||
reactive_traffic=False,
|
||||
no_static_vehicles=True,
|
||||
no_light=True,
|
||||
static_traffic_object=True,
|
||||
|
||||
# training scheme
|
||||
horizon=None,
|
||||
steering_range_penalty=2,
|
||||
heading_penalty=1.0,
|
||||
lateral_penalty=1.0,
|
||||
no_negative_reward=True,
|
||||
on_lane_line_penalty=0,
|
||||
crash_vehicle_penalty=2,
|
||||
crash_human_penalty=2,
|
||||
out_of_road_penalty=2,
|
||||
max_lateral_dist=2,
|
||||
# crash_vehicle_done=True,
|
||||
|
||||
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
||||
|
||||
),
|
||||
|
||||
# ===== Evaluation =====
|
||||
evaluation_interval=15,
|
||||
evaluation_num_episodes=1000,
|
||||
# TODO (LQY), this is a sample from testset do eval on all scenarios after training!
|
||||
evaluation_config=dict(env_config=dict(start_scenario_index=0,
|
||||
num_scenarios=1000,
|
||||
sequential_seed=True,
|
||||
curriculum_level=1, # turn off
|
||||
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg_test"))),
|
||||
evaluation_num_workers=10,
|
||||
metrics_smoothing_episodes=10,
|
||||
|
||||
# ===== Training =====
|
||||
model=dict(fcnet_hiddens=[512, 256, 128]),
|
||||
horizon=600,
|
||||
num_sgd_iter=20,
|
||||
lr=1e-4,
|
||||
rollout_fragment_length=500,
|
||||
sgd_minibatch_size=200,
|
||||
train_batch_size=50000,
|
||||
num_gpus=0.5,
|
||||
num_cpus_per_worker=0.3,
|
||||
num_cpus_for_driver=1,
|
||||
num_workers=20,
|
||||
framework="tf"
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# PG data is generated with seeds 10,000 to 60,000
|
||||
args = get_train_parser().parse_args()
|
||||
exp_name = get_exp_name(args)
|
||||
stop = int(100_000_000)
|
||||
config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0
|
||||
|
||||
train(
|
||||
MultiWorkerPPO,
|
||||
exp_name=exp_name,
|
||||
save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"),
|
||||
keep_checkpoints_num=5,
|
||||
stop=stop,
|
||||
config=config,
|
||||
num_gpus=args.num_gpus,
|
||||
# num_seeds=args.num_seeds,
|
||||
num_seeds=5,
|
||||
test_mode=args.test,
|
||||
# local_mode=True,
|
||||
# TODO remove this when we release our code
|
||||
# wandb_key_file="~/wandb_api_key_file.txt",
|
||||
wandb_project="scenarionet",
|
||||
)
|
||||
95
scenarionet_training/scripts/train_waymo.py
Normal file
95
scenarionet_training/scripts/train_waymo.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import os.path
|
||||
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
||||
|
||||
config = dict(
|
||||
env=ScenarioEnv,
|
||||
env_config=dict(
|
||||
# scenario
|
||||
start_scenario_index=0,
|
||||
num_scenarios=40000,
|
||||
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "waymo_train"),
|
||||
sequential_seed=True,
|
||||
|
||||
# curriculum training
|
||||
curriculum_level=100,
|
||||
target_success_rate=0.8,
|
||||
# episodes_to_evaluate_curriculum=400, # default=num_scenarios/curriculum_level
|
||||
|
||||
# traffic & light
|
||||
reactive_traffic=True,
|
||||
no_static_vehicles=True,
|
||||
no_light=True,
|
||||
static_traffic_object=True,
|
||||
|
||||
# training scheme
|
||||
horizon=None,
|
||||
driving_reward=1,
|
||||
steering_range_penalty=0,
|
||||
heading_penalty=1,
|
||||
lateral_penalty=1.0,
|
||||
no_negative_reward=True,
|
||||
on_lane_line_penalty=0,
|
||||
crash_vehicle_penalty=2,
|
||||
crash_human_penalty=2,
|
||||
out_of_road_penalty=2,
|
||||
max_lateral_dist=2,
|
||||
# crash_vehicle_done=True,
|
||||
|
||||
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
||||
|
||||
),
|
||||
|
||||
# ===== Evaluation =====
|
||||
evaluation_interval=15,
|
||||
evaluation_num_episodes=1000,
|
||||
# TODO (LQY), this is a sample from testset do eval on all scenarios after training!
|
||||
evaluation_config=dict(env_config=dict(start_scenario_index=0,
|
||||
num_scenarios=1000,
|
||||
sequential_seed=True,
|
||||
curriculum_level=1, # turn off
|
||||
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "waymo_test"))),
|
||||
evaluation_num_workers=10,
|
||||
metrics_smoothing_episodes=10,
|
||||
|
||||
# ===== Training =====
|
||||
model=dict(fcnet_hiddens=[512, 256, 128]),
|
||||
horizon=600,
|
||||
num_sgd_iter=20,
|
||||
lr=1e-4,
|
||||
rollout_fragment_length=500,
|
||||
sgd_minibatch_size=200,
|
||||
train_batch_size=50000,
|
||||
num_gpus=0.5,
|
||||
num_cpus_per_worker=0.3,
|
||||
num_cpus_for_driver=1,
|
||||
num_workers=20,
|
||||
framework="tf"
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# PG data is generated with seeds 10,000 to 60,000
|
||||
args = get_train_parser().parse_args()
|
||||
exp_name = get_exp_name(args)
|
||||
stop = int(100_000_000)
|
||||
config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0
|
||||
|
||||
train(
|
||||
MultiWorkerPPO,
|
||||
exp_name=exp_name,
|
||||
save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"),
|
||||
keep_checkpoints_num=5,
|
||||
stop=stop,
|
||||
config=config,
|
||||
num_gpus=args.num_gpus,
|
||||
# num_seeds=args.num_seeds,
|
||||
num_seeds=5,
|
||||
test_mode=args.test,
|
||||
# local_mode=True,
|
||||
# TODO remove this when we release our code
|
||||
# wandb_key_file="~/wandb_api_key_file.txt",
|
||||
wandb_project="scenarionet",
|
||||
)
|
||||
0
scenarionet_training/train_utils/__init__.py
Normal file
0
scenarionet_training/train_utils/__init__.py
Normal file
42
scenarionet_training/train_utils/anisotropic_workerset.py
Normal file
42
scenarionet_training/train_utils/anisotropic_workerset.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import copy
|
||||
import logging
|
||||
from typing import TypeVar
|
||||
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Generic type var for foreach_* methods.
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class AnisotropicWorkerSet(WorkerSet):
|
||||
"""
|
||||
Workers are assigned to different scenarios for saving memory/speeding up sampling
|
||||
"""
|
||||
|
||||
def add_workers(self, num_workers: int) -> None:
|
||||
"""
|
||||
Workers are assigned to different scenarios
|
||||
"""
|
||||
remote_args = {
|
||||
"num_cpus": self._remote_config["num_cpus_per_worker"],
|
||||
"num_gpus": self._remote_config["num_gpus_per_worker"],
|
||||
# memory=0 is an error, but memory=None means no limits.
|
||||
"memory": self._remote_config["memory_per_worker"] or None,
|
||||
"object_store_memory": self.
|
||||
_remote_config["object_store_memory_per_worker"] or None,
|
||||
"resources": self._remote_config["custom_resources_per_worker"],
|
||||
}
|
||||
cls = RolloutWorker.as_remote(**remote_args).remote
|
||||
for i in range(num_workers):
|
||||
config = copy.deepcopy(self._remote_config)
|
||||
config["env_config"]["worker_index"] = i
|
||||
config["env_config"]["num_workers"] = num_workers
|
||||
self._remote_workers.append(self._make_worker(cls, self._env_creator, self._policy_class, i + 1, config))
|
||||
110
scenarionet_training/train_utils/callbacks.py
Normal file
110
scenarionet_training/train_utils/callbacks.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.env import BaseEnv
|
||||
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
||||
from ray.rllib.policy import Policy
|
||||
|
||||
|
||||
class DrivingCallbacks(DefaultCallbacks):
|
||||
def on_episode_start(
|
||||
self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode,
|
||||
env_index: int, **kwargs
|
||||
):
|
||||
episode.user_data["velocity"] = []
|
||||
episode.user_data["steering"] = []
|
||||
episode.user_data["step_reward"] = []
|
||||
episode.user_data["acceleration"] = []
|
||||
episode.user_data["lateral_dist"] = []
|
||||
episode.user_data["cost"] = []
|
||||
episode.user_data["num_crash_vehicle"] = []
|
||||
episode.user_data["num_crash_human"] = []
|
||||
episode.user_data["num_crash_object"] = []
|
||||
episode.user_data["num_on_line"] = []
|
||||
|
||||
episode.user_data["step_reward_lateral"] = []
|
||||
episode.user_data["step_reward_heading"] = []
|
||||
episode.user_data["step_reward_action_smooth"] = []
|
||||
|
||||
def on_episode_step(
|
||||
self, *, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode, env_index: int, **kwargs
|
||||
):
|
||||
info = episode.last_info_for()
|
||||
if info is not None:
|
||||
episode.user_data["velocity"].append(info["velocity"])
|
||||
episode.user_data["steering"].append(info["steering"])
|
||||
episode.user_data["step_reward"].append(info["step_reward"])
|
||||
episode.user_data["acceleration"].append(info["acceleration"])
|
||||
episode.user_data["lateral_dist"].append(info["lateral_dist"])
|
||||
episode.user_data["cost"].append(info["cost"])
|
||||
for x in ["num_crash_vehicle", "num_crash_object", "num_crash_human", "num_on_line"]:
|
||||
episode.user_data[x].append(info[x])
|
||||
|
||||
for x in ["step_reward_lateral", "step_reward_heading", "step_reward_action_smooth"]:
|
||||
episode.user_data[x].append(info[x])
|
||||
|
||||
def on_episode_end(
|
||||
self, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode,
|
||||
**kwargs
|
||||
):
|
||||
arrive_dest = episode.last_info_for()["arrive_dest"]
|
||||
crash = episode.last_info_for()["crash"]
|
||||
out_of_road = episode.last_info_for()["out_of_road"]
|
||||
max_step_rate = not (arrive_dest or crash or out_of_road)
|
||||
episode.custom_metrics["success_rate"] = float(arrive_dest)
|
||||
episode.custom_metrics["crash_rate"] = float(crash)
|
||||
episode.custom_metrics["out_of_road_rate"] = float(out_of_road)
|
||||
episode.custom_metrics["max_step_rate"] = float(max_step_rate)
|
||||
episode.custom_metrics["velocity_max"] = float(np.max(episode.user_data["velocity"]))
|
||||
episode.custom_metrics["velocity_mean"] = float(np.mean(episode.user_data["velocity"]))
|
||||
episode.custom_metrics["velocity_min"] = float(np.min(episode.user_data["velocity"]))
|
||||
|
||||
episode.custom_metrics["lateral_dist_min"] = float(np.min(episode.user_data["lateral_dist"]))
|
||||
episode.custom_metrics["lateral_dist_max"] = float(np.max(episode.user_data["lateral_dist"]))
|
||||
episode.custom_metrics["lateral_dist_mean"] = float(np.mean(episode.user_data["lateral_dist"]))
|
||||
|
||||
episode.custom_metrics["steering_max"] = float(np.max(episode.user_data["steering"]))
|
||||
episode.custom_metrics["steering_mean"] = float(np.mean(episode.user_data["steering"]))
|
||||
episode.custom_metrics["steering_min"] = float(np.min(episode.user_data["steering"]))
|
||||
episode.custom_metrics["acceleration_min"] = float(np.min(episode.user_data["acceleration"]))
|
||||
episode.custom_metrics["acceleration_mean"] = float(np.mean(episode.user_data["acceleration"]))
|
||||
episode.custom_metrics["acceleration_max"] = float(np.max(episode.user_data["acceleration"]))
|
||||
episode.custom_metrics["step_reward_max"] = float(np.max(episode.user_data["step_reward"]))
|
||||
episode.custom_metrics["step_reward_mean"] = float(np.mean(episode.user_data["step_reward"]))
|
||||
episode.custom_metrics["step_reward_min"] = float(np.min(episode.user_data["step_reward"]))
|
||||
|
||||
episode.custom_metrics["cost"] = float(sum(episode.user_data["cost"]))
|
||||
for x in ["num_crash_vehicle", "num_crash_object", "num_crash_human", "num_on_line"]:
|
||||
episode.custom_metrics[x] = float(sum(episode.user_data[x]))
|
||||
|
||||
for x in ["step_reward_lateral", "step_reward_heading", "step_reward_action_smooth"]:
|
||||
episode.custom_metrics[x] = float(np.mean(episode.user_data[x]))
|
||||
|
||||
episode.custom_metrics["route_completion"] = float(episode.last_info_for()["route_completion"])
|
||||
episode.custom_metrics["curriculum_level"] = int(episode.last_info_for()["curriculum_level"])
|
||||
episode.custom_metrics["scenario_index"] = int(episode.last_info_for()["scenario_index"])
|
||||
episode.custom_metrics["track_length"] = float(episode.last_info_for()["track_length"])
|
||||
episode.custom_metrics["num_stored_maps"] = int(episode.last_info_for()["num_stored_maps"])
|
||||
episode.custom_metrics["scenario_difficulty"] = float(episode.last_info_for()["scenario_difficulty"])
|
||||
episode.custom_metrics["data_coverage"] = float(episode.last_info_for()["data_coverage"])
|
||||
episode.custom_metrics["curriculum_success"] = float(episode.last_info_for()["curriculum_success"])
|
||||
episode.custom_metrics["curriculum_route_completion"] = float(
|
||||
episode.last_info_for()["curriculum_route_completion"])
|
||||
|
||||
def on_train_result(self, *, trainer, result: dict, **kwargs):
|
||||
result["success"] = np.nan
|
||||
result["out"] = np.nan
|
||||
result["max_step"] = np.nan
|
||||
result["level"] = np.nan
|
||||
result["length"] = result["episode_len_mean"]
|
||||
result["coverage"] = np.nan
|
||||
if "custom_metrics" not in result:
|
||||
return
|
||||
|
||||
if "success_rate_mean" in result["custom_metrics"]:
|
||||
result["success"] = result["custom_metrics"]["success_rate_mean"]
|
||||
result["out"] = result["custom_metrics"]["out_of_road_rate_mean"]
|
||||
result["max_step"] = result["custom_metrics"]["max_step_rate_mean"]
|
||||
result["level"] = result["custom_metrics"]["curriculum_level_mean"]
|
||||
result["coverage"] = result["custom_metrics"]["data_coverage_mean"]
|
||||
11
scenarionet_training/train_utils/check_env.py
Normal file
11
scenarionet_training/train_utils/check_env.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from ray.rllib.utils import check_env
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from metadrive.envs.gymnasium_wrapper import GymnasiumEnvWrapper
|
||||
from gym import Env
|
||||
|
||||
if __name__ == '__main__':
|
||||
env = GymnasiumEnvWrapper.build(ScenarioEnv)()
|
||||
print(isinstance(ScenarioEnv, Env))
|
||||
print(isinstance(env, Env))
|
||||
print(env.observation_space)
|
||||
check_env(env)
|
||||
46
scenarionet_training/train_utils/multi_worker_PPO.py
Normal file
46
scenarionet_training/train_utils/multi_worker_PPO.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import logging
|
||||
from typing import Callable, Type
|
||||
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, \
|
||||
EnvType
|
||||
|
||||
from scenarionet_training.train_utils.anisotropic_workerset import AnisotropicWorkerSet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiWorkerPPO(PPOTrainer):
|
||||
"""
|
||||
In this class, each work will have different config for speeding up and saving memory. More importantly, it can
|
||||
allow us to cover all test/train cases more evenly
|
||||
"""
|
||||
|
||||
def _make_workers(self, env_creator: Callable[[EnvContext], EnvType],
|
||||
policy_class: Type[Policy], config: TrainerConfigDict,
|
||||
num_workers: int):
|
||||
"""Default factory method for a WorkerSet running under this Trainer.
|
||||
|
||||
Override this method by passing a custom `make_workers` into
|
||||
`build_trainer`.
|
||||
|
||||
Args:
|
||||
env_creator (callable): A function that return and Env given an env
|
||||
config.
|
||||
policy (Type[Policy]): The Policy class to use for creating the
|
||||
policies of the workers.
|
||||
config (TrainerConfigDict): The Trainer's config.
|
||||
num_workers (int): Number of remote rollout workers to create.
|
||||
0 for local only.
|
||||
|
||||
Returns:
|
||||
WorkerSet: The created WorkerSet.
|
||||
"""
|
||||
return AnisotropicWorkerSet(
|
||||
env_creator=env_creator,
|
||||
policy_class=policy_class,
|
||||
trainer_config=config,
|
||||
num_workers=num_workers,
|
||||
logdir=self.logdir)
|
||||
356
scenarionet_training/train_utils/utils.py
Normal file
356
scenarionet_training/train_utils/utils.py
Normal file
@@ -0,0 +1,356 @@
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from metadrive.constants import TerminationState
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from ray import tune
|
||||
from ray.tune import CLIReporter
|
||||
|
||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||
from scenarionet_training.wandb_utils import WANDB_KEY_FILE
|
||||
|
||||
root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
def get_api_key_file(wandb_key_file):
|
||||
if wandb_key_file is not None:
|
||||
default_path = os.path.expanduser(wandb_key_file)
|
||||
else:
|
||||
default_path = WANDB_KEY_FILE
|
||||
if os.path.exists(default_path):
|
||||
print("We are using this wandb key file: ", default_path)
|
||||
return default_path
|
||||
path = os.path.join(root, "scenarionet_training/wandb", "wandb_api_key_file.txt")
|
||||
print("We are using this wandb key file: ", path)
|
||||
return path
|
||||
|
||||
|
||||
def train(
|
||||
trainer,
|
||||
config,
|
||||
stop,
|
||||
exp_name,
|
||||
num_seeds=1,
|
||||
num_gpus=0,
|
||||
test_mode=False,
|
||||
suffix="",
|
||||
checkpoint_freq=10,
|
||||
keep_checkpoints_num=None,
|
||||
start_seed=0,
|
||||
local_mode=False,
|
||||
save_pkl=True,
|
||||
custom_callback=None,
|
||||
max_failures=0,
|
||||
wandb_key_file=None,
|
||||
wandb_project=None,
|
||||
wandb_team="drivingforce",
|
||||
wandb_log_config=True,
|
||||
init_kws=None,
|
||||
save_dir=None,
|
||||
**kwargs
|
||||
):
|
||||
init_kws = init_kws or dict()
|
||||
# initialize ray
|
||||
if not os.environ.get("redis_password"):
|
||||
initialize_ray(test_mode=test_mode, local_mode=local_mode, num_gpus=num_gpus, **init_kws)
|
||||
else:
|
||||
password = os.environ.get("redis_password")
|
||||
assert os.environ.get("ip_head")
|
||||
print(
|
||||
"We detect redis_password ({}) exists in environment! So "
|
||||
"we will start a ray cluster!".format(password)
|
||||
)
|
||||
if num_gpus:
|
||||
print(
|
||||
"We are in cluster mode! So GPU specification is disable and"
|
||||
" should be done when submitting task to cluster! You are "
|
||||
"requiring {} GPU for each machine!".format(num_gpus)
|
||||
)
|
||||
initialize_ray(address=os.environ["ip_head"], test_mode=test_mode, redis_password=password, **init_kws)
|
||||
|
||||
# prepare config
|
||||
|
||||
if custom_callback:
|
||||
callback = custom_callback
|
||||
else:
|
||||
from scenarionet_training.train_utils.callbacks import DrivingCallbacks
|
||||
callback = DrivingCallbacks
|
||||
|
||||
used_config = {
|
||||
"seed": tune.grid_search([i * 100 + start_seed for i in range(num_seeds)]) if num_seeds is not None else None,
|
||||
"log_level": "DEBUG" if test_mode else "INFO",
|
||||
"callbacks": callback
|
||||
}
|
||||
if custom_callback is False:
|
||||
used_config.pop("callbacks")
|
||||
if config:
|
||||
used_config.update(config)
|
||||
config = copy.deepcopy(used_config)
|
||||
|
||||
if isinstance(trainer, str):
|
||||
trainer_name = trainer
|
||||
elif hasattr(trainer, "_name"):
|
||||
trainer_name = trainer._name
|
||||
else:
|
||||
trainer_name = trainer.__name__
|
||||
|
||||
if not isinstance(stop, dict) and stop is not None:
|
||||
assert np.isscalar(stop)
|
||||
stop = {"timesteps_total": int(stop)}
|
||||
|
||||
if keep_checkpoints_num is not None and not test_mode:
|
||||
assert isinstance(keep_checkpoints_num, int)
|
||||
kwargs["keep_checkpoints_num"] = keep_checkpoints_num
|
||||
kwargs["checkpoint_score_attr"] = "episode_reward_mean"
|
||||
|
||||
if "verbose" not in kwargs:
|
||||
kwargs["verbose"] = 1 if not test_mode else 2
|
||||
|
||||
# This functionality is not supported yet!
|
||||
metric_columns = CLIReporter.DEFAULT_COLUMNS.copy()
|
||||
progress_reporter = CLIReporter(metric_columns=metric_columns)
|
||||
progress_reporter.add_metric_column("success")
|
||||
progress_reporter.add_metric_column("coverage")
|
||||
progress_reporter.add_metric_column("out")
|
||||
progress_reporter.add_metric_column("max_step")
|
||||
progress_reporter.add_metric_column("length")
|
||||
progress_reporter.add_metric_column("level")
|
||||
kwargs["progress_reporter"] = progress_reporter
|
||||
|
||||
if wandb_key_file is not None:
|
||||
assert wandb_project is not None
|
||||
if wandb_project is not None:
|
||||
assert wandb_project is not None
|
||||
failed_wandb = False
|
||||
try:
|
||||
from scenarionet_training.wandb_utils.our_wandb_callbacks import OurWandbLoggerCallback
|
||||
except Exception as e:
|
||||
# print("Please install wandb: pip install wandb")
|
||||
failed_wandb = True
|
||||
|
||||
if failed_wandb:
|
||||
from ray.tune.logger import DEFAULT_LOGGERS
|
||||
from scenarionet_training.wandb_utils.our_wandb_callbacks_ray100 import OurWandbLogger
|
||||
kwargs["loggers"] = DEFAULT_LOGGERS + (OurWandbLogger,)
|
||||
config["logger_config"] = {
|
||||
"wandb":
|
||||
{
|
||||
"group": exp_name,
|
||||
"exp_name": exp_name,
|
||||
"entity": wandb_team,
|
||||
"project": wandb_project,
|
||||
"api_key_file": get_api_key_file(wandb_key_file),
|
||||
"log_config": wandb_log_config,
|
||||
}
|
||||
}
|
||||
else:
|
||||
kwargs["callbacks"] = [
|
||||
OurWandbLoggerCallback(
|
||||
exp_name=exp_name,
|
||||
api_key_file=get_api_key_file(wandb_key_file),
|
||||
project=wandb_project,
|
||||
group=exp_name,
|
||||
log_config=wandb_log_config,
|
||||
entity=wandb_team
|
||||
)
|
||||
]
|
||||
|
||||
# start training
|
||||
analysis = tune.run(
|
||||
trainer,
|
||||
name=exp_name,
|
||||
checkpoint_freq=checkpoint_freq,
|
||||
checkpoint_at_end=True if "checkpoint_at_end" not in kwargs else kwargs.pop("checkpoint_at_end"),
|
||||
stop=stop,
|
||||
config=config,
|
||||
max_failures=max_failures if not test_mode else 0,
|
||||
reuse_actors=False,
|
||||
local_dir=save_dir or ".",
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# save training progress as insurance
|
||||
if save_pkl:
|
||||
pkl_path = "{}-{}{}.pkl".format(exp_name, trainer_name, "" if not suffix else "-" + suffix)
|
||||
with open(pkl_path, "wb") as f:
|
||||
data = analysis.fetch_trial_dataframes()
|
||||
pickle.dump(data, f)
|
||||
print("Result is saved at: <{}>".format(pkl_path))
|
||||
return analysis
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
def initialize_ray(local_mode=False, num_gpus=None, test_mode=False, **kwargs):
|
||||
os.environ['OMP_NUM_THREADS'] = '1'
|
||||
|
||||
if ray.__version__.split(".")[0] == "1": # 1.0 version Ray
|
||||
if "redis_password" in kwargs:
|
||||
redis_password = kwargs.pop("redis_password")
|
||||
kwargs["_redis_password"] = redis_password
|
||||
|
||||
ray.init(
|
||||
logging_level=logging.ERROR if not test_mode else logging.DEBUG,
|
||||
log_to_driver=test_mode,
|
||||
local_mode=local_mode,
|
||||
num_gpus=num_gpus,
|
||||
ignore_reinit_error=True,
|
||||
include_dashboard=False,
|
||||
**kwargs
|
||||
)
|
||||
print("Successfully initialize Ray!")
|
||||
try:
|
||||
print("Available resources: ", ray.available_resources())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_train_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--exp-name", type=str, default="")
|
||||
parser.add_argument("--num-gpus", type=int, default=0)
|
||||
parser.add_argument("--num-seeds", type=int, default=3)
|
||||
parser.add_argument("--num-cpus-per-worker", type=float, default=0.5)
|
||||
parser.add_argument("--num-gpus-per-trial", type=float, default=0.25)
|
||||
parser.add_argument("--test", action="store_true")
|
||||
return parser
|
||||
|
||||
|
||||
def setup_logger(debug=False):
|
||||
import logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if debug else logging.WARNING,
|
||||
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
||||
)
|
||||
|
||||
|
||||
def get_time_str():
|
||||
return datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
|
||||
|
||||
def get_exp_name(args):
|
||||
if args.exp_name != "":
|
||||
exp_name = args.exp_name + "_" + get_time_str()
|
||||
else:
|
||||
exp_name = "TEST"
|
||||
return exp_name
|
||||
|
||||
|
||||
def get_eval_config(config):
|
||||
eval_config = copy.deepcopy(config)
|
||||
eval_config.pop("evaluation_interval")
|
||||
eval_config.pop("evaluation_num_episodes")
|
||||
eval_config.pop("evaluation_config")
|
||||
eval_config.pop("evaluation_num_workers")
|
||||
return eval_config
|
||||
|
||||
|
||||
def get_function(ckpt, explore, config):
|
||||
trainer = MultiWorkerPPO(get_eval_config(config))
|
||||
trainer.restore(ckpt)
|
||||
|
||||
def _f(obs):
|
||||
ret = trainer.compute_actions({"default_policy": obs}, explore=explore)
|
||||
return ret
|
||||
|
||||
return _f
|
||||
|
||||
|
||||
def eval_ckpt(config,
|
||||
ckpt_path,
|
||||
scenario_data_path,
|
||||
num_scenarios,
|
||||
start_scenario_index,
|
||||
horizon=600,
|
||||
render=False,
|
||||
# PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance
|
||||
explore=True,
|
||||
log_interval=None,
|
||||
):
|
||||
initialize_ray(test_mode=False, num_gpus=1)
|
||||
# 27 29 30 37 39
|
||||
env_config = get_eval_config(config)["env_config"]
|
||||
env_config.update(dict(
|
||||
start_scenario_index=start_scenario_index,
|
||||
num_scenarios=num_scenarios,
|
||||
sequential_seed=True,
|
||||
curriculum_level=1, # disable curriculum
|
||||
target_success_rate=1,
|
||||
horizon=horizon,
|
||||
episodes_to_evaluate_curriculum=num_scenarios,
|
||||
data_directory=scenario_data_path,
|
||||
use_render=render))
|
||||
env = ScenarioEnv(env_config)
|
||||
|
||||
super_data = defaultdict(list)
|
||||
EPISODE_NUM = env.config["num_scenarios"]
|
||||
compute_actions = get_function(ckpt_path, explore=explore, config=config)
|
||||
|
||||
o = env.reset()
|
||||
assert env.current_seed == start_scenario_index, "Wrong start seed!"
|
||||
|
||||
total_cost = 0
|
||||
total_reward = 0
|
||||
success_rate = 0
|
||||
ep_cost = 0
|
||||
ep_reward = 0
|
||||
success_flag = False
|
||||
step = 0
|
||||
|
||||
def log_msg():
|
||||
print("CKPT:{} | success_rate:{}, mean_episode_reward:{}, mean_episode_cost:{}".format(epi_num,
|
||||
success_rate / epi_num,
|
||||
total_reward / epi_num,
|
||||
total_cost / epi_num))
|
||||
|
||||
for epi_num in tqdm.tqdm(range(0, EPISODE_NUM)):
|
||||
step += 1
|
||||
action_to_send = compute_actions(o)["default_policy"]
|
||||
o, r, d, info = env.step(action_to_send)
|
||||
if env.config["use_render"]:
|
||||
env.render(text={"reward": r})
|
||||
total_reward += r
|
||||
ep_reward += r
|
||||
total_cost += info["cost"]
|
||||
ep_cost += info["cost"]
|
||||
if d or step > horizon:
|
||||
if info["arrive_dest"]:
|
||||
success_rate += 1
|
||||
success_flag = True
|
||||
o = env.reset()
|
||||
|
||||
super_data[0].append(
|
||||
{"reward": ep_reward,
|
||||
"success": success_flag,
|
||||
"out_of_road": info[TerminationState.OUT_OF_ROAD],
|
||||
"cost": ep_cost,
|
||||
"seed": env.current_seed,
|
||||
"route_completion": info["route_completion"]})
|
||||
|
||||
ep_cost = 0.0
|
||||
ep_reward = 0.0
|
||||
success_flag = False
|
||||
step = 0
|
||||
|
||||
if log_interval is not None and epi_num % log_interval == 0:
|
||||
log_msg()
|
||||
if log_interval is not None:
|
||||
log_msg()
|
||||
del compute_actions
|
||||
env.close()
|
||||
with open("eval_ret_{}_{}_{}.json".format(start_scenario_index,
|
||||
start_scenario_index + num_scenarios,
|
||||
get_time_str()), "w") as f:
|
||||
json.dump(super_data, f)
|
||||
return super_data
|
||||
3
scenarionet_training/wandb_utils/__init__.py
Normal file
3
scenarionet_training/wandb_utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import os
|
||||
|
||||
WANDB_KEY_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "wandb_api_key_file.txt")
|
||||
67
scenarionet_training/wandb_utils/our_wandb_callbacks.py
Normal file
67
scenarionet_training/wandb_utils/our_wandb_callbacks.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from ray.tune.integration.wandb import WandbLoggerCallback, _clean_log, \
|
||||
Queue, WandbLogger
|
||||
|
||||
|
||||
class OurWandbLoggerCallback(WandbLoggerCallback):
|
||||
def __init__(self, exp_name, *args, **kwargs):
|
||||
super(OurWandbLoggerCallback, self).__init__(*args, **kwargs)
|
||||
self.exp_name = exp_name
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
config = trial.config.copy()
|
||||
|
||||
config.pop("callbacks", None) # Remove callbacks
|
||||
|
||||
exclude_results = self._exclude_results.copy()
|
||||
|
||||
# Additional excludes
|
||||
exclude_results += self.excludes
|
||||
|
||||
# Log config keys on each result?
|
||||
if not self.log_config:
|
||||
exclude_results += ["config"]
|
||||
|
||||
# Fill trial ID and name
|
||||
trial_id = trial.trial_id if trial else None
|
||||
# trial_name = str(trial) if trial else None
|
||||
|
||||
# Project name for Wandb
|
||||
wandb_project = self.project
|
||||
|
||||
# Grouping
|
||||
wandb_group = self.group or trial.trainable_name if trial else None
|
||||
|
||||
# remove unpickleable items!
|
||||
config = _clean_log(config)
|
||||
|
||||
assert trial_id is not None
|
||||
run_name = "{}_{}".format(self.exp_name, trial_id)
|
||||
|
||||
wandb_init_kwargs = dict(
|
||||
id=trial_id,
|
||||
name=run_name,
|
||||
resume=True,
|
||||
reinit=True,
|
||||
allow_val_change=True,
|
||||
group=wandb_group,
|
||||
project=wandb_project,
|
||||
config=config
|
||||
)
|
||||
wandb_init_kwargs.update(self.kwargs)
|
||||
|
||||
self._trial_queues[trial] = Queue()
|
||||
self._trial_processes[trial] = self._logger_process_cls(
|
||||
queue=self._trial_queues[trial],
|
||||
exclude=exclude_results,
|
||||
to_config=self._config_results,
|
||||
**wandb_init_kwargs
|
||||
)
|
||||
self._trial_processes[trial].start()
|
||||
|
||||
def __del__(self):
|
||||
if self._trial_processes:
|
||||
for v in self._trial_processes.values():
|
||||
if hasattr(v, "close"):
|
||||
v.close()
|
||||
self._trial_processes.clear()
|
||||
self._trial_processes = {}
|
||||
@@ -0,0 +1,80 @@
|
||||
from multiprocessing import Queue
|
||||
|
||||
from ray.tune.integration.wandb import WandbLogger, _clean_log, _set_api_key
|
||||
|
||||
|
||||
class OurWandbLogger(WandbLogger):
|
||||
def __init__(self, config, logdir, trial):
|
||||
self.exp_name = config["logger_config"]["wandb"].pop("exp_name")
|
||||
super(OurWandbLogger, self).__init__(config, logdir, trial)
|
||||
|
||||
def _init(self):
|
||||
|
||||
config = self.config.copy()
|
||||
|
||||
config.pop("callbacks", None) # Remove callbacks
|
||||
|
||||
try:
|
||||
if config.get("logger_config", {}).get("wandb"):
|
||||
logger_config = config.pop("logger_config")
|
||||
wandb_config = logger_config.get("wandb").copy()
|
||||
else:
|
||||
wandb_config = config.pop("wandb").copy()
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Wandb logger specified but no configuration has been passed. "
|
||||
"Make sure to include a `wandb` key in your `config` dict "
|
||||
"containing at least a `project` specification.")
|
||||
|
||||
_set_api_key(wandb_config)
|
||||
|
||||
exclude_results = self._exclude_results.copy()
|
||||
|
||||
# Additional excludes
|
||||
additional_excludes = wandb_config.pop("excludes", [])
|
||||
exclude_results += additional_excludes
|
||||
|
||||
# Log config keys on each result?
|
||||
log_config = wandb_config.pop("log_config", False)
|
||||
if not log_config:
|
||||
exclude_results += ["config"]
|
||||
|
||||
# Fill trial ID and name
|
||||
trial_id = self.trial.trial_id if self.trial else None
|
||||
trial_name = str(self.trial) if self.trial else None
|
||||
|
||||
# Project name for Wandb
|
||||
try:
|
||||
wandb_project = wandb_config.pop("project")
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"You need to specify a `project` in your wandb `config` dict.")
|
||||
|
||||
# Grouping
|
||||
wandb_group = wandb_config.pop(
|
||||
"group", self.trial.trainable_name if self.trial else None)
|
||||
|
||||
# remove unpickleable items!
|
||||
config = _clean_log(config)
|
||||
|
||||
assert trial_id is not None
|
||||
run_name = "{}_{}".format(self.exp_name, trial_id)
|
||||
|
||||
wandb_init_kwargs = dict(
|
||||
id=trial_id,
|
||||
name=run_name,
|
||||
resume=True,
|
||||
reinit=True,
|
||||
allow_val_change=True,
|
||||
group=wandb_group,
|
||||
project=wandb_project,
|
||||
config=config)
|
||||
wandb_init_kwargs.update(wandb_config)
|
||||
|
||||
self._queue = Queue()
|
||||
self._wandb = self._logger_process_cls(
|
||||
queue=self._queue,
|
||||
exclude=exclude_results,
|
||||
to_config=self._config_results,
|
||||
**wandb_init_kwargs)
|
||||
self._wandb.start()
|
||||
38
scenarionet_training/wandb_utils/test_wandb.py
Normal file
38
scenarionet_training/wandb_utils/test_wandb.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Procedure to use wandb:
|
||||
|
||||
1. Logup in wandb: https://wandb.ai/
|
||||
2. Get the API key in personal setting
|
||||
3. Store API key (a string)to some file as: ~/wandb_api_key_file.txt
|
||||
4. Install wandb: pip install wandb
|
||||
5. Fill the "wandb_key_file", "wandb_project" keys in our train function.
|
||||
|
||||
Note1: You don't need to specify who own "wandb_project", for example, in team "drivingforce"'s project
|
||||
"representation", you only need to fill wandb_project="representation"
|
||||
|
||||
Note2: In wanbd, there are "team name", "project name", "group name" and "trial_name". We only need to care
|
||||
"team name" and "project name". The "team name" is set to "drivingforce" by default. You can also use None to
|
||||
log result to your personal domain. The "group name" of the experiment is exactly the "exp_name" in our context, like
|
||||
"0304_train_ppo" or so.
|
||||
|
||||
Note3: It would be great to change the x-axis in wandb website to "timesteps_total".
|
||||
|
||||
Peng Zhenghao, 20210402
|
||||
"""
|
||||
from ray import tune
|
||||
|
||||
from scenarionet_training.train_utils.utils import train
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = dict(env="CartPole-v0", num_workers=0, lr=tune.grid_search([1e-2, 1e-4]))
|
||||
train(
|
||||
"PPO",
|
||||
exp_name="test_wandb",
|
||||
stop=10000,
|
||||
config=config,
|
||||
custom_callback=False,
|
||||
test_mode=False,
|
||||
local_mode=False,
|
||||
wandb_project="TEST",
|
||||
wandb_team="drivingforce" # drivingforce is set to default. Use None to log to your personal domain!
|
||||
)
|
||||
1
scenarionet_training/wandb_utils/wandb_api_key_file.txt
Normal file
1
scenarionet_training/wandb_utils/wandb_api_key_file.txt
Normal file
@@ -0,0 +1 @@
|
||||
132a8add578bdaeea5ab7a4942f35f2a17742df2
|
||||
20
setup.py
20
setup.py
@@ -40,9 +40,18 @@ install_requires = [
|
||||
"shapely"
|
||||
]
|
||||
|
||||
train_requirement = [
|
||||
"ray[rllib]==1.0.0",
|
||||
# "torch",
|
||||
"wandb==0.12.1",
|
||||
"aiohttp==3.6.0",
|
||||
"gymnasium",
|
||||
"tensorflow",
|
||||
"tensorflow_probability"]
|
||||
|
||||
setup(
|
||||
name="scenarionet",
|
||||
python_requires='>=3.6, <3.12', # do version check with assert
|
||||
python_requires='>=3.8', # do version check with assert
|
||||
version=version,
|
||||
description="Scalable Traffic Scenario Management System",
|
||||
url="https://github.com/metadriverse/ScenarioNet",
|
||||
@@ -50,12 +59,9 @@ setup(
|
||||
author_email="quanyili0057@gmail.com, pzh@cs.ucla.edu",
|
||||
packages=packages,
|
||||
install_requires=install_requires,
|
||||
# extras_require={
|
||||
# "cuda": cuda_requirement,
|
||||
# "nuplan": nuplan_requirement,
|
||||
# "waymo": waymo_requirement,
|
||||
# "all": nuplan_requirement + cuda_requirement
|
||||
# },
|
||||
extras_require={
|
||||
"train": train_requirement,
|
||||
},
|
||||
include_package_data=True,
|
||||
license="Apache 2.0",
|
||||
long_description=long_description,
|
||||
|
||||
Reference in New Issue
Block a user