Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Steps not incremented correctly with accumulate gradients #2242

Closed
mRcSchwering opened this issue Jun 18, 2020 · 6 comments
Closed

Steps not incremented correctly with accumulate gradients #2242

mRcSchwering opened this issue Jun 18, 2020 · 6 comments
Labels
bug Something isn't working help wanted Open to be worked on
Milestone

Comments

@mRcSchwering
Copy link

mRcSchwering commented Jun 18, 2020

🐛 Bug

global_step and current_epoch do not match up anymore after more than 1 epoch when setting accumulate gradients greater 1.
I think at the end of each epoch optimizer_step (and on_before_zero_grad) is not called in that case.

To Reproduce

  1. Create a pl.LightningModule that logs current_epoch and global_step in every training_step_end.
  2. Have a dataset with 20k samples, set the dataloader batch size to 1171, and set accumulate_grad_batches=7 in the trainer
  3. Train for 100 steps
  4. Look at logs and overall number of steps and epoch

Expected behavior

  • with 100 steps we should be at epoch 33
  • in the logs you should see that when current_epoch gets incremented, global_step gets incremented as well

Actual behavior

  • with 100 steps we are at epoch 50
  • in the logs we see that global_step increments with every batch, but not if current_epoch get incremented
  • in this particular example global_step is basically missing every 3rd increment

Code sample

Below is basically what I have.
I am adjusting the learning rate with every global step.
The learning rate adjustment and each training_step_end call gets printed.

class MyModule(pl.LightningModule):

    def __init__(self, hparams: dict):
        super().__init__()
        self.hparams = hparams
        self.model = model_fact()

    def training_step_end(self, outputs: dict):
        print(f'Epoch: {self.current_epoch} Step: {self.global_step} Batch size: {len(outputs["logits"])}')

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer):
        current_lr = [d['lr'] for d in optimizer.param_groups][0]
        print(f'Step: {self.global_step} LR: {current_lr:.4e}')

    def train_dataloader(self):
        return DataLoader(MyDataset(), batch_size=1171, shuffle=False)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx: int) -> dict:
        inputs, targets = batch
        logits = self.forward(*inputs)
        loss = self.criterion(logits, targets)
        return {'loss': loss, 'logits': logits}

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters())
    
    def optimizer_step(self, epoch: int, batch_idx: int, optimizer: optim.Optimizer,
                       optimizer_idx: int, second_order_closure: callable = None):
        # modify learning rate...
        optimizer.step()
        self.on_before_zero_grad(optimizer)
        optimizer.zero_grad()

trainer = pl.Trainer(
        max_steps=100,
        max_epochs=int(1e6),
        gpus=N_GPUS if N_GPUS > 0 else None,
        distributed_backend=None if N_GPUS < 2 else 'dp',
        num_sanity_val_steps=0,
        progress_bar_refresh_rate=0,
        accumulate_grad_batches=7,
        early_stop_callback=False)

Below is some sample output.
You can see the end of the epoch where the last 93 samples are processed.
Then, current_epoch increases, but global_step does not increase.
Additionally, the learning rate print is missing, so on_before_zero_grad
was not called.

Step: 53 LR: 3.3861e-04
Epoch: 26 Step: 54 Batch size: 1171
Epoch: 26 Step: 54 Batch size: 1171
Epoch: 26 Step: 54 Batch size: 1171
Epoch: 26 Step: 54 Batch size: 93
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Epoch: 27 Step: 54 Batch size: 1171
Step: 54 LR: 3.3267e-04
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Epoch: 27 Step: 55 Batch size: 1171
Step: 55 LR: 3.2673e-04
Epoch: 27 Step: 56 Batch size: 1171
Epoch: 27 Step: 56 Batch size: 1171
Epoch: 27 Step: 56 Batch size: 1171
Epoch: 27 Step: 56 Batch size: 93
Epoch: 28 Step: 56 Batch size: 1171
Epoch: 28 Step: 56 Batch size: 1171
Epoch: 28 Step: 56 Batch size: 1171
Epoch: 28 Step: 56 Batch size: 1171
Epoch: 28 Step: 56 Batch size: 1171
Epoch: 28 Step: 56 Batch size: 1171

Environment

  • pytorch-lightning==0.7.5
  • torch==1.5.0 (conda)
  • Ubuntu 18.04.4 LTS
  • python 3.7.7
  • CUDA/cuDNN version: Cuda compilation tools, release 10.1, V10.1.243
  • GPU models and configuration: 6 GPUs in data parallel mode
@mRcSchwering mRcSchwering added bug Something isn't working help wanted Open to be worked on labels Jun 18, 2020
@williamFalcon
Copy link
Contributor

@mRcSchwering can you try 0.8.0?

@mRcSchwering
Copy link
Author

