Files
scenarionet/scenarionet_training/train_utils/utils.py
Quanyi Li 64bf9811b9 Rebuttal (#15)
* pg+nuplan train

* Need map

* use gym wrapper

* use createGymWrapper

* doc

* use all scenarios!

* update 80000 scenario

* train script
2023-08-08 17:33:02 +01:00

358 lines
12 KiB
Python

import copy
import datetime
import json
import os
import pickle
from collections import defaultdict
import numpy as np
import tqdm
from metadrive.constants import TerminationState
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.envs.gym_wrapper import createGymWrapper
from ray import tune
from ray.tune import CLIReporter
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
from scenarionet_training.wandb_utils import WANDB_KEY_FILE
root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
def get_api_key_file(wandb_key_file):
if wandb_key_file is not None:
default_path = os.path.expanduser(wandb_key_file)
else:
default_path = WANDB_KEY_FILE
if os.path.exists(default_path):
print("We are using this wandb key file: ", default_path)
return default_path
path = os.path.join(root, "scenarionet_training/wandb", "wandb_api_key_file.txt")
print("We are using this wandb key file: ", path)
return path
def train(
trainer,
config,
stop,
exp_name,
num_seeds=1,
num_gpus=0,
test_mode=False,
suffix="",
checkpoint_freq=10,
keep_checkpoints_num=None,
start_seed=0,
local_mode=False,
save_pkl=True,
custom_callback=None,
max_failures=0,
wandb_key_file=None,
wandb_project=None,
wandb_team="drivingforce",
wandb_log_config=True,
init_kws=None,
save_dir=None,
**kwargs
):
init_kws = init_kws or dict()
# initialize ray
if not os.environ.get("redis_password"):
initialize_ray(test_mode=test_mode, local_mode=local_mode, num_gpus=num_gpus, **init_kws)
else:
password = os.environ.get("redis_password")
assert os.environ.get("ip_head")
print(
"We detect redis_password ({}) exists in environment! So "
"we will start a ray cluster!".format(password)
)
if num_gpus:
print(
"We are in cluster mode! So GPU specification is disable and"
" should be done when submitting task to cluster! You are "
"requiring {} GPU for each machine!".format(num_gpus)
)
initialize_ray(address=os.environ["ip_head"], test_mode=test_mode, redis_password=password, **init_kws)
# prepare config
if custom_callback:
callback = custom_callback
else:
from scenarionet_training.train_utils.callbacks import DrivingCallbacks
callback = DrivingCallbacks
used_config = {
"seed": tune.grid_search([i * 100 + start_seed for i in range(num_seeds)]) if num_seeds is not None else None,
"log_level": "DEBUG" if test_mode else "INFO",
"callbacks": callback
}
if custom_callback is False:
used_config.pop("callbacks")
if config:
used_config.update(config)
config = copy.deepcopy(used_config)
if isinstance(trainer, str):
trainer_name = trainer
elif hasattr(trainer, "_name"):
trainer_name = trainer._name
else:
trainer_name = trainer.__name__
if not isinstance(stop, dict) and stop is not None:
assert np.isscalar(stop)
stop = {"timesteps_total": int(stop)}
if keep_checkpoints_num is not None and not test_mode:
assert isinstance(keep_checkpoints_num, int)
kwargs["keep_checkpoints_num"] = keep_checkpoints_num
kwargs["checkpoint_score_attr"] = "episode_reward_mean"
if "verbose" not in kwargs:
kwargs["verbose"] = 1 if not test_mode else 2
# This functionality is not supported yet!
metric_columns = CLIReporter.DEFAULT_COLUMNS.copy()
progress_reporter = CLIReporter(metric_columns=metric_columns)
progress_reporter.add_metric_column("success")
progress_reporter.add_metric_column("coverage")
progress_reporter.add_metric_column("out")
progress_reporter.add_metric_column("max_step")
progress_reporter.add_metric_column("length")
progress_reporter.add_metric_column("level")
kwargs["progress_reporter"] = progress_reporter
if wandb_key_file is not None:
assert wandb_project is not None
if wandb_project is not None:
assert wandb_project is not None
failed_wandb = False
try:
from scenarionet_training.wandb_utils.our_wandb_callbacks import OurWandbLoggerCallback
except Exception as e:
# print("Please install wandb: pip install wandb")
failed_wandb = True
if failed_wandb:
from ray.tune.logger import DEFAULT_LOGGERS
from scenarionet_training.wandb_utils.our_wandb_callbacks_ray100 import OurWandbLogger
kwargs["loggers"] = DEFAULT_LOGGERS + (OurWandbLogger,)
config["logger_config"] = {
"wandb":
{
"group": exp_name,
"exp_name": exp_name,
"entity": wandb_team,
"project": wandb_project,
"api_key_file": get_api_key_file(wandb_key_file),
"log_config": wandb_log_config,
}
}
else:
kwargs["callbacks"] = [
OurWandbLoggerCallback(
exp_name=exp_name,
api_key_file=get_api_key_file(wandb_key_file),
project=wandb_project,
group=exp_name,
log_config=wandb_log_config,
entity=wandb_team
)
]
# start training
analysis = tune.run(
trainer,
name=exp_name,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=True if "checkpoint_at_end" not in kwargs else kwargs.pop("checkpoint_at_end"),
stop=stop,
config=config,
max_failures=max_failures if not test_mode else 0,
reuse_actors=False,
local_dir=save_dir or ".",
**kwargs
)
# save training progress as insurance
if save_pkl:
pkl_path = "{}-{}{}.pkl".format(exp_name, trainer_name, "" if not suffix else "-" + suffix)
with open(pkl_path, "wb") as f:
data = analysis.fetch_trial_dataframes()
pickle.dump(data, f)
print("Result is saved at: <{}>".format(pkl_path))
return analysis
import argparse
import logging
import os
import ray
def initialize_ray(local_mode=False, num_gpus=None, test_mode=False, **kwargs):
os.environ['OMP_NUM_THREADS'] = '1'
if ray.__version__.split(".")[0] == "1": # 1.0 version Ray
if "redis_password" in kwargs:
redis_password = kwargs.pop("redis_password")
kwargs["_redis_password"] = redis_password
ray.init(
logging_level=logging.ERROR if not test_mode else logging.DEBUG,
log_to_driver=test_mode,
local_mode=local_mode,
num_gpus=num_gpus,
ignore_reinit_error=True,
include_dashboard=False,
**kwargs
)
print("Successfully initialize Ray!")
try:
print("Available resources: ", ray.available_resources())
except Exception:
pass
def get_train_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default="")
parser.add_argument("--num-gpus", type=int, default=0)
parser.add_argument("--num-seeds", type=int, default=3)
parser.add_argument("--num-cpus-per-worker", type=float, default=0.5)
parser.add_argument("--num-gpus-per-trial", type=float, default=0.25)
parser.add_argument("--test", action="store_true")
return parser
def setup_logger(debug=False):
import logging
logging.basicConfig(
level=logging.DEBUG if debug else logging.WARNING,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
)
def get_time_str():
return datetime.datetime.now().strftime("%y%m%d-%H%M%S")
def get_exp_name(args):
if args.exp_name != "":
exp_name = args.exp_name + "_" + get_time_str()
else:
exp_name = "TEST"
return exp_name
def get_eval_config(config):
eval_config = copy.deepcopy(config)
eval_config.pop("evaluation_interval")
eval_config.pop("evaluation_num_episodes")
eval_config.pop("evaluation_config")
eval_config.pop("evaluation_num_workers")
return eval_config
def get_function(ckpt, explore, config):
trainer = MultiWorkerPPO(get_eval_config(config))
trainer.restore(ckpt)
def _f(obs):
ret = trainer.compute_actions({"default_policy": obs}, explore=explore)
return ret
return _f
def eval_ckpt(config,
ckpt_path,
scenario_data_path,
num_scenarios,
start_scenario_index,
horizon=600,
render=False,
# PPO is a stochastic policy, turning off exploration can reduce jitter but may harm performance
explore=True,
log_interval=None,
):
initialize_ray(test_mode=False, num_gpus=1)
# 27 29 30 37 39
env_config = get_eval_config(config)["env_config"]
env_config.update(dict(
start_scenario_index=start_scenario_index,
num_scenarios=num_scenarios,
sequential_seed=True,
curriculum_level=1, # disable curriculum
target_success_rate=1,
horizon=horizon,
episodes_to_evaluate_curriculum=num_scenarios,
data_directory=scenario_data_path,
use_render=render))
env = createGymWrapper(ScenarioEnv)(env_config)
super_data = defaultdict(list)
EPISODE_NUM = env.config["num_scenarios"]
compute_actions = get_function(ckpt_path, explore=explore, config=config)
o = env.reset()
assert env.current_seed == start_scenario_index, "Wrong start seed!"
total_cost = 0
total_reward = 0
success_rate = 0
ep_cost = 0
ep_reward = 0
success_flag = False
step = 0
def log_msg():
print("CKPT:{} | success_rate:{}, mean_episode_reward:{}, mean_episode_cost:{}".format(epi_num,
success_rate / epi_num,
total_reward / epi_num,
total_cost / epi_num))
for epi_num in tqdm.tqdm(range(0, EPISODE_NUM)):
step += 1
action_to_send = compute_actions(o)["default_policy"]
o, r, d, info = env.step(action_to_send)
if env.config["use_render"]:
env.render(text={"reward": r})
total_reward += r
ep_reward += r
total_cost += info["cost"]
ep_cost += info["cost"]
if d or step > horizon:
if info["arrive_dest"]:
success_rate += 1
success_flag = True
o = env.reset()
super_data[0].append(
{"reward": ep_reward,
"success": success_flag,
"out_of_road": info[TerminationState.OUT_OF_ROAD],
"cost": ep_cost,
"seed": env.current_seed,
"route_completion": info["route_completion"]})
ep_cost = 0.0
ep_reward = 0.0
success_flag = False
step = 0
if log_interval is not None and epi_num % log_interval == 0:
log_msg()
if log_interval is not None:
log_msg()
del compute_actions
env.close()
with open("eval_ret_{}_{}_{}.json".format(start_scenario_index,
start_scenario_index + num_scenarios,
get_time_str()), "w") as f:
json.dump(super_data, f)
return super_data