Skip to content

Commit

Permalink
Optimised Callback Class to Reduce Code and Fix Errors (ultralytics#4688
Browse files Browse the repository at this point in the history
)

* added callbacks

* added back callback to main

* added save_dir to callback output

* reduced code count

* updated callbacks

* added default callback class to main, added missing parameters to on_model_save

* Glenn updates

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
2 people authored and CesarBazanAV committed Sep 29, 2021
1 parent 9cb0aee commit aa35ada
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 125 deletions.
20 changes: 10 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
callbacks=Callbacks()
callbacks
):
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
Expand Down Expand Up @@ -231,7 +231,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
model.half().float() # pre-reduce anchor precision

callbacks.on_pretrain_routine_end()
callbacks.run('on_pretrain_routine_end')

# DDP mode
if cuda and RANK != -1:
Expand Down Expand Up @@ -333,7 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots, opt.sync_bn)
callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
# end batch ------------------------------------------------------------------------------------------------

# Scheduler
Expand All @@ -342,7 +342,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

if RANK in [-1, 0]:
# mAP
callbacks.on_train_epoch_end(epoch=epoch)
callbacks.run('on_train_epoch_end', epoch=epoch)
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
if not noval or final_epoch: # Calculate mAP
Expand All @@ -364,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if fi > best_fitness:
best_fitness = fi
log_vals = list(mloss) + list(results) + lr
callbacks.on_fit_epoch_end(log_vals, epoch, best_fitness, fi)
callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)

# Save model
if (not nosave) or (final_epoch and not evolve): # if save
Expand All @@ -381,7 +381,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if best_fitness == fi:
torch.save(ckpt, best)
del ckpt
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)

# Stop Single-GPU
if RANK == -1 and stopper(epoch=epoch, fitness=fi):
Expand Down Expand Up @@ -418,7 +418,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
callbacks.on_train_end(last, best, plots, epoch)
callbacks.run('on_train_end', last, best, plots, epoch)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")

torch.cuda.empty_cache()
Expand Down Expand Up @@ -467,7 +467,7 @@ def parse_opt(known=False):
return opt


def main(opt):
def main(opt, callbacks=Callbacks()):
# Checks
set_logging(RANK)
if RANK in [-1, 0]:
Expand Down Expand Up @@ -505,7 +505,7 @@ def main(opt):

# Train
if not opt.evolve:
train(opt.hyp, opt, device)
train(opt.hyp, opt, device, callbacks)
if WORLD_SIZE > 1 and RANK == 0:
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]

Expand Down Expand Up @@ -585,7 +585,7 @@ def main(opt):
hyp[k] = round(hyp[k], 5) # significant digits

# Train mutation
results = train(hyp.copy(), opt, device)
results = train(hyp.copy(), opt, device, callbacks)

# Write mutation results
print_mutation(results, hyp.copy(), save_dir, opt.bucket)
Expand Down
123 changes: 10 additions & 113 deletions utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Callbacks:
Handles all registered callbacks for YOLOv5 Hooks
"""

# Define the available callbacks
_callbacks = {
'on_pretrain_routine_start': [],
'on_pretrain_routine_end': [],
Expand All @@ -34,16 +35,13 @@ class Callbacks:
'teardown': [],
}

def __init__(self):
return

def register_action(self, hook, name='', callback=None):
"""
Register a new action to a callback hook
Args:
hook The callback hook name to register the action to
name The name of the action
name The name of the action for later reference
callback The callback to fire
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
Expand All @@ -62,118 +60,17 @@ def get_registered_actions(self, hook=None):
else:
return self._callbacks

def run_callbacks(self, hook, *args, **kwargs):
def run(self, hook, *args, **kwargs):
"""
Loop through the registered actions and fire all callbacks
"""
for logger in self._callbacks[hook]:
# print(f"Running callbacks.{logger['callback'].__name__}()")
logger['callback'](*args, **kwargs)

def on_pretrain_routine_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each pretraining routine
"""
self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)

def on_pretrain_routine_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each pretraining routine
"""
self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)

def on_train_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training
"""
self.run_callbacks('on_train_start', *args, **kwargs)

def on_train_epoch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training epoch
"""
self.run_callbacks('on_train_epoch_start', *args, **kwargs)

def on_train_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training batch
"""
self.run_callbacks('on_train_batch_start', *args, **kwargs)
def optimizer_step(self, *args, **kwargs):
"""
Fires all registered callbacks on each optimizer step
"""
self.run_callbacks('optimizer_step', *args, **kwargs)

def on_before_zero_grad(self, *args, **kwargs):
"""
Fires all registered callbacks before zero grad
"""
self.run_callbacks('on_before_zero_grad', *args, **kwargs)

def on_train_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training batch
"""
self.run_callbacks('on_train_batch_end', *args, **kwargs)

def on_train_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training epoch
"""
self.run_callbacks('on_train_epoch_end', *args, **kwargs)

def on_val_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of the validation
"""
self.run_callbacks('on_val_start', *args, **kwargs)

def on_val_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each validation batch
"""
self.run_callbacks('on_val_batch_start', *args, **kwargs)

def on_val_image_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each val image
"""
self.run_callbacks('on_val_image_end', *args, **kwargs)

def on_val_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each validation batch
"""
self.run_callbacks('on_val_batch_end', *args, **kwargs)

def on_val_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of the validation
"""
self.run_callbacks('on_val_end', *args, **kwargs)

def on_fit_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each fit (train+val) epoch
"""
self.run_callbacks('on_fit_epoch_end', *args, **kwargs)

def on_model_save(self, *args, **kwargs):
"""
Fires all registered callbacks after each model save
Args:
hook The name of the hook to check, defaults to all
args Arguments to receive from YOLOv5
kwargs Keyword Arguments to receive from YOLOv5
"""
self.run_callbacks('on_model_save', *args, **kwargs)

def on_train_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of training
"""
self.run_callbacks('on_train_end', *args, **kwargs)
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"

def teardown(self, *args, **kwargs):
"""
Fires all registered callbacks before teardown
"""
self.run_callbacks('teardown', *args, **kwargs)
for logger in self._callbacks[hook]:
logger['callback'](*args, **kwargs)
4 changes: 2 additions & 2 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def run(data,
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
if save_json:
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
callbacks.on_val_image_end(pred, predn, path, names, img[si])
callbacks.run('on_val_image_end', pred, predn, path, names, img[si])

# Plot images
if plots and batch_i < 3:
Expand Down Expand Up @@ -253,7 +253,7 @@ def run(data,
# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
callbacks.on_val_end()
callbacks.run('on_val_end')

# Save JSON
if save_json and len(jdict):
Expand Down

0 comments on commit aa35ada

Please sign in to comment.