diff --git a/.gitignore b/.gitignore index c69ffb2..3c6d605 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,7 @@ dataset/* **/passed_scenarios/ **/waymo_origin /dataset/ +/scenarionet_training/wandb/*.pkl +**/TEST/ +**/experiment/ +**/wandb/ diff --git a/README.md b/README.md index 922b2c9..5ce5f0a 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,8 @@ pip install -e . ## Usage We provide some explanation and demo for all scripts here. -**You are encouraged to try them on your own, add ```-h``` or ```--help``` argument to know more details about these scripts.** +**You are encouraged to try them on your own, add ```-h``` or ```--help``` argument to know more details about these +scripts.** ### Convert @@ -45,31 +46,42 @@ python -m scenarionet.scripts.convert_pg -d pg --num_workers=16 --num_scenarios= ``` ### Merge & move + For merging two or more database, use + ``` -python -m scenarionet.merge_database -d /destination/path --from_databases /database1 /2 ... +python -m scenarionet.merge_database -d /destination/path --from /database1 /2 ... ``` + As a database contains a path mapping, one should move database folder with the following script instead of ```cp``` -command +command. +Using ```--copy_raw_data``` will copy the raw scenario file into target directory and cancel the virtual mapping. + ``` -python -m scenarionet.move_database --to /destination/path --from /source/path +python -m scenarionet.copy_database --to /destination/path --from /source/path ``` ### Verify + The following scripts will check whether all scenarios exist or can be loaded into simulator. -The missing or broken scenarios will be recorded and stored into the error file. Otherwise, no error file will be -generated. +The missing or broken scenarios will be recorded and stored into the error file. Otherwise, no error file will be +generated. With teh error file, one can build a new database excluding or including the broken or missing scenarios. **Existence check** + ``` -python -m scenarionet.verify_existence -d /database/to/check --error_file_path /error/file/path +python -m scenarionet.check_existence -d /database/to/check --error_file_path /error/file/path ``` + **Runnable check** + ``` -python -m scenarionet.verify_simulation -d /database/to/check --error_file_path /error/file/path +python -m scenarionet.check_simulation -d /database/to/check --error_file_path /error/file/path ``` + **Generating new database** + ``` python -m scenarionet.generate_from_error_file -d /new/database/path --file /error/file/path ``` @@ -77,6 +89,7 @@ python -m scenarionet.generate_from_error_file -d /new/database/path --file /err ### visualization Visualizing the simulated scenario + ``` python -m scenarionet.run_simulation -d /path/to/database --render --scenario_index ``` diff --git a/scenarionet/builder/filters.py b/scenarionet/builder/filters.py index a292c78..6b89991 100644 --- a/scenarionet/builder/filters.py +++ b/scenarionet/builder/filters.py @@ -1,6 +1,8 @@ from functools import partial +import numpy as np from metadrive.scenario.scenario_description import ScenarioDescription as SD +from metadrive.scenario.utils import read_scenario_data class ScenarioFilter: @@ -8,7 +10,7 @@ class ScenarioFilter: SMALLER = "smaller" @staticmethod - def sdc_moving_dist(metadata, target_dist, condition=GREATER): + def sdc_moving_dist(metadata, file_path, target_dist, condition=GREATER): """ This function filters the scenario based on SDC information. """ @@ -22,10 +24,11 @@ class ScenarioFilter: return False @staticmethod - def object_number(metadata, number_threshold, object_type=None, condition=SMALLER): + def object_number(metadata, file_path, number_threshold, object_type=None, condition=SMALLER): """ Return True if the scenario satisfying the object number condition :param metadata: metadata in each scenario + :param file_path: where to find this scenario :param number_threshold: number of objects threshold :param object_type: MetaDriveType.VEHICLE or other object type. If none, calculate number for all object types :param condition: SMALLER or GREATER @@ -43,9 +46,33 @@ class ScenarioFilter: return False @staticmethod - def has_traffic_light(metadata): + def has_traffic_light(metadata, file_path): return metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_TRAFFIC_LIGHTS] > 0 + @staticmethod + def no_traffic_light(metadata, file_path): + return metadata[SD.SUMMARY.NUMBER_SUMMARY][SD.SUMMARY.NUM_TRAFFIC_LIGHTS] == 0 + + @staticmethod + def no_overpass(metadata, file_path): + """ + We need read the map data to do overpass filter + """ + max_height_diff = 5 + if SD.SUMMARY.MAP_HEIGHT_DIFF in metadata: + return metadata[SD.SUMMARY.MAP_HEIGHT_DIFF] < max_height_diff + else: + # calculate online + map_features = read_scenario_data(file_path)[SD.MAP_FEATURES] + return abs(SD.map_height_diff(map_features, target=max_height_diff)) < max_height_diff + + @staticmethod + def id_filter(metadata, file_path, ids): + for id in ids: + if metadata["id"] in id: + return False + return True + @staticmethod def make(func, **kwargs): """ diff --git a/scenarionet/builder/utils.py b/scenarionet/builder/utils.py index 9d002fe..041bfd4 100644 --- a/scenarionet/builder/utils.py +++ b/scenarionet/builder/utils.py @@ -1,4 +1,6 @@ import copy +from random import sample +from metadrive.scenario.utils import read_dataset_summary import logging import os import os.path as osp @@ -27,21 +29,23 @@ def try_generating_summary(file_folder): def merge_database( - output_path, - *dataset_paths, - exist_ok=False, - overwrite=False, - try_generate_missing_file=True, - filters: List[Callable] = None + output_path, + *dataset_paths, + exist_ok=False, + overwrite=False, + try_generate_missing_file=True, + filters: List[Callable] = None, + save=True, ): """ - Combine multiple datasets. Each dataset should have a dataset_summary.pkl - :param output_path: The path to store the output dataset + Combine multiple datasets. Each database should have a dataset_summary.pkl + :param output_path: The path to store the output database :param exist_ok: If True, though the output_path already exist, still write into it :param overwrite: If True, overwrite existing dataset_summary.pkl and mapping.pkl. Otherwise, raise error :param try_generate_missing_file: If dataset_summary.pkl and mapping.pkl are missing, whether to try generating them - :param dataset_paths: Path of each dataset - :param filters: a set of filters to choose which scenario to be selected and added into this combined dataset + :param dataset_paths: Path of each database + :param filters: a set of filters to choose which scenario to be selected and added into this combined database + :param save: save to output path, immediately :return: summary, mapping """ filters = filters or [] @@ -60,15 +64,15 @@ def merge_database( mappings = {} # collect - for dataset_path in tqdm.tqdm(dataset_paths): + for dataset_path in tqdm.tqdm(dataset_paths, desc="Merge Data"): abs_dir_path = osp.abspath(dataset_path) # summary - assert osp.exists(abs_dir_path), "Wrong dataset path. Can not find dataset at: {}".format(abs_dir_path) + assert osp.exists(abs_dir_path), "Wrong database path. Can not find database at: {}".format(abs_dir_path) if not osp.exists(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE)): if try_generate_missing_file: summary = try_generating_summary(abs_dir_path) else: - raise FileNotFoundError("Can not find summary file for dataset: {}".format(abs_dir_path)) + raise FileNotFoundError("Can not find summary file for database: {}".format(abs_dir_path)) else: with open(osp.join(abs_dir_path, ScenarioDescription.DATASET.SUMMARY_FILE), "rb+") as f: summary = pickle.load(f) @@ -85,7 +89,7 @@ def merge_database( if try_generate_missing_file: mapping = {k: "" for k in summary} else: - raise FileNotFoundError("Can not find mapping file for dataset: {}".format(abs_dir_path)) + raise FileNotFoundError("Can not find mapping file for database: {}".format(abs_dir_path)) else: with open(osp.join(abs_dir_path, ScenarioDescription.DATASET.MAPPING_FILE), "rb+") as f: mapping = pickle.load(f) @@ -98,37 +102,108 @@ def merge_database( # apply filter stage file_to_pop = [] - for file_name, metadata, in summaries.items(): - if not all([fil(metadata) for fil in filters]): + for file_name in tqdm.tqdm(summaries.keys(), desc="Filter Scenarios"): + metadata = summaries[file_name] + if not all([fil(metadata, os.path.join(output_abs_path, mappings[file_name], file_name)) for fil in filters]): file_to_pop.append(file_name) for file in file_to_pop: summaries.pop(file) mappings.pop(file) - - save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings) + if save: + save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings) return summaries, mappings -def move_database( - from_path, - to_path, - exist_ok=False, - overwrite=False, +def copy_database( + from_path, + to_path, + exist_ok=False, + overwrite=False, + copy_raw_data=False, + remove_source=False ): if not os.path.exists(from_path): - raise FileNotFoundError("Can not find dataset: {}".format(from_path)) + raise FileNotFoundError("Can not find database: {}".format(from_path)) if os.path.exists(to_path): - assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a dataset" + assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a database" assert not os.path.samefile(from_path, to_path), "to_directory is the same as from_directory. Abort!" - merge_database( + files = os.listdir(from_path) + if ScenarioDescription.DATASET.MAPPING_FILE in files and ScenarioDescription.DATASET.SUMMARY_FILE in files and len( + files) > 2: + raise RuntimeError("The source database is not allowed to move! " + "This will break the relationship between this database and other database built on it." + "If it is ok for you, use 'mv' to move it manually ") + + summaries, mappings = merge_database( to_path, from_path, exist_ok=exist_ok, overwrite=overwrite, try_generate_missing_file=True, + save=False ) - files = os.listdir(from_path) - if ScenarioDescription.DATASET.MAPPING_FILE in files and ScenarioDescription.DATASET.SUMMARY_FILE in files and len( - files) == 2: + summary_file = osp.join(to_path, ScenarioDescription.DATASET.SUMMARY_FILE) + mapping_file = osp.join(to_path, ScenarioDescription.DATASET.MAPPING_FILE) + + if copy_raw_data: + logger.info("Copy raw data...") + for scenario_file in tqdm.tqdm(mappings.keys()): + rel_path = mappings[scenario_file] + shutil.copyfile(os.path.join(to_path, rel_path, scenario_file), os.path.join(to_path, scenario_file)) + mappings = {key: "./" for key in summaries.keys()} + save_summary_anda_mapping(summary_file, mapping_file, summaries, mappings) + + if remove_source and ScenarioDescription.DATASET.MAPPING_FILE in files and \ + ScenarioDescription.DATASET.SUMMARY_FILE in files and len(files) == 2: shutil.rmtree(from_path) + + +def split_database( + from_path, + to_path, + start_index, + num_scenarios, + exist_ok=False, + overwrite=False, + random=False, +): + if not os.path.exists(from_path): + raise FileNotFoundError("Can not find database: {}".format(from_path)) + if os.path.exists(to_path): + assert exist_ok, "to_directory already exists. Set exists_ok to allow turning it into a database" + assert not os.path.samefile(from_path, to_path), "to_directory is the same as from_directory. Abort!" + overwrite = overwrite, + output_abs_path = osp.abspath(to_path) + os.makedirs(output_abs_path, exist_ok=exist_ok) + summary_file = osp.join(output_abs_path, ScenarioDescription.DATASET.SUMMARY_FILE) + mapping_file = osp.join(output_abs_path, ScenarioDescription.DATASET.MAPPING_FILE) + for file in [summary_file, mapping_file]: + if os.path.exists(file): + if overwrite: + os.remove(file) + else: + raise FileExistsError("{} already exists at: {}!".format(file, output_abs_path)) + + # collect + abs_dir_path = osp.abspath(from_path) + # summary + assert osp.exists(abs_dir_path), "Wrong database path. Can not find database at: {}".format(abs_dir_path) + summaries, lookup, mappings = read_dataset_summary(from_path) + assert start_index >= 0 and start_index + num_scenarios <= len( + lookup), "No enough scenarios in source dataset: total {}, start_index: {}, need: {}".format(len(lookup), + start_index, + num_scenarios) + if random: + selected = sample(lookup[start_index:], k=num_scenarios) + else: + selected = lookup[start_index: start_index + num_scenarios] + selected_summary = {} + selected_mapping = {} + for scenario in selected: + selected_summary[scenario] = summaries[scenario] + selected_mapping[scenario] = os.path.relpath(osp.join(abs_dir_path, mappings[scenario]), output_abs_path) + + save_summary_anda_mapping(summary_file, mapping_file, selected_summary, selected_mapping) + + return summaries, mappings diff --git a/scenarionet/verify_existence.py b/scenarionet/check_existence.py similarity index 100% rename from scenarionet/verify_existence.py rename to scenarionet/check_existence.py diff --git a/scenarionet/check_overlap.py b/scenarionet/check_overlap.py new file mode 100644 index 0000000..5c74cf7 --- /dev/null +++ b/scenarionet/check_overlap.py @@ -0,0 +1,22 @@ +""" +Check If any overlap between two database +""" + +import argparse + +from scenarionet.common_utils import read_dataset_summary + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--database_1', type=str, required=True, help="The path of the first database") + parser.add_argument('--database_2', type=str, required=True, help="The path of the second database") + args = parser.parse_args() + + summary_1, _, _ = read_dataset_summary(args.database_1) + summary_2, _, _ = read_dataset_summary(args.database_2) + + intersection = set(summary_1.keys()).intersection(set(summary_2.keys())) + if len(intersection) == 0: + print("No overlapping in two database!") + else: + print("Find overlapped scenarios: {}".format(intersection)) diff --git a/scenarionet/verify_simulation.py b/scenarionet/check_simulation.py similarity index 100% rename from scenarionet/verify_simulation.py rename to scenarionet/check_simulation.py diff --git a/scenarionet/convert_waymo.py b/scenarionet/convert_waymo.py index 99c589d..afedb76 100644 --- a/scenarionet/convert_waymo.py +++ b/scenarionet/convert_waymo.py @@ -1,4 +1,5 @@ import pkg_resources # for suppress warning +import shutil import argparse import logging import os @@ -28,6 +29,20 @@ if __name__ == '__main__': default=os.path.join(SCENARIONET_REPO_PATH, "waymo_origin"), help="The directory stores all waymo tfrecord" ) + parser.add_argument( + "--start_file_index", + default=0, + type=int, + help="Control how many files to use. We will list all files in the raw data folder " + "and select files[start_file_index: start_file_index+num_files]" + ) + parser.add_argument( + "--num_files", + default=1000, + type=int, + help="Control how many files to use. We will list all files in the raw data folder " + "and select files[start_file_index: start_file_index+num_files]" + ) args = parser.parse_args() overwrite = args.overwrite @@ -35,8 +50,19 @@ if __name__ == '__main__': output_path = args.database_path version = args.version + save_path = output_path + if os.path.exists(output_path): + if not overwrite: + raise ValueError( + "Directory {} already exists! Abort. " + "\n Try setting overwrite=True or adding --overwrite".format(output_path) + ) + else: + shutil.rmtree(output_path) + waymo_data_directory = os.path.join(SCENARIONET_DATASET_PATH, args.raw_data_path) - scenarios = get_waymo_scenarios(waymo_data_directory) + scenarios = get_waymo_scenarios(waymo_data_directory, args.start_file_index, args.num_files, + num_workers=8) # do not use too much worker to read data write_to_directory( convert_func=convert_waymo_scenario, diff --git a/scenarionet/converter/nuplan/utils.py b/scenarionet/converter/nuplan/utils.py index 85f04e5..ccf1ed5 100644 --- a/scenarionet/converter/nuplan/utils.py +++ b/scenarionet/converter/nuplan/utils.py @@ -357,9 +357,9 @@ def extract_traffic(scenario: NuPlanScenario, center): type=MetaDriveType.UNSET, state=dict( position=np.zeros(shape=(episode_len, 3)), - heading=np.zeros(shape=(episode_len, )), + heading=np.zeros(shape=(episode_len,)), velocity=np.zeros(shape=(episode_len, 2)), - valid=np.zeros(shape=(episode_len, )), + valid=np.zeros(shape=(episode_len,)), length=np.zeros(shape=(episode_len, 1)), width=np.zeros(shape=(episode_len, 1)), height=np.zeros(shape=(episode_len, 1)) @@ -490,3 +490,84 @@ example_scenario_types = "[behind_pedestrian_on_pickup_dropoff, \ starting_unprotected_cross_turn, \ starting_protected_noncross_turn, \ on_pickup_dropoff]" + +# - accelerating_at_crosswalk +# - accelerating_at_stop_sign +# - accelerating_at_stop_sign_no_crosswalk +# - accelerating_at_traffic_light +# - accelerating_at_traffic_light_with_lead +# - accelerating_at_traffic_light_without_lead +# - behind_bike +# - behind_long_vehicle +# - behind_pedestrian_on_driveable +# - behind_pedestrian_on_pickup_dropoff +# - changing_lane +# - changing_lane_to_left +# - changing_lane_to_right +# - changing_lane_with_lead +# - changing_lane_with_trail +# - crossed_by_bike +# - crossed_by_vehicle +# - following_lane_with_lead +# - following_lane_with_slow_lead +# - following_lane_without_lead +# - high_lateral_acceleration +# - high_magnitude_jerk +# - high_magnitude_speed +# - low_magnitude_speed +# - medium_magnitude_speed +# - near_barrier_on_driveable +# - near_construction_zone_sign +# - near_high_speed_vehicle +# - near_long_vehicle +# - near_multiple_bikes +# - near_multiple_pedestrians +# - near_multiple_vehicles +# - near_pedestrian_at_pickup_dropoff +# - near_pedestrian_on_crosswalk +# - near_pedestrian_on_crosswalk_with_ego +# - near_trafficcone_on_driveable +# - on_all_way_stop_intersection +# - on_carpark +# - on_intersection +# - on_pickup_dropoff +# - on_stopline_crosswalk +# - on_stopline_stop_sign +# - on_stopline_traffic_light +# - on_traffic_light_intersection +# - starting_high_speed_turn +# - starting_left_turn +# - starting_low_speed_turn +# - starting_protected_cross_turn +# - starting_protected_noncross_turn +# - starting_right_turn +# - starting_straight_stop_sign_intersection_traversal +# - starting_straight_traffic_light_intersection_traversal +# - starting_u_turn +# - starting_unprotected_cross_turn +# - starting_unprotected_noncross_turn +# - stationary +# - stationary_at_crosswalk +# - stationary_at_traffic_light_with_lead +# - stationary_at_traffic_light_without_lead +# - stationary_in_traffic +# - stopping_at_crosswalk +# - stopping_at_stop_sign_no_crosswalk +# - stopping_at_stop_sign_with_lead +# - stopping_at_stop_sign_without_lead +# - stopping_at_traffic_light_with_lead +# - stopping_at_traffic_light_without_lead +# - stopping_with_lead +# - traversing_crosswalk +# - traversing_intersection +# - traversing_narrow_lane +# - traversing_pickup_dropoff +# - traversing_traffic_light_intersection +# - waiting_for_pedestrian_to_cross +# + +all_scenario_types = "[near_pedestrian_on_crosswalk_with_ego," \ + "near_trafficcone_on_driveable, " \ + "following_lane_with_lead, " \ + "following_lane_with_slow_lead, " \ + "following_lane_without_lead]" diff --git a/scenarionet/converter/waymo/utils.py b/scenarionet/converter/waymo/utils.py index 9096552..c2250e4 100644 --- a/scenarionet/converter/waymo/utils.py +++ b/scenarionet/converter/waymo/utils.py @@ -3,6 +3,8 @@ import multiprocessing import os import pickle +import tqdm + from scenarionet.converter.utils import mph_to_kmh from scenarionet.converter.waymo.type import WaymoLaneType, WaymoAgentType, WaymoRoadLineType, WaymoRoadEdgeType @@ -418,13 +420,20 @@ def convert_waymo_scenario(scenario, version): } for count, id in enumerate(track_id) } + # clean memory + del scenario + scenario = None return md_scenario -def get_waymo_scenarios(waymo_data_directory, num_workers=8): +def get_waymo_scenarios(waymo_data_directory, start_index, num, num_workers=8): # parse raw data from input path to output path, # there is 1000 raw data in google cloud, each of them produce about 500 pkl file + logger.info("\n Reading raw data") file_list = os.listdir(waymo_data_directory) + assert len(file_list) >= start_index + num and start_index >= 0, \ + "No sufficient files ({}) in raw_data_directory. need: {}, start: {}".format(len(file_list), num, start_index) + file_list = file_list[start_index: start_index + num] num_files = len(file_list) if num_files < num_workers: # single process @@ -441,23 +450,29 @@ def get_waymo_scenarios(waymo_data_directory, num_workers=8): argument_list.append([waymo_data_directory, file_list[i * num_files_each_worker:end_idx]]) # Run, workers and process result from worker - with multiprocessing.Pool(num_workers) as p: - all_result = list(p.imap(read_from_files, argument_list)) - ret = [] - - # get result - for r in all_result: - if len(r) == 0: - logger.warning("0 scenarios found") - ret += r - logger.info("\n Find {} waymo scenarios from {} files".format(len(ret), num_files)) - return ret + # with multiprocessing.Pool(num_workers) as p: + # all_result = list(p.imap(read_from_files, argument_list)) + # Disable multiprocessing read + all_result = read_from_files([waymo_data_directory, file_list]) + # ret = [] + # + # # get result + # for r in all_result: + # if len(r) == 0: + # logger.warning("0 scenarios found") + # ret += r + logger.info("\n Find {} waymo scenarios from {} files".format(len(all_result), num_files)) + return all_result def read_from_files(arg): + try: + scenario_pb2 + except NameError: + raise ImportError("Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0==1.5.0") waymo_data_directory, file_list = arg[0], arg[1] scenarios = [] - for file_count, file in enumerate(file_list): + for file in tqdm.tqdm(file_list): file_path = os.path.join(waymo_data_directory, file) if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)): continue diff --git a/scenarionet/copy_database.py b/scenarionet/copy_database.py new file mode 100644 index 0000000..758e18e --- /dev/null +++ b/scenarionet/copy_database.py @@ -0,0 +1,47 @@ +import argparse + +from scenarionet.builder.utils import copy_database + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--from', required=True, help="Which database to move.") + parser.add_argument( + "--to", + required=True, + help="The name of the new database. " + "It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. " + "If exists_ok=True, those two .pkl files will be stored in an existing directory and turn " + "that directory into a database." + ) + parser.add_argument( + "--remove_source", + action="store_true", + help="Remove the `from_database` if set this flag" + ) + parser.add_argument( + "--copy_raw_data", + action="store_true", + help="Instead of creating virtual file mapping, copy raw scenario.pkl file" + ) + parser.add_argument( + "--exist_ok", + action="store_true", + help="Still allow to write, if the to_folder exists already. " + "This write will only create two .pkl files and this directory will become a database." + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, " + "whether to overwrite both files" + ) + args = parser.parse_args() + from_path = args.__getattribute__("from") + copy_database( + from_path, + args.to, + exist_ok=args.exist_ok, + overwrite=args.overwrite, + copy_raw_data=args.copy_raw_data, + remove_source=args.remove_source + ) diff --git a/scenarionet/filter_database.py b/scenarionet/filter_database.py new file mode 100644 index 0000000..619741a --- /dev/null +++ b/scenarionet/filter_database.py @@ -0,0 +1,111 @@ +import argparse + +from scenarionet.builder.filters import ScenarioFilter +from scenarionet.builder.utils import merge_database + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--database_path", + "-d", + required=True, + help="The name of the new database. " + "It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. " + "If exists_ok=True, those two .pkl files will be stored in an existing directory and turn " + "that directory into a database." + ) + parser.add_argument( + '--from', + required=True, + type=str, + help="Which dataset to filter. It takes one directory path as input" + ) + parser.add_argument( + "--exist_ok", + action="store_true", + help="Still allow to write, if the dir exists already. " + "This write will only create two .pkl files and this directory will become a database." + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, " + "whether to overwrite both files" + ) + parser.add_argument( + "--moving_dist", + action="store_true", + help="add this flag to select cases with SDC moving dist > sdc_moving_dist_min" + ) + parser.add_argument( + "--sdc_moving_dist_min", + default=10, + type=float, + help="Selecting case with sdc_moving_dist > this value. " + ) + + parser.add_argument( + "--num_object", + action="store_true", + help="add this flag to select cases with object_num < max_num_object" + ) + parser.add_argument( + "--max_num_object", + default=30, + type=float, + help="case will be selected if num_obj < this argument" + ) + + parser.add_argument( + "--no_overpass", + action="store_true", + help="Scenarios with overpass WON'T be selected" + ) + + parser.add_argument( + "--no_traffic_light", + action="store_true", + help="Scenarios with traffic light WON'T be selected" + ) + + parser.add_argument( + "--id_filter", + action="store_true", + help="Scenarios with indicated name will NOT be selected" + ) + + parser.add_argument( + "--exclude_ids", + nargs='+', + default=[], + help="Scenarios with indicated name will NOT be selected" + ) + + args = parser.parse_args() + target = args.sdc_moving_dist_min + obj_threshold = args.max_num_object + from_path = args.__getattribute__("from") + + filters = [] + if args.no_overpass: + filters.append(ScenarioFilter.make(ScenarioFilter.no_overpass)) + if args.num_object: + filters.append(ScenarioFilter.make(ScenarioFilter.object_number, number_threshold=obj_threshold)) + if args.moving_dist: + filters.append(ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=target, condition="greater")) + if args.no_traffic_light: + filters.append(ScenarioFilter.make(ScenarioFilter.no_traffic_light)) + if args.id_filter: + filters.append(ScenarioFilter.make(ScenarioFilter.id_filter, ids=args.exclude_ids)) + + if len(filters) == 0: + raise ValueError("No filters are applied. Abort.") + + merge_database( + args.database_path, + from_path, + exist_ok=args.exist_ok, + overwrite=args.overwrite, + try_generate_missing_file=True, + filters=filters + ) diff --git a/scenarionet/generate_from_error_file.py b/scenarionet/generate_from_error_file.py index 5897f24..a254db9 100644 --- a/scenarionet/generate_from_error_file.py +++ b/scenarionet/generate_from_error_file.py @@ -5,14 +5,14 @@ from scenarionet.verifier.error import ErrorFile if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("--database_path", "-d", required=True, help="The path of the newly generated dataset") + parser.add_argument("--database_path", "-d", required=True, help="The path of the newly generated database") parser.add_argument("--file", "-f", required=True, help="The path of the error file, should be xyz.json") parser.add_argument("--overwrite", action="store_true", help="If the database_path exists, overwrite it") parser.add_argument( "--broken", action="store_true", - help="By default, only successful scenarios will be picked to build the new dataset. " - "If turn on this flog, it will generate dataset containing only broken scenarios." + help="By default, only successful scenarios will be picked to build the new database. " + "If turn on this flog, it will generate database containing only broken scenarios." ) args = parser.parse_args() ErrorFile.generate_dataset(args.file, args.database_path, args.overwrite, args.broken) diff --git a/scenarionet/merge_database.py b/scenarionet/merge_database.py index be23d5c..fd43799 100644 --- a/scenarionet/merge_database.py +++ b/scenarionet/merge_database.py @@ -1,5 +1,5 @@ -import pkg_resources # for suppress warning import argparse + from scenarionet.builder.filters import ScenarioFilter from scenarionet.builder.utils import merge_database @@ -9,13 +9,13 @@ if __name__ == '__main__': "--database_path", "-d", required=True, - help="The name of the new combined dataset. " - "It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. " - "If exists_ok=True, those two .pkl files will be stored in an existing directory and turn " - "that directory into a dataset." + help="The name of the new combined database. " + "It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. " + "If exists_ok=True, those two .pkl files will be stored in an existing directory and turn " + "that directory into a database." ) parser.add_argument( - '--from_datasets', + '--from', required=True, nargs='+', default=[], @@ -25,32 +25,38 @@ if __name__ == '__main__': "--exist_ok", action="store_true", help="Still allow to write, if the dir exists already. " - "This write will only create two .pkl files and this directory will become a dataset." + "This write will only create two .pkl files and this directory will become a database." ) parser.add_argument( "--overwrite", action="store_true", help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, " - "whether to overwrite both files" + "whether to overwrite both files" + ) + parser.add_argument( + "--filter_moving_dist", + action="store_true", + help="add this flag to select cases with SDC moving dist > sdc_moving_dist_min" ) parser.add_argument( "--sdc_moving_dist_min", - default=20, + default=5, + type=float, help="Selecting case with sdc_moving_dist > this value. " - "We will add more filter conditions in the future." + "We will add more filter conditions in the future." ) args = parser.parse_args() target = args.sdc_moving_dist_min filters = [ScenarioFilter.make(ScenarioFilter.sdc_moving_dist, target_dist=target, condition="greater")] - - if len(args.from_datasets) != 0: + source = args.__getattribute__("from") + if len(source) != 0: merge_database( args.database_path, - *args.from_datasets, + *source, exist_ok=args.exist_ok, overwrite=args.overwrite, try_generate_missing_file=True, - filters=filters + filters=filters if args.filter_moving_dist else [] ) else: - raise ValueError("No source dataset are provided. Abort.") + raise ValueError("No source database are provided. Abort.") diff --git a/scenarionet/move_database.py b/scenarionet/move_database.py deleted file mode 100644 index d32ca78..0000000 --- a/scenarionet/move_database.py +++ /dev/null @@ -1,35 +0,0 @@ -import argparse - -from scenarionet.builder.utils import move_database - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--from', required=True, help="Which dataset to move.") - parser.add_argument( - "--to", - required=True, - help="The name of the new dataset. " - "It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. " - "If exists_ok=True, those two .pkl files will be stored in an existing directory and turn " - "that directory into a dataset." - ) - parser.add_argument( - "--exist_ok", - action="store_true", - help="Still allow to write, if the to_folder exists already. " - "This write will only create two .pkl files and this directory will become a dataset." - ) - parser.add_argument( - "--overwrite", - action="store_true", - help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, " - "whether to overwrite both files" - ) - args = parser.parse_args() - from_path = args.__getattr__("from") - move_database( - from_path, - args.to, - exist_ok=args.exist_ok, - overwrite=args.overwrite, - ) diff --git a/scenarionet/num_scenarios.py b/scenarionet/num_scenarios.py new file mode 100644 index 0000000..162b2cd --- /dev/null +++ b/scenarionet/num_scenarios.py @@ -0,0 +1,18 @@ +import pkg_resources # for suppress warning +import argparse +import logging +from scenarionet.common_utils import read_dataset_summary + +logger = logging.getLogger(__file__) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--database_path", + "-d", + required=True, + help="Database to check number of scenarios" + ) + args = parser.parse_args() + summary, _, _, = read_dataset_summary(args.database_path) + logger.info("Number of scenarios: {}".format(len(summary))) diff --git a/scenarionet/run_simulation.py b/scenarionet/run_simulation.py index 65bb13f..aff6d1f 100644 --- a/scenarionet/run_simulation.py +++ b/scenarionet/run_simulation.py @@ -7,7 +7,7 @@ from metadrive.scenario.utils import get_number_of_scenarios if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("--database_path", "-d", required=True, help="The path of the dataset") + parser.add_argument("--database_path", "-d", required=True, help="The path of the database") parser.add_argument("--render", action="store_true", help="Enable 3D rendering") parser.add_argument("--scenario_index", default=None, type=int, help="Specifying a scenario to run") args = parser.parse_args() @@ -38,13 +38,13 @@ if __name__ == '__main__': "data_directory": database_path, } ) - for seed in range(num_scenario if args.scenario_index is not None else 1000000): - env.reset(force_seed=seed if args.scenario_index is None else args.scenario_index) + for index in range(num_scenario if args.scenario_index is not None else 1000000): + env.reset(force_seed=index if args.scenario_index is None else args.scenario_index) for t in range(10000): o, r, d, info = env.step([0, 0]) if env.config["use_render"]: env.render(text={ - "seed": env.engine.global_seed + env.config["start_scenario_index"], + "scenario index": env.engine.global_seed + env.config["start_scenario_index"], }) if d and info["arrive_dest"]: diff --git a/scenarionet/split_database.py b/scenarionet/split_database.py new file mode 100644 index 0000000..04420f7 --- /dev/null +++ b/scenarionet/split_database.py @@ -0,0 +1,47 @@ +""" +This script is for extracting a subset of data from an existing database +""" +import pkg_resources # for suppress warning +import argparse + +from scenarionet.builder.utils import split_database + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--from', required=True, help="Which database to extract data from.") + parser.add_argument( + "--to", + required=True, + help="The name of the new database. " + "It will create a new directory to store dataset_summary.pkl and dataset_mapping.pkl. " + "If exists_ok=True, those two .pkl files will be stored in an existing directory and turn " + "that directory into a database." + ) + parser.add_argument("--num_scenarios", type=int, default=64, help="how many scenarios to extract (default: 30)") + parser.add_argument("--start_index", type=int, default=0, help="which index to start") + parser.add_argument("--random", action="store_true", help="If set to true, it will choose scenarios randomly " + "from all_scenarios[start_index:]. " + "Otherwise, the scenarios will be selected sequentially") + parser.add_argument( + "--exist_ok", + action="store_true", + help="Still allow to write, if the to_folder exists already. " + "This write will only create two .pkl files and this directory will become a database." + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="When exists ok is set but summary.pkl and map.pkl exists in existing dir, " + "whether to overwrite both files" + ) + args = parser.parse_args() + from_path = args.__getattribute__("from") + split_database( + from_path, + args.to, + args.start_index, + args.num_scenarios, + exist_ok=args.exist_ok, + overwrite=args.overwrite, + random=args.random + ) diff --git a/scenarionet/tests/local_test/combine_verify_generate.sh b/scenarionet/tests/local_test/combine_verify_generate.sh index 3e4ad08..a6f82fa 100644 --- a/scenarionet/tests/local_test/combine_verify_generate.sh +++ b/scenarionet/tests/local_test/combine_verify_generate.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -python ../../merge_database.py --overwrite --exist_ok --database_path ../tmp/test_combine_dataset --from_datasets ../../../dataset/waymo ../../../dataset/pg ../../../dataset/nuscenes ../../../dataset/nuplan --overwrite -python ../../verify_simulation.py --overwrite --database_path ../tmp/test_combine_dataset --error_file_path ../tmp/test_combine_dataset --random_drop --num_workers=16 +python ../../merge_database.py --overwrite --exist_ok --database_path ../tmp/test_combine_dataset --from ../../../dataset/waymo ../../../dataset/pg ../../../dataset/nuscenes ../../../dataset/nuplan --overwrite +python ../../check_simulation.py --overwrite --database_path ../tmp/test_combine_dataset --error_file_path ../tmp/test_combine_dataset --random_drop --num_workers=16 python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --database_path ../tmp/verify_pass python ../../generate_from_error_file.py --file ../tmp/test_combine_dataset/error_scenarios_for_test_combine_dataset.json --overwrite --database_path ../tmp/verify_fail --broken \ No newline at end of file diff --git a/scenarionet/tests/local_test/convert_pg_large.sh b/scenarionet/tests/local_test/convert_pg_large.sh index 8e65e6b..58eaa79 100644 --- a/scenarionet/tests/local_test/convert_pg_large.sh +++ b/scenarionet/tests/local_test/convert_pg_large.sh @@ -32,8 +32,8 @@ done # combine the datasets if [ "$overwrite" = true ]; then - python -m scenarionet.scripts.merge_database --database_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --overwrite --exist_ok + python -m scenarionet.scripts.merge_database --database_path $dataset_path --from $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --overwrite --exist_ok else - python -m scenarionet.scripts.merge_database --database_path $dataset_path --from_datasets $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --exist_ok + python -m scenarionet.scripts.merge_database --database_path $dataset_path --from $(for i in $(seq 0 $((num_sub_dataset-1))); do echo -n "${dataset_path}/pg_$i "; done) --exist_ok fi diff --git a/scenarionet/tests/script/capture_pg.py b/scenarionet/tests/script/capture_pg.py new file mode 100644 index 0000000..1bf1313 --- /dev/null +++ b/scenarionet/tests/script/capture_pg.py @@ -0,0 +1,119 @@ +import pygame +from metadrive.envs.metadrive_env import MetaDriveEnv +from metadrive.utils import setup_logger + +if __name__ == "__main__": + setup_logger(True) + env = MetaDriveEnv( + { + "num_scenarios": 1, + "traffic_density": 0.15, + "traffic_mode": "hybrid", + "start_seed": 74, + # "_disable_detector_mask":True, + # "debug_physics_world": True, + # "debug": True, + # "global_light": False, + # "debug_static_world": True, + "show_interface": False, + "cull_scene": False, + "random_spawn_lane_index": False, + "random_lane_width": False, + # "image_observation": True, + # "controller": "joystick", + # "show_coordinates": True, + "random_agent_model": False, + "manual_control": True, + "use_render": True, + "accident_prob": 1, + "decision_repeat": 5, + "interface_panel": [], + "need_inverse_traffic": False, + "rgb_clip": True, + "map": 2, + # "agent_policy": ExpertPolicy, + "random_traffic": False, + # "random_lane_width": True, + "driving_reward": 1.0, + # "pstats": True, + "force_destroy": False, + # "show_skybox": False, + "show_fps": False, + "render_pipeline": True, + # "camera_dist": 8, + "window_size": (1600, 900), + "camera_dist": 9, + # "camera_pitch": 30, + # "camera_height": 1, + # "camera_smooth": False, + # "camera_height": -1, + "vehicle_config": { + "enable_reverse": False, + # "vehicle_model": "xl", + # "rgb_camera": (1024, 1024), + # "spawn_velocity": [8.728615581032535, -0.24411703918728195], + "spawn_velocity_car_frame": True, + # "image_source": "depth_camera", + # "random_color": True + # "show_lidar": True, + "spawn_lane_index": None, + # "destination":"2R1_3_", + # "show_side_detector": True, + # "show_lane_line_detector": True, + # "side_detector": dict(num_lasers=2, distance=50), + # "lane_line_detector": dict(num_lasers=2, distance=50), + # "show_line_to_navi_mark": True, + "show_navi_mark": False, + # "show_dest_mark": True + }, + } + ) + + o = env.reset() + + + def capture(): + env.capture() + ret = env.render(mode="topdown", screen_size=(1600, 900), film_size=(2000, 2000), track_target_vehicle=True) + pygame.image.save(ret, "top_down_{}.png".format(env.current_seed)) + + env.engine.accept("c", capture) + # env.main_camera.set_follow_lane(True) + # env.vehicle.get_camera("rgb_camera").save_image(env.vehicle) + # for line in env.engine.coordinate_line: + # line.reparentTo(env.vehicle.origin) + # env.vehicle.set_velocity([5, 0], in_local_frame=True) + for s in range(1, 100000): + # env.vehicle.set_velocity([1, 0], in_local_frame=True) + o, r, d, info = env.step([0, 0]) + + # env.vehicle.set_pitch(-np.pi/4) + # [0.09231533, 0.491018, 0.47076905, 0.7691619, 0.5, 0.5, 1.0, 0.0, 0.48037243, 0.8904728, 0.81229943, 0.7317231, 1.0, 0.85320455, 0.9747932, 0.65675277, 0.0, 0.5, 0.5] + # else: + # if s % 100 == 0: + # env.close() + # env.reset() + # info["fuel"] = env.vehicle.energy_consumption + # env.render( + # text={ + # # "heading_diff": env.vehicle.heading_diff(env.vehicle.lane), + # # "lane_width": env.vehicle.lane.width, + # # "lane_index": env.vehicle.lane_index, + # # "lateral": env.vehicle.lane.local_coordinates(env.vehicle.position), + # "current_seed": env.current_seed + # } + # ) + # if d: + # env.reset() + # # assert env.observation_space.contains(o) + # if (s + 1) % 100 == 0: + # # print( + # "Finish {}/10000 simulation steps. Time elapse: {:.4f}. Average FPS: {:.4f}".format( + # s + 1,f + # time.time() - start, (s + 1) / (time.time() - start) + # ) + # ) + # if d: + # # # env.close() + # # # print(len(env.engine._spawned_objects)) + # env.reset() diff --git a/scenarionet/tests/script/compare_data.py b/scenarionet/tests/script/compare_data.py new file mode 100644 index 0000000..a271761 --- /dev/null +++ b/scenarionet/tests/script/compare_data.py @@ -0,0 +1,17 @@ +from metadrive.scenario.scenario_description import ScenarioDescription as SD +from metadrive.scenario.utils import read_scenario_data, read_dataset_summary, assert_scenario_equal +from scenarionet.common_utils import read_scenario + +if __name__ == '__main__': + data_1 = "D:\\code\\scenarionet\\dataset\pg_2000" + data_2 = "C:\\Users\\x1\\Desktop\\remote" + summary_1, lookup_1, mapping_1 = read_dataset_summary(data_1) + summary_2, lookup_2, mapping_2 = read_dataset_summary(data_2) + # assert lookup_1[-10:] == lookup_2 + scenarios_1 = {} + scenarios_2 = {} + + for i in range(9): + scenarios_1[str(i)] = read_scenario(data_1, mapping_1, lookup_1[-9+i]) + scenarios_2[str(i)] = read_scenario(data_2, mapping_2, lookup_2[i]) + # assert_scenario_equal(scenarios_1, scenarios_2, check_self_type=False, only_compare_sdc=True) diff --git a/scenarionet/tests/script/generate_sensor.py b/scenarionet/tests/script/generate_sensor.py new file mode 100644 index 0000000..e4bc506 --- /dev/null +++ b/scenarionet/tests/script/generate_sensor.py @@ -0,0 +1,97 @@ +import time +import pygame +from metadrive.engine.asset_loader import AssetLoader +from metadrive.envs.scenario_env import ScenarioEnv +from metadrive.policy.replay_policy import ReplayEgoCarPolicy + +NuScenesEnv = ScenarioEnv + +if __name__ == "__main__": + env = NuScenesEnv( + { + "use_render": True, + "agent_policy": ReplayEgoCarPolicy, + "show_interface": False, + # "need_lane_localization": False, + "show_logo": False, + "no_traffic": False, + "sequential_seed": True, + "reactive_traffic": False, + "show_fps": False, + # "debug": True, + # "render_pipeline": True, + "daytime": "11:01", + "window_size": (1600, 900), + "camera_dist": 0.8, + "camera_height": 1.5, + "camera_pitch": None, + "camera_fov": 60, + "start_scenario_index": 0, + "num_scenarios": 10, + # "force_reuse_object_name": True, + # "data_directory": "/home/shady/Downloads/test_processed", + "horizon": 1000, + # "no_static_vehicles": True, + # "show_policy_mark": True, + # "show_coordinates": True, + # "force_destroy": True, + # "default_vehicle_in_traffic": True, + "vehicle_config": dict( + # light=True, + # random_color=True, + show_navi_mark=False, + use_special_color=False, + image_source="depth_camera", + rgb_camera=(1600, 900), + depth_camera=(1600, 900, True), + # no_wheel_friction=True, + lidar=dict(num_lasers=120, distance=50), + lane_line_detector=dict(num_lasers=0, distance=50), + side_detector=dict(num_lasers=12, distance=50) + ), + "data_directory": AssetLoader.file_path("nuscenes", return_raw_style=False), + "image_observation": True, + } + ) + + # 0,1,3,4,5,6 + + success = [] + reset_num = 0 + start = time.time() + reset_used_time = 0 + s = 0 + while True: + # for i in range(10): + start_reset = time.time() + env.reset(force_seed=0) + + reset_used_time += time.time() - start_reset + reset_num += 1 + for t in range(10000): + if t==30: + # env.capture("camera_deluxe.jpg") + # ret = env.render(mode="topdown", screen_size=(1600, 900), film_size=(5000, 5000), track_target_vehicle=True) + # pygame.image.save(ret, "top_down.png") + env.vehicle.get_camera("depth_camera").save_image(env.vehicle, "camera.jpg") + o, r, d, info = env.step([1, 0.88]) + assert env.observation_space.contains(o) + s += 1 + # if env.config["use_render"]: + # env.render(text={"seed": env.current_seed, + # # "num_map": info["num_stored_maps"], + # "data_coverage": info["data_coverage"], + # "reward": r, + # "heading_r": info["step_reward_heading"], + # "lateral_r": info["step_reward_lateral"], + # "smooth_action_r": info["step_reward_action_smooth"]}) + if d: + print( + "Time elapse: {:.4f}. Average FPS: {:.4f}, AVG_Reset_time: {:.4f}".format( + time.time() - start, s / (time.time() - start - reset_used_time), + reset_used_time / reset_num + ) + ) + print("seed:{}, success".format(env.engine.global_random_seed)) + print(list(env.engine.curriculum_manager.recent_success.dict.values())) + break diff --git a/scenarionet/tests/script/replay_origin.py b/scenarionet/tests/script/replay_origin.py new file mode 100644 index 0000000..83777c9 --- /dev/null +++ b/scenarionet/tests/script/replay_origin.py @@ -0,0 +1,104 @@ +import time + +import pygame +from metadrive.envs.scenario_env import ScenarioEnv +from metadrive.policy.replay_policy import ReplayEgoCarPolicy + +NuScenesEnv = ScenarioEnv + +if __name__ == "__main__": + env = NuScenesEnv( + { + "use_render": True, + "agent_policy": ReplayEgoCarPolicy, + "show_interface": False, + "image_observation": False, + "show_logo": False, + "no_traffic": False, + "drivable_region_extension": 15, + "sequential_seed": True, + "reactive_traffic": False, + "show_fps": False, + # "debug": True, + "render_pipeline": True, + "daytime": "19:30", + "window_size": (1600, 900), + "camera_dist": 9, + # "camera_height": 1.5, + # "camera_pitch": None, + # "camera_fov": 60, + "start_scenario_index": 0, + "num_scenarios": 4, + # "force_reuse_object_name": True, + # "data_directory": "/home/shady/Downloads/test_processed", + "horizon": 1000, + # "no_static_vehicles": True, + # "show_policy_mark": True, + # "show_coordinates": True, + # "force_destroy": True, + # "default_vehicle_in_traffic": True, + "vehicle_config": dict( + # light=True, + # random_color=True, + show_navi_mark=False, + use_special_color=False, + image_source="depth_camera", + # rgb_camera=(1600, 900), + # depth_camera=(1600, 900, True), + # no_wheel_friction=True, + lidar=dict(num_lasers=120, distance=50), + lane_line_detector=dict(num_lasers=0, distance=50), + side_detector=dict(num_lasers=12, distance=50) + ), + "data_directory": "D:\\code\\scenarionet\\scenarionet\\tests\\script\\waymo_scenes_adv" + } + ) + + # 0,1,3,4,5,6 + + success = [] + reset_num = 0 + start = time.time() + reset_used_time = 0 + s = 0 + + env.reset() + + + def capture(): + env.capture() + ret = env.render(mode="topdown", screen_size=(1600, 900), film_size=(7000, 7000), track_target_vehicle=True) + pygame.image.save(ret, "top_down_{}.png".format(env.current_seed)) + + + env.engine.accept("c", capture) + + while True: + # for i in range(10): + start_reset = time.time() + env.reset() + + reset_used_time += time.time() - start_reset + reset_num += 1 + for t in range(10000): + o, r, d, info = env.step([1, 0.88]) + assert env.observation_space.contains(o) + s += 1 + # if env.config["use_render"]: + # env.render(text={"seed": env.current_seed, + # # "num_map": info["num_stored_maps"], + # "data_coverage": info["data_coverage"], + # "reward": r, + # "heading_r": info["step_reward_heading"], + # "lateral_r": info["step_reward_lateral"], + # "smooth_action_r": info["step_reward_action_smooth"]}) + if d: + print( + "Time elapse: {:.4f}. Average FPS: {:.4f}, AVG_Reset_time: {:.4f}".format( + time.time() - start, s / (time.time() - start - reset_used_time), + reset_used_time / reset_num + ) + ) + print("seed:{}, success".format(env.engine.global_random_seed)) + print(list(env.engine.curriculum_manager.recent_success.dict.values())) + break diff --git a/scenarionet/tests/script/run_env.py b/scenarionet/tests/script/run_env.py index 682a145..77061b4 100644 --- a/scenarionet/tests/script/run_env.py +++ b/scenarionet/tests/script/run_env.py @@ -50,9 +50,9 @@ if __name__ == '__main__': long, lat, = c_lane.local_coordinates(env.vehicle.position) if env.config["use_render"]: env.render(text={ - "seed": env.engine.global_seed + env.config["start_scenario_index"], + "scenario index": env.engine.global_seed + env.config["start_scenario_index"], }) if d and info["arrive_dest"]: - print("seed:{}, success".format(env.engine.global_random_seed)) + print("scenario index:{}, success".format(env.engine.global_random_seed)) break diff --git a/scenarionet/tests/test_dataset/overpass/sd_waymo_v1.2_eb4b91b10ca94ff2.pkl b/scenarionet/tests/test_dataset/overpass/sd_waymo_v1.2_eb4b91b10ca94ff2.pkl new file mode 100644 index 0000000..9c99d8a Binary files /dev/null and b/scenarionet/tests/test_dataset/overpass/sd_waymo_v1.2_eb4b91b10ca94ff2.pkl differ diff --git a/scenarionet/tests/test_filter_overpass.py b/scenarionet/tests/test_filter_overpass.py new file mode 100644 index 0000000..4e005ae --- /dev/null +++ b/scenarionet/tests/test_filter_overpass.py @@ -0,0 +1,37 @@ +import os +import os.path + +from metadrive.engine.asset_loader import AssetLoader + +from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH +from scenarionet.builder.filters import ScenarioFilter +from scenarionet.builder.utils import merge_database + + +def test_filter_overpass(): + overpass_1 = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", "overpass") + overpass_in_md = AssetLoader.file_path("waymo", return_raw_style=False) + dataset_paths = [overpass_1, overpass_in_md] + + output_path = os.path.join(TMP_PATH, "combine") + merge_database(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) + + # filter + filters = [] + filters.append(ScenarioFilter.make(ScenarioFilter.no_overpass)) + + summaries, _ = merge_database( + output_path, + *dataset_paths, + exist_ok=True, + overwrite=True, + try_generate_missing_file=True, + filters=filters + ) + assert len(summaries) == 3 + for scenario in summaries: + assert scenario != "sd_waymo_v1.2_eb4b91b10ca94ff2.pkl" + + +if __name__ == '__main__': + test_filter_overpass() diff --git a/scenarionet/tests/test_move.py b/scenarionet/tests/test_move.py index 0e0f320..492d500 100644 --- a/scenarionet/tests/test_move.py +++ b/scenarionet/tests/test_move.py @@ -4,13 +4,13 @@ import os.path import pytest from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH -from scenarionet.builder.utils import move_database, merge_database +from scenarionet.builder.utils import copy_database, merge_database from scenarionet.common_utils import read_dataset_summary, read_scenario from scenarionet.verifier.utils import verify_database @pytest.mark.order("first") -def test_move_database(): +def test_copy_database(): dataset_name = "nuscenes" original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name) dataset_paths = [original_dataset_path + "_{}".format(i) for i in range(5)] @@ -19,7 +19,7 @@ def test_move_database(): # move for k, from_path in enumerate(dataset_paths): to = os.path.join(TMP_PATH, str(k)) - move_database(from_path, to) + copy_database(from_path, to) moved_path.append(to) assert os.path.exists(from_path) merge_database(output_path, *moved_path, exist_ok=True, overwrite=True, try_generate_missing_file=True) @@ -37,7 +37,7 @@ def test_move_database(): for k, from_path in enumerate(moved_path): new_p = os.path.join(TMP_PATH, str(k) + str(k)) new_move_pathes.append(new_p) - move_database(from_path, new_p, exist_ok=True, overwrite=True) + copy_database(from_path, new_p, exist_ok=True, overwrite=True) assert not os.path.exists(from_path) merge_database(output_path, *new_move_pathes, exist_ok=True, overwrite=True, try_generate_missing_file=True) # verify @@ -51,4 +51,4 @@ def test_move_database(): if __name__ == '__main__': - test_move_database() + test_copy_database() diff --git a/scenarionet/tests/test_split_dataset.py b/scenarionet/tests/test_split_dataset.py new file mode 100644 index 0000000..cd3336c --- /dev/null +++ b/scenarionet/tests/test_split_dataset.py @@ -0,0 +1,37 @@ +import os +import os.path + +from scenarionet import SCENARIONET_PACKAGE_PATH, TMP_PATH +from scenarionet.builder.utils import merge_database, split_database +from scenarionet.common_utils import read_dataset_summary + + +def test_split_dataset(): + dataset_name = "nuscenes" + original_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset", dataset_name) + test_dataset_path = os.path.join(SCENARIONET_PACKAGE_PATH, "tests", "test_dataset") + dataset_paths = [original_dataset_path + "_{}".format(i) for i in [0, 1, 3, 4]] + + output_path = os.path.join(TMP_PATH, "combine") + merge_database(output_path, *dataset_paths, exist_ok=True, overwrite=True, try_generate_missing_file=True) + + # split + from_path = output_path + to_path = os.path.join(TMP_PATH, "split", "split") + summary_1, lookup_1, mapping_1 = read_dataset_summary(from_path) + + split_database(from_path, to_path, start_index=3, random=True, num_scenarios=4, overwrite=True, exist_ok=True) + summary_2, lookup_2, mapping_2 = read_dataset_summary(to_path) + assert len(summary_2) == 4 + for scenario in summary_2: + assert scenario not in lookup_1[:3] + + split_database(from_path, to_path, start_index=3, num_scenarios=4, overwrite=True, exist_ok=True) + summary_2, lookup_2, mapping_2 = read_dataset_summary(to_path) + assert lookup_1[3:7] == lookup_2 + for k in range(4): + assert summary_1[lookup_2[k]] == summary_2[lookup_2[k]] + + +if __name__ == '__main__': + test_split_dataset() diff --git a/scenarionet/verifier/error.py b/scenarionet/verifier/error.py index f7e6b64..e7d8351 100644 --- a/scenarionet/verifier/error.py +++ b/scenarionet/verifier/error.py @@ -53,12 +53,12 @@ class ErrorFile: @classmethod def generate_dataset(cls, error_file_path, new_dataset_path, overwrite=False, broken_scenario=False): """ - Generate a new dataset containing all broken scenarios or all good scenarios + Generate a new database containing all broken scenarios or all good scenarios :param error_file_path: error file path :param new_dataset_path: a directory where you want to store your data :param overwrite: if new_dataset_path exists, whether to overwrite :param broken_scenario: generate broken scenarios. You can generate such a broken scenarios for debugging - :return: dataset summary, dataset mapping + :return: database summary, database mapping """ new_dataset_path = os.path.abspath(new_dataset_path) if os.path.exists(new_dataset_path): diff --git a/scenarionet/verifier/utils.py b/scenarionet/verifier/utils.py index 9d42716..0d7de97 100644 --- a/scenarionet/verifier/utils.py +++ b/scenarionet/verifier/utils.py @@ -108,6 +108,7 @@ def loading_into_metadrive( "agent_policy": ReplayEgoCarPolicy, "num_scenarios": num_scenario, "horizon": 1000, + "store_map": False, "start_scenario_index": start_scenario_index, "no_static_vehicles": False, "data_directory": dataset_path, diff --git a/scenarionet_training/__init__.py b/scenarionet_training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scenarionet_training/scripts/__init__.py b/scenarionet_training/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scenarionet_training/scripts/evaluate_nuplan.py b/scenarionet_training/scripts/evaluate_nuplan.py new file mode 100644 index 0000000..aee11b5 --- /dev/null +++ b/scenarionet_training/scripts/evaluate_nuplan.py @@ -0,0 +1,23 @@ +from scenarionet_training.scripts.train_nuplan import config +from scenarionet_training.train_utils.utils import eval_ckpt + +if __name__ == '__main__': + # 27 29 30 37 39 + ckpt_path = "C:\\Users\\x1\\Desktop\\checkpoint_510\\checkpoint-510" + scenario_data_path = "D:\\scenarionet_testset\\nuplan_test\\nuplan_test_w_raw" + num_scenarios = 2000 + start_scenario_index = 0 + horizon = 600 + render = False + explore = True # PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance + log_interval = 10 + + eval_ckpt(config, + ckpt_path, + scenario_data_path, + num_scenarios, + start_scenario_index, + horizon, + render, + explore, + log_interval) diff --git a/scenarionet_training/scripts/evaluate_pg.py b/scenarionet_training/scripts/evaluate_pg.py new file mode 100644 index 0000000..e79b252 --- /dev/null +++ b/scenarionet_training/scripts/evaluate_pg.py @@ -0,0 +1,27 @@ +import os.path + +from scenarionet import SCENARIONET_DATASET_PATH +from scenarionet_training.scripts.train_pg import config +from scenarionet_training.train_utils.utils import eval_ckpt + +if __name__ == '__main__': + # Merge all evaluate script + # 10/15/20/26/30/31/32 + ckpt_path = "C:\\Users\\x1\\Desktop\\checkpoint_330\\checkpoint-330" + scenario_data_path = os.path.join(SCENARIONET_DATASET_PATH, "pg_2000") + num_scenarios = 2000 + start_scenario_index = 0 + horizon = 600 + render = False + explore = True # PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance + log_interval = 2 + + eval_ckpt(config, + ckpt_path, + scenario_data_path, + num_scenarios, + start_scenario_index, + horizon, + render, + explore, + log_interval) diff --git a/scenarionet_training/scripts/evaluate_waymo.py b/scenarionet_training/scripts/evaluate_waymo.py new file mode 100644 index 0000000..c00fa7e --- /dev/null +++ b/scenarionet_training/scripts/evaluate_waymo.py @@ -0,0 +1,22 @@ +from scenarionet_training.scripts.train_waymo import config +from scenarionet_training.train_utils.utils import eval_ckpt + +if __name__ == '__main__': + ckpt_path = "C:\\Users\\x1\\Desktop\\checkpoint_170\\checkpoint-170" + scenario_data_path = "D:\\scenarionet_testset\\waymo_test_raw_data" + num_scenarios = 2000 + start_scenario_index = 0 + horizon = 600 + render = True + explore = True # PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance + log_interval = 2 + + eval_ckpt(config, + ckpt_path, + scenario_data_path, + num_scenarios, + start_scenario_index, + horizon, + render, + explore, + log_interval) diff --git a/scenarionet_training/scripts/local_test.py b/scenarionet_training/scripts/local_test.py new file mode 100644 index 0000000..e785a5d --- /dev/null +++ b/scenarionet_training/scripts/local_test.py @@ -0,0 +1,80 @@ +import os.path + +from metadrive.envs.scenario_env import ScenarioEnv + +from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH +from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO +from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name + +if __name__ == '__main__': + env = ScenarioEnv + args = get_train_parser().parse_args() + exp_name = get_exp_name(args) + stop = int(100_000_000) + + config = dict( + env=env, + env_config=dict( + # scenario + start_scenario_index=0, + num_scenarios=32, + data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg"), + sequential_seed=True, + + # traffic & light + reactive_traffic=False, + no_static_vehicles=True, + no_light=True, + static_traffic_object=True, + + # curriculum training + curriculum_level=4, + target_success_rate=0.8, + + # training + horizon=None, + use_lateral_reward=True, + ), + + # # ===== Evaluation ===== + evaluation_interval=2, + evaluation_num_episodes=32, + evaluation_config=dict(env_config=dict(start_scenario_index=32, + num_scenarios=32, + sequential_seed=True, + curriculum_level=1, # turn off + data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg"))), + evaluation_num_workers=2, + metrics_smoothing_episodes=10, + + # ===== Training ===== + model=dict(fcnet_hiddens=[512, 256, 128]), + horizon=600, + num_sgd_iter=20, + lr=5e-5, + rollout_fragment_length=500, + sgd_minibatch_size=100, + train_batch_size=4000, + num_gpus=0.5 if args.num_gpus != 0 else 0, + num_cpus_per_worker=0.4, + num_cpus_for_driver=1, + num_workers=2, + framework="tf" + ) + + train( + MultiWorkerPPO, + exp_name=exp_name, + save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"), + keep_checkpoints_num=5, + stop=stop, + config=config, + num_gpus=args.num_gpus, + # num_seeds=args.num_seeds, + num_seeds=1, + # test_mode=args.test, + # local_mode=True, + # TODO remove this when we release our code + # wandb_key_file="~/wandb_api_key_file.txt", + wandb_project="scenarionet", + ) diff --git a/scenarionet_training/scripts/multi_worker_eval.py b/scenarionet_training/scripts/multi_worker_eval.py new file mode 100644 index 0000000..7c51d84 --- /dev/null +++ b/scenarionet_training/scripts/multi_worker_eval.py @@ -0,0 +1,74 @@ +import argparse +import pickle +import json +import os + +import numpy as np + +from scenarionet_training.scripts.train_nuplan import config +from scenarionet_training.train_utils.callbacks import DrivingCallbacks +from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO +from scenarionet_training.train_utils.utils import initialize_ray + + +class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.int32): + return int(obj) + elif isinstance(obj, np.int64): + return int(obj) + return json.JSONEncoder.default(self, obj) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--start_index", type=int, default=0) + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--database_path", type=str, required=True) + parser.add_argument("--id", type=str, default="") + parser.add_argument("--num_scenarios", type=int, default=5000) + parser.add_argument("--num_workers", type=int, default=10) + parser.add_argument("--horizon", type=int, default=600) + parser.add_argument("--allowed_more_steps", type=int, default=50) + parser.add_argument("--max_lateral_dist", type=int, default=2.5) + parser.add_argument("--overwrite", action="store_true") + + args = parser.parse_args() + file = "eval_{}_{}_{}".format(args.id, os.path.basename(args.ckpt_path), os.path.basename(args.database_path)) + if os.path.exists(file) and not args.overwrite: + raise FileExistsError("Please remove {} or set --overwrite".format(file)) + initialize_ray(test_mode=True, num_gpus=1) + + config["callbacks"] = DrivingCallbacks + config["evaluation_num_workers"] = args.num_workers + config["evaluation_num_episodes"] = args.num_scenarios + config["metrics_smoothing_episodes"] = args.num_scenarios + config["custom_eval_function"] = None + config["num_workers"] = 0 + config["evaluation_config"]["env_config"].update(dict( + start_scenario_index=args.start_index, + num_scenarios=args.num_scenarios, + sequential_seed=True, + store_map=False, + store_data=False, + allowed_more_steps=args.allowed_more_steps, + # no_map=True, + max_lateral_dist=args.max_lateral_dist, + curriculum_level=1, # disable curriculum + target_success_rate=1, + horizon=args.horizon, + episodes_to_evaluate_curriculum=args.num_scenarios, + data_directory=args.database_path, + use_render=False)) + + trainer = MultiWorkerPPO(config) + trainer.restore(args.ckpt_path) + + ret = trainer._evaluate()["evaluation"] + with open(file + ".json", "w") as f: + json.dump(ret, f, cls=NumpyEncoder) + + with open(file + ".pkl", "wb+") as f: + pickle.dump(ret, f) diff --git a/scenarionet_training/scripts/train_nuplan.py b/scenarionet_training/scripts/train_nuplan.py new file mode 100644 index 0000000..766e6c6 --- /dev/null +++ b/scenarionet_training/scripts/train_nuplan.py @@ -0,0 +1,96 @@ +import os.path + +from metadrive.envs.scenario_env import ScenarioEnv +from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH +from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO +from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name + +config = dict( + env=ScenarioEnv, + env_config=dict( + # scenario + start_scenario_index=0, + num_scenarios=40000, + data_directory=os.path.join(SCENARIONET_DATASET_PATH, "nuplan_train"), + sequential_seed=True, + + # curriculum training + curriculum_level=100, + target_success_rate=0.8, # or 0.7 + # episodes_to_evaluate_curriculum=400, # default=num_scenarios/curriculum_level + + # traffic & light + reactive_traffic=True, + no_static_vehicles=True, + no_light=True, + static_traffic_object=True, + + # training scheme + horizon=None, + driving_reward=4, + steering_range_penalty=1.0, + heading_penalty=2, + lateral_penalty=2.0, + no_negative_reward=True, + on_lane_line_penalty=0, + crash_vehicle_penalty=2, + crash_human_penalty=2, + crash_object_penalty=0.5, + # out_of_road_penalty=2, + max_lateral_dist=2, + # crash_vehicle_done=True, + + vehicle_config=dict(side_detector=dict(num_lasers=0)) + + ), + + # ===== Evaluation ===== + evaluation_interval=15, + evaluation_num_episodes=1000, + # TODO (LQY), this is a sample from testset do eval on all scenarios after training! + evaluation_config=dict(env_config=dict(start_scenario_index=0, + num_scenarios=1000, + sequential_seed=True, + curriculum_level=1, # turn off + data_directory=os.path.join(SCENARIONET_DATASET_PATH, "nuplan_test"))), + evaluation_num_workers=10, + metrics_smoothing_episodes=10, + + # ===== Training ===== + model=dict(fcnet_hiddens=[512, 256, 128]), + horizon=600, + num_sgd_iter=20, + lr=1e-4, + rollout_fragment_length=500, + sgd_minibatch_size=200, + train_batch_size=50000, + num_gpus=0.5, + num_cpus_per_worker=0.3, + num_cpus_for_driver=1, + num_workers=20, + framework="tf" +) + +if __name__ == '__main__': + # PG data is generated with seeds 10,000 to 60,000 + args = get_train_parser().parse_args() + exp_name = get_exp_name(args) + stop = int(100_000_000) + config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0 + + train( + MultiWorkerPPO, + exp_name=exp_name, + save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"), + keep_checkpoints_num=5, + stop=stop, + config=config, + num_gpus=args.num_gpus, + # num_seeds=args.num_seeds, + num_seeds=5, + test_mode=args.test, + # local_mode=True, + # TODO remove this when we release our code + # wandb_key_file="~/wandb_api_key_file.txt", + wandb_project="scenarionet", + ) diff --git a/scenarionet_training/scripts/train_pg.py b/scenarionet_training/scripts/train_pg.py new file mode 100644 index 0000000..80aae0d --- /dev/null +++ b/scenarionet_training/scripts/train_pg.py @@ -0,0 +1,95 @@ +import os.path +from ray.tune import grid_search +from metadrive.envs.scenario_env import ScenarioEnv + +from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH +from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO +from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name + +config = dict( + env=ScenarioEnv, + env_config=dict( + # scenario + start_scenario_index=0, + num_scenarios=40000, + data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg_train"), + sequential_seed=True, + + # curriculum training + curriculum_level=100, + target_success_rate=0.8, + # episodes_to_evaluate_curriculum=400, # default=num_scenarios/curriculum_level + + # traffic & light + reactive_traffic=False, + no_static_vehicles=True, + no_light=True, + static_traffic_object=True, + + # training scheme + horizon=None, + steering_range_penalty=2, + heading_penalty=1.0, + lateral_penalty=1.0, + no_negative_reward=True, + on_lane_line_penalty=0, + crash_vehicle_penalty=2, + crash_human_penalty=2, + out_of_road_penalty=2, + max_lateral_dist=2, + # crash_vehicle_done=True, + + vehicle_config=dict(side_detector=dict(num_lasers=0)) + + ), + + # ===== Evaluation ===== + evaluation_interval=15, + evaluation_num_episodes=1000, + # TODO (LQY), this is a sample from testset do eval on all scenarios after training! + evaluation_config=dict(env_config=dict(start_scenario_index=0, + num_scenarios=1000, + sequential_seed=True, + curriculum_level=1, # turn off + data_directory=os.path.join(SCENARIONET_DATASET_PATH, "pg_test"))), + evaluation_num_workers=10, + metrics_smoothing_episodes=10, + + # ===== Training ===== + model=dict(fcnet_hiddens=[512, 256, 128]), + horizon=600, + num_sgd_iter=20, + lr=1e-4, + rollout_fragment_length=500, + sgd_minibatch_size=200, + train_batch_size=50000, + num_gpus=0.5, + num_cpus_per_worker=0.3, + num_cpus_for_driver=1, + num_workers=20, + framework="tf" +) + +if __name__ == '__main__': + # PG data is generated with seeds 10,000 to 60,000 + args = get_train_parser().parse_args() + exp_name = get_exp_name(args) + stop = int(100_000_000) + config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0 + + train( + MultiWorkerPPO, + exp_name=exp_name, + save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"), + keep_checkpoints_num=5, + stop=stop, + config=config, + num_gpus=args.num_gpus, + # num_seeds=args.num_seeds, + num_seeds=5, + test_mode=args.test, + # local_mode=True, + # TODO remove this when we release our code + # wandb_key_file="~/wandb_api_key_file.txt", + wandb_project="scenarionet", + ) diff --git a/scenarionet_training/scripts/train_waymo.py b/scenarionet_training/scripts/train_waymo.py new file mode 100644 index 0000000..767e51d --- /dev/null +++ b/scenarionet_training/scripts/train_waymo.py @@ -0,0 +1,95 @@ +import os.path + +from metadrive.envs.scenario_env import ScenarioEnv +from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH +from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO +from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name + +config = dict( + env=ScenarioEnv, + env_config=dict( + # scenario + start_scenario_index=0, + num_scenarios=40000, + data_directory=os.path.join(SCENARIONET_DATASET_PATH, "waymo_train"), + sequential_seed=True, + + # curriculum training + curriculum_level=100, + target_success_rate=0.8, + # episodes_to_evaluate_curriculum=400, # default=num_scenarios/curriculum_level + + # traffic & light + reactive_traffic=True, + no_static_vehicles=True, + no_light=True, + static_traffic_object=True, + + # training scheme + horizon=None, + driving_reward=1, + steering_range_penalty=0, + heading_penalty=1, + lateral_penalty=1.0, + no_negative_reward=True, + on_lane_line_penalty=0, + crash_vehicle_penalty=2, + crash_human_penalty=2, + out_of_road_penalty=2, + max_lateral_dist=2, + # crash_vehicle_done=True, + + vehicle_config=dict(side_detector=dict(num_lasers=0)) + + ), + + # ===== Evaluation ===== + evaluation_interval=15, + evaluation_num_episodes=1000, + # TODO (LQY), this is a sample from testset do eval on all scenarios after training! + evaluation_config=dict(env_config=dict(start_scenario_index=0, + num_scenarios=1000, + sequential_seed=True, + curriculum_level=1, # turn off + data_directory=os.path.join(SCENARIONET_DATASET_PATH, "waymo_test"))), + evaluation_num_workers=10, + metrics_smoothing_episodes=10, + + # ===== Training ===== + model=dict(fcnet_hiddens=[512, 256, 128]), + horizon=600, + num_sgd_iter=20, + lr=1e-4, + rollout_fragment_length=500, + sgd_minibatch_size=200, + train_batch_size=50000, + num_gpus=0.5, + num_cpus_per_worker=0.3, + num_cpus_for_driver=1, + num_workers=20, + framework="tf" +) + +if __name__ == '__main__': + # PG data is generated with seeds 10,000 to 60,000 + args = get_train_parser().parse_args() + exp_name = get_exp_name(args) + stop = int(100_000_000) + config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0 + + train( + MultiWorkerPPO, + exp_name=exp_name, + save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"), + keep_checkpoints_num=5, + stop=stop, + config=config, + num_gpus=args.num_gpus, + # num_seeds=args.num_seeds, + num_seeds=5, + test_mode=args.test, + # local_mode=True, + # TODO remove this when we release our code + # wandb_key_file="~/wandb_api_key_file.txt", + wandb_project="scenarionet", + ) diff --git a/scenarionet_training/train_utils/__init__.py b/scenarionet_training/train_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scenarionet_training/train_utils/anisotropic_workerset.py b/scenarionet_training/train_utils/anisotropic_workerset.py new file mode 100644 index 0000000..4b055c7 --- /dev/null +++ b/scenarionet_training/train_utils/anisotropic_workerset.py @@ -0,0 +1,42 @@ +import copy +import logging +from typing import TypeVar + +from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.framework import try_import_tf + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + +# Generic type var for foreach_* methods. +T = TypeVar("T") + + +@DeveloperAPI +class AnisotropicWorkerSet(WorkerSet): + """ + Workers are assigned to different scenarios for saving memory/speeding up sampling + """ + + def add_workers(self, num_workers: int) -> None: + """ + Workers are assigned to different scenarios + """ + remote_args = { + "num_cpus": self._remote_config["num_cpus_per_worker"], + "num_gpus": self._remote_config["num_gpus_per_worker"], + # memory=0 is an error, but memory=None means no limits. + "memory": self._remote_config["memory_per_worker"] or None, + "object_store_memory": self. + _remote_config["object_store_memory_per_worker"] or None, + "resources": self._remote_config["custom_resources_per_worker"], + } + cls = RolloutWorker.as_remote(**remote_args).remote + for i in range(num_workers): + config = copy.deepcopy(self._remote_config) + config["env_config"]["worker_index"] = i + config["env_config"]["num_workers"] = num_workers + self._remote_workers.append(self._make_worker(cls, self._env_creator, self._policy_class, i + 1, config)) diff --git a/scenarionet_training/train_utils/callbacks.py b/scenarionet_training/train_utils/callbacks.py new file mode 100644 index 0000000..13a9d53 --- /dev/null +++ b/scenarionet_training/train_utils/callbacks.py @@ -0,0 +1,110 @@ +from typing import Dict + +import numpy as np +from ray.rllib.agents.callbacks import DefaultCallbacks +from ray.rllib.env import BaseEnv +from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker +from ray.rllib.policy import Policy + + +class DrivingCallbacks(DefaultCallbacks): + def on_episode_start( + self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode, + env_index: int, **kwargs + ): + episode.user_data["velocity"] = [] + episode.user_data["steering"] = [] + episode.user_data["step_reward"] = [] + episode.user_data["acceleration"] = [] + episode.user_data["lateral_dist"] = [] + episode.user_data["cost"] = [] + episode.user_data["num_crash_vehicle"] = [] + episode.user_data["num_crash_human"] = [] + episode.user_data["num_crash_object"] = [] + episode.user_data["num_on_line"] = [] + + episode.user_data["step_reward_lateral"] = [] + episode.user_data["step_reward_heading"] = [] + episode.user_data["step_reward_action_smooth"] = [] + + def on_episode_step( + self, *, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode, env_index: int, **kwargs + ): + info = episode.last_info_for() + if info is not None: + episode.user_data["velocity"].append(info["velocity"]) + episode.user_data["steering"].append(info["steering"]) + episode.user_data["step_reward"].append(info["step_reward"]) + episode.user_data["acceleration"].append(info["acceleration"]) + episode.user_data["lateral_dist"].append(info["lateral_dist"]) + episode.user_data["cost"].append(info["cost"]) + for x in ["num_crash_vehicle", "num_crash_object", "num_crash_human", "num_on_line"]: + episode.user_data[x].append(info[x]) + + for x in ["step_reward_lateral", "step_reward_heading", "step_reward_action_smooth"]: + episode.user_data[x].append(info[x]) + + def on_episode_end( + self, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode, + **kwargs + ): + arrive_dest = episode.last_info_for()["arrive_dest"] + crash = episode.last_info_for()["crash"] + out_of_road = episode.last_info_for()["out_of_road"] + max_step_rate = not (arrive_dest or crash or out_of_road) + episode.custom_metrics["success_rate"] = float(arrive_dest) + episode.custom_metrics["crash_rate"] = float(crash) + episode.custom_metrics["out_of_road_rate"] = float(out_of_road) + episode.custom_metrics["max_step_rate"] = float(max_step_rate) + episode.custom_metrics["velocity_max"] = float(np.max(episode.user_data["velocity"])) + episode.custom_metrics["velocity_mean"] = float(np.mean(episode.user_data["velocity"])) + episode.custom_metrics["velocity_min"] = float(np.min(episode.user_data["velocity"])) + + episode.custom_metrics["lateral_dist_min"] = float(np.min(episode.user_data["lateral_dist"])) + episode.custom_metrics["lateral_dist_max"] = float(np.max(episode.user_data["lateral_dist"])) + episode.custom_metrics["lateral_dist_mean"] = float(np.mean(episode.user_data["lateral_dist"])) + + episode.custom_metrics["steering_max"] = float(np.max(episode.user_data["steering"])) + episode.custom_metrics["steering_mean"] = float(np.mean(episode.user_data["steering"])) + episode.custom_metrics["steering_min"] = float(np.min(episode.user_data["steering"])) + episode.custom_metrics["acceleration_min"] = float(np.min(episode.user_data["acceleration"])) + episode.custom_metrics["acceleration_mean"] = float(np.mean(episode.user_data["acceleration"])) + episode.custom_metrics["acceleration_max"] = float(np.max(episode.user_data["acceleration"])) + episode.custom_metrics["step_reward_max"] = float(np.max(episode.user_data["step_reward"])) + episode.custom_metrics["step_reward_mean"] = float(np.mean(episode.user_data["step_reward"])) + episode.custom_metrics["step_reward_min"] = float(np.min(episode.user_data["step_reward"])) + + episode.custom_metrics["cost"] = float(sum(episode.user_data["cost"])) + for x in ["num_crash_vehicle", "num_crash_object", "num_crash_human", "num_on_line"]: + episode.custom_metrics[x] = float(sum(episode.user_data[x])) + + for x in ["step_reward_lateral", "step_reward_heading", "step_reward_action_smooth"]: + episode.custom_metrics[x] = float(np.mean(episode.user_data[x])) + + episode.custom_metrics["route_completion"] = float(episode.last_info_for()["route_completion"]) + episode.custom_metrics["curriculum_level"] = int(episode.last_info_for()["curriculum_level"]) + episode.custom_metrics["scenario_index"] = int(episode.last_info_for()["scenario_index"]) + episode.custom_metrics["track_length"] = float(episode.last_info_for()["track_length"]) + episode.custom_metrics["num_stored_maps"] = int(episode.last_info_for()["num_stored_maps"]) + episode.custom_metrics["scenario_difficulty"] = float(episode.last_info_for()["scenario_difficulty"]) + episode.custom_metrics["data_coverage"] = float(episode.last_info_for()["data_coverage"]) + episode.custom_metrics["curriculum_success"] = float(episode.last_info_for()["curriculum_success"]) + episode.custom_metrics["curriculum_route_completion"] = float( + episode.last_info_for()["curriculum_route_completion"]) + + def on_train_result(self, *, trainer, result: dict, **kwargs): + result["success"] = np.nan + result["out"] = np.nan + result["max_step"] = np.nan + result["level"] = np.nan + result["length"] = result["episode_len_mean"] + result["coverage"] = np.nan + if "custom_metrics" not in result: + return + + if "success_rate_mean" in result["custom_metrics"]: + result["success"] = result["custom_metrics"]["success_rate_mean"] + result["out"] = result["custom_metrics"]["out_of_road_rate_mean"] + result["max_step"] = result["custom_metrics"]["max_step_rate_mean"] + result["level"] = result["custom_metrics"]["curriculum_level_mean"] + result["coverage"] = result["custom_metrics"]["data_coverage_mean"] diff --git a/scenarionet_training/train_utils/check_env.py b/scenarionet_training/train_utils/check_env.py new file mode 100644 index 0000000..d84a3eb --- /dev/null +++ b/scenarionet_training/train_utils/check_env.py @@ -0,0 +1,11 @@ +from ray.rllib.utils import check_env +from metadrive.envs.scenario_env import ScenarioEnv +from metadrive.envs.gymnasium_wrapper import GymnasiumEnvWrapper +from gym import Env + +if __name__ == '__main__': + env = GymnasiumEnvWrapper.build(ScenarioEnv)() + print(isinstance(ScenarioEnv, Env)) + print(isinstance(env, Env)) + print(env.observation_space) + check_env(env) diff --git a/scenarionet_training/train_utils/multi_worker_PPO.py b/scenarionet_training/train_utils/multi_worker_PPO.py new file mode 100644 index 0000000..255d60d --- /dev/null +++ b/scenarionet_training/train_utils/multi_worker_PPO.py @@ -0,0 +1,46 @@ +import logging +from typing import Callable, Type + +from ray.rllib.agents.ppo.ppo import PPOTrainer +from ray.rllib.env.env_context import EnvContext +from ray.rllib.policy import Policy +from ray.rllib.utils.typing import TrainerConfigDict, \ + EnvType + +from scenarionet_training.train_utils.anisotropic_workerset import AnisotropicWorkerSet + +logger = logging.getLogger(__name__) + + +class MultiWorkerPPO(PPOTrainer): + """ + In this class, each work will have different config for speeding up and saving memory. More importantly, it can + allow us to cover all test/train cases more evenly + """ + + def _make_workers(self, env_creator: Callable[[EnvContext], EnvType], + policy_class: Type[Policy], config: TrainerConfigDict, + num_workers: int): + """Default factory method for a WorkerSet running under this Trainer. + + Override this method by passing a custom `make_workers` into + `build_trainer`. + + Args: + env_creator (callable): A function that return and Env given an env + config. + policy (Type[Policy]): The Policy class to use for creating the + policies of the workers. + config (TrainerConfigDict): The Trainer's config. + num_workers (int): Number of remote rollout workers to create. + 0 for local only. + + Returns: + WorkerSet: The created WorkerSet. + """ + return AnisotropicWorkerSet( + env_creator=env_creator, + policy_class=policy_class, + trainer_config=config, + num_workers=num_workers, + logdir=self.logdir) diff --git a/scenarionet_training/train_utils/utils.py b/scenarionet_training/train_utils/utils.py new file mode 100644 index 0000000..edb7e11 --- /dev/null +++ b/scenarionet_training/train_utils/utils.py @@ -0,0 +1,356 @@ +import copy +import datetime +import json +import os +import pickle +from collections import defaultdict + +import numpy as np +import tqdm +from metadrive.constants import TerminationState +from metadrive.envs.scenario_env import ScenarioEnv +from ray import tune +from ray.tune import CLIReporter + +from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO +from scenarionet_training.wandb_utils import WANDB_KEY_FILE + +root = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + + +def get_api_key_file(wandb_key_file): + if wandb_key_file is not None: + default_path = os.path.expanduser(wandb_key_file) + else: + default_path = WANDB_KEY_FILE + if os.path.exists(default_path): + print("We are using this wandb key file: ", default_path) + return default_path + path = os.path.join(root, "scenarionet_training/wandb", "wandb_api_key_file.txt") + print("We are using this wandb key file: ", path) + return path + + +def train( + trainer, + config, + stop, + exp_name, + num_seeds=1, + num_gpus=0, + test_mode=False, + suffix="", + checkpoint_freq=10, + keep_checkpoints_num=None, + start_seed=0, + local_mode=False, + save_pkl=True, + custom_callback=None, + max_failures=0, + wandb_key_file=None, + wandb_project=None, + wandb_team="drivingforce", + wandb_log_config=True, + init_kws=None, + save_dir=None, + **kwargs +): + init_kws = init_kws or dict() + # initialize ray + if not os.environ.get("redis_password"): + initialize_ray(test_mode=test_mode, local_mode=local_mode, num_gpus=num_gpus, **init_kws) + else: + password = os.environ.get("redis_password") + assert os.environ.get("ip_head") + print( + "We detect redis_password ({}) exists in environment! So " + "we will start a ray cluster!".format(password) + ) + if num_gpus: + print( + "We are in cluster mode! So GPU specification is disable and" + " should be done when submitting task to cluster! You are " + "requiring {} GPU for each machine!".format(num_gpus) + ) + initialize_ray(address=os.environ["ip_head"], test_mode=test_mode, redis_password=password, **init_kws) + + # prepare config + + if custom_callback: + callback = custom_callback + else: + from scenarionet_training.train_utils.callbacks import DrivingCallbacks + callback = DrivingCallbacks + + used_config = { + "seed": tune.grid_search([i * 100 + start_seed for i in range(num_seeds)]) if num_seeds is not None else None, + "log_level": "DEBUG" if test_mode else "INFO", + "callbacks": callback + } + if custom_callback is False: + used_config.pop("callbacks") + if config: + used_config.update(config) + config = copy.deepcopy(used_config) + + if isinstance(trainer, str): + trainer_name = trainer + elif hasattr(trainer, "_name"): + trainer_name = trainer._name + else: + trainer_name = trainer.__name__ + + if not isinstance(stop, dict) and stop is not None: + assert np.isscalar(stop) + stop = {"timesteps_total": int(stop)} + + if keep_checkpoints_num is not None and not test_mode: + assert isinstance(keep_checkpoints_num, int) + kwargs["keep_checkpoints_num"] = keep_checkpoints_num + kwargs["checkpoint_score_attr"] = "episode_reward_mean" + + if "verbose" not in kwargs: + kwargs["verbose"] = 1 if not test_mode else 2 + + # This functionality is not supported yet! + metric_columns = CLIReporter.DEFAULT_COLUMNS.copy() + progress_reporter = CLIReporter(metric_columns=metric_columns) + progress_reporter.add_metric_column("success") + progress_reporter.add_metric_column("coverage") + progress_reporter.add_metric_column("out") + progress_reporter.add_metric_column("max_step") + progress_reporter.add_metric_column("length") + progress_reporter.add_metric_column("level") + kwargs["progress_reporter"] = progress_reporter + + if wandb_key_file is not None: + assert wandb_project is not None + if wandb_project is not None: + assert wandb_project is not None + failed_wandb = False + try: + from scenarionet_training.wandb_utils.our_wandb_callbacks import OurWandbLoggerCallback + except Exception as e: + # print("Please install wandb: pip install wandb") + failed_wandb = True + + if failed_wandb: + from ray.tune.logger import DEFAULT_LOGGERS + from scenarionet_training.wandb_utils.our_wandb_callbacks_ray100 import OurWandbLogger + kwargs["loggers"] = DEFAULT_LOGGERS + (OurWandbLogger,) + config["logger_config"] = { + "wandb": + { + "group": exp_name, + "exp_name": exp_name, + "entity": wandb_team, + "project": wandb_project, + "api_key_file": get_api_key_file(wandb_key_file), + "log_config": wandb_log_config, + } + } + else: + kwargs["callbacks"] = [ + OurWandbLoggerCallback( + exp_name=exp_name, + api_key_file=get_api_key_file(wandb_key_file), + project=wandb_project, + group=exp_name, + log_config=wandb_log_config, + entity=wandb_team + ) + ] + + # start training + analysis = tune.run( + trainer, + name=exp_name, + checkpoint_freq=checkpoint_freq, + checkpoint_at_end=True if "checkpoint_at_end" not in kwargs else kwargs.pop("checkpoint_at_end"), + stop=stop, + config=config, + max_failures=max_failures if not test_mode else 0, + reuse_actors=False, + local_dir=save_dir or ".", + **kwargs + ) + + # save training progress as insurance + if save_pkl: + pkl_path = "{}-{}{}.pkl".format(exp_name, trainer_name, "" if not suffix else "-" + suffix) + with open(pkl_path, "wb") as f: + data = analysis.fetch_trial_dataframes() + pickle.dump(data, f) + print("Result is saved at: <{}>".format(pkl_path)) + return analysis + + +import argparse +import logging +import os + +import ray + + +def initialize_ray(local_mode=False, num_gpus=None, test_mode=False, **kwargs): + os.environ['OMP_NUM_THREADS'] = '1' + + if ray.__version__.split(".")[0] == "1": # 1.0 version Ray + if "redis_password" in kwargs: + redis_password = kwargs.pop("redis_password") + kwargs["_redis_password"] = redis_password + + ray.init( + logging_level=logging.ERROR if not test_mode else logging.DEBUG, + log_to_driver=test_mode, + local_mode=local_mode, + num_gpus=num_gpus, + ignore_reinit_error=True, + include_dashboard=False, + **kwargs + ) + print("Successfully initialize Ray!") + try: + print("Available resources: ", ray.available_resources()) + except Exception: + pass + + +def get_train_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default="") + parser.add_argument("--num-gpus", type=int, default=0) + parser.add_argument("--num-seeds", type=int, default=3) + parser.add_argument("--num-cpus-per-worker", type=float, default=0.5) + parser.add_argument("--num-gpus-per-trial", type=float, default=0.25) + parser.add_argument("--test", action="store_true") + return parser + + +def setup_logger(debug=False): + import logging + logging.basicConfig( + level=logging.DEBUG if debug else logging.WARNING, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s' + ) + + +def get_time_str(): + return datetime.datetime.now().strftime("%y%m%d-%H%M%S") + + +def get_exp_name(args): + if args.exp_name != "": + exp_name = args.exp_name + "_" + get_time_str() + else: + exp_name = "TEST" + return exp_name + + +def get_eval_config(config): + eval_config = copy.deepcopy(config) + eval_config.pop("evaluation_interval") + eval_config.pop("evaluation_num_episodes") + eval_config.pop("evaluation_config") + eval_config.pop("evaluation_num_workers") + return eval_config + + +def get_function(ckpt, explore, config): + trainer = MultiWorkerPPO(get_eval_config(config)) + trainer.restore(ckpt) + + def _f(obs): + ret = trainer.compute_actions({"default_policy": obs}, explore=explore) + return ret + + return _f + + +def eval_ckpt(config, + ckpt_path, + scenario_data_path, + num_scenarios, + start_scenario_index, + horizon=600, + render=False, + # PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance + explore=True, + log_interval=None, + ): + initialize_ray(test_mode=False, num_gpus=1) + # 27 29 30 37 39 + env_config = get_eval_config(config)["env_config"] + env_config.update(dict( + start_scenario_index=start_scenario_index, + num_scenarios=num_scenarios, + sequential_seed=True, + curriculum_level=1, # disable curriculum + target_success_rate=1, + horizon=horizon, + episodes_to_evaluate_curriculum=num_scenarios, + data_directory=scenario_data_path, + use_render=render)) + env = ScenarioEnv(env_config) + + super_data = defaultdict(list) + EPISODE_NUM = env.config["num_scenarios"] + compute_actions = get_function(ckpt_path, explore=explore, config=config) + + o = env.reset() + assert env.current_seed == start_scenario_index, "Wrong start seed!" + + total_cost = 0 + total_reward = 0 + success_rate = 0 + ep_cost = 0 + ep_reward = 0 + success_flag = False + step = 0 + + def log_msg(): + print("CKPT:{} | success_rate:{}, mean_episode_reward:{}, mean_episode_cost:{}".format(epi_num, + success_rate / epi_num, + total_reward / epi_num, + total_cost / epi_num)) + + for epi_num in tqdm.tqdm(range(0, EPISODE_NUM)): + step += 1 + action_to_send = compute_actions(o)["default_policy"] + o, r, d, info = env.step(action_to_send) + if env.config["use_render"]: + env.render(text={"reward": r}) + total_reward += r + ep_reward += r + total_cost += info["cost"] + ep_cost += info["cost"] + if d or step > horizon: + if info["arrive_dest"]: + success_rate += 1 + success_flag = True + o = env.reset() + + super_data[0].append( + {"reward": ep_reward, + "success": success_flag, + "out_of_road": info[TerminationState.OUT_OF_ROAD], + "cost": ep_cost, + "seed": env.current_seed, + "route_completion": info["route_completion"]}) + + ep_cost = 0.0 + ep_reward = 0.0 + success_flag = False + step = 0 + + if log_interval is not None and epi_num % log_interval == 0: + log_msg() + if log_interval is not None: + log_msg() + del compute_actions + env.close() + with open("eval_ret_{}_{}_{}.json".format(start_scenario_index, + start_scenario_index + num_scenarios, + get_time_str()), "w") as f: + json.dump(super_data, f) + return super_data diff --git a/scenarionet_training/wandb_utils/__init__.py b/scenarionet_training/wandb_utils/__init__.py new file mode 100644 index 0000000..0959e0f --- /dev/null +++ b/scenarionet_training/wandb_utils/__init__.py @@ -0,0 +1,3 @@ +import os + +WANDB_KEY_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "wandb_api_key_file.txt") diff --git a/scenarionet_training/wandb_utils/our_wandb_callbacks.py b/scenarionet_training/wandb_utils/our_wandb_callbacks.py new file mode 100644 index 0000000..ffb69b3 --- /dev/null +++ b/scenarionet_training/wandb_utils/our_wandb_callbacks.py @@ -0,0 +1,67 @@ +from ray.tune.integration.wandb import WandbLoggerCallback, _clean_log, \ + Queue, WandbLogger + + +class OurWandbLoggerCallback(WandbLoggerCallback): + def __init__(self, exp_name, *args, **kwargs): + super(OurWandbLoggerCallback, self).__init__(*args, **kwargs) + self.exp_name = exp_name + + def log_trial_start(self, trial: "Trial"): + config = trial.config.copy() + + config.pop("callbacks", None) # Remove callbacks + + exclude_results = self._exclude_results.copy() + + # Additional excludes + exclude_results += self.excludes + + # Log config keys on each result? + if not self.log_config: + exclude_results += ["config"] + + # Fill trial ID and name + trial_id = trial.trial_id if trial else None + # trial_name = str(trial) if trial else None + + # Project name for Wandb + wandb_project = self.project + + # Grouping + wandb_group = self.group or trial.trainable_name if trial else None + + # remove unpickleable items! + config = _clean_log(config) + + assert trial_id is not None + run_name = "{}_{}".format(self.exp_name, trial_id) + + wandb_init_kwargs = dict( + id=trial_id, + name=run_name, + resume=True, + reinit=True, + allow_val_change=True, + group=wandb_group, + project=wandb_project, + config=config + ) + wandb_init_kwargs.update(self.kwargs) + + self._trial_queues[trial] = Queue() + self._trial_processes[trial] = self._logger_process_cls( + queue=self._trial_queues[trial], + exclude=exclude_results, + to_config=self._config_results, + **wandb_init_kwargs + ) + self._trial_processes[trial].start() + + def __del__(self): + if self._trial_processes: + for v in self._trial_processes.values(): + if hasattr(v, "close"): + v.close() + self._trial_processes.clear() + self._trial_processes = {} diff --git a/scenarionet_training/wandb_utils/our_wandb_callbacks_ray100.py b/scenarionet_training/wandb_utils/our_wandb_callbacks_ray100.py new file mode 100644 index 0000000..1c516db --- /dev/null +++ b/scenarionet_training/wandb_utils/our_wandb_callbacks_ray100.py @@ -0,0 +1,80 @@ +from multiprocessing import Queue + +from ray.tune.integration.wandb import WandbLogger, _clean_log, _set_api_key + + +class OurWandbLogger(WandbLogger): + def __init__(self, config, logdir, trial): + self.exp_name = config["logger_config"]["wandb"].pop("exp_name") + super(OurWandbLogger, self).__init__(config, logdir, trial) + + def _init(self): + + config = self.config.copy() + + config.pop("callbacks", None) # Remove callbacks + + try: + if config.get("logger_config", {}).get("wandb"): + logger_config = config.pop("logger_config") + wandb_config = logger_config.get("wandb").copy() + else: + wandb_config = config.pop("wandb").copy() + except KeyError: + raise ValueError( + "Wandb logger specified but no configuration has been passed. " + "Make sure to include a `wandb` key in your `config` dict " + "containing at least a `project` specification.") + + _set_api_key(wandb_config) + + exclude_results = self._exclude_results.copy() + + # Additional excludes + additional_excludes = wandb_config.pop("excludes", []) + exclude_results += additional_excludes + + # Log config keys on each result? + log_config = wandb_config.pop("log_config", False) + if not log_config: + exclude_results += ["config"] + + # Fill trial ID and name + trial_id = self.trial.trial_id if self.trial else None + trial_name = str(self.trial) if self.trial else None + + # Project name for Wandb + try: + wandb_project = wandb_config.pop("project") + except KeyError: + raise ValueError( + "You need to specify a `project` in your wandb `config` dict.") + + # Grouping + wandb_group = wandb_config.pop( + "group", self.trial.trainable_name if self.trial else None) + + # remove unpickleable items! + config = _clean_log(config) + + assert trial_id is not None + run_name = "{}_{}".format(self.exp_name, trial_id) + + wandb_init_kwargs = dict( + id=trial_id, + name=run_name, + resume=True, + reinit=True, + allow_val_change=True, + group=wandb_group, + project=wandb_project, + config=config) + wandb_init_kwargs.update(wandb_config) + + self._queue = Queue() + self._wandb = self._logger_process_cls( + queue=self._queue, + exclude=exclude_results, + to_config=self._config_results, + **wandb_init_kwargs) + self._wandb.start() diff --git a/scenarionet_training/wandb_utils/test_wandb.py b/scenarionet_training/wandb_utils/test_wandb.py new file mode 100644 index 0000000..56d76d8 --- /dev/null +++ b/scenarionet_training/wandb_utils/test_wandb.py @@ -0,0 +1,38 @@ +""" +Procedure to use wandb: + +1. Logup in wandb: https://wandb.ai/ +2. Get the API key in personal setting +3. Store API key (a string)to some file as: ~/wandb_api_key_file.txt +4. Install wandb: pip install wandb +5. Fill the "wandb_key_file", "wandb_project" keys in our train function. + +Note1: You don't need to specify who own "wandb_project", for example, in team "drivingforce"'s project +"representation", you only need to fill wandb_project="representation" + +Note2: In wanbd, there are "team name", "project name", "group name" and "trial_name". We only need to care +"team name" and "project name". The "team name" is set to "drivingforce" by default. You can also use None to +log result to your personal domain. The "group name" of the experiment is exactly the "exp_name" in our context, like +"0304_train_ppo" or so. + +Note3: It would be great to change the x-axis in wandb website to "timesteps_total". + +Peng Zhenghao, 20210402 +""" +from ray import tune + +from scenarionet_training.train_utils.utils import train + +if __name__ == "__main__": + config = dict(env="CartPole-v0", num_workers=0, lr=tune.grid_search([1e-2, 1e-4])) + train( + "PPO", + exp_name="test_wandb", + stop=10000, + config=config, + custom_callback=False, + test_mode=False, + local_mode=False, + wandb_project="TEST", + wandb_team="drivingforce" # drivingforce is set to default. Use None to log to your personal domain! + ) diff --git a/scenarionet_training/wandb_utils/wandb_api_key_file.txt b/scenarionet_training/wandb_utils/wandb_api_key_file.txt new file mode 100644 index 0000000..a19a4a2 --- /dev/null +++ b/scenarionet_training/wandb_utils/wandb_api_key_file.txt @@ -0,0 +1 @@ +132a8add578bdaeea5ab7a4942f35f2a17742df2 \ No newline at end of file diff --git a/setup.py b/setup.py index 02c4969..8b6b3ef 100644 --- a/setup.py +++ b/setup.py @@ -40,9 +40,18 @@ install_requires = [ "shapely" ] +train_requirement = [ + "ray[rllib]==1.0.0", + # "torch", + "wandb==0.12.1", + "aiohttp==3.6.0", + "gymnasium", + "tensorflow", + "tensorflow_probability"] + setup( name="scenarionet", - python_requires='>=3.6, <3.12', # do version check with assert + python_requires='>=3.8', # do version check with assert version=version, description="Scalable Traffic Scenario Management System", url="https://github.com/metadriverse/ScenarioNet", @@ -50,12 +59,9 @@ setup( author_email="quanyili0057@gmail.com, pzh@cs.ucla.edu", packages=packages, install_requires=install_requires, - # extras_require={ - # "cuda": cuda_requirement, - # "nuplan": nuplan_requirement, - # "waymo": waymo_requirement, - # "all": nuplan_requirement + cuda_requirement - # }, + extras_require={ + "train": train_requirement, + }, include_package_data=True, license="Apache 2.0", long_description=long_description,