Skip to content

Commit

Permalink
feat(engine): support wandb logger (Megvii-BaseDetection#1144)
Browse files Browse the repository at this point in the history
feat(engine): support wandb logger
  • Loading branch information
manangoel99 committed Feb 25, 2022
1 parent 4f90007 commit 4d99098
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 6 deletions.
20 changes: 20 additions & 0 deletions docs/quick_run.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache]
* --fp16: mixed precision training
* --cache: caching imgs into RAM to accelarate training, which need large system RAM.

**Weights & Biases for Logging**

To use W&B for logging, install wandb in your environment and log in to your W&B account using

```shell
pip install wandb
wandb login
```

Log in to your W&B account

To start logging metrics to W&B during training add the flag `--logger` to the previous command and use the prefix "wandb-" to specify arguments for initializing the wandb run.

```shell
python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb wandb-project <project name>
yolox-m
yolox-l
yolox-x
```

**Multi Machine Training**

We also support multi-nodes training. Just add the following args:
Expand Down
7 changes: 7 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ def make_parser():
action="store_true",
help="occupy GPU memory first for training.",
)
parser.add_argument(
"-l",
"--logger",
type=str,
help="Logger to be used for metrics",
default="tensorboard"
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
Expand Down
37 changes: 33 additions & 4 deletions yolox/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from yolox.utils import (
MeterBuffer,
ModelEMA,
WandbLogger,
all_reduce_norm,
get_local_rank,
get_model_info,
Expand Down Expand Up @@ -173,9 +174,18 @@ def before_train(self):
self.evaluator = self.exp.get_evaluator(
batch_size=self.args.batch_size, is_distributed=self.is_distributed
)
# Tensorboard logger
# Tensorboard and Wandb loggers
if self.rank == 0:
self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
if self.args.logger == "tensorboard":
self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
elif self.args.logger == "wandb":
wandb_params = dict()
for k, v in zip(self.args.opts[0::2], self.args.opts[1::2]):
if k.startswith("wandb-"):
wandb_params.update({k.lstrip("wandb-"): v})
self.wandb_logger = WandbLogger(config=vars(self.exp), **wandb_params)
else:
raise ValueError("logger must be either 'tensorboard' or 'wandb'")

logger.info("Training start...")
logger.info("\n{}".format(model))
Expand All @@ -184,6 +194,9 @@ def after_train(self):
logger.info(
"Training of experiment is done and the best AP is {:.2f}".format(self.best_ap * 100)
)
if self.rank == 0:
if self.args.logger == "wandb":
self.wandb_logger.finish()

def before_epoch(self):
logger.info("---> start train epoch{}".format(self.epoch + 1))
Expand Down Expand Up @@ -246,6 +259,12 @@ def after_iter(self):
)
+ (", size: {:d}, {}".format(self.input_size[0], eta_str))
)

if self.rank == 0:
if self.args.logger == "wandb":
self.wandb_logger.log_metrics({k: v.latest for k, v in loss_meter.items()})
self.wandb_logger.log_metrics({"lr": self.meter["lr"].latest})

self.meter.clear_meters()

# random resizing
Expand Down Expand Up @@ -309,8 +328,15 @@ def evaluate_and_save_model(self):

self.model.train()
if self.rank == 0:
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
if self.args.logger == "tensorboard":
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
if self.args.logger == "wandb":
self.wandb_logger.log_metrics({
"val/COCOAP50": ap50,
"val/COCOAP50_95": ap50_95,
"epoch": self.epoch + 1,
})
logger.info("\n" + summary)
synchronize()

Expand All @@ -334,3 +360,6 @@ def save_ckpt(self, ckpt_name, update_best_ckpt=False):
self.file_name,
ckpt_name,
)

if self.args.logger == "wandb":
self.wandb_logger.save_checkpoint(self.file_name, ckpt_name, update_best_ckpt)
2 changes: 1 addition & 1 deletion yolox/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .demo_utils import *
from .dist import *
from .ema import *
from .logger import setup_logger
from .logger import WandbLogger, setup_logger
from .lr_scheduler import LRScheduler
from .metric import *
from .model_utils import *
Expand Down
124 changes: 123 additions & 1 deletion yolox/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
import sys
from loguru import logger

import torch


def get_caller_name(depth=0):
"""
Args:
depth (int): Depth of caller conext, use 0 for caller depth. Default value: 0.
depth (int): Depth of caller conext, use 0 for caller depth.
Default value: 0.
Returns:
str: module name of the caller
Expand Down Expand Up @@ -93,3 +96,122 @@ def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):

# redirect stdout/stderr to loguru
redirect_sys_output("INFO")


class WandbLogger(object):
"""
Log training runs, datasets, models, and predictions to Weights & Biases.
This logger sends information to W&B at wandb.ai.
By default, this information includes hyperparameters,
system configuration and metrics, model metrics,
and basic data metrics and analyses.
For more information, please refer to:
https://docs.wandb.ai/guides/track
"""
def __init__(self,
project=None,
name=None,
id=None,
entity=None,
save_dir=None,
config=None,
**kwargs):
"""
Args:
project (str): wandb project name.
name (str): wandb run name.
id (str): wandb run id.
entity (str): wandb entity name.
save_dir (str): save directory.
config (dict): config dict.
**kwargs: other kwargs.
"""
try:
import wandb
self.wandb = wandb
except ModuleNotFoundError:
raise ModuleNotFoundError(
"wandb is not installed."
"Please install wandb using pip install wandb"
)

self.project = project
self.name = name
self.id = id
self.save_dir = save_dir
self.config = config
self.kwargs = kwargs
self.entity = entity
self._run = None
self._wandb_init = dict(
project=self.project,
name=self.name,
id=self.id,
entity=self.entity,
dir=self.save_dir,
resume="allow"
)
self._wandb_init.update(**kwargs)

_ = self.run

if self.config:
self.run.config.update(self.config)
self.run.define_metric("epoch")
self.run.define_metric("val/", step_metric="epoch")

@property
def run(self):
if self._run is None:
if self.wandb.run is not None:
logger.info(
"There is a wandb run already in progress "
"and newly created instances of `WandbLogger` will reuse"
" this run. If this is not desired, call `wandb.finish()`"
"before instantiating `WandbLogger`."
)
self._run = self.wandb.run
else:
self._run = self.wandb.init(**self._wandb_init)
return self._run

def log_metrics(self, metrics, step=None):
"""
Args:
metrics (dict): metrics dict.
step (int): step number.
"""

for k, v in metrics.items():
if isinstance(v, torch.Tensor):
metrics[k] = v.item()

if step is not None:
self.run.log(metrics, step=step)
else:
self.run.log(metrics)

def save_checkpoint(self, save_dir, model_name, is_best):
"""
Args:
save_dir (str): save directory.
model_name (str): model name.
is_best (bool): whether the model is the best model.
"""
filename = os.path.join(save_dir, model_name + "_ckpt.pth")
artifact = self.wandb.Artifact(
name=f"model-{self.run.id}",
type="model"
)
artifact.add_file(filename, name="model_ckpt.pth")

aliases = ["latest"]

if is_best:
aliases.append("best")

self.run.log_artifact(artifact, aliases=aliases)

def finish(self):
self.run.finish()

0 comments on commit 4d99098

Please sign in to comment.