Skip to content

Commit

Permalink
Delete tests.helpers.TrialMNISTDataModule (#5999)
Browse files Browse the repository at this point in the history
* Remove TrialMNISTDataModule

* Allow using TrialMNIST in the MNISTDataModule

* Update tests/helpers/datasets.py
  • Loading branch information
carmocca authored Feb 18, 2021
1 parent d2cd7cb commit bfcfac4
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 143 deletions.
81 changes: 16 additions & 65 deletions tests/helpers/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,100 +11,51 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Optional

import torch
from sklearn.datasets import make_classification, make_regression
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, random_split
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader

from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.utilities import _module_available
from tests.helpers.datasets import MNIST, SklearnDataset, TrialMNIST


class TrialMNISTDataModule(LightningDataModule):

def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.non_picklable = None
self.checkpoint_state: Optional[str] = None

def prepare_data(self):
TrialMNIST(self.data_dir, train=True, download=True)
TrialMNIST(self.data_dir, train=False, download=True)

def setup(self, stage: Optional[str] = None):

if stage == "fit" or stage is None:
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
self.dims = self.mnist_train[0][0].shape

if stage == "test" or stage is None:
self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=64, download=True)
self.dims = getattr(self, "dims", self.mnist_test[0][0].shape)

self.non_picklable = lambda x: x**2

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint[self.__class__.__name__] = self.__class__.__name__

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.checkpoint_state = checkpoint.get(self.__class__.__name__)
_SKLEARN_AVAILABLE = _module_available("sklearn")
if _SKLEARN_AVAILABLE:
from sklearn.datasets import make_classification, make_regression
from sklearn.model_selection import train_test_split


class MNISTDataModule(LightningDataModule):

def __init__(self, data_dir: str = "./", batch_size: int = 32, dist_sampler: bool = False) -> None:
def __init__(self, data_dir: str = "./", batch_size: int = 32, use_trials: bool = False) -> None:
super().__init__()

self.dist_sampler = dist_sampler
self.data_dir = data_dir
self.batch_size = batch_size

# TrialMNIST is a constrained MNIST dataset
self.dataset_cls = TrialMNIST if use_trials else MNIST

# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)

def prepare_data(self):
# download only
MNIST(self.data_dir, train=True, download=True, normalize=(0.1307, 0.3081))
MNIST(self.data_dir, train=False, download=True, normalize=(0.1307, 0.3081))
self.dataset_cls(self.data_dir, train=True, download=True)
self.dataset_cls(self.data_dir, train=False, download=True)

def setup(self, stage: Optional[str] = None):

# Assign train/val datasets for use in dataloaders
# TODO: need to split using random_split once updated to torch >= 1.6
if stage == "fit" or stage is None:
self.mnist_train = MNIST(self.data_dir, train=True, normalize=(0.1307, 0.3081))

# Assign test dataset for use in dataloader(s)
self.mnist_train = self.dataset_cls(self.data_dir, train=True)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, normalize=(0.1307, 0.3081))
self.mnist_test = self.dataset_cls(self.data_dir, train=False)

def train_dataloader(self):
dist_sampler = None
if self.dist_sampler:
dist_sampler = DistributedSampler(self.mnist_train, shuffle=False)

return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
sampler=dist_sampler,
shuffle=False,
)
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=False)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False)
Expand Down
119 changes: 44 additions & 75 deletions tests/helpers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self,
root: str = PATH_DATASETS,
train: bool = True,
normalize: tuple = (0.5, 1.0),
normalize: tuple = (0.1307, 0.3081),
download: bool = True,
):
super().__init__()
Expand All @@ -77,18 +77,15 @@ def __init__(

self.prepare_data(download)

if not self._check_exists(self.cached_folder_path):
raise RuntimeError('Dataset not found.')

data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
self.data, self.targets = _try_load(os.path.join(self.cached_folder_path, data_file))
self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file))

def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
img = self.data[idx].float().unsqueeze(0)
target = int(self.targets[idx])

if self.normalize is not None:
img = normalize_tensor(img, mean=self.normalize[0], std=self.normalize[1])
if self.normalize is not None and len(self.normalize) == 2:
img = self.normalize_tensor(img, *self.normalize)

return img, target

Expand All @@ -105,67 +102,53 @@ def _check_exists(self, data_folder: str) -> bool:
existing = existing and os.path.isfile(os.path.join(data_folder, fname))
return existing

def prepare_data(self, download: bool):
if download:
def prepare_data(self, download: bool = True):
if download and not self._check_exists(self.cached_folder_path):
self._download(self.cached_folder_path)
if not self._check_exists(self.cached_folder_path):
raise RuntimeError('Dataset not found.')

def _download(self, data_folder: str) -> None:
"""Download the MNIST data if it doesn't exist in cached_folder_path already."""

if self._check_exists(data_folder):
return

os.makedirs(data_folder, exist_ok=True)

