Add come updates for Neurips paper (#4)
* scenarionet training * wandb * train utils * fix callback * run PPO * use pg test * save path * use torch * add dependency * update ignore * update training * large model * use curriculum training * add time to exp name * storage_path * restore * update training * use my key * add log message * check seed * restore callback * restore call bacl * add log message * add logging message * restore ray1.4 * length 500 * ray 100 * wandb * use tf * more levels * add callback * 10 worker * show level * no env horizon * callback result level * more call back * add diffuculty * add mroen stat * mroe stat * show levels * add callback * new * ep len 600 * fix setup * fix stepup * fix to 3.8 * update setup * parallel worker! * new exp * add callback * lateral dist * pg dataset * evaluate * modify config * align config * train single RL * update training script * 100w eval * less eval to reveal * 2000 env eval * new trianing * eval 1000 * update eval * more workers * more worker * 20 worker * dataset to database * split tool! * split dataset * try fix * train 003 * fix mapping * fix test * add waymo tqdm * utils * fix bug * fix bug * waymo * int type * 8 worker read * disable * read file * add log message * check existence * dist 0 * int * check num * suprass warning * add filter API * filter * store map false * new * ablation * filter * fix * update filyter * reanme to from * random select * add overlapping checj * fix * new training sceheme * new reward * add waymo train script * waymo different config * copy raw data * fix bug * add tqdm * update readme * waymo * pg * max lateral dist 3 * pg * crash_done instead of penalty * no crash done * gpu * update eval script * steering range penalty * evaluate * finish pg * update setup * fix bug * test * fix * add on line * train nuplan * generate sensor * udpate training * static obj * multi worker eval * filx bug * use ray for testing * eval! * filter senario * id filter * fox bug * dist = 2 * filter * eval * eval ret * ok * update training pg * test before use * store data=False * collect figures * capture pic --------- Co-authored-by: Quanyi Li <quanyi@bolei-gpu02.cs.ucla.edu>
This commit is contained in:
0
scenarionet_training/train_utils/__init__.py
Normal file
0
scenarionet_training/train_utils/__init__.py
Normal file
42
scenarionet_training/train_utils/anisotropic_workerset.py
Normal file
42
scenarionet_training/train_utils/anisotropic_workerset.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import copy
|
||||
import logging
|
||||
from typing import TypeVar
|
||||
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Generic type var for foreach_* methods.
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class AnisotropicWorkerSet(WorkerSet):
|
||||
"""
|
||||
Workers are assigned to different scenarios for saving memory/speeding up sampling
|
||||
"""
|
||||
|
||||
def add_workers(self, num_workers: int) -> None:
|
||||
"""
|
||||
Workers are assigned to different scenarios
|
||||
"""
|
||||
remote_args = {
|
||||
"num_cpus": self._remote_config["num_cpus_per_worker"],
|
||||
"num_gpus": self._remote_config["num_gpus_per_worker"],
|
||||
# memory=0 is an error, but memory=None means no limits.
|
||||
"memory": self._remote_config["memory_per_worker"] or None,
|
||||
"object_store_memory": self.
|
||||
_remote_config["object_store_memory_per_worker"] or None,
|
||||
"resources": self._remote_config["custom_resources_per_worker"],
|
||||
}
|
||||
cls = RolloutWorker.as_remote(**remote_args).remote
|
||||
for i in range(num_workers):
|
||||
config = copy.deepcopy(self._remote_config)
|
||||
config["env_config"]["worker_index"] = i
|
||||
config["env_config"]["num_workers"] = num_workers
|
||||
self._remote_workers.append(self._make_worker(cls, self._env_creator, self._policy_class, i + 1, config))
|
||||
110
scenarionet_training/train_utils/callbacks.py
Normal file
110
scenarionet_training/train_utils/callbacks.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.env import BaseEnv
|
||||
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
||||
from ray.rllib.policy import Policy
|
||||
|
||||
|
||||
class DrivingCallbacks(DefaultCallbacks):
|
||||
def on_episode_start(
|
||||
self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode,
|
||||
env_index: int, **kwargs
|
||||
):
|
||||
episode.user_data["velocity"] = []
|
||||
episode.user_data["steering"] = []
|
||||
episode.user_data["step_reward"] = []
|
||||
episode.user_data["acceleration"] = []
|
||||
episode.user_data["lateral_dist"] = []
|
||||
episode.user_data["cost"] = []
|
||||
episode.user_data["num_crash_vehicle"] = []
|
||||
episode.user_data["num_crash_human"] = []
|
||||
episode.user_data["num_crash_object"] = []
|
||||
episode.user_data["num_on_line"] = []
|
||||
|
||||
episode.user_data["step_reward_lateral"] = []
|
||||
episode.user_data["step_reward_heading"] = []
|
||||
episode.user_data["step_reward_action_smooth"] = []
|
||||
|
||||
def on_episode_step(
|
||||
self, *, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode, env_index: int, **kwargs
|
||||
):
|
||||
info = episode.last_info_for()
|
||||
if info is not None:
|
||||
episode.user_data["velocity"].append(info["velocity"])
|
||||
episode.user_data["steering"].append(info["steering"])
|
||||
episode.user_data["step_reward"].append(info["step_reward"])
|
||||
episode.user_data["acceleration"].append(info["acceleration"])
|
||||
episode.user_data["lateral_dist"].append(info["lateral_dist"])
|
||||
episode.user_data["cost"].append(info["cost"])
|
||||
for x in ["num_crash_vehicle", "num_crash_object", "num_crash_human", "num_on_line"]:
|
||||
episode.user_data[x].append(info[x])
|
||||
|
||||
for x in ["step_reward_lateral", "step_reward_heading", "step_reward_action_smooth"]:
|
||||
episode.user_data[x].append(info[x])
|
||||
|
||||
def on_episode_end(
|
||||
self, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode,
|
||||
**kwargs
|
||||
):
|
||||
arrive_dest = episode.last_info_for()["arrive_dest"]
|
||||
crash = episode.last_info_for()["crash"]
|
||||
out_of_road = episode.last_info_for()["out_of_road"]
|
||||
max_step_rate = not (arrive_dest or crash or out_of_road)
|
||||
episode.custom_metrics["success_rate"] = float(arrive_dest)
|
||||
episode.custom_metrics["crash_rate"] = float(crash)
|
||||
episode.custom_metrics["out_of_road_rate"] = float(out_of_road)
|
||||
episode.custom_metrics["max_step_rate"] = float(max_step_rate)
|
||||
episode.custom_metrics["velocity_max"] = float(np.max(episode.user_data["velocity"]))
|
||||
episode.custom_metrics["velocity_mean"] = float(np.mean(episode.user_data["velocity"]))
|
||||
episode.custom_metrics["velocity_min"] = float(np.min(episode.user_data["velocity"]))
|
||||
|
||||
episode.custom_metrics["lateral_dist_min"] = float(np.min(episode.user_data["lateral_dist"]))
|
||||
episode.custom_metrics["lateral_dist_max"] = float(np.max(episode.user_data["lateral_dist"]))
|
||||
episode.custom_metrics["lateral_dist_mean"] = float(np.mean(episode.user_data["lateral_dist"]))
|
||||
|
||||
episode.custom_metrics["steering_max"] = float(np.max(episode.user_data["steering"]))
|
||||
episode.custom_metrics["steering_mean"] = float(np.mean(episode.user_data["steering"]))
|
||||
episode.custom_metrics["steering_min"] = float(np.min(episode.user_data["steering"]))
|
||||
episode.custom_metrics["acceleration_min"] = float(np.min(episode.user_data["acceleration"]))
|
||||
episode.custom_metrics["acceleration_mean"] = float(np.mean(episode.user_data["acceleration"]))
|
||||
episode.custom_metrics["acceleration_max"] = float(np.max(episode.user_data["acceleration"]))
|
||||
episode.custom_metrics["step_reward_max"] = float(np.max(episode.user_data["step_reward"]))
|
||||
episode.custom_metrics["step_reward_mean"] = float(np.mean(episode.user_data["step_reward"]))
|
||||
episode.custom_metrics["step_reward_min"] = float(np.min(episode.user_data["step_reward"]))
|
||||
|
||||
episode.custom_metrics["cost"] = float(sum(episode.user_data["cost"]))
|
||||
for x in ["num_crash_vehicle", "num_crash_object", "num_crash_human", "num_on_line"]:
|
||||
episode.custom_metrics[x] = float(sum(episode.user_data[x]))
|
||||
|
||||
for x in ["step_reward_lateral", "step_reward_heading", "step_reward_action_smooth"]:
|
||||
episode.custom_metrics[x] = float(np.mean(episode.user_data[x]))
|
||||
|
||||
episode.custom_metrics["route_completion"] = float(episode.last_info_for()["route_completion"])
|
||||
episode.custom_metrics["curriculum_level"] = int(episode.last_info_for()["curriculum_level"])
|
||||
episode.custom_metrics["scenario_index"] = int(episode.last_info_for()["scenario_index"])
|
||||
episode.custom_metrics["track_length"] = float(episode.last_info_for()["track_length"])
|
||||
episode.custom_metrics["num_stored_maps"] = int(episode.last_info_for()["num_stored_maps"])
|
||||
episode.custom_metrics["scenario_difficulty"] = float(episode.last_info_for()["scenario_difficulty"])
|
||||
episode.custom_metrics["data_coverage"] = float(episode.last_info_for()["data_coverage"])
|
||||
episode.custom_metrics["curriculum_success"] = float(episode.last_info_for()["curriculum_success"])
|
||||
episode.custom_metrics["curriculum_route_completion"] = float(
|
||||
episode.last_info_for()["curriculum_route_completion"])
|
||||
|
||||
def on_train_result(self, *, trainer, result: dict, **kwargs):
|
||||
result["success"] = np.nan
|
||||
result["out"] = np.nan
|
||||
result["max_step"] = np.nan
|
||||
result["level"] = np.nan
|
||||
result["length"] = result["episode_len_mean"]
|
||||
result["coverage"] = np.nan
|
||||
if "custom_metrics" not in result:
|
||||
return
|
||||
|
||||
if "success_rate_mean" in result["custom_metrics"]:
|
||||
result["success"] = result["custom_metrics"]["success_rate_mean"]
|
||||
result["out"] = result["custom_metrics"]["out_of_road_rate_mean"]
|
||||
result["max_step"] = result["custom_metrics"]["max_step_rate_mean"]
|
||||
result["level"] = result["custom_metrics"]["curriculum_level_mean"]
|
||||
result["coverage"] = result["custom_metrics"]["data_coverage_mean"]
|
||||
11
scenarionet_training/train_utils/check_env.py
Normal file
11
scenarionet_training/train_utils/check_env.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from ray.rllib.utils import check_env
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from metadrive.envs.gymnasium_wrapper import GymnasiumEnvWrapper
|
||||
from gym import Env
|
||||
|
||||
if __name__ == '__main__':
|
||||
env = GymnasiumEnvWrapper.build(ScenarioEnv)()
|
||||
print(isinstance(ScenarioEnv, Env))
|
||||
print(isinstance(env, Env))
|
||||
print(env.observation_space)
|
||||
check_env(env)
|
||||
46
scenarionet_training/train_utils/multi_worker_PPO.py
Normal file
46
scenarionet_training/train_utils/multi_worker_PPO.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import logging
|
||||
from typing import Callable, Type
|
||||
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, \
|
||||
EnvType
|
||||
|
||||
from scenarionet_training.train_utils.anisotropic_workerset import AnisotropicWorkerSet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiWorkerPPO(PPOTrainer):
|
||||
"""
|
||||
In this class, each work will have different config for speeding up and saving memory. More importantly, it can
|
||||
allow us to cover all test/train cases more evenly
|
||||
"""
|
||||
|
||||
def _make_workers(self, env_creator: Callable[[EnvContext], EnvType],
|
||||
policy_class: Type[Policy], config: TrainerConfigDict,
|
||||
num_workers: int):
|
||||
"""Default factory method for a WorkerSet running under this Trainer.
|
||||
|
||||
Override this method by passing a custom `make_workers` into
|
||||
`build_trainer`.
|
||||
|
||||
Args:
|
||||
env_creator (callable): A function that return and Env given an env
|
||||
config.
|
||||
policy (Type[Policy]): The Policy class to use for creating the
|
||||
policies of the workers.
|
||||
config (TrainerConfigDict): The Trainer's config.
|
||||
num_workers (int): Number of remote rollout workers to create.
|
||||
0 for local only.
|
||||
|
||||
Returns:
|
||||
WorkerSet: The created WorkerSet.
|
||||
"""
|
||||
return AnisotropicWorkerSet(
|
||||
env_creator=env_creator,
|
||||
policy_class=policy_class,
|
||||
trainer_config=config,
|
||||
num_workers=num_workers,
|
||||
logdir=self.logdir)
|
||||
356
scenarionet_training/train_utils/utils.py
Normal file
356
scenarionet_training/train_utils/utils.py
Normal file
@@ -0,0 +1,356 @@
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from metadrive.constants import TerminationState
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from ray import tune
|
||||
from ray.tune import CLIReporter
|
||||
|
||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||
from scenarionet_training.wandb_utils import WANDB_KEY_FILE
|
||||
|
||||
root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
def get_api_key_file(wandb_key_file):
|
||||
if wandb_key_file is not None:
|
||||
default_path = os.path.expanduser(wandb_key_file)
|
||||
else:
|
||||
default_path = WANDB_KEY_FILE
|
||||
if os.path.exists(default_path):
|
||||
print("We are using this wandb key file: ", default_path)
|
||||
return default_path
|
||||
path = os.path.join(root, "scenarionet_training/wandb", "wandb_api_key_file.txt")
|
||||
print("We are using this wandb key file: ", path)
|
||||
return path
|
||||
|
||||
|
||||
def train(
|
||||
trainer,
|
||||
config,
|
||||
stop,
|
||||
exp_name,
|
||||
num_seeds=1,
|
||||
num_gpus=0,
|
||||
test_mode=False,
|
||||
suffix="",
|
||||
checkpoint_freq=10,
|
||||
keep_checkpoints_num=None,
|
||||
start_seed=0,
|
||||
local_mode=False,
|
||||
save_pkl=True,
|
||||
custom_callback=None,
|
||||
max_failures=0,
|
||||
wandb_key_file=None,
|
||||
wandb_project=None,
|
||||
wandb_team="drivingforce",
|
||||
wandb_log_config=True,
|
||||
init_kws=None,
|
||||
save_dir=None,
|
||||
**kwargs
|
||||
):
|
||||
init_kws = init_kws or dict()
|
||||
# initialize ray
|
||||
if not os.environ.get("redis_password"):
|
||||
initialize_ray(test_mode=test_mode, local_mode=local_mode, num_gpus=num_gpus, **init_kws)
|
||||
else:
|
||||
password = os.environ.get("redis_password")
|
||||
assert os.environ.get("ip_head")
|
||||
print(
|
||||
"We detect redis_password ({}) exists in environment! So "
|
||||
"we will start a ray cluster!".format(password)
|
||||
)
|
||||
if num_gpus:
|
||||
print(
|
||||
"We are in cluster mode! So GPU specification is disable and"
|
||||
" should be done when submitting task to cluster! You are "
|
||||
"requiring {} GPU for each machine!".format(num_gpus)
|
||||
)
|
||||
initialize_ray(address=os.environ["ip_head"], test_mode=test_mode, redis_password=password, **init_kws)
|
||||
|
||||
# prepare config
|
||||
|
||||
if custom_callback:
|
||||
callback = custom_callback
|
||||
else:
|
||||
from scenarionet_training.train_utils.callbacks import DrivingCallbacks
|
||||
callback = DrivingCallbacks
|
||||
|
||||
used_config = {
|
||||
"seed": tune.grid_search([i * 100 + start_seed for i in range(num_seeds)]) if num_seeds is not None else None,
|
||||
"log_level": "DEBUG" if test_mode else "INFO",
|
||||
"callbacks": callback
|
||||
}
|
||||
if custom_callback is False:
|
||||
used_config.pop("callbacks")
|
||||
if config:
|
||||
used_config.update(config)
|
||||
config = copy.deepcopy(used_config)
|
||||
|
||||
if isinstance(trainer, str):
|
||||
trainer_name = trainer
|
||||
elif hasattr(trainer, "_name"):
|
||||
trainer_name = trainer._name
|
||||
else:
|
||||
trainer_name = trainer.__name__
|
||||
|
||||
if not isinstance(stop, dict) and stop is not None:
|
||||
assert np.isscalar(stop)
|
||||
stop = {"timesteps_total": int(stop)}
|
||||
|
||||
if keep_checkpoints_num is not None and not test_mode:
|
||||
assert isinstance(keep_checkpoints_num, int)
|
||||
kwargs["keep_checkpoints_num"] = keep_checkpoints_num
|
||||
kwargs["checkpoint_score_attr"] = "episode_reward_mean"
|
||||
|
||||
if "verbose" not in kwargs:
|
||||
kwargs["verbose"] = 1 if not test_mode else 2
|
||||
|
||||
# This functionality is not supported yet!
|
||||
metric_columns = CLIReporter.DEFAULT_COLUMNS.copy()
|
||||
progress_reporter = CLIReporter(metric_columns=metric_columns)
|
||||
progress_reporter.add_metric_column("success")
|
||||
progress_reporter.add_metric_column("coverage")
|
||||
progress_reporter.add_metric_column("out")
|
||||
progress_reporter.add_metric_column("max_step")
|
||||
progress_reporter.add_metric_column("length")
|
||||
progress_reporter.add_metric_column("level")
|
||||
kwargs["progress_reporter"] = progress_reporter
|
||||
|
||||
if wandb_key_file is not None:
|
||||
assert wandb_project is not None
|
||||
if wandb_project is not None:
|
||||
assert wandb_project is not None
|
||||
failed_wandb = False
|
||||
try:
|
||||
from scenarionet_training.wandb_utils.our_wandb_callbacks import OurWandbLoggerCallback
|
||||
except Exception as e:
|
||||
# print("Please install wandb: pip install wandb")
|
||||
failed_wandb = True
|
||||
|
||||
if failed_wandb:
|
||||
from ray.tune.logger import DEFAULT_LOGGERS
|
||||
from scenarionet_training.wandb_utils.our_wandb_callbacks_ray100 import OurWandbLogger
|
||||
kwargs["loggers"] = DEFAULT_LOGGERS + (OurWandbLogger,)
|
||||
config["logger_config"] = {
|
||||
"wandb":
|
||||
{
|
||||
"group": exp_name,
|
||||
"exp_name": exp_name,
|
||||
"entity": wandb_team,
|
||||
"project": wandb_project,
|
||||
"api_key_file": get_api_key_file(wandb_key_file),
|
||||
"log_config": wandb_log_config,
|
||||
}
|
||||
}
|
||||
else:
|
||||
kwargs["callbacks"] = [
|
||||
OurWandbLoggerCallback(
|
||||
exp_name=exp_name,
|
||||
api_key_file=get_api_key_file(wandb_key_file),
|
||||
project=wandb_project,
|
||||
group=exp_name,
|
||||
log_config=wandb_log_config,
|
||||
entity=wandb_team
|
||||
)
|
||||
]
|
||||
|
||||
# start training
|
||||
analysis = tune.run(
|
||||
trainer,
|
||||
name=exp_name,
|
||||
checkpoint_freq=checkpoint_freq,
|
||||
checkpoint_at_end=True if "checkpoint_at_end" not in kwargs else kwargs.pop("checkpoint_at_end"),
|
||||
stop=stop,
|
||||
config=config,
|
||||
max_failures=max_failures if not test_mode else 0,
|
||||
reuse_actors=False,
|
||||
local_dir=save_dir or ".",
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# save training progress as insurance
|
||||
if save_pkl:
|
||||
pkl_path = "{}-{}{}.pkl".format(exp_name, trainer_name, "" if not suffix else "-" + suffix)
|
||||
with open(pkl_path, "wb") as f:
|
||||
data = analysis.fetch_trial_dataframes()
|
||||
pickle.dump(data, f)
|
||||
print("Result is saved at: <{}>".format(pkl_path))
|
||||
return analysis
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
def initialize_ray(local_mode=False, num_gpus=None, test_mode=False, **kwargs):
|
||||
os.environ['OMP_NUM_THREADS'] = '1'
|
||||
|
||||
if ray.__version__.split(".")[0] == "1": # 1.0 version Ray
|
||||
if "redis_password" in kwargs:
|
||||
redis_password = kwargs.pop("redis_password")
|
||||
kwargs["_redis_password"] = redis_password
|
||||
|
||||
ray.init(
|
||||
logging_level=logging.ERROR if not test_mode else logging.DEBUG,
|
||||
log_to_driver=test_mode,
|
||||
local_mode=local_mode,
|
||||
num_gpus=num_gpus,
|
||||
ignore_reinit_error=True,
|
||||
include_dashboard=False,
|
||||
**kwargs
|
||||
)
|
||||
print("Successfully initialize Ray!")
|
||||
try:
|
||||
print("Available resources: ", ray.available_resources())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_train_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--exp-name", type=str, default="")
|
||||
parser.add_argument("--num-gpus", type=int, default=0)
|
||||
parser.add_argument("--num-seeds", type=int, default=3)
|
||||
parser.add_argument("--num-cpus-per-worker", type=float, default=0.5)
|
||||
parser.add_argument("--num-gpus-per-trial", type=float, default=0.25)
|
||||
parser.add_argument("--test", action="store_true")
|
||||
return parser
|
||||
|
||||
|
||||
def setup_logger(debug=False):
|
||||
import logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if debug else logging.WARNING,
|
||||
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
||||
)
|
||||
|
||||
|
||||
def get_time_str():
|
||||
return datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
|
||||
|
||||
def get_exp_name(args):
|
||||
if args.exp_name != "":
|
||||
exp_name = args.exp_name + "_" + get_time_str()
|
||||
else:
|
||||
exp_name = "TEST"
|
||||
return exp_name
|
||||
|
||||
|
||||
def get_eval_config(config):
|
||||
eval_config = copy.deepcopy(config)
|
||||
eval_config.pop("evaluation_interval")
|
||||
eval_config.pop("evaluation_num_episodes")
|
||||
eval_config.pop("evaluation_config")
|
||||
eval_config.pop("evaluation_num_workers")
|
||||
return eval_config
|
||||
|
||||
|
||||
def get_function(ckpt, explore, config):
|
||||
trainer = MultiWorkerPPO(get_eval_config(config))
|
||||
trainer.restore(ckpt)
|
||||
|
||||
def _f(obs):
|
||||
ret = trainer.compute_actions({"default_policy": obs}, explore=explore)
|
||||
return ret
|
||||
|
||||
return _f
|
||||
|
||||
|
||||
def eval_ckpt(config,
|
||||
ckpt_path,
|
||||
scenario_data_path,
|
||||
num_scenarios,
|
||||
start_scenario_index,
|
||||
horizon=600,
|
||||
render=False,
|
||||
# PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance
|
||||
explore=True,
|
||||
log_interval=None,
|
||||
):
|
||||
initialize_ray(test_mode=False, num_gpus=1)
|
||||
# 27 29 30 37 39
|
||||
env_config = get_eval_config(config)["env_config"]
|
||||
env_config.update(dict(
|
||||
start_scenario_index=start_scenario_index,
|
||||
num_scenarios=num_scenarios,
|
||||
sequential_seed=True,
|
||||
curriculum_level=1, # disable curriculum
|
||||
target_success_rate=1,
|
||||
horizon=horizon,
|
||||
episodes_to_evaluate_curriculum=num_scenarios,
|
||||
data_directory=scenario_data_path,
|
||||
use_render=render))
|
||||
env = ScenarioEnv(env_config)
|
||||
|
||||
super_data = defaultdict(list)
|
||||
EPISODE_NUM = env.config["num_scenarios"]
|
||||
compute_actions = get_function(ckpt_path, explore=explore, config=config)
|
||||
|
||||
o = env.reset()
|
||||
assert env.current_seed == start_scenario_index, "Wrong start seed!"
|
||||
|
||||
total_cost = 0
|
||||
total_reward = 0
|
||||
success_rate = 0
|
||||
ep_cost = 0
|
||||
ep_reward = 0
|
||||
success_flag = False
|
||||
step = 0
|
||||
|
||||
def log_msg():
|
||||
print("CKPT:{} | success_rate:{}, mean_episode_reward:{}, mean_episode_cost:{}".format(epi_num,
|
||||
success_rate / epi_num,
|
||||
total_reward / epi_num,
|
||||
total_cost / epi_num))
|
||||
|
||||
for epi_num in tqdm.tqdm(range(0, EPISODE_NUM)):
|
||||
step += 1
|
||||
action_to_send = compute_actions(o)["default_policy"]
|
||||
o, r, d, info = env.step(action_to_send)
|
||||
if env.config["use_render"]:
|
||||
env.render(text={"reward": r})
|
||||
total_reward += r
|
||||
ep_reward += r
|
||||
total_cost += info["cost"]
|
||||
ep_cost += info["cost"]
|
||||
if d or step > horizon:
|
||||
if info["arrive_dest"]:
|
||||
success_rate += 1
|
||||
success_flag = True
|
||||
o = env.reset()
|
||||
|
||||
super_data[0].append(
|
||||
{"reward": ep_reward,
|
||||
"success": success_flag,
|
||||
"out_of_road": info[TerminationState.OUT_OF_ROAD],
|
||||
"cost": ep_cost,
|
||||
"seed": env.current_seed,
|
||||
"route_completion": info["route_completion"]})
|
||||
|
||||
ep_cost = 0.0
|
||||
ep_reward = 0.0
|
||||
success_flag = False
|
||||
step = 0
|
||||
|
||||
if log_interval is not None and epi_num % log_interval == 0:
|
||||
log_msg()
|
||||
if log_interval is not None:
|
||||
log_msg()
|
||||
del compute_actions
|
||||
env.close()
|
||||
with open("eval_ret_{}_{}_{}.json".format(start_scenario_index,
|
||||
start_scenario_index + num_scenarios,
|
||||
get_time_str()), "w") as f:
|
||||
json.dump(super_data, f)
|
||||
return super_data
|
||||
Reference in New Issue
Block a user