Just tried on 0.8.1 (hope that's as good). Issue remains. E.g.

Epoch: 43 Step: 85 Batch size: 1171
Epoch: 43 Step: 85 Batch size: 1171
Epoch: 43 Step: 85 Batch size: 1171
Step: 85 LR: 1.4851e-04
Epoch: 43 Step: 86 Batch size: 1171
Epoch: 43 Step: 86 Batch size: 1171
Epoch: 43 Step: 86 Batch size: 1171
Epoch: 43 Step: 86 Batch size: 93
Epoch: 44 Step: 86 Batch size: 1171
Epoch: 44 Step: 86 Batch size: 1171
Epoch: 44 Step: 86 Batch size: 1171
Epoch: 44 Step: 86 Batch size: 1171
Epoch: 44 Step: 86 Batch size: 1171

@edenlightning edenlightning added this to the 0.9.0 milestone Jul 29, 2020
@edenlightning edenlightning modified the milestones: 0.9.0, 0.9.x Aug 20, 2020
@Borda
Copy link
Member

Borda commented Sep 15, 2020

@mRcSchwering mind check it with our latest master? 🐰

@ydcjeff
Copy link
Contributor

ydcjeff commented Oct 3, 2020

I guess this is solved by #2853
and i could reproduce the expected behavior, can you please confirm this? @mRcSchwering
Running script to be compatible with current master

import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

pl.seed_everything(666)

class MyModule(pl.LightningModule):

    def __init__(self, hparams: dict):
        super().__init__()
        self.hparams = hparams
        self.model = nn.Linear(28*28, 10)

    def training_step_end(self, outputs: dict):
        print(f'Epoch: {self.current_epoch} Step: {self.global_step} Batch size: {len(outputs["logits"])}')
        return outputs

    def on_before_zero_grad(self, optimizer: torch.optim.Optimizer):
        current_lr = [d['lr'] for d in optimizer.param_groups][0]
        print(f'Step: {self.global_step} LR: {current_lr:.4e}')

    def train_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=1171, shuffle=False)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx: int) -> dict:
        inputs, targets = batch
        logits = self.forward(inputs.view(inputs.size(0), -1))
        loss = F.cross_entropy(logits, targets)
        return {'loss': loss, 'logits': logits}

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=3e-4)
    
    def optimizer_step(self, epoch, batch_idx, optimizer, opt_idx, lambda_closure, using_native_amp, using_lbfgs):
        # modify learning rate...
        optimizer.step()
        self.on_before_zero_grad(optimizer)
        optimizer.zero_grad()

trainer = pl.Trainer(
        max_steps=100,
        max_epochs=int(1e6),
        gpus=-1,
        num_sanity_val_steps=0,
        progress_bar_refresh_rate=0,
        accumulate_grad_batches=7,
        early_stop_callback=False)
model = MyModule({})
trainer.fit(model)

Output: LR is printed every 7 accumulate steps and also in the last batch. current_epoch and global_step too incremented.

Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Epoch: 1 Step: 8 Batch size: 1171
Step: 8 LR: 3.0000e-04
Step: 8 LR: 3.0000e-04
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Epoch: 1 Step: 9 Batch size: 1171
Step: 9 LR: 3.0000e-04
Step: 9 LR: 3.0000e-04
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Epoch: 1 Step: 10 Batch size: 1171
Step: 10 LR: 3.0000e-04
Step: 10 LR: 3.0000e-04
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Epoch: 1 Step: 11 Batch size: 1171
Step: 11 LR: 3.0000e-04
Step: 11 LR: 3.0000e-04
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Epoch: 1 Step: 12 Batch size: 1171
Step: 12 LR: 3.0000e-04
Step: 12 LR: 3.0000e-04
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Epoch: 1 Step: 13 Batch size: 1171
Step: 13 LR: 3.0000e-04
Step: 13 LR: 3.0000e-04
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Epoch: 1 Step: 14 Batch size: 1171
Step: 14 LR: 3.0000e-04
Step: 14 LR: 3.0000e-04
Epoch: 1 Step: 15 Batch size: 1171
Epoch: 1 Step: 15 Batch size: 1171
Epoch: 1 Step: 15 Batch size: 279
Step: 15 LR: 3.0000e-04
Step: 15 LR: 3.0000e-04
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Epoch: 2 Step: 16 Batch size: 1171
Step: 16 LR: 3.0000e-04
Step: 16 LR: 3.0000e-04

@mRcSchwering
Copy link
Author

Cool, thx. And I learned something new (seed_everything)

@ydcjeff
Copy link
Contributor

ydcjeff commented Oct 3, 2020

so I guess this could be closed.

@edenlightning edenlightning modified the milestones: 0.9.x, 1.0 Oct 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

5 participants