Add come updates for Neurips paper (#4)

* scenarionet training

* wandb

* train utils

* fix callback

* run PPO

* use pg test

* save path

* use torch

* add dependency

* update ignore

* update training

* large model

* use curriculum training

* add time to exp name

* storage_path

* restore

* update training

* use my key

* add log message

* check seed

* restore callback

* restore call bacl

* add log message

* add logging message

* restore ray1.4

* length 500

* ray 100

* wandb

* use tf

* more levels

* add callback

* 10 worker

* show level

* no env horizon

* callback result level

* more call back

* add diffuculty

* add mroen stat

* mroe stat

* show levels

* add callback

* new

* ep len 600

* fix setup

* fix stepup

* fix to 3.8

* update setup

* parallel worker!

* new exp

* add callback

* lateral dist

* pg dataset

* evaluate

* modify config

* align config

* train single RL

* update training script

* 100w eval

* less eval to reveal

* 2000 env eval

* new trianing

* eval 1000

* update eval

* more workers

* more worker

* 20 worker

* dataset to database

* split tool!

* split dataset

* try fix

* train 003

* fix mapping

* fix test

* add waymo tqdm

* utils

* fix bug

* fix bug

* waymo

* int type

* 8 worker read

* disable

* read file

* add log message

* check existence

* dist 0

* int

* check num

* suprass warning

* add filter API

* filter

* store map false

* new

* ablation

* filter

* fix

* update filyter

* reanme to from

* random select

* add overlapping checj

* fix

* new training sceheme

* new reward

* add waymo train script

* waymo different config

* copy raw data

* fix bug

* add tqdm

* update readme

* waymo

* pg

* max lateral dist 3

* pg

* crash_done instead of penalty

* no crash done

* gpu

* update eval script

* steering range penalty

* evaluate

* finish pg

* update setup

* fix bug

* test

* fix

* add on line

* train nuplan

* generate sensor

* udpate training

* static obj

* multi worker eval

* filx bug

* use ray for testing

* eval!

* filter senario

* id filter

* fox bug

* dist = 2

* filter

* eval

* eval ret

* ok

* update training pg

* test before use

* store data=False

* collect figures

* capture pic

---------

Co-authored-by: Quanyi Li <quanyi@bolei-gpu02.cs.ucla.edu>
This commit is contained in:
Quanyi Li
2023-06-10 18:56:33 +01:00
committed by GitHub
parent 41c0b01f39
commit db50bca7fd
53 changed files with 2274 additions and 133 deletions

4
.gitignore vendored
View File

@@ -19,3 +19,7 @@ dataset/*
**/passed_scenarios/
**/waymo_origin
/dataset/
/scenarionet_training/wandb/*.pkl
**/TEST/
**/experiment/
**/wandb/

View File

@@ -13,7 +13,8 @@ pip install -e .
## Usage
We provide some explanation and demo for all scripts here.
**You are encouraged to try them on your own, add ```-h``` or ```--help``` argument to know more details about these scripts.**
**You are encouraged to try them on your own, add ```-h``` or ```--help``` argument to know more details about these
scripts.**
### Convert
@@ -45,31 +46,42 @@ python -m scenarionet.scripts.convert_pg -d pg --num_workers=16 --num_scenarios=
```
### Merge & move
For merging two or more database, use
```
python -m scenarionet.merge_database -d /destination/path --from_databases /database1 /2 ...
python -m scenarionet.merge_database -d /destination/path --from /database1 /2 ...
```
As a database contains a path mapping, one should move database folder with the following script instead of ```cp```
command
command.
Using ```--copy_raw_data``` will copy the raw scenario file into target directory and cancel the virtual mapping.
```
python -m scenarionet.move_database --to /destination/path --from /source/path
python -m scenarionet.copy_database --to /destination/path --from /source/path
```
### Verify
The following scripts will check whether all scenarios exist or can be loaded into simulator.
The missing or broken scenarios will be recorded and stored into the error file. Otherwise, no error file will be
generated.
The missing or broken scenarios will be recorded and stored into the error file. Otherwise, no error file will be
generated.
With teh error file, one can build a new database excluding or including the broken or missing scenarios.
**Existence check**
```
python -m scenarionet.verify_existence -d /database/to/check --error_file_path /error/file/path
python -m scenarionet.check_existence -d /database/to/check --error_file_path /error/file/path
```
**Runnable check**
```
python -m scenarionet.verify_simulation -d /database/to/check --error_file_path /error/file/path
python -m scenarionet.check_simulation -d /database/to/check --error_file_path /error/file/path
```
**Generating new database**
```
python -m scenarionet.generate_from_error_file -d /new/database/path --file /error/file/path
```
@@ -77,6 +89,7 @@ python -m scenarionet.generate_from_error_file -d /new/database/path --file /err
### visualization
Visualizing the simulated scenario
```
python -m scenarionet.run_simulation -d /path/to/database --render --scenario_index
```

View File

@@ -1,6 +1,8 @@
from functools import partial
import numpy as np
from metadrive.scenario.scenario_description import ScenarioDescription as SD
from metadrive.scenario.utils import read_scenario_data
class ScenarioFilter:
@@ -8,7 +10,7 @@ class ScenarioFilter:
SMALLER = "smaller"
@staticmethod
def sdc_moving_dist(metadata, target_dist, condition=GREATER):
def sdc_moving_dist(metadata, file_path, target_dist, condition=GREATER):
"""
This function filters the scenario based on SDC information.
"""
@@ -22,10 +24,11 @@ class ScenarioFilter:
return False
@staticmethod
def object_number(metadata, number_threshold, object_type=None, condition=SMALLER):
def object_number(metadata, file_path, number_threshold, object_type=None, condition=SMALLER):
"""
Return True if the scenario satisfying the object number condition
:param metadata: metadata in each scenario
:param file_path: where to find this scenario
:param number_threshold: number of objects threshold
:param object_type: MetaDriveType.VEHICLE or other object type. If none, calculate number for all object types
:param condition: SMALLER or GREATER
@@ -43,9 +46,33 @@ class ScenarioFilter:
return False
@staticmethod
def has_traffic_light(metadata):
def has_traffic_light(metadata, file_path):
return metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_TRAFFIC_LIGHTS] > 0
@staticmethod
def no_traffic_light(metadata, file_path):
return metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_TRAFFIC_LIGHTS] == 0
@staticmethod
def no_overpass(metadata, file_path):
"""
We need read the map data to do overpass filter
"""
max_height_diff = 5
if SD.SUMMARY.MAP_HEIGHT_DIFF in metadata:
return metadata[SD.SUMMARY.MAP_HEIGHT_DIFF] < max_height_diff
else:
# calculate online
map_features = read_scenario_data(file_path)[SD.MAP_FEATURES]
return abs(SD.map_height_diff(map_features, target=max_height_diff)) < max_height_diff
@staticmethod
def id_filter(metadata, file_path, ids):
for id in ids:
if metadata["id"] in id:
return False
return True
@staticmethod
def make(func, **kwargs):
"""

View File

@@ -1,4 +1,6 @@
import copy
from random import sample
from metadrive.scenario.utils import read_dataset_summary
import logging
import os
import os.path as osp
@@ -27,21 +29,23 @@ def try_generating_summary(file_folder):
def merge_database(
output_path,
*dataset_paths,
exist_ok=False,
overwrite=False,
try_generate_missing_file=True,
filters: List[Callable] = None
output_path,
*dataset_paths,
exist_ok=False,
overwrite=False,
try_generate_missing_file=True,
filters: List[Callable] = None,
save=True,
):
"""
Combine multiple datasets. Each dataset should have a dataset_summary.pkl
:param output_path: The path to store the output dataset
Combine multiple datasets. Each database should have a dataset_summary.pkl
:param output_path: The path to store the output database
:param exist_ok: If True, though the output_path already exist, still write into it
:param overwrite: If True, overwrite existing dataset_summary.pkl and mapping.pkl. Otherwise, raise error
:param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them
:param dataset_paths: Path of each dataset
:param filters: a set of filters to choose which scenario to be selected and added into this combined dataset
:param dataset_paths: Path of each database
:param filters: a set of filters to choose which scenario to be selected and added into this combined database
:param save: save to output path, immediately
:return: summary, mapping
"""
filters = filters or []
@@ -60,15 +64,15 @@ def merge_database(
mappings = {}
# collect
for dataset_path in tqdm.tqdm(dataset_paths):
for dataset_path in tqdm.tqdm(dataset_paths, desc="Merge Data"):
abs_dir_path = osp.abspath(dataset_path)
# summary
assert osp.exists(abs_dir_path), "Wrong dataset path. Can not find dataset at: {}".format(abs_dir_path)
assert osp.exists(abs_dir_path), "Wrong database path. Can not find database at: {}".format(abs_dir_path)
if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE)):
if try_generate_missing_file:
summary = try_generating_summary(abs_dir_path)
else:
raise FileNotFoundError("Can not find summary file for dataset: {}".format(abs_dir_path))
raise FileNotFoundError("Can not find summary file for database: {}".format(abs_dir_path))
else:
with open(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE), "rb+") as f:
summary = pickle.load(f)
@@ -85,7 +89,7 @@ def merge_database(
if try_generate_missing_file:
mapping = {k: "" for k in summary}
else:
raise FileNotFoundError("Can not find mapping file for dataset: {}".format(abs_dir_path))
raise FileNotFoundError("Can not find mapping file for database: {}".format(abs_dir_path))
else:
with open(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE), "rb+") as f:
mapping = pickle.load(f)
@@ -98,37 +102,108 @@ def merge_database(
# apply filter stage
file_to_pop = []
for file_name, metadata, in summaries.items():
if not all([fil(metadata) for fil in filters]):
for file_name in tqdm.tqdm(summaries.keys(), desc="Filter Scenarios"):
metadata = summaries[file_name]
if not all([fil(metadata, os.path.join(output_abs_path, mappings[file_name], file_name)) for fil in filters]):
file_to_pop.append(file_name)
for file in file_to_pop:
summaries.pop(file)
mappings.pop(file)
save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings)
if save:
save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings)
return summaries, mappings
def move_database(
from_path,
to_path,
exist_ok=False,
overwrite=False,
def copy_database(
from_path,
to_path,
exist_ok=False,
overwrite=False,
copy_raw_data=False,
remove_source=False
):
if not os.path.exists(from_path):
raise FileNotFoundError("Can not find dataset: {}".format(from_path))
raise FileNotFoundError("Can not find database: {}".format(from_path))
if os.path.exists(to_path):
assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a dataset"
assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a database"
assert not os.path.samefile(from_path, to_path), "to_directory is the same as from_directory. Abort!"
merge_database(
files = os.listdir(from_path)
if ScenarioDescription.DATASET.MAPPING_FILE in files and ScenarioDescription.DATASET.SUMMARY_FILE in files and len(
files) > 2:
raise RuntimeError("The source database is not allowed to move! "
"This will break the relationship between this database and other database built on it."
"If it is ok for you, use 'mv' to move it manually ")
summaries, mappings = merge_database(
to_path,
from_path,
exist_ok=exist_ok,
overwrite=overwrite,
try_generate_missing_file=True,
save=False
)
files = os.listdir(from_path)
if ScenarioDescription.DATASET.MAPPING_FILE in files and ScenarioDescription.DATASET.SUMMARY_FILE in files and len(
files) == 2:
summary_file = osp.join(to_path, ScenarioDescription.DATASET.SUMMARY_FILE)
mapping_file = osp.join(to_path, ScenarioDescription.DATASET.MAPPING_FILE)
if copy_raw_data:
logger.info("Copy raw data...")
for scenario_file in tqdm.tqdm(mappings.keys()):
rel_path = mappings[scenario_file]
shutil.copyfile(os.path.join(to_path, rel_path, scenario_file), os.path.join(to_path, scenario_file))
mappings = {key: "./" for key in summaries.keys()}
save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings)
if remove_source and ScenarioDescription.DATASET.MAPPING_FILE in files and \
ScenarioDescription.DATASET.SUMMARY_FILE in files and len(files) == 2:
shutil.rmtree(from_path)
def split_database(
from_path,
to_path,
start_index,
num_scenarios,
exist_ok=False,
overwrite=False,
random=False,
):
if not os.path.exists(from_path):
raise FileNotFoundError("Can not find database: {}".format(from_path))
if os.path.exists(to_path):
assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a database"
assert not os.path.samefile(from_path, to_path), "to_directory is the same as from_directory. Abort!"
overwrite = overwrite,
output_abs_path = osp.abspath(to_path)
os.makedirs(output_abs_path, exist_ok=exist_ok)
summary_file = osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE)
mapping_file = osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE)
for file in [summary_file, mapping_file]:
if os.path.exists(file):
if overwrite:
os.remove(file)
else:
raise FileExistsError("{} already exists at: {}!".format(file, output_abs_path))
# collect
abs_dir_path = osp.abspath(from_path)
# summary
assert osp.exists(abs_dir_path), "Wrong database path. Can not find database at: {}".format(abs_dir_path)
summaries, lookup, mappings = read_dataset_summary(from_path)
assert start_index >= 0 and start_index + num_scenarios <= len(
lookup), "No enough scenarios in source dataset: total {}, start_index: {}, need: {}".format(len(lookup),
start_index,
num_scenarios)
if random:
selected = sample(lookup[start_index:], k=num_scenarios)
else:
selected = lookup[start_index: start_index + num_scenarios]
selected_summary = {}
selected_mapping = {}
for scenario in selected:
selected_summary[scenario] = summaries[scenario]
selected_mapping[scenario] = os.path.relpath(osp.join(abs_dir_path, mappings[scenario]), output_abs_path)
save_summary_anda_mapping(summary_file, mapping_file, selected_summary, selected_mapping)
return summaries, mappings

View File

@@ -0,0 +1,22 @@
"""
Check If any overlap between two database
"""
import argparse
from scenarionet.common_utils import read_dataset_summary
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--database_1', type=str, required=True, help="The path of the first database")
parser.add_argument('--database_2', type=str, required=True, help="The path of the second database")
args = parser.parse_args()
summary_1, _, _ = read_dataset_summary(args.database_1)
summary_2, _, _ = read_dataset_summary(args.database_2)
intersection = set(summary_1.keys()).intersection(set(summary_2.keys()))
if len(intersection) == 0:
print("No overlapping in two database!")
else:
print("Find overlapped scenarios: {}".format(intersection))

View File

@@ -1,4 +1,5 @@
import pkg_resources # for suppress warning
import shutil
import argparse
import logging
import os
@@ -28,6 +29,20 @@ if __name__ == '__main__':
default=os.path.join(SCENARIONET_REPO_PATH, "waymo_origin"),
help="The directory stores all waymo tfrecord"
)
parser.add_argument(
"--start_file_index",
default=0,
type=int,
help="Control how many files to use. We will list all files in the raw data folder "
"and select files[start_file_index: start_file_index+num_files]"
)
parser.add_argument(
"--num_files",
default=1000,
type=int,
help="Control how many files to use. We will list all files in the raw data folder "
"and select files[start_file_index: start_file_index+num_files]"
)
args = parser.parse_args()
overwrite = args.overwrite
@@ -35,8 +50,19 @@ if __name__ == '__main__':
output_path = args.database_path
version = args.version
save_path = output_path
if os.path.exists(output_path):
if not overwrite:
raise ValueError(
"Directory {} already exists! Abort. "
"\n Try setting overwrite=True or adding --overwrite".format(output_path)
)
else:
shutil.rmtree(output_path)
waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path)
scenarios = get_waymo_scenarios(waymo_data_directory)
scenarios = get_waymo_scenarios(waymo_data_directory, args.start_file_index, args.num_files,
num_workers=8) # do not use too much worker to read data
write_to_directory(
convert_func=convert_waymo_scenario,

View File

@@ -357,9 +357,9 @@ def extract_traffic(scenario: NuPlanScenario, center):
type=MetaDriveType.UNSET,
state=dict(
position=np.zeros(shape=(episode_len, 3)),
heading=np.zeros(shape=(episode_len, )),
heading=np.zeros(shape=(episode_len,)),
velocity=np.zeros(shape=(episode_len, 2)),
valid=np.zeros(shape=(episode_len, )),
valid=np.zeros(shape=(episode_len,)),
length=np.zeros(shape=(episode_len, 1)),
width=np.zeros(shape=(episode_len, 1)),
height=np.zeros(shape=(episode_len, 1))
@@ -490,3 +490,84 @@ example_scenario_types = "[behind_pedestrian_on_pickup_dropoff, \
starting_unprotected_cross_turn, \
starting_protected_noncross_turn, \
on_pickup_dropoff]"
# - accelerating_at_crosswalk
# - accelerating_at_stop_sign
# - accelerating_at_stop_sign_no_crosswalk
# - accelerating_at_traffic_light
# - accelerating_at_traffic_light_with_lead
# - accelerating_at_traffic_light_without_lead
# - behind_bike
# - behind_long_vehicle
# - behind_pedestrian_on_driveable
# - behind_pedestrian_on_pickup_dropoff
# - changing_lane
# - changing_lane_to_left
# - changing_lane_to_right
# - changing_lane_with_lead
# - changing_lane_with_trail
# - crossed_by_bike
# - crossed_by_vehicle
# - following_lane_with_lead
# - following_lane_with_slow_lead
# - following_lane_without_lead
# - high_lateral_acceleration
# - high_magnitude_jerk
# - high_magnitude_speed
# - low_magnitude_speed
# - medium_magnitude_speed
# - near_barrier_on_driveable
# - near_construction_zone_sign
# - near_high_speed_vehicle
# - near_long_vehicle
# - near_multiple_bikes
# - near_multiple_pedestrians
# - near_multiple_vehicles
# - near_pedestrian_at_pickup_dropoff
# - near_pedestrian_on_crosswalk
# - near_pedestrian_on_crosswalk_with_ego
# - near_trafficcone_on_driveable
# - on_all_way_stop_intersection
# - on_carpark
# - on_intersection
# - on_pickup_dropoff
# - on_stopline_crosswalk
# - on_stopline_stop_sign
# - on_stopline_traffic_light
# - on_traffic_light_intersection
# - starting_high_speed_turn
# - starting_left_turn
# - starting_low_speed_turn
# - starting_protected_cross_turn
# - starting_protected_noncross_turn
# - starting_right_turn
# - starting_straight_stop_sign_intersection_traversal
# - starting_straight_traffic_light_intersection_traversal
# - starting_u_turn
# - starting_unprotected_cross_turn
# - starting_unprotected_noncross_turn
# - stationary
# - stationary_at_crosswalk
# - stationary_at_traffic_light_with_lead
# - stationary_at_traffic_light_without_lead
# - stationary_in_traffic
# - stopping_at_crosswalk
# - stopping_at_stop_sign_no_crosswalk
# - stopping_at_stop_sign_with_lead
# - stopping_at_stop_sign_without_lead
# - stopping_at_traffic_light_with_lead
# - stopping_at_traffic_light_without_lead
# - stopping_with_lead
# - traversing_crosswalk
# - traversing_intersection
# - traversing_narrow_lane
# - traversing_pickup_dropoff
# - traversing_traffic_light_intersection
# - waiting_for_pedestrian_to_cross
#
all_scenario_types = "[near_pedestrian_on_crosswalk_with_ego," \
"near_trafficcone_on_driveable, " \
"following_lane_with_lead, " \
"following_lane_with_slow_lead, " \
"following_lane_without_lead]"

View File

@@ -3,6 +3,8 @@ import multiprocessing
import os
import pickle
import tqdm
from scenarionet.converter.utils import mph_to_kmh
from scenarionet.converter.waymo.type import WaymoLaneType, WaymoAgentType, WaymoRoadLineType, WaymoRoadEdgeType
@@ -418,13 +420,20 @@ def convert_waymo_scenario(scenario, version):
}
for count, id in enumerate(track_id)
}
# clean memory
del scenario
scenario = None
return md_scenario
def get_waymo_scenarios(waymo_data_directory, num_workers=8):
def get_waymo_scenarios(waymo_data_directory, start_index, num, num_workers=8):
# parse raw data from input path to output path,
# there is 1000 raw data in google cloud, each of them produce about 500 pkl file
logger.info("\n Reading raw data")
file_list = os.listdir(waymo_data_directory)
assert len(file_list) >= start_index + num and start_index >= 0, \
"No sufficient files ({}) in raw_data_directory. need: {}, start: {}".format(len(file_list), num, start_index)
file_list = file_list[start_index: start_index + num]
num_files = len(file_list)
if num_files < num_workers:
# single process
@@ -441,23 +450,29 @@ def get_waymo_scenarios(waymo_data_directory, num_workers=8):
argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]])
# Run, workers and process result from worker
with multiprocessing.Pool(num_workers) as p:
all_result = list(p.imap(read_from_files, argument_list))
ret = []
# get result
for r in all_result:
if len(r) == 0:
logger.warning("0 scenarios found")
ret += r
logger.info("\n Find {} waymo scenarios from {} files".format(len(ret), num_files))
return ret
# with multiprocessing.Pool(num_workers) as p:
# all_result = list(p.imap(read_from_files, argument_list))
# Disable multiprocessing read
all_result = read_from_files([waymo_data_directory, file_list])
# ret = []
#
# # get result
# for r in all_result:
# if len(r) == 0:
# logger.warning("0 scenarios found")
# ret += r
logger.info("\n Find {} waymo scenarios from {} files".format(len(all_result), num_files))
return all_result
def read_from_files(arg):
try:
scenario_pb2
except NameError:
raise ImportError("Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0==1.5.0")
waymo_data_directory, file_list = arg[0], arg[1]
scenarios = []
for file_count, file in enumerate(file_list):
for file in tqdm.tqdm(file_list):
file_path = os.path.join(waymo_data_directory, file)
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
continue

View File

@@ -0,0 +1,47 @@
import argparse
from scenarionet.builder.utils import copy_database
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--from', required=True, help="Which database to move.")
parser.add_argument(
"--to",
required=True,
help="The name of the new database. "
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
"that directory into a database."
)
parser.add_argument(
"--remove_source",
action="store_true",
help="Remove the `from_database` if set this flag"
)
parser.add_argument(
"--copy_raw_data",
action="store_true",
help="Instead of creating virtual file mapping, copy raw scenario.pkl file"
)
parser.add_argument(
"--exist_ok",
action="store_true",
help="Still allow to write, if the to_folder exists already. "
"This write will only create two .pkl files and this directory will become a database."
)
parser.add_argument(
"--overwrite",
action="store_true",
help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, "
"whether to overwrite both files"
)
args = parser.parse_args()
from_path = args.__getattribute__("from")
copy_database(
from_path,
args.to,
exist_ok=args.exist_ok,
overwrite=args.overwrite,
copy_raw_data=args.copy_raw_data,
remove_source=args.remove_source
)

View File

@@ -0,0 +1,111 @@
import argparse
from scenarionet.builder.filters import ScenarioFilter
from scenarionet.builder.utils import merge_database
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--database_path",
"-d",
required=True,
help="The name of the new database. "
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
"that directory into a database."
)
parser.add_argument(
'--from',
required=True,
type=str,
help="Which dataset to filter. It takes one directory path as input"
)
parser.add_argument(
"--exist_ok",
action="store_true",
help="Still allow to write, if the dir exists already. "
"This write will only create two .pkl files and this directory will become a database."
)
parser.add_argument(
"--overwrite",
action="store_true",
help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, "
"whether to overwrite both files"
)
parser.add_argument(
"--moving_dist",
action="store_true",
help="add this flag to select cases with SDC moving dist > sdc_moving_dist_min"
)
parser.add_argument(
"--sdc_moving_dist_min",
default=10,
type=float,
help="Selecting case with sdc_moving_dist > this value. "
)
parser.add_argument(
"--num_object",
action="store_true",
help="add this flag to select cases with object_num < max_num_object"
)
parser.add_argument(
"--max_num_object",
default=30,
type=float,
help="case will be selected if num_obj < this argument"
)
parser.add_argument(
"--no_overpass",
action="store_true",
help="Scenarios with overpass WON'T be selected"
)
parser.add_argument(
"--no_traffic_light",
action="store_true",
help="Scenarios with traffic light WON'T be selected"
)
parser.add_argument(
"--id_filter",
action="store_true",
help="Scenarios with indicated name will NOT be selected"
)
parser.add_argument(
"--exclude_ids",
nargs='+',
default=[],
help="Scenarios with indicated name will NOT be selected"
)
args = parser.parse_args()
target = args.sdc_moving_dist_min
obj_threshold = args.max_num_object
from_path = args.__getattribute__("from")
filters = []
if args.no_overpass:
filters.append(ScenarioFilter.make(ScenarioFilter.no_overpass))
if args.num_object:
filters.append(ScenarioFilter.make(ScenarioFilter.object_number, number_threshold=obj_threshold))
if args.moving_dist:
filters.append(ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=target, condition="greater"))
if args.no_traffic_light:
filters.append(ScenarioFilter.make(ScenarioFilter.no_traffic_light))
if args.id_filter:
filters.append(ScenarioFilter.make(ScenarioFilter.id_filter, ids=args.exclude_ids))
if len(filters) == 0:
raise ValueError("No filters are applied. Abort.")
merge_database(
args.database_path,
from_path,
exist_ok=args.exist_ok,
overwrite=args.overwrite,
try_generate_missing_file=True,
filters=filters
)

View File

@@ -5,14 +5,14 @@ from scenarionet.verifier.error import ErrorFile
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--database_path", "-d", required=True, help="The path of the newly generated dataset")
parser.add_argument("--database_path", "-d", required=True, help="The path of the newly generated database")
parser.add_argument("--file", "-f", required=True, help="The path of the error file, should be xyz.json")
parser.add_argument("--overwrite", action="store_true", help="If the database_path exists, overwrite it")
parser.add_argument(
"--broken",
action="store_true",
help="By default, only successful scenarios will be picked to build the new dataset. "
"If turn on this flog, it will generate dataset containing only broken scenarios."
help="By default, only successful scenarios will be picked to build the new database. "
"If turn on this flog, it will generate database containing only broken scenarios."
)
args = parser.parse_args()
ErrorFile.generate_dataset(args.file, args.database_path, args.overwrite, args.broken)

View File

@@ -1,5 +1,5 @@
import pkg_resources # for suppress warning
import argparse
from scenarionet.builder.filters import ScenarioFilter
from scenarionet.builder.utils import merge_database
@@ -9,13 +9,13 @@ if __name__ == '__main__':
"--database_path",
"-d",
required=True,
help="The name of the new combined dataset. "
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
"that directory into a dataset."
help="The name of the new combined database. "
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
"that directory into a database."
)
parser.add_argument(
'--from_datasets',
'--from',
required=True,
nargs='+',
default=[],
@@ -25,32 +25,38 @@ if __name__ == '__main__':
"--exist_ok",
action="store_true",
help="Still allow to write, if the dir exists already. "
"This write will only create two .pkl files and this directory will become a dataset."
"This write will only create two .pkl files and this directory will become a database."
)
parser.add_argument(
"--overwrite",
action="store_true",
help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, "
"whether to overwrite both files"
"whether to overwrite both files"
)
parser.add_argument(
"--filter_moving_dist",
action="store_true",
help="add this flag to select cases with SDC moving dist > sdc_moving_dist_min"
)
parser.add_argument(
"--sdc_moving_dist_min",
default=20,
default=5,
type=float,
help="Selecting case with sdc_moving_dist > this value. "
"We will add more filter conditions in the future."
"We will add more filter conditions in the future."
)
args = parser.parse_args()
target = args.sdc_moving_dist_min
filters = [ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=target, condition="greater")]
if len(args.from_datasets) != 0:
source = args.__getattribute__("from")
if len(source) != 0:
merge_database(
args.database_path,
*args.from_datasets,
*source,
exist_ok=args.exist_ok,
overwrite=args.overwrite,
try_generate_missing_file=True,
filters=filters
filters=filters if args.filter_moving_dist else []
)
else:
raise ValueError("No source dataset are provided. Abort.")
raise ValueError("No source database are provided. Abort.")

View File

@@ -1,35 +0,0 @@
import argparse
from scenarionet.builder.utils import move_database
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--from', required=True, help="Which dataset to move.")
parser.add_argument(
"--to",
required=True,
help="The name of the new dataset. "
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
"that directory into a dataset."
)
parser.add_argument(
"--exist_ok",
action="store_true",
help="Still allow to write, if the to_folder exists already. "
"This write will only create two .pkl files and this directory will become a dataset."
)
parser.add_argument(
"--overwrite",
action="store_true",
help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, "
"whether to overwrite both files"
)
args = parser.parse_args()
from_path = args.__getattr__("from")
move_database(
from_path,
args.to,
exist_ok=args.exist_ok,
overwrite=args.overwrite,
)

View File

@@ -0,0 +1,18 @@
import pkg_resources # for suppress warning
import argparse
import logging
from scenarionet.common_utils import read_dataset_summary
logger = logging.getLogger(__file__)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--database_path",
"-d",
required=True,
help="Database to check number of scenarios"
)
args = parser.parse_args()
summary, _, _, = read_dataset_summary(args.database_path)
logger.info("Number of scenarios: {}".format(len(summary)))

View File

@@ -7,7 +7,7 @@ from metadrive.scenario.utils import get_number_of_scenarios
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--database_path", "-d", required=True, help="The path of the dataset")
parser.add_argument("--database_path", "-d", required=True, help="The path of the database")
parser.add_argument("--render", action="store_true", help="Enable 3D rendering")
parser.add_argument("--scenario_index", default=None, type=int, help="Specifying a scenario to run")
args = parser.parse_args()
@@ -38,13 +38,13 @@ if __name__ == '__main__':
"data_directory": database_path,
}
)
for seed in range(num_scenario if args.scenario_index is not None else 1000000):
env.reset(force_seed=seed if args.scenario_index is None else args.scenario_index)
for index in range(num_scenario if args.scenario_index is not None else 1000000):
env.reset(force_seed=index if args.scenario_index is None else args.scenario_index)
for t in range(10000):
o, r, d, info = env.step([0, 0])
if env.config["use_render"]:
env.render(text={
"seed": env.engine.global_seed + env.config["start_scenario_index"],
"scenario index": env.engine.global_seed + env.config["start_scenario_index"],
})
if d and info["arrive_dest"]:

View File

@@ -0,0 +1,47 @@
"""
This script is for extracting a subset of data from an existing database
"""
import pkg_resources # for suppress warning
import argparse
from scenarionet.builder.utils import split_database
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--from', required=True, help="Which database to extract data from.")
parser.add_argument(
"--to",
required=True,
help="The name of the new database. "
"It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. "
"If exists_ok=True, those two .pkl files will be stored in an existing directory and turn "
"that directory into a database."
)
parser.add_argument("--num_scenarios", type=int, default=64, help="how many scenarios to extract (default: 30)")
parser.add_argument("--start_index", type=int, default=0, help="which index to start")
parser.add_argument("--random", action="store_true", help="If set to true, it will choose scenarios randomly "
"from all_scenarios[start_index:]. "
"Otherwise, the scenarios will be selected sequentially")
parser.add_argument(
"--exist_ok",
action="store_true",
help="Still allow to write, if the to_folder exists already. "
"This write will only create two .pkl files and this directory will become a database."
)
parser.add_argument(
"--overwrite",
action="store_true",
help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, "
"whether to overwrite both files"
)
args = parser.parse_args()
from_path = args.__getattribute__("from")
split_database(
from_path,
args.to,
args.start_index,
args.num_scenarios,
exist_ok=args.exist_ok,
overwrite=args.overwrite,
random=args.random
)

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env bash
python ../../merge_database.py --overwrite --exist_ok --database_path ../tmp/test_combine_dataset --from_datasets ../../../dataset/waymo ../../../dataset/pg ../../../dataset/nuscenes ../../../dataset/nuplan --overwrite
python ../../verify_simulation.py --overwrite --database_path ../tmp/test_combine_dataset --error_file_path ../tmp/test_combine_dataset --random_drop --num_workers=16
python ../../merge_database.py --overwrite --exist_ok --database_path ../tmp/test_combine_dataset --from ../../../dataset/waymo ../../../dataset/pg ../../../dataset/nuscenes ../../../dataset/nuplan --overwrite
python ../../check_simulation.py --overwrite --database_path ../tmp/test_combine_dataset --error_file_path ../tmp/test_combine_dataset --random_drop --num_workers=16
python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --database_path ../tmp/verify_pass
python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --database_path ../tmp/verify_fail --broken

View File

@@ -32,8 +32,8 @@ done
# combine the datasets
if [ "$overwrite" = true ]; then
python -m scenarionet.scripts.merge_database --database_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --overwrite --exist_ok
python -m scenarionet.scripts.merge_database --database_path $dataset_path --from $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --overwrite --exist_ok
else
python -m scenarionet.scripts.merge_database --database_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --exist_ok
python -m scenarionet.scripts.merge_database --database_path $dataset_path --from $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --exist_ok
fi

View File

@@ -0,0 +1,119 @@
import pygame
from metadrive.envs.metadrive_env import MetaDriveEnv
from metadrive.utils import setup_logger
if __name__ == "__main__":
setup_logger(True)
env = MetaDriveEnv(
{
"num_scenarios": 1,
"traffic_density": 0.15,
"traffic_mode": "hybrid",
"start_seed": 74,
# "_disable_detector_mask":True,
# "debug_physics_world": True,
# "debug": True,
# "global_light": False,
# "debug_static_world": True,
"show_interface": False,
"cull_scene": False,
"random_spawn_lane_index": False,
"random_lane_width": False,
# "image_observation": True,
# "controller": "joystick",
# "show_coordinates": True,
"random_agent_model": False,
"manual_control": True,
"use_render": True,
"accident_prob": 1,
"decision_repeat": 5,
"interface_panel": [],
"need_inverse_traffic": False,
"rgb_clip": True,
"map": 2,
# "agent_policy": ExpertPolicy,
"random_traffic": False,
# "random_lane_width": True,
"driving_reward": 1.0,
# "pstats": True,
"force_destroy": False,
# "show_skybox": False,
"show_fps": False,
"render_pipeline": True,
# "camera_dist": 8,
"window_size": (1600, 900),
"camera_dist": 9,
# "camera_pitch": 30,
# "camera_height": 1,
# "camera_smooth": False,
# "camera_height": -1,
"vehicle_config": {
"enable_reverse": False,
# "vehicle_model": "xl",
# "rgb_camera": (1024, 1024),
# "spawn_velocity": [8.728615581032535, -0.24411703918728195],
"spawn_velocity_car_frame": True,
# "image_source": "depth_camera",
# "random_color": True
# "show_lidar": True,
"spawn_lane_index": None,
# "destination":"2R1_3_",
# "show_side_detector": True,
# "show_lane_line_detector": True,
# "side_detector": dict(num_lasers=2, distance=50),
# "lane_line_detector": dict(num_lasers=2, distance=50),
# "show_line_to_navi_mark": True,
"show_navi_mark": False,
# "show_dest_mark": True
},
}
)
o = env.reset()
def capture():
env.capture()
ret = env.render(mode="topdown", screen_size=(1600, 900), film_size=(2000, 2000), track_target_vehicle=True)
pygame.image.save(ret, "top_down_{}.png".format(env.current_seed))
env.engine.accept("c", capture)
# env.main_camera.set_follow_lane(True)
# env.vehicle.get_camera("rgb_camera").save_image(env.vehicle)
# for line in env.engine.coordinate_line:
# line.reparentTo(env.vehicle.origin)
# env.vehicle.set_velocity([5, 0], in_local_frame=True)
for s in range(1, 100000):
# env.vehicle.set_velocity([1, 0], in_local_frame=True)
o, r, d, info = env.step([0, 0])
# env.vehicle.set_pitch(-np.pi/4)
# [0.09231533, 0.491018, 0.47076905, 0.7691619, 0.5, 0.5, 1.0, 0.0, 0.48037243, 0.8904728, 0.81229943, 0.7317231, 1.0, 0.85320455, 0.9747932, 0.65675277, 0.0, 0.5, 0.5]
# else:
# if s % 100 == 0:
# env.close()
# env.reset()
# info["fuel"] = env.vehicle.energy_consumption
# env.render(
# text={
# # "heading_diff": env.vehicle.heading_diff(env.vehicle.lane),
# # "lane_width": env.vehicle.lane.width,
# # "lane_index": env.vehicle.lane_index,
# # "lateral": env.vehicle.lane.local_coordinates(env.vehicle.position),
# "current_seed": env.current_seed
# }
# )
# if d:
# env.reset()
# # assert env.observation_space.contains(o)
# if (s + 1) % 100 == 0:
# # print(
# "Finish {}/10000 simulation steps. Time elapse: {:.4f}. Average FPS: {:.4f}".format(
# s + 1,f
# time.time() - start, (s + 1) / (time.time() - start)
# )
# )
# if d:
# # # env.close()
# # # print(len(env.engine._spawned_objects))
# env.reset()

View File

@@ -0,0 +1,17 @@
from metadrive.scenario.scenario_description import ScenarioDescription as SD
from metadrive.scenario.utils import read_scenario_data, read_dataset_summary, assert_scenario_equal
from scenarionet.common_utils import read_scenario
if __name__ == '__main__':
data_1 = "D:\\code\\scenarionet\\dataset\pg_2000"
data_2 = "C:\\Users\\x1\\Desktop\\remote"
summary_1, lookup_1, mapping_1 = read_dataset_summary(data_1)
summary_2, lookup_2, mapping_2 = read_dataset_summary(data_2)
# assert lookup_1[-10:] == lookup_2
scenarios_1 = {}
scenarios_2 = {}
for i in range(9):
scenarios_1[str(i)] = read_scenario(data_1, mapping_1, lookup_1[-9+i])
scenarios_2[str(i)] = read_scenario(data_2, mapping_2, lookup_2[i])
# assert_scenario_equal(scenarios_1, scenarios_2, check_self_type=False, only_compare_sdc=True)

View File

@@ -0,0 +1,97 @@
import time
import pygame
from metadrive.engine.asset_loader import AssetLoader
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
NuScenesEnv = ScenarioEnv
if __name__ == "__main__":
env = NuScenesEnv(
{
"use_render": True,
"agent_policy": ReplayEgoCarPolicy,
"show_interface": False,
# "need_lane_localization": False,
"show_logo": False,
"no_traffic": False,
"sequential_seed": True,
"reactive_traffic": False,
"show_fps": False,
# "debug": True,
# "render_pipeline": True,
"daytime": "11:01",
"window_size": (1600, 900),
"camera_dist": 0.8,
"camera_height": 1.5,
"camera_pitch": None,
"camera_fov": 60,
"start_scenario_index": 0,
"num_scenarios": 10,
# "force_reuse_object_name": True,
# "data_directory": "/home/shady/Downloads/test_processed",
"horizon": 1000,
# "no_static_vehicles": True,
# "show_policy_mark": True,
# "show_coordinates": True,
# "force_destroy": True,
# "default_vehicle_in_traffic": True,
"vehicle_config": dict(
# light=True,
# random_color=True,
show_navi_mark=False,
use_special_color=False,
image_source="depth_camera",
rgb_camera=(1600, 900),
depth_camera=(1600, 900, True),
# no_wheel_friction=True,
lidar=dict(num_lasers=120, distance=50),
lane_line_detector=dict(num_lasers=0, distance=50),
side_detector=dict(num_lasers=12, distance=50)
),
"data_directory": AssetLoader.file_path("nuscenes", return_raw_style=False),
"image_observation": True,
}
)
# 0,1,3,4,5,6
success = []
reset_num = 0
start = time.time()
reset_used_time = 0
s = 0
while True:
# for i in range(10):
start_reset = time.time()
env.reset(force_seed=0)
reset_used_time += time.time() - start_reset
reset_num += 1
for t in range(10000):
if t==30:
# env.capture("camera_deluxe.jpg")
# ret = env.render(mode="topdown", screen_size=(1600, 900), film_size=(5000, 5000), track_target_vehicle=True)
# pygame.image.save(ret, "top_down.png")
env.vehicle.get_camera("depth_camera").save_image(env.vehicle, "camera.jpg")
o, r, d, info = env.step([1, 0.88])
assert env.observation_space.contains(o)
s += 1
# if env.config["use_render"]:
# env.render(text={"seed": env.current_seed,
# # "num_map": info["num_stored_maps"],
# "data_coverage": info["data_coverage"],
# "reward": r,
# "heading_r": info["step_reward_heading"],
# "lateral_r": info["step_reward_lateral"],
# "smooth_action_r": info["step_reward_action_smooth"]})
if d:
print(
"Time elapse: {:.4f}. Average FPS: {:.4f}, AVG_Reset_time: {:.4f}".format(
time.time() - start, s / (time.time() - start - reset_used_time),
reset_used_time / reset_num
)
)
print("seed:{}, success".format(env.engine.global_random_seed))
print(list(env.engine.curriculum_manager.recent_success.dict.values()))
break

View File

@@ -0,0 +1,104 @@
import time
import pygame
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
NuScenesEnv = ScenarioEnv
if __name__ == "__main__":
env = NuScenesEnv(
{
"use_render": True,
"agent_policy": ReplayEgoCarPolicy,
"show_interface": False,
"image_observation": False,
"show_logo": False,
"no_traffic": False,
"drivable_region_extension": 15,
"sequential_seed": True,
"reactive_traffic": False,
"show_fps": False,
# "debug": True,
"render_pipeline": True,
"daytime": "19:30",
"window_size": (1600, 900),
"camera_dist": 9,
# "camera_height": 1.5,
# "camera_pitch": None,
# "camera_fov": 60,
"start_scenario_index": 0,
"num_scenarios": 4,
# "force_reuse_object_name": True,
# "data_directory": "/home/shady/Downloads/test_processed",
"horizon": 1000,
# "no_static_vehicles": True,
# "show_policy_mark": True,
# "show_coordinates": True,
# "force_destroy": True,
# "default_vehicle_in_traffic": True,
"vehicle_config": dict(
# light=True,
# random_color=True,
show_navi_mark=False,
use_special_color=False,
image_source="depth_camera",
# rgb_camera=(1600, 900),
# depth_camera=(1600, 900, True),
# no_wheel_friction=True,
lidar=dict(num_lasers=120, distance=50),
lane_line_detector=dict(num_lasers=0, distance=50),
side_detector=dict(num_lasers=12, distance=50)
),
"data_directory": "D:\\code\\scenarionet\\scenarionet\\tests\\script\\waymo_scenes_adv"
}
)
# 0,1,3,4,5,6
success = []
reset_num = 0
start = time.time()
reset_used_time = 0
s = 0
env.reset()
def capture():
env.capture()
ret = env.render(mode="topdown", screen_size=(1600, 900), film_size=(7000, 7000), track_target_vehicle=True)
pygame.image.save(ret, "top_down_{}.png".format(env.current_seed))
env.engine.accept("c", capture)
while True:
# for i in range(10):
start_reset = time.time()
env.reset()
reset_used_time += time.time() - start_reset
reset_num += 1
for t in range(10000):
o, r, d, info = env.step([1, 0.88])
assert env.observation_space.contains(o)
s += 1
# if env.config["use_render"]:
# env.render(text={"seed": env.current_seed,
# # "num_map": info["num_stored_maps"],
# "data_coverage": info["data_coverage"],
# "reward": r,
# "heading_r": info["step_reward_heading"],
# "lateral_r": info["step_reward_lateral"],
# "smooth_action_r": info["step_reward_action_smooth"]})
if d:
print(
"Time elapse: {:.4f}. Average FPS: {:.4f}, AVG_Reset_time: {:.4f}".format(
time.time() - start, s / (time.time() - start - reset_used_time),
reset_used_time / reset_num
)
)
print("seed:{}, success".format(env.engine.global_random_seed))
print(list(env.engine.curriculum_manager.recent_success.dict.values()))
break

View File

@@ -50,9 +50,9 @@ if __name__ == '__main__':
long, lat, = c_lane.local_coordinates(env.vehicle.position)
if env.config["use_render"]:
env.render(text={
"seed": env.engine.global_seed + env.config["start_scenario_index"],
"scenario index": env.engine.global_seed + env.config["start_scenario_index"],
})
if d and info["arrive_dest"]:
print("seed:{}, success".format(env.engine.global_random_seed))
print("scenario index:{}, success".format(env.engine.global_random_seed))
break

View File

@@ -0,0 +1,37 @@
import os
import os.path
from metadrive.engine.asset_loader import AssetLoader
from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH
from scenarionet.builder.filters import ScenarioFilter
from scenarionet.builder.utils import merge_database
def test_filter_overpass():
overpass_1 = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", "overpass")
overpass_in_md = AssetLoader.file_path("waymo", return_raw_style=False)
dataset_paths = [overpass_1, overpass_in_md]
output_path = os.path.join(TMP_PATH, "combine")
merge_database(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True)
# filter
filters = []
filters.append(ScenarioFilter.make(ScenarioFilter.no_overpass))
summaries, _ = merge_database(
output_path,
*dataset_paths,
exist_ok=True,
overwrite=True,
try_generate_missing_file=True,
filters=filters
)
assert len(summaries) == 3
for scenario in summaries:
assert scenario != "sd_waymo_v1.2_eb4b91b10ca94ff2.pkl"
if __name__ == '__main__':
test_filter_overpass()

View File

@@ -4,13 +4,13 @@ import os.path
import pytest
from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH
from scenarionet.builder.utils import move_database, merge_database
from scenarionet.builder.utils import copy_database, merge_database
from scenarionet.common_utils import read_dataset_summary, read_scenario
from scenarionet.verifier.utils import verify_database
@pytest.mark.order("first")
def test_move_database():
def test_copy_database():
dataset_name = "nuscenes"
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)]
@@ -19,7 +19,7 @@ def test_move_database():
# move
for k, from_path in enumerate(dataset_paths):
to = os.path.join(TMP_PATH, str(k))
move_database(from_path, to)
copy_database(from_path, to)
moved_path.append(to)
assert os.path.exists(from_path)
merge_database(output_path, *moved_path, exist_ok=True, overwrite=True, try_generate_missing_file=True)
@@ -37,7 +37,7 @@ def test_move_database():
for k, from_path in enumerate(moved_path):
new_p = os.path.join(TMP_PATH, str(k) + str(k))
new_move_pathes.append(new_p)
move_database(from_path, new_p, exist_ok=True, overwrite=True)
copy_database(from_path, new_p, exist_ok=True, overwrite=True)
assert not os.path.exists(from_path)
merge_database(output_path, *new_move_pathes, exist_ok=True, overwrite=True, try_generate_missing_file=True)
# verify
@@ -51,4 +51,4 @@ def test_move_database():
if __name__ == '__main__':
test_move_database()
test_copy_database()

View File

@@ -0,0 +1,37 @@
import os
import os.path
from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH
from scenarionet.builder.utils import merge_database, split_database
from scenarionet.common_utils import read_dataset_summary
def test_split_dataset():
dataset_name = "nuscenes"
original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name)
test_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset")
dataset_paths = [original_dataset_path + "_{}".format(i) for i in [0, 1, 3, 4]]
output_path = os.path.join(TMP_PATH, "combine")
merge_database(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True)
# split
from_path = output_path
to_path = os.path.join(TMP_PATH, "split", "split")
summary_1, lookup_1, mapping_1 = read_dataset_summary(from_path)
split_database(from_path, to_path, start_index=3, random=True, num_scenarios=4, overwrite=True, exist_ok=True)
summary_2, lookup_2, mapping_2 = read_dataset_summary(to_path)
assert len(summary_2) == 4
for scenario in summary_2:
assert scenario not in lookup_1[:3]
split_database(from_path, to_path, start_index=3, num_scenarios=4, overwrite=True, exist_ok=True)
summary_2, lookup_2, mapping_2 = read_dataset_summary(to_path)
assert lookup_1[3:7] == lookup_2
for k in range(4):
assert summary_1[lookup_2[k]] == summary_2[lookup_2[k]]
if __name__ == '__main__':
test_split_dataset()

View File

@@ -53,12 +53,12 @@ class ErrorFile:
@classmethod
def generate_dataset(cls, error_file_path, new_dataset_path, overwrite=False, broken_scenario=False):
"""
Generate a new dataset containing all broken scenarios or all good scenarios
Generate a new database containing all broken scenarios or all good scenarios
:param error_file_path: error file path
:param new_dataset_path: a directory where you want to store your data
:param overwrite: if new_dataset_path exists, whether to overwrite
:param broken_scenario: generate broken scenarios. You can generate such a broken scenarios for debugging
:return: dataset summary, dataset mapping
:return: database summary, database mapping
"""
new_dataset_path = os.path.abspath(new_dataset_path)
if os.path.exists(new_dataset_path):

View File

@@ -108,6 +108,7 @@ def loading_into_metadrive(
"agent_policy": ReplayEgoCarPolicy,
"num_scenarios": num_scenario,
"horizon": 1000,
"store_map": False,
"start_scenario_index": start_scenario_index,
"no_static_vehicles": False,
"data_directory": dataset_path,

View File

View File

View File

@@ -0,0 +1,23 @@
from scenarionet_training.scripts.train_nuplan import config
from scenarionet_training.train_utils.utils import eval_ckpt
if __name__ == '__main__':
# 27 29 30 37 39
ckpt_path = "C:\\Users\\x1\\Desktop\\checkpoint_510\\checkpoint-510"
scenario_data_path = "D:\\scenarionet_testset\\nuplan_test\\nuplan_test_w_raw"
num_scenarios = 2000
start_scenario_index = 0
horizon = 600
render = False
explore = True # PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance
log_interval = 10
eval_ckpt(config,
ckpt_path,
scenario_data_path,
num_scenarios,
start_scenario_index,
horizon,
render,
explore,
log_interval)

View File

@@ -0,0 +1,27 @@
import os.path
from scenarionet import SCENARIONET_DATASET_PATH
from scenarionet_training.scripts.train_pg import config
from scenarionet_training.train_utils.utils import eval_ckpt
if __name__ == '__main__':
# Merge all evaluate script
# 10/15/20/26/30/31/32
ckpt_path = "C:\\Users\\x1\\Desktop\\checkpoint_330\\checkpoint-330"
scenario_data_path = os.path.join(SCENARIONET_DATASET_PATH, "pg_2000")
num_scenarios = 2000
start_scenario_index = 0
horizon = 600
render = False
explore = True # PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance
log_interval = 2
eval_ckpt(config,
ckpt_path,
scenario_data_path,
num_scenarios,
start_scenario_index,
horizon,
render,
explore,
log_interval)

View File

@@ -0,0 +1,22 @@
from scenarionet_training.scripts.train_waymo import config
from scenarionet_training.train_utils.utils import eval_ckpt
if __name__ == '__main__':
ckpt_path = "C:\\Users\\x1\\Desktop\\checkpoint_170\\checkpoint-170"
scenario_data_path = "D:\\scenarionet_testset\\waymo_test_raw_data"
num_scenarios = 2000
start_scenario_index = 0
horizon = 600
render = True
explore = True # PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance
log_interval = 2
eval_ckpt(config,
ckpt_path,
scenario_data_path,
num_scenarios,
start_scenario_index,
horizon,
render,
explore,
log_interval)

View File

@@ -0,0 +1,80 @@
import os.path
from metadrive.envs.scenario_env import ScenarioEnv
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
if __name__ == '__main__':
env = ScenarioEnv
args = get_train_parser().parse_args()
exp_name = get_exp_name(args)
stop = int(100_000_000)
config = dict(
env=env,
env_config=dict(
# scenario
start_scenario_index=0,
num_scenarios=32,
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg"),
sequential_seed=True,
# traffic & light
reactive_traffic=False,
no_static_vehicles=True,
no_light=True,
static_traffic_object=True,
# curriculum training
curriculum_level=4,
target_success_rate=0.8,
# training
horizon=None,
use_lateral_reward=True,
),
# # ===== Evaluation =====
evaluation_interval=2,
evaluation_num_episodes=32,
evaluation_config=dict(env_config=dict(start_scenario_index=32,
num_scenarios=32,
sequential_seed=True,
curriculum_level=1, # turn off
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg"))),
evaluation_num_workers=2,
metrics_smoothing_episodes=10,
# ===== Training =====
model=dict(fcnet_hiddens=[512, 256, 128]),
horizon=600,
num_sgd_iter=20,
lr=5e-5,
rollout_fragment_length=500,
sgd_minibatch_size=100,
train_batch_size=4000,
num_gpus=0.5 if args.num_gpus != 0 else 0,
num_cpus_per_worker=0.4,
num_cpus_for_driver=1,
num_workers=2,
framework="tf"
)
train(
MultiWorkerPPO,
exp_name=exp_name,
save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"),
keep_checkpoints_num=5,
stop=stop,
config=config,
num_gpus=args.num_gpus,
# num_seeds=args.num_seeds,
num_seeds=1,
# test_mode=args.test,
# local_mode=True,
# TODO remove this when we release our code
# wandb_key_file="~/wandb_api_key_file.txt",
wandb_project="scenarionet",
)

View File

@@ -0,0 +1,74 @@
import argparse
import pickle
import json
import os
import numpy as np
from scenarionet_training.scripts.train_nuplan import config
from scenarionet_training.train_utils.callbacks import DrivingCallbacks
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
from scenarionet_training.train_utils.utils import initialize_ray
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.int32):
return int(obj)
elif isinstance(obj, np.int64):
return int(obj)
return json.JSONEncoder.default(self, obj)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--start_index", type=int, default=0)
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--database_path", type=str, required=True)
parser.add_argument("--id", type=str, default="")
parser.add_argument("--num_scenarios", type=int, default=5000)
parser.add_argument("--num_workers", type=int, default=10)
parser.add_argument("--horizon", type=int, default=600)
parser.add_argument("--allowed_more_steps", type=int, default=50)
parser.add_argument("--max_lateral_dist", type=int, default=2.5)
parser.add_argument("--overwrite", action="store_true")
args = parser.parse_args()
file = "eval_{}_{}_{}".format(args.id, os.path.basename(args.ckpt_path), os.path.basename(args.database_path))
if os.path.exists(file) and not args.overwrite:
raise FileExistsError("Please remove {} or set --overwrite".format(file))
initialize_ray(test_mode=True, num_gpus=1)
config["callbacks"] = DrivingCallbacks
config["evaluation_num_workers"] = args.num_workers
config["evaluation_num_episodes"] = args.num_scenarios
config["metrics_smoothing_episodes"] = args.num_scenarios
config["custom_eval_function"] = None
config["num_workers"] = 0
config["evaluation_config"]["env_config"].update(dict(
start_scenario_index=args.start_index,
num_scenarios=args.num_scenarios,
sequential_seed=True,
store_map=False,
store_data=False,
allowed_more_steps=args.allowed_more_steps,
# no_map=True,
max_lateral_dist=args.max_lateral_dist,
curriculum_level=1, # disable curriculum
target_success_rate=1,
horizon=args.horizon,
episodes_to_evaluate_curriculum=args.num_scenarios,
data_directory=args.database_path,
use_render=False))
trainer = MultiWorkerPPO(config)
trainer.restore(args.ckpt_path)
ret = trainer._evaluate()["evaluation"]
with open(file + ".json", "w") as f:
json.dump(ret, f, cls=NumpyEncoder)
with open(file + ".pkl", "wb+") as f:
pickle.dump(ret, f)

View File

@@ -0,0 +1,96 @@
import os.path
from metadrive.envs.scenario_env import ScenarioEnv
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
config = dict(
env=ScenarioEnv,
env_config=dict(
# scenario
start_scenario_index=0,
num_scenarios=40000,
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "nuplan_train"),
sequential_seed=True,
# curriculum training
curriculum_level=100,
target_success_rate=0.8, # or 0.7
# episodes_to_evaluate_curriculum=400, # default=num_scenarios/curriculum_level
# traffic & light
reactive_traffic=True,
no_static_vehicles=True,
no_light=True,
static_traffic_object=True,
# training scheme
horizon=None,
driving_reward=4,
steering_range_penalty=1.0,
heading_penalty=2,
lateral_penalty=2.0,
no_negative_reward=True,
on_lane_line_penalty=0,
crash_vehicle_penalty=2,
crash_human_penalty=2,
crash_object_penalty=0.5,
# out_of_road_penalty=2,
max_lateral_dist=2,
# crash_vehicle_done=True,
vehicle_config=dict(side_detector=dict(num_lasers=0))
),
# ===== Evaluation =====
evaluation_interval=15,
evaluation_num_episodes=1000,
# TODO (LQY), this is a sample from testset do eval on all scenarios after training!
evaluation_config=dict(env_config=dict(start_scenario_index=0,
num_scenarios=1000,
sequential_seed=True,
curriculum_level=1, # turn off
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "nuplan_test"))),
evaluation_num_workers=10,
metrics_smoothing_episodes=10,
# ===== Training =====
model=dict(fcnet_hiddens=[512, 256, 128]),
horizon=600,
num_sgd_iter=20,
lr=1e-4,
rollout_fragment_length=500,
sgd_minibatch_size=200,
train_batch_size=50000,
num_gpus=0.5,
num_cpus_per_worker=0.3,
num_cpus_for_driver=1,
num_workers=20,
framework="tf"
)
if __name__ == '__main__':
# PG data is generated with seeds 10,000 to 60,000
args = get_train_parser().parse_args()
exp_name = get_exp_name(args)
stop = int(100_000_000)
config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0
train(
MultiWorkerPPO,
exp_name=exp_name,
save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"),
keep_checkpoints_num=5,
stop=stop,
config=config,
num_gpus=args.num_gpus,
# num_seeds=args.num_seeds,
num_seeds=5,
test_mode=args.test,
# local_mode=True,
# TODO remove this when we release our code
# wandb_key_file="~/wandb_api_key_file.txt",
wandb_project="scenarionet",
)

View File

@@ -0,0 +1,95 @@
import os.path
from ray.tune import grid_search
from metadrive.envs.scenario_env import ScenarioEnv
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
config = dict(
env=ScenarioEnv,
env_config=dict(
# scenario
start_scenario_index=0,
num_scenarios=40000,
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg_train"),
sequential_seed=True,
# curriculum training
curriculum_level=100,
target_success_rate=0.8,
# episodes_to_evaluate_curriculum=400, # default=num_scenarios/curriculum_level
# traffic & light
reactive_traffic=False,
no_static_vehicles=True,
no_light=True,
static_traffic_object=True,
# training scheme
horizon=None,
steering_range_penalty=2,
heading_penalty=1.0,
lateral_penalty=1.0,
no_negative_reward=True,
on_lane_line_penalty=0,
crash_vehicle_penalty=2,
crash_human_penalty=2,
out_of_road_penalty=2,
max_lateral_dist=2,
# crash_vehicle_done=True,
vehicle_config=dict(side_detector=dict(num_lasers=0))
),
# ===== Evaluation =====
evaluation_interval=15,
evaluation_num_episodes=1000,
# TODO (LQY), this is a sample from testset do eval on all scenarios after training!
evaluation_config=dict(env_config=dict(start_scenario_index=0,
num_scenarios=1000,
sequential_seed=True,
curriculum_level=1, # turn off
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg_test"))),
evaluation_num_workers=10,
metrics_smoothing_episodes=10,
# ===== Training =====
model=dict(fcnet_hiddens=[512, 256, 128]),
horizon=600,
num_sgd_iter=20,
lr=1e-4,
rollout_fragment_length=500,
sgd_minibatch_size=200,
train_batch_size=50000,
num_gpus=0.5,
num_cpus_per_worker=0.3,
num_cpus_for_driver=1,
num_workers=20,
framework="tf"
)
if __name__ == '__main__':
# PG data is generated with seeds 10,000 to 60,000
args = get_train_parser().parse_args()
exp_name = get_exp_name(args)
stop = int(100_000_000)
config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0
train(
MultiWorkerPPO,
exp_name=exp_name,
save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"),
keep_checkpoints_num=5,
stop=stop,
config=config,
num_gpus=args.num_gpus,
# num_seeds=args.num_seeds,
num_seeds=5,
test_mode=args.test,
# local_mode=True,
# TODO remove this when we release our code
# wandb_key_file="~/wandb_api_key_file.txt",
wandb_project="scenarionet",
)

View File

@@ -0,0 +1,95 @@
import os.path
from metadrive.envs.scenario_env import ScenarioEnv
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
config = dict(
env=ScenarioEnv,
env_config=dict(
# scenario
start_scenario_index=0,
num_scenarios=40000,
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "waymo_train"),
sequential_seed=True,
# curriculum training
curriculum_level=100,
target_success_rate=0.8,
# episodes_to_evaluate_curriculum=400, # default=num_scenarios/curriculum_level
# traffic & light
reactive_traffic=True,
no_static_vehicles=True,
no_light=True,
static_traffic_object=True,
# training scheme
horizon=None,
driving_reward=1,
steering_range_penalty=0,
heading_penalty=1,
lateral_penalty=1.0,
no_negative_reward=True,
on_lane_line_penalty=0,
crash_vehicle_penalty=2,
crash_human_penalty=2,
out_of_road_penalty=2,
max_lateral_dist=2,
# crash_vehicle_done=True,
vehicle_config=dict(side_detector=dict(num_lasers=0))
),
# ===== Evaluation =====
evaluation_interval=15,
evaluation_num_episodes=1000,
# TODO (LQY), this is a sample from testset do eval on all scenarios after training!
evaluation_config=dict(env_config=dict(start_scenario_index=0,
num_scenarios=1000,
sequential_seed=True,
curriculum_level=1, # turn off
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "waymo_test"))),
evaluation_num_workers=10,
metrics_smoothing_episodes=10,
# ===== Training =====
model=dict(fcnet_hiddens=[512, 256, 128]),
horizon=600,
num_sgd_iter=20,
lr=1e-4,
rollout_fragment_length=500,
sgd_minibatch_size=200,
train_batch_size=50000,
num_gpus=0.5,
num_cpus_per_worker=0.3,
num_cpus_for_driver=1,
num_workers=20,
framework="tf"
)
if __name__ == '__main__':
# PG data is generated with seeds 10,000 to 60,000
args = get_train_parser().parse_args()
exp_name = get_exp_name(args)
stop = int(100_000_000)
config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0
train(
MultiWorkerPPO,
exp_name=exp_name,
save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"),
keep_checkpoints_num=5,
stop=stop,
config=config,
num_gpus=args.num_gpus,
# num_seeds=args.num_seeds,
num_seeds=5,
test_mode=args.test,
# local_mode=True,
# TODO remove this when we release our code
# wandb_key_file="~/wandb_api_key_file.txt",
wandb_project="scenarionet",
)

View File

@@ -0,0 +1,42 @@
import copy
import logging
from typing import TypeVar
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
# Generic type var for foreach_* methods.
T = TypeVar("T")
@DeveloperAPI
class AnisotropicWorkerSet(WorkerSet):
"""
Workers are assigned to different scenarios for saving memory/speeding up sampling
"""
def add_workers(self, num_workers: int) -> None:
"""
Workers are assigned to different scenarios
"""
remote_args = {
"num_cpus": self._remote_config["num_cpus_per_worker"],
"num_gpus": self._remote_config["num_gpus_per_worker"],
# memory=0 is an error, but memory=None means no limits.
"memory": self._remote_config["memory_per_worker"] or None,
"object_store_memory": self.
_remote_config["object_store_memory_per_worker"] or None,
"resources": self._remote_config["custom_resources_per_worker"],
}
cls = RolloutWorker.as_remote(**remote_args).remote
for i in range(num_workers):
config = copy.deepcopy(self._remote_config)
config["env_config"]["worker_index"] = i
config["env_config"]["num_workers"] = num_workers
self._remote_workers.append(self._make_worker(cls, self._env_creator, self._policy_class, i + 1, config))

View File

@@ -0,0 +1,110 @@
from typing import Dict
import numpy as np
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.policy import Policy
class DrivingCallbacks(DefaultCallbacks):
def on_episode_start(
self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode,
env_index: int, **kwargs
):
episode.user_data["velocity"] = []
episode.user_data["steering"] = []
episode.user_data["step_reward"] = []
episode.user_data["acceleration"] = []
episode.user_data["lateral_dist"] = []
episode.user_data["cost"] = []
episode.user_data["num_crash_vehicle"] = []
episode.user_data["num_crash_human"] = []
episode.user_data["num_crash_object"] = []
episode.user_data["num_on_line"] = []
episode.user_data["step_reward_lateral"] = []
episode.user_data["step_reward_heading"] = []
episode.user_data["step_reward_action_smooth"] = []
def on_episode_step(
self, *, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode, env_index: int, **kwargs
):
info = episode.last_info_for()
if info is not None:
episode.user_data["velocity"].append(info["velocity"])
episode.user_data["steering"].append(info["steering"])
episode.user_data["step_reward"].append(info["step_reward"])
episode.user_data["acceleration"].append(info["acceleration"])
episode.user_data["lateral_dist"].append(info["lateral_dist"])
episode.user_data["cost"].append(info["cost"])
for x in ["num_crash_vehicle", "num_crash_object", "num_crash_human", "num_on_line"]:
episode.user_data[x].append(info[x])
for x in ["step_reward_lateral", "step_reward_heading", "step_reward_action_smooth"]:
episode.user_data[x].append(info[x])
def on_episode_end(
self, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode,
**kwargs
):
arrive_dest = episode.last_info_for()["arrive_dest"]
crash = episode.last_info_for()["crash"]
out_of_road = episode.last_info_for()["out_of_road"]
max_step_rate = not (arrive_dest or crash or out_of_road)
episode.custom_metrics["success_rate"] = float(arrive_dest)
episode.custom_metrics["crash_rate"] = float(crash)
episode.custom_metrics["out_of_road_rate"] = float(out_of_road)
episode.custom_metrics["max_step_rate"] = float(max_step_rate)
episode.custom_metrics["velocity_max"] = float(np.max(episode.user_data["velocity"]))
episode.custom_metrics["velocity_mean"] = float(np.mean(episode.user_data["velocity"]))
episode.custom_metrics["velocity_min"] = float(np.min(episode.user_data["velocity"]))
episode.custom_metrics["lateral_dist_min"] = float(np.min(episode.user_data["lateral_dist"]))
episode.custom_metrics["lateral_dist_max"] = float(np.max(episode.user_data["lateral_dist"]))
episode.custom_metrics["lateral_dist_mean"] = float(np.mean(episode.user_data["lateral_dist"]))
episode.custom_metrics["steering_max"] = float(np.max(episode.user_data["steering"]))
episode.custom_metrics["steering_mean"] = float(np.mean(episode.user_data["steering"]))
episode.custom_metrics["steering_min"] = float(np.min(episode.user_data["steering"]))
episode.custom_metrics["acceleration_min"] = float(np.min(episode.user_data["acceleration"]))
episode.custom_metrics["acceleration_mean"] = float(np.mean(episode.user_data["acceleration"]))
episode.custom_metrics["acceleration_max"] = float(np.max(episode.user_data["acceleration"]))
episode.custom_metrics["step_reward_max"] = float(np.max(episode.user_data["step_reward"]))
episode.custom_metrics["step_reward_mean"] = float(np.mean(episode.user_data["step_reward"]))
episode.custom_metrics["step_reward_min"] = float(np.min(episode.user_data["step_reward"]))
episode.custom_metrics["cost"] = float(sum(episode.user_data["cost"]))
for x in ["num_crash_vehicle", "num_crash_object", "num_crash_human", "num_on_line"]:
episode.custom_metrics[x] = float(sum(episode.user_data[x]))
for x in ["step_reward_lateral", "step_reward_heading", "step_reward_action_smooth"]:
episode.custom_metrics[x] = float(np.mean(episode.user_data[x]))
episode.custom_metrics["route_completion"] = float(episode.last_info_for()["route_completion"])
episode.custom_metrics["curriculum_level"] = int(episode.last_info_for()["curriculum_level"])
episode.custom_metrics["scenario_index"] = int(episode.last_info_for()["scenario_index"])
episode.custom_metrics["track_length"] = float(episode.last_info_for()["track_length"])
episode.custom_metrics["num_stored_maps"] = int(episode.last_info_for()["num_stored_maps"])
episode.custom_metrics["scenario_difficulty"] = float(episode.last_info_for()["scenario_difficulty"])
episode.custom_metrics["data_coverage"] = float(episode.last_info_for()["data_coverage"])
episode.custom_metrics["curriculum_success"] = float(episode.last_info_for()["curriculum_success"])
episode.custom_metrics["curriculum_route_completion"] = float(
episode.last_info_for()["curriculum_route_completion"])
def on_train_result(self, *, trainer, result: dict, **kwargs):
result["success"] = np.nan
result["out"] = np.nan
result["max_step"] = np.nan
result["level"] = np.nan
result["length"] = result["episode_len_mean"]
result["coverage"] = np.nan
if "custom_metrics" not in result:
return
if "success_rate_mean" in result["custom_metrics"]:
result["success"] = result["custom_metrics"]["success_rate_mean"]
result["out"] = result["custom_metrics"]["out_of_road_rate_mean"]
result["max_step"] = result["custom_metrics"]["max_step_rate_mean"]
result["level"] = result["custom_metrics"]["curriculum_level_mean"]
result["coverage"] = result["custom_metrics"]["data_coverage_mean"]

View File

@@ -0,0 +1,11 @@
from ray.rllib.utils import check_env
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.envs.gymnasium_wrapper import GymnasiumEnvWrapper
from gym import Env
if __name__ == '__main__':
env = GymnasiumEnvWrapper.build(ScenarioEnv)()
print(isinstance(ScenarioEnv, Env))
print(isinstance(env, Env))
print(env.observation_space)
check_env(env)

View File

@@ -0,0 +1,46 @@
import logging
from typing import Callable, Type
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict, \
EnvType
from scenarionet_training.train_utils.anisotropic_workerset import AnisotropicWorkerSet
logger = logging.getLogger(__name__)
class MultiWorkerPPO(PPOTrainer):
"""
In this class, each work will have different config for speeding up and saving memory. More importantly, it can
allow us to cover all test/train cases more evenly
"""
def _make_workers(self, env_creator: Callable[[EnvContext], EnvType],
policy_class: Type[Policy], config: TrainerConfigDict,
num_workers: int):
"""Default factory method for a WorkerSet running under this Trainer.
Override this method by passing a custom `make_workers` into
`build_trainer`.
Args:
env_creator (callable): A function that return and Env given an env
config.
policy (Type[Policy]): The Policy class to use for creating the
policies of the workers.
config (TrainerConfigDict): The Trainer's config.
num_workers (int): Number of remote rollout workers to create.
0 for local only.
Returns:
WorkerSet: The created WorkerSet.
"""
return AnisotropicWorkerSet(
env_creator=env_creator,
policy_class=policy_class,
trainer_config=config,
num_workers=num_workers,
logdir=self.logdir)

View File

@@ -0,0 +1,356 @@
import copy
import datetime
import json
import os
import pickle
from collections import defaultdict
import numpy as np
import tqdm
from metadrive.constants import TerminationState
from metadrive.envs.scenario_env import ScenarioEnv
from ray import tune
from ray.tune import CLIReporter
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
from scenarionet_training.wandb_utils import WANDB_KEY_FILE
root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
def get_api_key_file(wandb_key_file):
if wandb_key_file is not None:
default_path = os.path.expanduser(wandb_key_file)
else:
default_path = WANDB_KEY_FILE
if os.path.exists(default_path):
print("We are using this wandb key file: ", default_path)
return default_path
path = os.path.join(root, "scenarionet_training/wandb", "wandb_api_key_file.txt")
print("We are using this wandb key file: ", path)
return path
def train(
trainer,
config,
stop,
exp_name,
num_seeds=1,
num_gpus=0,
test_mode=False,
suffix="",
checkpoint_freq=10,
keep_checkpoints_num=None,
start_seed=0,
local_mode=False,
save_pkl=True,
custom_callback=None,
max_failures=0,
wandb_key_file=None,
wandb_project=None,
wandb_team="drivingforce",
wandb_log_config=True,
init_kws=None,
save_dir=None,
**kwargs
):
init_kws = init_kws or dict()
# initialize ray
if not os.environ.get("redis_password"):
initialize_ray(test_mode=test_mode, local_mode=local_mode, num_gpus=num_gpus, **init_kws)
else:
password = os.environ.get("redis_password")
assert os.environ.get("ip_head")
print(
"We detect redis_password ({}) exists in environment! So "
"we will start a ray cluster!".format(password)
)
if num_gpus:
print(
"We are in cluster mode! So GPU specification is disable and"
" should be done when submitting task to cluster! You are "
"requiring {} GPU for each machine!".format(num_gpus)
)
initialize_ray(address=os.environ["ip_head"], test_mode=test_mode, redis_password=password, **init_kws)
# prepare config
if custom_callback:
callback = custom_callback
else:
from scenarionet_training.train_utils.callbacks import DrivingCallbacks
callback = DrivingCallbacks
used_config = {
"seed": tune.grid_search([i * 100 + start_seed for i in range(num_seeds)]) if num_seeds is not None else None,
"log_level": "DEBUG" if test_mode else "INFO",
"callbacks": callback
}
if custom_callback is False:
used_config.pop("callbacks")
if config:
used_config.update(config)
config = copy.deepcopy(used_config)
if isinstance(trainer, str):
trainer_name = trainer
elif hasattr(trainer, "_name"):
trainer_name = trainer._name
else:
trainer_name = trainer.__name__
if not isinstance(stop, dict) and stop is not None:
assert np.isscalar(stop)
stop = {"timesteps_total": int(stop)}
if keep_checkpoints_num is not None and not test_mode:
assert isinstance(keep_checkpoints_num, int)
kwargs["keep_checkpoints_num"] = keep_checkpoints_num
kwargs["checkpoint_score_attr"] = "episode_reward_mean"
if "verbose" not in kwargs:
kwargs["verbose"] = 1 if not test_mode else 2
# This functionality is not supported yet!
metric_columns = CLIReporter.DEFAULT_COLUMNS.copy()
progress_reporter = CLIReporter(metric_columns=metric_columns)
progress_reporter.add_metric_column("success")
progress_reporter.add_metric_column("coverage")
progress_reporter.add_metric_column("out")
progress_reporter.add_metric_column("max_step")
progress_reporter.add_metric_column("length")
progress_reporter.add_metric_column("level")
kwargs["progress_reporter"] = progress_reporter
if wandb_key_file is not None:
assert wandb_project is not None
if wandb_project is not None:
assert wandb_project is not None
failed_wandb = False
try:
from scenarionet_training.wandb_utils.our_wandb_callbacks import OurWandbLoggerCallback
except Exception as e:
# print("Please install wandb: pip install wandb")
failed_wandb = True
if failed_wandb:
from ray.tune.logger import DEFAULT_LOGGERS
from scenarionet_training.wandb_utils.our_wandb_callbacks_ray100 import OurWandbLogger
kwargs["loggers"] = DEFAULT_LOGGERS + (OurWandbLogger,)
config["logger_config"] = {
"wandb":
{
"group": exp_name,
"exp_name": exp_name,
"entity": wandb_team,
"project": wandb_project,
"api_key_file": get_api_key_file(wandb_key_file),
"log_config": wandb_log_config,
}
}
else:
kwargs["callbacks"] = [
OurWandbLoggerCallback(
exp_name=exp_name,
api_key_file=get_api_key_file(wandb_key_file),
project=wandb_project,
group=exp_name,
log_config=wandb_log_config,
entity=wandb_team
)
]
# start training
analysis = tune.run(
trainer,
name=exp_name,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=True if "checkpoint_at_end" not in kwargs else kwargs.pop("checkpoint_at_end"),
stop=stop,
config=config,
max_failures=max_failures if not test_mode else 0,
reuse_actors=False,
local_dir=save_dir or ".",
**kwargs
)
# save training progress as insurance
if save_pkl:
pkl_path = "{}-{}{}.pkl".format(exp_name, trainer_name, "" if not suffix else "-" + suffix)
with open(pkl_path, "wb") as f:
data = analysis.fetch_trial_dataframes()
pickle.dump(data, f)
print("Result is saved at: <{}>".format(pkl_path))
return analysis
import argparse
import logging
import os
import ray
def initialize_ray(local_mode=False, num_gpus=None, test_mode=False, **kwargs):
os.environ['OMP_NUM_THREADS'] = '1'
if ray.__version__.split(".")[0] == "1": # 1.0 version Ray
if "redis_password" in kwargs:
redis_password = kwargs.pop("redis_password")
kwargs["_redis_password"] = redis_password
ray.init(
logging_level=logging.ERROR if not test_mode else logging.DEBUG,
log_to_driver=test_mode,
local_mode=local_mode,
num_gpus=num_gpus,
ignore_reinit_error=True,
include_dashboard=False,
**kwargs
)
print("Successfully initialize Ray!")
try:
print("Available resources: ", ray.available_resources())
except Exception:
pass
def get_train_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default="")
parser.add_argument("--num-gpus", type=int, default=0)
parser.add_argument("--num-seeds", type=int, default=3)
parser.add_argument("--num-cpus-per-worker", type=float, default=0.5)
parser.add_argument("--num-gpus-per-trial", type=float, default=0.25)
parser.add_argument("--test", action="store_true")
return parser
def setup_logger(debug=False):
import logging
logging.basicConfig(
level=logging.DEBUG if debug else logging.WARNING,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
)
def get_time_str():
return datetime.datetime.now().strftime("%y%m%d-%H%M%S")
def get_exp_name(args):
if args.exp_name != "":
exp_name = args.exp_name + "_" + get_time_str()
else:
exp_name = "TEST"
return exp_name
def get_eval_config(config):
eval_config = copy.deepcopy(config)
eval_config.pop("evaluation_interval")
eval_config.pop("evaluation_num_episodes")
eval_config.pop("evaluation_config")
eval_config.pop("evaluation_num_workers")
return eval_config
def get_function(ckpt, explore, config):
trainer = MultiWorkerPPO(get_eval_config(config))
trainer.restore(ckpt)
def _f(obs):
ret = trainer.compute_actions({"default_policy": obs}, explore=explore)
return ret
return _f
def eval_ckpt(config,
ckpt_path,
scenario_data_path,
num_scenarios,
start_scenario_index,
horizon=600,
render=False,
# PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance
explore=True,
log_interval=None,
):
initialize_ray(test_mode=False, num_gpus=1)
# 27 29 30 37 39
env_config = get_eval_config(config)["env_config"]
env_config.update(dict(
start_scenario_index=start_scenario_index,
num_scenarios=num_scenarios,
sequential_seed=True,
curriculum_level=1, # disable curriculum
target_success_rate=1,
horizon=horizon,
episodes_to_evaluate_curriculum=num_scenarios,
data_directory=scenario_data_path,
use_render=render))
env = ScenarioEnv(env_config)
super_data = defaultdict(list)
EPISODE_NUM = env.config["num_scenarios"]
compute_actions = get_function(ckpt_path, explore=explore, config=config)
o = env.reset()
assert env.current_seed == start_scenario_index, "Wrong start seed!"
total_cost = 0
total_reward = 0
success_rate = 0
ep_cost = 0
ep_reward = 0
success_flag = False
step = 0
def log_msg():
print("CKPT:{} | success_rate:{}, mean_episode_reward:{}, mean_episode_cost:{}".format(epi_num,
success_rate / epi_num,
total_reward / epi_num,
total_cost / epi_num))
for epi_num in tqdm.tqdm(range(0, EPISODE_NUM)):
step += 1
action_to_send = compute_actions(o)["default_policy"]
o, r, d, info = env.step(action_to_send)
if env.config["use_render"]:
env.render(text={"reward": r})
total_reward += r
ep_reward += r
total_cost += info["cost"]
ep_cost += info["cost"]
if d or step > horizon:
if info["arrive_dest"]:
success_rate += 1
success_flag = True
o = env.reset()
super_data[0].append(
{"reward": ep_reward,
"success": success_flag,
"out_of_road": info[TerminationState.OUT_OF_ROAD],
"cost": ep_cost,
"seed": env.current_seed,
"route_completion": info["route_completion"]})
ep_cost = 0.0
ep_reward = 0.0
success_flag = False
step = 0
if log_interval is not None and epi_num % log_interval == 0:
log_msg()
if log_interval is not None:
log_msg()
del compute_actions
env.close()
with open("eval_ret_{}_{}_{}.json".format(start_scenario_index,
start_scenario_index + num_scenarios,
get_time_str()), "w") as f:
json.dump(super_data, f)
return super_data

View File

@@ -0,0 +1,3 @@
import os
WANDB_KEY_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "wandb_api_key_file.txt")

View File

@@ -0,0 +1,67 @@
from ray.tune.integration.wandb import WandbLoggerCallback, _clean_log, \
Queue, WandbLogger
class OurWandbLoggerCallback(WandbLoggerCallback):
def __init__(self, exp_name, *args, **kwargs):
super(OurWandbLoggerCallback, self).__init__(*args, **kwargs)
self.exp_name = exp_name
def log_trial_start(self, trial: "Trial"):
config = trial.config.copy()
config.pop("callbacks", None) # Remove callbacks
exclude_results = self._exclude_results.copy()
# Additional excludes
exclude_results += self.excludes
# Log config keys on each result?
if not self.log_config:
exclude_results += ["config"]
# Fill trial ID and name
trial_id = trial.trial_id if trial else None
# trial_name = str(trial) if trial else None
# Project name for Wandb
wandb_project = self.project
# Grouping
wandb_group = self.group or trial.trainable_name if trial else None
# remove unpickleable items!
config = _clean_log(config)
assert trial_id is not None
run_name = "{}_{}".format(self.exp_name, trial_id)
wandb_init_kwargs = dict(
id=trial_id,
name=run_name,
resume=True,
reinit=True,
allow_val_change=True,
group=wandb_group,
project=wandb_project,
config=config
)
wandb_init_kwargs.update(self.kwargs)
self._trial_queues[trial] = Queue()
self._trial_processes[trial] = self._logger_process_cls(
queue=self._trial_queues[trial],
exclude=exclude_results,
to_config=self._config_results,
**wandb_init_kwargs
)
self._trial_processes[trial].start()
def __del__(self):
if self._trial_processes:
for v in self._trial_processes.values():
if hasattr(v, "close"):
v.close()
self._trial_processes.clear()
self._trial_processes = {}

View File

@@ -0,0 +1,80 @@
from multiprocessing import Queue
from ray.tune.integration.wandb import WandbLogger, _clean_log, _set_api_key
class OurWandbLogger(WandbLogger):
def __init__(self, config, logdir, trial):
self.exp_name = config["logger_config"]["wandb"].pop("exp_name")
super(OurWandbLogger, self).__init__(config, logdir, trial)
def _init(self):
config = self.config.copy()
config.pop("callbacks", None) # Remove callbacks
try:
if config.get("logger_config", {}).get("wandb"):
logger_config = config.pop("logger_config")
wandb_config = logger_config.get("wandb").copy()
else:
wandb_config = config.pop("wandb").copy()
except KeyError:
raise ValueError(
"Wandb logger specified but no configuration has been passed. "
"Make sure to include a `wandb` key in your `config` dict "
"containing at least a `project` specification.")
_set_api_key(wandb_config)
exclude_results = self._exclude_results.copy()
# Additional excludes
additional_excludes = wandb_config.pop("excludes", [])
exclude_results += additional_excludes
# Log config keys on each result?
log_config = wandb_config.pop("log_config", False)
if not log_config:
exclude_results += ["config"]
# Fill trial ID and name
trial_id = self.trial.trial_id if self.trial else None
trial_name = str(self.trial) if self.trial else None
# Project name for Wandb
try:
wandb_project = wandb_config.pop("project")
except KeyError:
raise ValueError(
"You need to specify a `project` in your wandb `config` dict.")
# Grouping
wandb_group = wandb_config.pop(
"group", self.trial.trainable_name if self.trial else None)
# remove unpickleable items!
config = _clean_log(config)
assert trial_id is not None
run_name = "{}_{}".format(self.exp_name, trial_id)
wandb_init_kwargs = dict(
id=trial_id,
name=run_name,
resume=True,
reinit=True,
allow_val_change=True,
group=wandb_group,
project=wandb_project,
config=config)
wandb_init_kwargs.update(wandb_config)
self._queue = Queue()
self._wandb = self._logger_process_cls(
queue=self._queue,
exclude=exclude_results,
to_config=self._config_results,
**wandb_init_kwargs)
self._wandb.start()

View File

@@ -0,0 +1,38 @@
"""
Procedure to use wandb:
1. Logup in wandb: https://wandb.ai/
2. Get the API key in personal setting
3. Store API key (a string)to some file as: ~/wandb_api_key_file.txt
4. Install wandb: pip install wandb
5. Fill the "wandb_key_file", "wandb_project" keys in our train function.
Note1: You don't need to specify who own "wandb_project", for example, in team "drivingforce"'s project
"representation", you only need to fill wandb_project="representation"
Note2: In wanbd, there are "team name", "project name", "group name" and "trial_name". We only need to care
"team name" and "project name". The "team name" is set to "drivingforce" by default. You can also use None to
log result to your personal domain. The "group name" of the experiment is exactly the "exp_name" in our context, like
"0304_train_ppo" or so.
Note3: It would be great to change the x-axis in wandb website to "timesteps_total".
Peng Zhenghao, 20210402
"""
from ray import tune
from scenarionet_training.train_utils.utils import train
if __name__ == "__main__":
config = dict(env="CartPole-v0", num_workers=0, lr=tune.grid_search([1e-2, 1e-4]))
train(
"PPO",
exp_name="test_wandb",
stop=10000,
config=config,
custom_callback=False,
test_mode=False,
local_mode=False,
wandb_project="TEST",
wandb_team="drivingforce" # drivingforce is set to default. Use None to log to your personal domain!
)

View File

@@ -0,0 +1 @@
132a8add578bdaeea5ab7a4942f35f2a17742df2

View File

@@ -40,9 +40,18 @@ install_requires = [
"shapely"
]
train_requirement = [
"ray[rllib]==1.0.0",
# "torch",
"wandb==0.12.1",
"aiohttp==3.6.0",
"gymnasium",
"tensorflow",
"tensorflow_probability"]
setup(
name="scenarionet",
python_requires='>=3.6, <3.12', # do version check with assert
python_requires='>=3.8', # do version check with assert
version=version,
description="Scalable Traffic Scenario Management System",
url="https://github.com/metadriverse/ScenarioNet",
@@ -50,12 +59,9 @@ setup(
author_email="quanyili0057@gmail.com, pzh@cs.ucla.edu",
packages=packages,
install_requires=install_requires,
# extras_require={
# "cuda": cuda_requirement,
# "nuplan": nuplan_requirement,
# "waymo": waymo_requirement,
# "all": nuplan_requirement + cuda_requirement
# },
extras_require={
"train": train_requirement,
},
include_package_data=True,
license="Apache 2.0",
long_description=long_description,