diff --git a/train.py b/train.py index dc93c22d621a..6ada2a2f121b 100644 --- a/train.py +++ b/train.py @@ -131,6 +131,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze for k, v in model.named_parameters(): v.requires_grad = True # train all layers + v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0.0 if any(x in k for x in freeze): LOGGER.info(f'freezing {k}') v.requires_grad = False @@ -334,8 +335,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # Backward scaler.scale(loss).backward() - # Optimize + # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html if ni - last_opt_step >= accumulate: + scaler.unscale_(optimizer) # unscale gradients + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients scaler.step(optimizer) # optimizer.step scaler.update() optimizer.zero_grad()