Skip to content

Commit

Permalink
Refactor train.py and val.py loggers (#4137)
Browse files Browse the repository at this point in the history
* Update loggers

* Config

* Update val.py

* cleanup

* fix1

* fix2

* fix3 and reformat

* format sweep.py

* Logger() class

* cleanup

* cleanup2

* wandb package import fix

* wandb package import fix2

* txt fix

* fix4

* fix5

* fix6

* drop wandb into utils/loggers

* fix 7

* rename loggers/wandb_logging to loggers/wandb

* Update message

* Update message

* Update message

* cleanup

* Fix x axis bug

* fix rank 0 issue

* cleanup
  • Loading branch information
glenn-jocher committed Jul 24, 2021
1 parent 63dd65e commit efe60b5
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 91 deletions.
87 changes: 21 additions & 66 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import random
import sys
import time
import warnings
from copy import deepcopy
from pathlib import Path
from threading import Thread
Expand All @@ -24,7 +23,6 @@
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam, SGD, lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

FILE = Path(__file__).absolute()
Expand All @@ -42,8 +40,9 @@
from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness
from utils.loggers import Loggers

LOGGER = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
Expand Down Expand Up @@ -76,37 +75,23 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
with open(save_dir / 'opt.yaml', 'w') as f:
yaml.safe_dump(vars(opt), f, sort_keys=False)

# Configure
# Config
plots = not evolve # create plots
cuda = device.type != 'cpu'
init_seeds(1 + RANK)
with open(data) as f:
data_dict = yaml.safe_load(f) # data dict

# Loggers
loggers = {'wandb': None, 'tb': None} # loggers dict
if RANK in [-1, 0]:
# TensorBoard
if plots:
prefix = colorstr('tensorboard: ')
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
loggers['tb'] = SummaryWriter(str(save_dir))

# W&B
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
run_id = run_id if opt.resume else None # start fresh run if transfer learning
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb
if loggers['wandb']:
data_dict = wandb_logger.data_dict
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update values if resuming

nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset

# Loggers
if RANK in [-1, 0]:
loggers = Loggers(save_dir, results_file, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
if loggers.wandb and resume:
weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict

# Model
pretrained = weights.endswith('.pt')
if pretrained:
Expand Down Expand Up @@ -351,30 +336,24 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
pbar.set_description(s)

# Plot
if plots and ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if loggers['tb'] and ni == 0: # TensorBoard
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
elif plots and ni == 10 and loggers['wandb']:
wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]})
if plots:
if ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
loggers.on_train_batch_end(ni, model, imgs)

# end batch ------------------------------------------------------------------------------------------------

# Scheduler
lr = [x['lr'] for x in optimizer.param_groups] # for loggers
scheduler.step()

# DDP process 0 or single-GPU
if RANK in [-1, 0]:
# mAP
loggers.on_train_epoch_end(epoch)
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not noval or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1
results, maps, _ = val.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz,
Expand All @@ -385,29 +364,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
wandb_logger=wandb_logger,
loggers=loggers,
compute_loss=compute_loss)

# Write
with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss

# Log
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
if loggers['tb']:
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
if loggers['wandb']:
wandb_logger.log({tag: x}) # W&B

# Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness:
best_fitness = fi
wandb_logger.end_epoch(best_result=best_fitness == fi)
loggers.on_train_val_end(mloss, results, lr, epoch, s, best_fitness, fi)

# Save model
if (not nosave) or (final_epoch and not evolve): # if save
Expand All @@ -418,27 +382,21 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
'ema': deepcopy(ema.ema).half(),
'updates': ema.updates,
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None}
'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None}

# Save last, best and delete
torch.save(ckpt, last)
if best_fitness == fi:
torch.save(ckpt, best)
if loggers['wandb']:
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
del ckpt
loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi)

# end epoch ----------------------------------------------------------------------------------------------------
# end training -----------------------------------------------------------------------------------------------------
if RANK in [-1, 0]:
LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots:
plot_results(save_dir=save_dir) # save as results.png
if loggers['wandb']:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})

