Skip to content

Commit

Permalink
Add tensor hooks and 10.0 gradient clipping (ultralytics#8598)
Browse files Browse the repository at this point in the history
* Add tensor hooks and gradient clipping ultralytics#8578

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove retain_grad(), because its not necessary

* Update train.py

* Simplify

* Update train.py

* Update train.py

* Update train.py

* Update train.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
3 people authored and Clay Januhowski committed Sep 8, 2022
1 parent fbe8dbb commit 5e7b060
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
for k, v in model.named_parameters():
v.requires_grad = True # train all layers
v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0.0
if any(x in k for x in freeze):
LOGGER.info(f'freezing {k}')
v.requires_grad = False
Expand Down Expand Up @@ -334,8 +335,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
# Backward
scaler.scale(loss).backward()

# Optimize
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= accumulate:
scaler.unscale_(optimizer) # unscale gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad()
Expand Down

0 comments on commit 5e7b060

Please sign in to comment.