os.makedirs(data_folder)
for url in self.RESOURCES:
logging.info(f'Downloading {url}')
fpath = os.path.join(data_folder, os.path.basename(url))
urllib.request.urlretrieve(url, fpath)

@staticmethod
def _try_load(path_data, trials: int = 30, delta: float = 1.):
"""Resolving loading from the same time from multiple concurrent processes."""
res, exception = None, None
assert trials, "at least some trial has to be set"
assert os.path.isfile(path_data), f'missing file: {path_data}'
for _ in range(trials):
try:
res = torch.load(path_data)
# todo: specify the possible exception
except Exception as e:
exception = e
time.sleep(delta * random.random())
else:
break
if exception is not None:
# raise the caught exception
raise exception
return res

def _try_load(path_data, trials: int = 30, delta: float = 1.):
"""Resolving loading from the same time from multiple concurrentprocesses."""
res, exp = None, None
assert trials, "at least some trial has to be set"
assert os.path.isfile(path_data), 'missing file: %s' % path_data
for _ in range(trials):
try:
res = torch.load(path_data)
# todo: specify the possible exception
except Exception as ex:
exp = ex
time.sleep(delta * random.random())
else:
break
else:
# raise the caught exception if any
if exp:
raise exp
return res


def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
tensor = tensor.clone()
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
tensor.sub_(mean).div_(std)
return tensor
@staticmethod
def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
return tensor.sub(mean).div(std)


class TrialMNIST(MNIST):
"""Constrain image dataset
"""Constrained MNIST dataset
Args:
root: Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train: If ``True``, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
normalize: mean and std deviation of the MNIST dataset.
download: If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
num_samples: number of examples per selected class/digit
digits: list selected MNIST digits/classes
kwargs: Same as MNIST
Examples:
>>> dataset = TrialMNIST(download=True)
Expand All @@ -177,25 +160,15 @@ class TrialMNIST(MNIST):
tensor([100, 100, 100])
"""

def __init__(
self,
root: str = PATH_DATASETS,
train: bool = True,
normalize: tuple = (0.5, 1.0),
download: bool = False,
num_samples: int = 100,
digits: Optional[Sequence] = (0, 1, 2),
):

def __init__(self, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs):
# number of examples per class
self.num_samples = num_samples
# take just a subset of MNIST dataset
self.digits = digits if digits else list(range(10))
self.digits = sorted(digits) if digits else list(range(10))

self.cache_folder_name = 'digits-' + '-'.join(str(d) for d in sorted(self.digits)) \
+ f'_nb-{self.num_samples}'
self.cache_folder_name = f"digits-{'-'.join(str(d) for d in self.digits)}_nb-{self.num_samples}"

super().__init__(root, train=train, normalize=normalize, download=download)
super().__init__(normalize=(0.5, 1.0), **kwargs)

@staticmethod
def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_samples: int, digits: Sequence):
Expand All @@ -213,16 +186,12 @@ def _prepare_subset(full_data: torch.Tensor, full_targets: torch.Tensor, num_sam
targets = full_targets[indexes]
return data, targets

def prepare_data(self, download: bool) -> None:
if self._check_exists(self.cached_folder_path):
return
if download:
self._download(super().cached_folder_path)

def _download(self, data_folder: str) -> None:
super()._download(data_folder)
for fname in (self.TRAIN_FILE_NAME, self.TEST_FILE_NAME):
path_fname = os.path.join(super().cached_folder_path, fname)
assert os.path.isfile(path_fname), 'Missing cached file: %s' % path_fname
data, targets = _try_load(path_fname)
path_fname = os.path.join(self.cached_folder_path, fname)
assert os.path.isfile(path_fname), f'Missing cached file: {path_fname}'
data, targets = self._try_load(path_fname)
data, targets = self._prepare_subset(data, targets, self.num_samples, self.digits)
torch.save((data, targets), os.path.join(self.cached_folder_path, fname))

Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from tests.helpers import BoringModel
from tests.helpers.advanced_models import BasicGAN, ParityModuleRNN
from tests.helpers.datamodules import TrialMNISTDataModule
from tests.helpers.datamodules import MNISTDataModule


@pytest.mark.parametrize("modelclass", [
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_torchscript_retain_training_state():
def test_torchscript_properties(tmpdir, modelclass):
""" Test that scripted LightningModule has unnecessary methods removed. """
model = modelclass()
model.datamodule = TrialMNISTDataModule(tmpdir)
model.datamodule = MNISTDataModule(tmpdir)
script = model.to_torchscript()
assert not hasattr(script, "datamodule")
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def test_call_to_trainer_method(tmpdir, optimizer):
def test_datamodule_parameter(tmpdir):
""" Test that the datamodule parameter works """

# trial datamodule
dm = ClassifDataModule()
model = ClassificationModel()

Expand Down

0 comments on commit bfcfac4

Please sign in to comment.