Skip to content

Commit

Permalink
Add tensor hooks and gradient clipping ultralytics#8578
Browse files Browse the repository at this point in the history
  • Loading branch information
UnglvKitDe committed Jul 16, 2022
1 parent 6e86af3 commit ae398f7
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,16 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
model.hyp = hyp # attach hyperparameters to model
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
model.names = names


# Add tensor hooks for gradient inspection
if opt.use_tensor_hooks:
LOGGER.info('Add tensor hooks...')
for name, params in model.named_parameters():
if params is not None:
params.register_hook(lambda grad: torch.nan_to_num(grad, nan=0., neginf=0., posinf=0.))
params.retain_grad()
LOGGER.info(f'Use gradient clipping with max_norm={opt.clip_grad}')

# Start training
t0 = time.time()
nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
Expand Down Expand Up @@ -361,6 +370,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio

# Optimize
if ni - last_opt_step >= accumulate:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=opt.clip_grad)
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad()
Expand Down Expand Up @@ -506,7 +517,9 @@ def parse_opt(known=False):
parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
parser.add_argument('--seed', type=int, default=0, help='Global training seed')
parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')

parser.add_argument('--clip_grad', type=int, default=10, help='Max norm for gradient clipping')
parser.add_argument('--use_tensor_hooks', action='store_true', help='Use tensor hooks to remove infs and nans')

# Weights & Biases arguments
parser.add_argument('--entity', default=None, help='W&B: Entity')
parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')
Expand Down

0 comments on commit ae398f7

Please sign in to comment.