Skip to content

Commit

Permalink
Add callbacks (#7315)
Browse files Browse the repository at this point in the history
* Add `on_train_start()` callback

* Update

* Update
  • Loading branch information
glenn-jocher authored Apr 6, 2022
1 parent 32661f7 commit 245d645
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
callbacks.run('on_pretrain_routine_start')

# Directories
w = save_dir / 'weights' # weights dir
Expand Down Expand Up @@ -291,11 +292,13 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
scaler = amp.GradScaler(enabled=cuda)
stopper = EarlyStopping(patience=opt.patience)
compute_loss = ComputeLoss(model) # init loss class
callbacks.run('on_train_start')
LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
f"Logging results to {colorstr('bold', save_dir)}\n"
f'Starting training for {epochs} epochs...')
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
callbacks.run('on_train_epoch_start')
model.train()

# Update image weights (optional, single-GPU only)
Expand All @@ -317,6 +320,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
callbacks.run('on_train_batch_start')
ni = i + nb * epoch # number integrated batches (since train start)
imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0

Expand Down
4 changes: 4 additions & 0 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None,
else:
self.wandb = None

def on_train_start(self):
# Callback runs on train start
pass

def on_pretrain_routine_end(self):
# Callback runs on pre-train routine end
paths = self.save_dir.glob('*labels*.jpg') # training labels
Expand Down
4 changes: 4 additions & 0 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,10 @@ def run(
dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class = [], [], [], []
callbacks.run('on_val_start')
pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
callbacks.run('on_val_batch_start')
t1 = time_sync()
if cuda:
im = im.to(device, non_blocking=True)
Expand Down Expand Up @@ -260,6 +262,8 @@ def run(
f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()

callbacks.run('on_val_batch_end')

# Compute metrics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any():
Expand Down

0 comments on commit 245d645

Please sign in to comment.