Skip to content

Commit

Permalink
[WIP] Reduction when batch size < num gpus (#1609)
Browse files Browse the repository at this point in the history
* reduce if <= num_gpus

* add test with explanation

* chlog

* fix changelog

Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
awaelchli and Borda authored May 2, 2020
1 parent fafe5d6 commit e6b34ef
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added transfer learning example (for a binary classification task in computer vision) ([#1564](https://github.com/PyTorchLightning/pytorch-lightning/pull/1564))

### Changed

- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))

### Deprecated

Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def reduce_distributed_output(self, output, num_gpus):
elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
pass

# reduce only metrics that have the same number of gpus
elif output[k].size(0) == num_gpus:
reduced = torch.mean(output[k])
output[k] = reduced
# do not reduce metrics that have batch size > num gpus
elif output[k].size(0) <= num_gpus:
output[k] = torch.mean(output[k])

return output
45 changes: 45 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset

import tests.base.utils as tutils
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -482,3 +484,46 @@ class CustomDummyObj:
assert isinstance(result, torch.utils.data.DataLoader)
assert isinstance(result, CustomDataLoader)
assert hasattr(result, 'dummy_kwarg')


@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
def test_batch_size_smaller_than_num_gpus():
# we need at least 3 gpus for this test
num_gpus = 3
batch_size = 3

class CurrentTestModel(
LightTrainDataloader,
TestModelBase,
):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.c_d1_bn = torch.nn.ReLU()

def train_dataloader(self):
dataloader = super().train_dataloader()
# construct a dataset with a size that is not divisible by num_gpus
# therefore the last batch will have a size < num_gpus
size = num_gpus * batch_size + (num_gpus - 1)
dataset = Subset(dataloader.dataset, range(size))
dataloader = DataLoader(
dataset,
batch_size=self.hparams.batch_size,
drop_last=False,
)
return dataloader

hparams = tutils.get_default_hparams()
hparams.batch_size = batch_size
model = CurrentTestModel(hparams)

trainer = Trainer(
max_epochs=1,
gpus=num_gpus,
)

# we expect the reduction for the metrics also to happen on the last batch
# where we will get fewer metrics than gpus
result = trainer.fit(model)
assert 1 == result

0 comments on commit e6b34ef

Please sign in to comment.