Files
scenarionet/scenarionet_training/wandb_utils/our_wandb_callbacks_ray100.py
Quanyi Li db50bca7fd Add come updates for Neurips paper (#4)
* scenarionet training

* wandb

* train utils

* fix callback

* run PPO

* use pg test

* save path

* use torch

* add dependency

* update ignore

* update training

* large model

* use curriculum training

* add time to exp name

* storage_path

* restore

* update training

* use my key

* add log message

* check seed

* restore callback

* restore call bacl

* add log message

* add logging message

* restore ray1.4

* length 500

* ray 100

* wandb

* use tf

* more levels

* add callback

* 10 worker

* show level

* no env horizon

* callback result level

* more call back

* add diffuculty

* add mroen stat

* mroe stat

* show levels

* add callback

* new

* ep len 600

* fix setup

* fix stepup

* fix to 3.8

* update setup

* parallel worker!

* new exp

* add callback

* lateral dist

* pg dataset

* evaluate

* modify config

* align config

* train single RL

* update training script

* 100w eval

* less eval to reveal

* 2000 env eval

* new trianing

* eval 1000

* update eval

* more workers

* more worker

* 20 worker

* dataset to database

* split tool!

* split dataset

* try fix

* train 003

* fix mapping

* fix test

* add waymo tqdm

* utils

* fix bug

* fix bug

* waymo

* int type

* 8 worker read

* disable

* read file

* add log message

* check existence

* dist 0

* int

* check num

* suprass warning

* add filter API

* filter

* store map false

* new

* ablation

* filter

* fix

* update filyter

* reanme to from

* random select

* add overlapping checj

* fix

* new training sceheme

* new reward

* add waymo train script

* waymo different config

* copy raw data

* fix bug

* add tqdm

* update readme

* waymo

* pg

* max lateral dist 3

* pg

* crash_done instead of penalty

* no crash done

* gpu

* update eval script

* steering range penalty

* evaluate

* finish pg

* update setup

* fix bug

* test

* fix

* add on line

* train nuplan

* generate sensor

* udpate training

* static obj

* multi worker eval

* filx bug

* use ray for testing

* eval!

* filter senario

* id filter

* fox bug

* dist = 2

* filter

* eval

* eval ret

* ok

* update training pg

* test before use

* store data=False

* collect figures

* capture pic

---------

Co-authored-by: Quanyi Li <quanyi@bolei-gpu02.cs.ucla.edu>
2023-06-10 18:56:33 +01:00

81 lines
2.6 KiB
Python

from multiprocessing import Queue
from ray.tune.integration.wandb import WandbLogger, _clean_log, _set_api_key
class OurWandbLogger(WandbLogger):
def __init__(self, config, logdir, trial):
self.exp_name = config["logger_config"]["wandb"].pop("exp_name")
super(OurWandbLogger, self).__init__(config, logdir, trial)
def _init(self):
config = self.config.copy()
config.pop("callbacks", None) # Remove callbacks
try:
if config.get("logger_config", {}).get("wandb"):
logger_config = config.pop("logger_config")
wandb_config = logger_config.get("wandb").copy()
else:
wandb_config = config.pop("wandb").copy()
except KeyError:
raise ValueError(
"Wandb logger specified but no configuration has been passed. "
"Make sure to include a `wandb` key in your `config` dict "
"containing at least a `project` specification.")
_set_api_key(wandb_config)
exclude_results = self._exclude_results.copy()
# Additional excludes
additional_excludes = wandb_config.pop("excludes", [])
exclude_results += additional_excludes
# Log config keys on each result?
log_config = wandb_config.pop("log_config", False)
if not log_config:
exclude_results += ["config"]
# Fill trial ID and name
trial_id = self.trial.trial_id if self.trial else None
trial_name = str(self.trial) if self.trial else None
# Project name for Wandb
try:
wandb_project = wandb_config.pop("project")
except KeyError:
raise ValueError(
"You need to specify a `project` in your wandb `config` dict.")
# Grouping
wandb_group = wandb_config.pop(
"group", self.trial.trainable_name if self.trial else None)
# remove unpickleable items!
config = _clean_log(config)
assert trial_id is not None
run_name = "{}_{}".format(self.exp_name, trial_id)
wandb_init_kwargs = dict(
id=trial_id,
name=run_name,
resume=True,
reinit=True,
allow_val_change=True,
group=wandb_group,
project=wandb_project,
config=config)
wandb_init_kwargs.update(wandb_config)
self._queue = Queue()
self._wandb = self._logger_process_cls(
queue=self._queue,
exclude=exclude_results,
to_config=self._config_results,
**wandb_init_kwargs)
self._wandb.start()