Skip to content

Commit

Permalink
bugfix: batch_size for MNISTDataModule (#331)
Browse files Browse the repository at this point in the history
* bugfix: batch_size for MNISTDataModule

* fix MNISTDataModule *_dataloader() signatures
  • Loading branch information
hecoding authored Nov 6, 2020
1 parent a6bc807 commit ef34a17
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
self.num_workers = num_workers
self.normalize = normalize
self.seed = seed
self.batch_size = batch_size

@property
def num_classes(self):
Expand All @@ -92,15 +93,14 @@ def prepare_data(self):
MNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor())
MNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor())

def train_dataloader(self, batch_size=32, transforms=None):
def train_dataloader(self):
"""
MNIST train set removes a subset to use for validation
Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.train_transforms or self._default_transforms()
transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms

dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
Expand All @@ -109,55 +109,54 @@ def train_dataloader(self, batch_size=32, transforms=None):
)
loader = DataLoader(
dataset_train,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True,
)
return loader

def val_dataloader(self, batch_size=32, transforms=None):
def val_dataloader(self):
"""
MNIST val set uses a subset of the training set for validation
Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.val_transforms or self._default_transforms()
transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
_, dataset_val = random_split(
dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed)
)
loader = DataLoader(
dataset_val,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True,
)
return loader

def test_dataloader(self, batch_size=32, transforms=None):
def test_dataloader(self):
"""
MNIST test set uses the test split
Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.val_transforms or self._default_transforms()
transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms

dataset = MNIST(self.data_dir, train=False, download=False, transform=transforms)
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True
dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True,
pin_memory=True
)
return loader

def _default_transforms(self):
def default_transforms(self):
if self.normalize:
mnist_transforms = transform_lib.Compose(
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
Expand Down

0 comments on commit ef34a17

Please sign in to comment.