Skip to content

Commit

Permalink
give a more complete GAN example (#6294)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Mar 5, 2021
1 parent 2a3ab67 commit 4f391bc
Showing 1 changed file with 135 additions and 41 deletions.
176 changes: 135 additions & 41 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,26 @@ Manual optimization
For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable
to manually manage the optimization process. To do so, do the following:

* Override your LightningModule ``automatic_optimization`` property to return ``False``
* Drop or ignore the optimizer_idx argument
* Set the ``automatic_optimization`` property to ``False`` in your ``LightningModule`` ``__init__`` function
* Use ``self.manual_backward(loss)`` instead of ``loss.backward()``.

.. testcode:: python

from pytorch_lightning import LightningModule

class MyModel(LightningModule):

def __init__(self):
super().__init__()
# Important: This property activate ``manual optimization`` for your model
self.automatic_optimization = False

def training_step(batch, batch_idx):
opt = self.optimizers()
loss = self.compute_loss(batch)
self.manual_backward(loss)


.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc..

.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertize.
Expand All @@ -35,7 +51,7 @@ to manually manage the optimization process. To do so, do the following:

.. code-block:: python
def training_step(batch, batch_idx, optimizer_idx):
def training_step(batch, batch_idx):
opt = self.optimizers()
loss = self.compute_loss(batch)
Expand All @@ -51,9 +67,9 @@ to manually manage the optimization process. To do so, do the following:

Here is the same example as above using a ``closure``.

.. code-block:: python
.. testcode:: python

def training_step(batch, batch_idx, optimizer_idx):
def training_step(batch, batch_idx):
opt = self.optimizers()

def forward_and_backward():
Expand All @@ -67,28 +83,78 @@ Here is the same example as above using a ``closure``.
opt.zero_grad()


.. code-block:: python
.. tip:: Be careful where you call ``zero_grad`` or your model won't converge. It is good pratice to call ``zero_grad`` before ``manual_backward``.


.. testcode:: python

import torch
from torch import Tensor
from pytorch_lightning import LightningModule

class SimpleGAN(LightningModule):

def __init__(self):
super().__init__()
self.G = Generator()
self.D = Discriminator()

# Important: This property activate ``manual optimization`` for this model
self.automatic_optimization = False

def sample_z(self, n) -> Tensor:
sample = self._Z.sample((n,))
return sample

def sample_G(self, n) -> Tensor:
z = self.sample_z(n)
return self.G(z)

def training_step(self, batch, batch_idx):
# Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
g_opt, d_opt = self.optimizers()

# Scenario for a GAN.
def training_step(...):
opt_gen, opt_dis = self.optimizers()
X, _ = batch
batch_size = X.shape[0]

# compute generator loss
loss_gen = self.compute_generator_loss(...)
real_label = torch.ones((batch_size, 1), device=self.device)
fake_label = torch.zeros((batch_size, 1), device=self.device)

# zero_grad needs to be called before backward
opt_gen.zero_grad()
self.manual_backward(loss_gen)
opt_gen.step()
g_X = self.sample_G(batch_size)

# compute discriminator loss
loss_dis = self.compute_discriminator_loss(...)
###########################
# Optimize Discriminator #
###########################
d_opt.zero_grad()

# zero_grad needs to be called before backward
opt_dis.zero_grad()
self.manual_backward(loss_dis)
opt_dis.step()
d_x = self.D(X)
errD_real = self.criterion(d_x, real_label)

d_z = self.D(g_X.detach())
errD_fake = self.criterion(d_z, fake_label)

errD = (errD_real + errD_fake)

self.manual_backward(errD)
d_opt.step()

#######################
# Optimize Generator #
#######################
g_opt.zero_grad()

d_z = self.D(g_X)
errG = self.criterion(d_z, real_label)

self.manual_backward(errG)
g_opt.step()

self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True)

def configure_optimizers(self):
g_opt = torch.optim.Adam(self.G.parameters(), lr=1e-5)
d_opt = torch.optim.Adam(self.D.parameters(), lr=1e-5)
return g_opt, d_opt

.. note:: ``LightningOptimizer`` provides a ``toggle_model`` function as a ``@context_manager`` for advanced users. It can be useful when performing gradient accumulation with several optimizers or training in a distributed setting.

Expand All @@ -100,36 +166,64 @@ Toggling means that all parameters from B exclusive to A will have their ``requi
When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase.
Setting ``sync_grad`` to ``False`` will block this synchronization and improve your training speed.

Here is an example on how to use it:

.. code-block:: python
Here is an example for advanced use-case.


.. testcode:: python


# Scenario for a GAN with gradient accumulation every 2 batches and optimized for multiple gpus.

def training_step(self, batch, batch_idx, ...):
opt_gen, opt_dis = self.optimizers()
class SimpleGAN(LightningModule):

...

def training_step(self, batch, batch_idx):
# Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
g_opt, d_opt = self.optimizers()

X, _ = batch
X.requires_grad = True
batch_size = X.shape[0]

real_label = torch.ones((batch_size, 1), device=self.device)
fake_label = torch.zeros((batch_size, 1), device=self.device)

accumulated_grad_batches = batch_idx % 2 == 0

g_X = self.sample_G(batch_size)

###########################
# Optimize Discriminator #
###########################
with d_opt.toggle_model(sync_grad=accumulated_grad_batches):
d_x = self.D(X)
errD_real = self.criterion(d_x, real_label)

d_z = self.D(g_X.detach())
errD_fake = self.criterion(d_z, fake_label)

accumulated_grad_batches = batch_idx % 2 == 0
errD = (errD_real + errD_fake)

# compute generator loss
def closure_gen():
loss_gen = self.compute_generator_loss(...)
self.manual_backward(loss_gen)
if accumulated_grad_batches:
opt_gen.zero_grad()
self.manual_backward(errD)
if accumulated_grad_batches:
d_opt.step()
d_opt.zero_grad()

with opt_gen.toggle_model(sync_grad=accumulated_grad_batches):
opt_gen.step(closure=closure_gen)
#######################
# Optimize Generator #
#######################
with g_opt.toggle_model(sync_grad=accumulated_grad_batches):
d_z = self.D(g_X)
errG = self.criterion(d_z, real_label)

def closure_dis():
loss_dis = self.compute_discriminator_loss(...)
self.manual_backward(loss_dis)
if accumulated_grad_batches:
opt_dis.zero_grad()
self.manual_backward(errG)
if accumulated_grad_batches:
g_opt.step()
g_opt.zero_grad()

with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
opt_dis.step(closure=closure_dis)
self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True)

------

Expand Down

0 comments on commit 4f391bc

Please sign in to comment.