From 4f391bce7c419646893a6d79247f3fc2d5d7beca Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 5 Mar 2021 22:54:09 +0000 Subject: [PATCH] give a more complete GAN example (#6294) --- docs/source/common/optimizers.rst | 176 +++++++++++++++++++++++------- 1 file changed, 135 insertions(+), 41 deletions(-) diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index 22898e4f1a1b2..3b29fd4c08f13 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -21,10 +21,26 @@ Manual optimization For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable to manually manage the optimization process. To do so, do the following: -* Override your LightningModule ``automatic_optimization`` property to return ``False`` -* Drop or ignore the optimizer_idx argument +* Set the ``automatic_optimization`` property to ``False`` in your ``LightningModule`` ``__init__`` function * Use ``self.manual_backward(loss)`` instead of ``loss.backward()``. +.. testcode:: python + + from pytorch_lightning import LightningModule + + class MyModel(LightningModule): + + def __init__(self): + super().__init__() + # Important: This property activate ``manual optimization`` for your model + self.automatic_optimization = False + + def training_step(batch, batch_idx): + opt = self.optimizers() + loss = self.compute_loss(batch) + self.manual_backward(loss) + + .. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc.. .. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertize. @@ -35,7 +51,7 @@ to manually manage the optimization process. To do so, do the following: .. code-block:: python - def training_step(batch, batch_idx, optimizer_idx): + def training_step(batch, batch_idx): opt = self.optimizers() loss = self.compute_loss(batch) @@ -51,9 +67,9 @@ to manually manage the optimization process. To do so, do the following: Here is the same example as above using a ``closure``. -.. code-block:: python +.. testcode:: python - def training_step(batch, batch_idx, optimizer_idx): + def training_step(batch, batch_idx): opt = self.optimizers() def forward_and_backward(): @@ -67,28 +83,78 @@ Here is the same example as above using a ``closure``. opt.zero_grad() -.. code-block:: python +.. tip:: Be careful where you call ``zero_grad`` or your model won't converge. It is good pratice to call ``zero_grad`` before ``manual_backward``. + + +.. testcode:: python + + import torch + from torch import Tensor + from pytorch_lightning import LightningModule + + class SimpleGAN(LightningModule): + + def __init__(self): + super().__init__() + self.G = Generator() + self.D = Discriminator() + + # Important: This property activate ``manual optimization`` for this model + self.automatic_optimization = False + + def sample_z(self, n) -> Tensor: + sample = self._Z.sample((n,)) + return sample + + def sample_G(self, n) -> Tensor: + z = self.sample_z(n) + return self.G(z) + + def training_step(self, batch, batch_idx): + # Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html + g_opt, d_opt = self.optimizers() - # Scenario for a GAN. - def training_step(...): - opt_gen, opt_dis = self.optimizers() + X, _ = batch + batch_size = X.shape[0] - # compute generator loss - loss_gen = self.compute_generator_loss(...) + real_label = torch.ones((batch_size, 1), device=self.device) + fake_label = torch.zeros((batch_size, 1), device=self.device) - # zero_grad needs to be called before backward - opt_gen.zero_grad() - self.manual_backward(loss_gen) - opt_gen.step() + g_X = self.sample_G(batch_size) - # compute discriminator loss - loss_dis = self.compute_discriminator_loss(...) + ########################### + # Optimize Discriminator # + ########################### + d_opt.zero_grad() - # zero_grad needs to be called before backward - opt_dis.zero_grad() - self.manual_backward(loss_dis) - opt_dis.step() + d_x = self.D(X) + errD_real = self.criterion(d_x, real_label) + d_z = self.D(g_X.detach()) + errD_fake = self.criterion(d_z, fake_label) + + errD = (errD_real + errD_fake) + + self.manual_backward(errD) + d_opt.step() + + ####################### + # Optimize Generator # + ####################### + g_opt.zero_grad() + + d_z = self.D(g_X) + errG = self.criterion(d_z, real_label) + + self.manual_backward(errG) + g_opt.step() + + self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True) + + def configure_optimizers(self): + g_opt = torch.optim.Adam(self.G.parameters(), lr=1e-5) + d_opt = torch.optim.Adam(self.D.parameters(), lr=1e-5) + return g_opt, d_opt .. note:: ``LightningOptimizer`` provides a ``toggle_model`` function as a ``@context_manager`` for advanced users. It can be useful when performing gradient accumulation with several optimizers or training in a distributed setting. @@ -100,36 +166,64 @@ Toggling means that all parameters from B exclusive to A will have their ``requi When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting ``sync_grad`` to ``False`` will block this synchronization and improve your training speed. -Here is an example on how to use it: -.. code-block:: python +Here is an example for advanced use-case. + + +.. testcode:: python # Scenario for a GAN with gradient accumulation every 2 batches and optimized for multiple gpus. - def training_step(self, batch, batch_idx, ...): - opt_gen, opt_dis = self.optimizers() + class SimpleGAN(LightningModule): + + ... + + def training_step(self, batch, batch_idx): + # Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html + g_opt, d_opt = self.optimizers() + + X, _ = batch + X.requires_grad = True + batch_size = X.shape[0] + + real_label = torch.ones((batch_size, 1), device=self.device) + fake_label = torch.zeros((batch_size, 1), device=self.device) + + accumulated_grad_batches = batch_idx % 2 == 0 + + g_X = self.sample_G(batch_size) + + ########################### + # Optimize Discriminator # + ########################### + with d_opt.toggle_model(sync_grad=accumulated_grad_batches): + d_x = self.D(X) + errD_real = self.criterion(d_x, real_label) + + d_z = self.D(g_X.detach()) + errD_fake = self.criterion(d_z, fake_label) - accumulated_grad_batches = batch_idx % 2 == 0 + errD = (errD_real + errD_fake) - # compute generator loss - def closure_gen(): - loss_gen = self.compute_generator_loss(...) - self.manual_backward(loss_gen) - if accumulated_grad_batches: - opt_gen.zero_grad() + self.manual_backward(errD) + if accumulated_grad_batches: + d_opt.step() + d_opt.zero_grad() - with opt_gen.toggle_model(sync_grad=accumulated_grad_batches): - opt_gen.step(closure=closure_gen) + ####################### + # Optimize Generator # + ####################### + with g_opt.toggle_model(sync_grad=accumulated_grad_batches): + d_z = self.D(g_X) + errG = self.criterion(d_z, real_label) - def closure_dis(): - loss_dis = self.compute_discriminator_loss(...) - self.manual_backward(loss_dis) - if accumulated_grad_batches: - opt_dis.zero_grad() + self.manual_backward(errG) + if accumulated_grad_batches: + g_opt.step() + g_opt.zero_grad() - with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): - opt_dis.step(closure=closure_dis) + self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True) ------