if not evolve:
if is_coco: # COCO dataset
Expand All @@ -458,11 +416,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
if loggers['wandb']: # Log the stripped model
loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run()

loggers.on_train_end(last, best)

torch.cuda.empty_cache()
return results
Expand Down
129 changes: 129 additions & 0 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# YOLOv5 experiment logging utils

import warnings

import torch
from torch.utils.tensorboard import SummaryWriter

from utils.general import colorstr, emojis
from utils.loggers.wandb.wandb_utils import WandbLogger
from utils.torch_utils import de_parallel

LOGGERS = ('txt', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases

try:
import wandb

assert hasattr(wandb, '__version__') # verify package import not local dir
except (ImportError, AssertionError):
wandb = None


class Loggers():
# YOLOv5 Loggers class
def __init__(self, save_dir=None, results_file=None, weights=None, opt=None, hyp=None,
data_dict=None, logger=None, include=LOGGERS):
self.save_dir = save_dir
self.results_file = results_file
self.weights = weights
self.opt = opt
self.hyp = hyp
self.data_dict = data_dict
self.logger = logger # for printing results to console
self.include = include
for k in LOGGERS:
setattr(self, k, None) # init empty logger dictionary

def start(self):
self.txt = True # always log to txt

# Message
try:
import wandb
except ImportError:
prefix = colorstr('Weights & Biases: ')
s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 πŸš€ runs (RECOMMENDED)"
print(emojis(s))

# TensorBoard
s = self.save_dir
if 'tb' in self.include and not self.opt.evolve:
prefix = colorstr('TensorBoard: ')
self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
self.tb = SummaryWriter(str(s))

# W&B
try:
assert 'wandb' in self.include and wandb
run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume else None
self.opt.hyp = self.hyp # add hyperparameters
self.wandb = WandbLogger(self.opt, s.stem, run_id, self.data_dict)
except:
self.wandb = None

return self

def on_train_batch_end(self, ni, model, imgs):
# Callback runs on train batch end
if ni == 0:
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})

def on_train_epoch_end(self, epoch):
# Callback runs on train epoch end
if self.wandb:
self.wandb.current_epoch = epoch + 1

def on_val_batch_end(self, pred, predn, path, names, im):
# Callback runs on train batch end
if self.wandb:
self.wandb.val_one_image(pred, predn, path, names, im)

def on_val_end(self):
# Callback runs on val end
if self.wandb:
files = sorted(self.save_dir.glob('val*.jpg'))
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})

def on_train_val_end(self, mloss, results, lr, epoch, s, best_fitness, fi):
# Callback runs on validation end during training
vals = list(mloss[:-1]) + list(results) + lr
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
if self.txt:
with open(self.results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
if self.tb:
for x, tag in zip(vals, tags):
self.tb.add_scalar(tag, x, epoch) # TensorBoard
if self.wandb:
self.wandb.log({k: v for k, v in zip(tags, vals)})
self.wandb.end_epoch(best_result=best_fitness == fi)

def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
# Callback runs on model save event
if self.wandb:
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)

def on_train_end(self, last, best):
# Callback runs on training end
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
if self.wandb:
wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
wandb.log_artifact(str(best if best.exists() else last), type='model',
name='run_' + self.wandb.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
self.wandb.finish_run()

def log_images(self, paths):
# Log images
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import sys
from pathlib import Path

import wandb

FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[2].as_posix()) # add utils/ to path

from train import train, parse_opt
import test
from utils.general import increment_path
from utils.torch_utils import select_device

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You can use grid, bayesian and hyperopt search strategy
# For more info on configuring sweeps visit - https://docs.wandb.ai/guides/sweeps/configuration

program: utils/wandb_logging/sweep.py
program: utils/loggers/wandb/sweep.py
method: random
metric:
name: metrics/mAP_0.5
Expand Down
Loading

0 comments on commit efe60b5

Please sign in to comment.