diff --git a/train.py b/train.py index ff13f1e256ec..444316ca4c08 100644 --- a/train.py +++ b/train.py @@ -284,7 +284,16 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio model.hyp = hyp # attach hyperparameters to model model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights model.names = names - + + # Add tensor hooks for gradient inspection + if opt.use_tensor_hooks: + LOGGER.info('Add tensor hooks...') + for name, params in model.named_parameters(): + if params is not None: + params.register_hook(lambda grad: torch.nan_to_num(grad, nan=0., neginf=0., posinf=0.)) + params.retain_grad() + LOGGER.info(f'Use gradient clipping with max_norm={opt.clip_grad}') + # Start training t0 = time.time() nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations) @@ -361,6 +370,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # Optimize if ni - last_opt_step >= accumulate: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=opt.clip_grad) scaler.step(optimizer) # optimizer.step scaler.update() optimizer.zero_grad() @@ -506,7 +517,9 @@ def parse_opt(known=False): parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)') parser.add_argument('--seed', type=int, default=0, help='Global training seed') parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify') - + parser.add_argument('--clip_grad', type=int, default=10, help='Max norm for gradient clipping') + parser.add_argument('--use_tensor_hooks', action='store_true', help='Use tensor hooks to remove infs and nans') + # Weights & Biases arguments parser.add_argument('--entity', default=None, help='W&B: Entity') parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')