Skip to content

Commit

Permalink
making optimization steps for hooks (#2363)
Browse files Browse the repository at this point in the history
*simplified optimizer step and zero grad overriding
  • Loading branch information
williamFalcon authored Jun 25, 2020
1 parent d221817 commit 0a092f6
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 35 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,15 @@ As you see, you're just organizing your PyTorch code - there's no abstraction.

And for the stuff that the Trainer abstracts out, you can [override any part](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#extensibility) you want to do things like implement your own distributed training, 16-bit precision, or even a custom backward pass.

For example, here you could do your own backward pass
For example, here you could do your own backward pass without worrying about GPUs, TPUs or 16-bit since we already handle it.

```python
class LitModel(LightningModule):
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
second_order_closure=None):
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
optimizer.step()

def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
optimizer.zero_grad()
```

Expand Down
10 changes: 6 additions & 4 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ For example, here step optimizer A every 2 batches and optimizer B every 4 batch

.. testcode::

def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs):
optimizer.step()
optimizer.zero_grad()

def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
optimizer.zero_grad()

# Alternating schedule for optimizer steps (ie: GANs)
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs):
# update generator opt every 2 steps
if optimizer_i == 0:
if batch_nb % 2 == 0 :
Expand All @@ -109,7 +111,7 @@ Here we add a learning-rate warm up
.. testcode::

# learning rate warm-up
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
Expand Down
43 changes: 19 additions & 24 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,9 @@ def optimizer_step(
optimizer: Optimizer,
optimizer_idx: int,
second_order_closure: Optional[Callable] = None,
on_tpu: bool = False,
using_native_amp: bool = False,
using_lbfgs: bool = False,
) -> None:
r"""
Override this method to adjust the default way the
Expand All @@ -1146,19 +1149,21 @@ def optimizer_step(
optimizer: A PyTorch optimizer
optimizer_idx: If you used multiple optimizers this indexes into that list.
second_order_closure: closure for second order methods
on_tpu: true if TPU backward is required
using_native_amp: True if using native amp
using_lbfgs: True if the matching optimizer is lbfgs
Examples:
.. code-block:: python
# DEFAULT
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
second_order_closure=None):
second_order_closure, on_tpu, using_native_amp, using_lbfgs):
optimizer.step()
optimizer.zero_grad()
# Alternating schedule for optimizer steps (i.e.: GANs)
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
second_order_closure=None):
second_order_closure, on_tpu, using_native_amp, using_lbfgs):
# update generator opt every 2 steps
if optimizer_idx == 0:
if batch_idx % 2 == 0 :
Expand All @@ -1182,7 +1187,7 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
# learning rate warm-up
def optimizer_step(self, current_epoch, batch_idx, optimizer,
optimizer_idx, second_order_closure=None):
optimizer_idx, second_order_closure, on_tpu, using_native_amp, using_lbfgs):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
Expand All @@ -1198,30 +1203,20 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer,
model hook don't forget to add the call to it before ``optimizer.zero_grad()`` yourself.
"""
if self.trainer.use_tpu and XLA_AVAILABLE:
if on_tpu:
xm.optimizer_step(optimizer)
elif isinstance(optimizer, torch.optim.LBFGS):

# native amp + lbfgs is a no go right now
if self.trainer.use_amp and self.trainer.use_native_amp:
raise MisconfigurationException(
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')
elif using_native_amp:
self.trainer.scaler.step(optimizer)
elif using_lbfgs:
optimizer.step(second_order_closure)
else:
if self.trainer.use_amp and self.trainer.use_native_amp:
self.trainer.scaler.step(optimizer)
else:
optimizer.step()

# in native 16-bit we need to update scaler after optimizer step
if self.trainer.use_amp and self.trainer.use_native_amp:
self.trainer.scaler.update()

# model hook
self.on_before_zero_grad(optimizer)
optimizer.step()

# clear gradients
def optimizer_zero_grad(self,
epoch: int,
batch_idx: int,
optimizer: Optimizer,
optimizer_idx: int):
optimizer.zero_grad()

def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list:
Expand Down
42 changes: 38 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,15 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
# ------------------
# .STEP + ZERO_GRAD
# ------------------
self.call_optimizer_step(optimizer, opt_idx, batch_idx, split_batch)

return grad_norm_dic

def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
# calls .step(), .zero_grad()
# override function to modify this behavior
model = self.get_model()

with self.profiler.profile('optimizer_step'):
lambda_closure = lambda: self.optimizer_closure(
split_batch,
Expand All @@ -725,11 +733,37 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
optimizer,
self.hiddens
).loss
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx,
lambda_closure)

return grad_norm_dic
# apply TPU optimizer
if self.use_tpu and XLA_AVAILABLE:
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx, lambda_closure, on_tpu=True)

# for LBFGS do something a bit different
elif isinstance(optimizer, torch.optim.LBFGS):

# native amp + lbfgs is a no go right now
if self.use_amp and self.use_native_amp:
raise MisconfigurationException(
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')
model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure,
using_lbfgs=True)

# when using 16-bit
else:
native_amp = self.use_amp and self.use_native_amp
model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, native_amp)

# in native 16-bit we need to update scaler after optimizer step
if self.use_amp and self.use_native_amp:
self.scaler.update()

# model hook
model.on_before_zero_grad(optimizer)

# clear gradients
model.optimizer_zero_grad(self.current_epoch, batch_idx, optimizer, opt_idx)

def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
"""
Expand Down

0 comments on commit 0a092f6

Please sign in to comment.