Skip to content

Commit

Permalink
Add train.py and val.py callbacks (#4220)
Browse files Browse the repository at this point in the history
* added callbacks

* Update callbacks.py

* Update train.py

* Update val.py

* Fix CamlCase add staticmethod

* Refactor logger into callbacks

* Cleanup

* New callback on_val_image_end()

* Add curves and results images to TensorBoard

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
kalenmike and glenn-jocher committed Jul 31, 2021
1 parent d8f1883 commit b74929c
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 41 deletions.
29 changes: 19 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
check_requirements, print_mutation, set_logging, one_cycle, colorstr
check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness
from utils.loggers import Loggers
from utils.callbacks import Callbacks

LOGGER = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
Expand All @@ -52,6 +53,7 @@
def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
callbacks=Callbacks()
):
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
Expand All @@ -77,12 +79,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Loggers
if RANK in [-1, 0]:
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER).start() # loggers dict
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
if loggers.wandb:
data_dict = loggers.wandb.data_dict
if resume:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp

# Register actions
for k in methods(loggers):
callbacks.register_action(k, callback=getattr(loggers, k))

# Config
plots = not evolve # create plots
cuda = device.type != 'cpu'
Expand Down Expand Up @@ -215,13 +221,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device))
if plots:
plot_labels(labels, names, save_dir, loggers)
plot_labels(labels, names, save_dir)

# Anchors
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
model.half().float() # pre-reduce anchor precision

callbacks.on_pretrain_routine_end()

# DDP mode
if cuda and RANK != -1:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
Expand Down Expand Up @@ -329,8 +337,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots)

callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots)
# end batch ------------------------------------------------------------------------------------------------

# Scheduler
Expand All @@ -339,7 +346,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

if RANK in [-1, 0]:
# mAP
loggers.on_train_epoch_end(epoch)
callbacks.on_train_epoch_end(epoch=epoch)
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not noval or final_epoch: # Calculate mAP
Expand All @@ -353,14 +360,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
loggers=loggers,
callbacks=callbacks,
compute_loss=compute_loss)

# Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness:
best_fitness = fi
loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi)
callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi)

# Save model
if (not nosave) or (final_epoch and not evolve): # if save
Expand All @@ -377,7 +384,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if best_fitness == fi:
torch.save(ckpt, best)
del ckpt
loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi)
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)

# end epoch ----------------------------------------------------------------------------------------------------
# end training -----------------------------------------------------------------------------------------------------
Expand All @@ -400,7 +407,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
loggers.on_train_end(last, best, plots)
callbacks.on_train_end(last, best, plots, epoch)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")

torch.cuda.empty_cache()
return results
Expand Down Expand Up @@ -448,6 +456,7 @@ def parse_opt(known=False):


def main(opt):
# Checks
set_logging(RANK)
if RANK in [-1, 0]:
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
Expand Down
176 changes: 176 additions & 0 deletions utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python

class Callbacks:
""""
Handles all registered callbacks for YOLOv5 Hooks
"""

_callbacks = {
'on_pretrain_routine_start': [],
'on_pretrain_routine_end': [],

'on_train_start': [],
'on_train_epoch_start': [],
'on_train_batch_start': [],
'optimizer_step': [],
'on_before_zero_grad': [],
'on_train_batch_end': [],
'on_train_epoch_end': [],

'on_val_start': [],
'on_val_batch_start': [],
'on_val_image_end': [],
'on_val_batch_end': [],
'on_val_end': [],

'on_fit_epoch_end': [], # fit = train + val
'on_model_save': [],
'on_train_end': [],

'teardown': [],
}

def __init__(self):
return

def register_action(self, hook, name='', callback=None):
"""
Register a new action to a callback hook
Args:
hook The callback hook name to register the action to
name The name of the action
callback The callback to fire
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
assert callable(callback), f"callback '{callback}' is not callable"
self._callbacks[hook].append({'name': name, 'callback': callback})

def get_registered_actions(self, hook=None):
""""
Returns all the registered actions by callback hook
Args:
hook The name of the hook to check, defaults to all
"""
if hook:
return self._callbacks[hook]
else:
return self._callbacks

@staticmethod
def run_callbacks(register, *args, **kwargs):
"""
Loop through the registered actions and fire all callbacks
"""
for logger in register:
# print(f"Running callbacks.{logger['callback'].__name__}()")
logger['callback'](*args, **kwargs)

def on_pretrain_routine_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each pretraining routine
"""
self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)

def on_pretrain_routine_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each pretraining routine
"""
self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)

def on_train_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training
"""
self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)

def on_train_epoch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training epoch
"""
self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)

def on_train_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training batch
"""
self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)

def optimizer_step(self, *args, **kwargs):
"""
Fires all registered callbacks on each optimizer step
"""
self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)

def on_before_zero_grad(self, *args, **kwargs):
"""
Fires all registered callbacks before zero grad
"""
self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)

def on_train_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training batch
"""
self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)

def on_train_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training epoch
"""
self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)

def on_val_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of the validation
"""
self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)

def on_val_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each validation batch
"""
self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)

def on_val_image_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each val image
"""
self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)

def on_val_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each validation batch
"""
self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)

def on_val_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of the validation
"""
self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)

def on_fit_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each fit (train+val) epoch
"""
self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)

def on_model_save(self, *args, **kwargs):
"""
Fires all registered callbacks after each model save
"""
self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)

def on_train_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of training
"""
self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)

def teardown(self, *args, **kwargs):
"""
Fires all registered callbacks before teardown
"""
self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
5 changes: 5 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def handler(*args, **kwargs):
return handler


def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]


def set_logging(rank=-1, verbose=True):
logging.basicConfig(
format="%(message)s",
Expand Down
Loading

0 comments on commit b74929c

Please sign in to comment.