-
Notifications
You must be signed in to change notification settings - Fork 321
/
binary_emnist_datamodule.py
82 lines (73 loc) · 3.05 KB
/
binary_emnist_datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import Any, Optional, Union
from pl_bolts.datamodules.emnist_datamodule import EMNISTDataModule
from pl_bolts.datasets import BinaryEMNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
class BinaryEMNISTDataModule(EMNISTDataModule):
"""
.. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png
:width: 400
:alt: EMNIST
Please see :class:`~pl_bolts.datamodules.emnist_datamodule.EMNISTDataModule` for more details.
Example::
from pl_bolts.datamodules import BinaryEMNISTDataModule
dm = BinaryEMNISTDataModule('.')
model = LitModel()
Trainer().fit(model, datamodule=dm)
"""
name = "binary_emnist"
dataset_cls = BinaryEMNIST
dims = (1, 28, 28)
def __init__(
self,
data_dir: Optional[str] = None,
split: str = "mnist",
val_split: Union[int, float] = 0.2,
num_workers: int = 0,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
strict_val_split: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data.
split: The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``.
This argument is passed to :class:`torchvision.datasets.EMNIST`.
val_split: Percent (float) or number (int) of samples
to use for the validation split.
num_workers: How many workers to use for loading data
normalize: If ``True``, applies image normalize.
batch_size: How many samples per batch to load.
seed: Random seed to be used for train/val/test splits.
shuffle: If ``True``, shuffles the train data every epoch.
pin_memory: If ``True``, the data loader will copy Tensors into
CUDA pinned memory before returning them.
drop_last: If ``True``, drops the last incomplete batch.
strict_val_split: If ``True``, uses the validation split defined in the paper and ignores ``val_split``.
Note that it only works with ``"balanced"``, ``"digits"``, ``"letters"``, ``"mnist"`` splits.
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use EMNIST dataset loaded from `torchvision` which is not installed yet."
)
super().__init__( # type: ignore[misc]
data_dir=data_dir,
split=split,
val_split=val_split,
num_workers=num_workers,
normalize=normalize,
batch_size=batch_size,
seed=seed,
shuffle=shuffle,
pin_memory=pin_memory,
drop_last=drop_last,
strict_val_split=strict_val_split,
*args,
**kwargs,
)