Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for wandb #1144

Merged
merged 7 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/quick_run.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache]
* --fp16: mixed precision training
* --cache: caching imgs into RAM to accelarate training, which need large system RAM.

**Weights & Biases for Logging**

To use W&B for logging, install wandb in your environment and log in to your W&B account using

```shell
pip install wandb
wandb login
```

Log in to your W&B account

To start logging metrics to W&B during training add the flag `--logger` to the previous command and use the prefix "wandb-" to specify arguments for initializing the wandb run.

```shell
python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb wandb-project <project name>
FateScript marked this conversation as resolved.
Show resolved Hide resolved
yolox-m
yolox-l
yolox-x
```

**Multi Machine Training**

We also support multi-nodes training. Just add the following args:
Expand Down
7 changes: 7 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ def make_parser():
action="store_true",
help="occupy GPU memory first for training.",
)
parser.add_argument(
"-l",
"--logger",
type=str,
help="Logger to be used for metrics",
default="tensorboard"
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
Expand Down
37 changes: 33 additions & 4 deletions yolox/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from yolox.utils import (
MeterBuffer,
ModelEMA,
WandbLogger,
all_reduce_norm,
get_local_rank,
get_model_info,
Expand Down Expand Up @@ -173,9 +174,18 @@ def before_train(self):
self.evaluator = self.exp.get_evaluator(
batch_size=self.args.batch_size, is_distributed=self.is_distributed
)
# Tensorboard logger
# Tensorboard and Wandb loggers
if self.rank == 0:
self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
if self.args.logger == "tensorboard":
self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
elif self.args.logger == "wandb":
wandb_params = dict()
for k, v in zip(self.args.opts[0::2], self.args.opts[1::2]):
if k.startswith("wandb-"):
wandb_params.update({k.lstrip("wandb-"): v})
self.wandb_logger = WandbLogger(config=vars(self.exp), **wandb_params)
else:
raise ValueError("logger must be either 'tensorboard' or 'wandb'")

logger.info("Training start...")
logger.info("\n{}".format(model))
Expand All @@ -184,6 +194,9 @@ def after_train(self):
logger.info(
"Training of experiment is done and the best AP is {:.2f}".format(self.best_ap * 100)
)
if self.rank == 0:
if self.args.logger == "wandb":
self.wandb_logger.finish()

def before_epoch(self):
logger.info("---> start train epoch{}".format(self.epoch + 1))
Expand Down Expand Up @@ -246,6 +259,12 @@ def after_iter(self):
)
+ (", size: {:d}, {}".format(self.input_size[0], eta_str))
)

if self.rank == 0:
if self.args.logger == "wandb":
self.wandb_logger.log_metrics({k: v.latest for k, v in loss_meter.items()})
self.wandb_logger.log_metrics({"lr": self.meter["lr"].latest})

self.meter.clear_meters()

# random resizing
Expand Down Expand Up @@ -309,8 +328,15 @@ def evaluate_and_save_model(self):

self.model.train()
if self.rank == 0:
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
if self.args.logger == "tensorboard":
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
if self.args.logger == "wandb":
self.wandb_logger.log_metrics({
"val/COCOAP50": ap50,
"val/COCOAP50_95": ap50_95,
"epoch": self.epoch + 1,
})
logger.info("\n" + summary)
synchronize()

Expand All @@ -334,3 +360,6 @@ def save_ckpt(self, ckpt_name, update_best_ckpt=False):
self.file_name,
ckpt_name,
)

if self.args.logger == "wandb":
self.wandb_logger.save_checkpoint(self.file_name, ckpt_name, update_best_ckpt)
2 changes: 1 addition & 1 deletion yolox/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .demo_utils import *
from .dist import *
from .ema import *
from .logger import setup_logger
from .logger import WandbLogger, setup_logger
from .lr_scheduler import LRScheduler
from .metric import *
from .model_utils import *
Expand Down
124 changes: 123 additions & 1 deletion yolox/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
import sys
from loguru import logger

import torch


def get_caller_name(depth=0):
"""
Args:
depth (int): Depth of caller conext, use 0 for caller depth. Default value: 0.
depth (int): Depth of caller conext, use 0 for caller depth.
Default value: 0.
Returns:
str: module name of the caller
Expand Down Expand Up @@ -93,3 +96,122 @@ def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):

# redirect stdout/stderr to loguru
redirect_sys_output("INFO")


class WandbLogger(object):
"""
Log training runs, datasets, models, and predictions to Weights & Biases.
This logger sends information to W&B at wandb.ai.
By default, this information includes hyperparameters,
system configuration and metrics, model metrics,
and basic data metrics and analyses.
For more information, please refer to:
https://docs.wandb.ai/guides/track
"""
def __init__(self,
project=None,
name=None,
id=None,
entity=None,
save_dir=None,
config=None,
**kwargs):
"""
Args:
project (str): wandb project name.
name (str): wandb run name.
id (str): wandb run id.
entity (str): wandb entity name.
save_dir (str): save directory.
config (dict): config dict.
**kwargs: other kwargs.
"""
try:
import wandb
self.wandb = wandb
except ModuleNotFoundError:
raise ModuleNotFoundError(
"wandb is not installed."
"Please install wandb using pip install wandb"
)

self.project = project
self.name = name
self.id = id
self.save_dir = save_dir
self.config = config
self.kwargs = kwargs
self.entity = entity
self._run = None
self._wandb_init = dict(
project=self.project,
name=self.name,
id=self.id,
entity=self.entity,
dir=self.save_dir,
resume="allow"
)
self._wandb_init.update(**kwargs)

_ = self.run

if self.config:
self.run.config.update(self.config)
self.run.define_metric("epoch")
self.run.define_metric("val/", step_metric="epoch")

@property
def run(self):
if self._run is None:
FateScript marked this conversation as resolved.
Show resolved Hide resolved
if self.wandb.run is not None:
logger.info(
"There is a wandb run already in progress "
"and newly created instances of `WandbLogger` will reuse"
" this run. If this is not desired, call `wandb.finish()`"
"before instantiating `WandbLogger`."
)
self._run = self.wandb.run
else:
self._run = self.wandb.init(**self._wandb_init)
return self._run

def log_metrics(self, metrics, step=None):
"""
Args:
metrics (dict): metrics dict.
step (int): step number.
"""

for k, v in metrics.items():
if isinstance(v, torch.Tensor):
metrics[k] = v.item()

if step is not None:
self.run.log(metrics, step=step)
else:
self.run.log(metrics)

def save_checkpoint(self, save_dir, model_name, is_best):
"""
Args:
save_dir (str): save directory.
model_name (str): model name.
is_best (bool): whether the model is the best model.
"""
filename = os.path.join(save_dir, model_name + "_ckpt.pth")
artifact = self.wandb.Artifact(
name=f"model-{self.run.id}",
type="model"
)
artifact.add_file(filename, name="model_ckpt.pth")

aliases = ["latest"]

if is_best:
aliases.append("best")

self.run.log_artifact(artifact, aliases=aliases)

def finish(self):
self.run.finish()