Skip to content

Commit

Permalink
[CIFAR] pass through required image_size arg and warn that it is not …
Browse files Browse the repository at this point in the history
…used (#1671)
  • Loading branch information
bfineran committed Jul 13, 2023
1 parent a0fcc1f commit a8991c0
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/sparseml/pytorch/datasets/classification/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
`here <https://www.cs.toronto.edu/~kriz/cifar.html>`__.
"""

import logging


try:
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100
Expand All @@ -36,6 +39,8 @@
__all__ = ["CIFAR10Dataset", "CIFAR100Dataset"]


_LOGGER = logging.getLogger(__file__)

_CIFAR10_RGB_MEANS = [0.491, 0.482, 0.447]
_CIFAR10_RGB_STDS = [0.247, 0.243, 0.262]

Expand Down Expand Up @@ -64,7 +69,12 @@ def __init__(
root: str = default_dataset_path("cifar10"),
train: bool = True,
rand_trans: bool = False,
image_size: None = None,
):
if image_size is not None:
_LOGGER.warning(
"image_size not supported for CIFAR dataset, using default size"
)
if torchvision_import_error is not None:
raise torchvision_import_error

Expand Down Expand Up @@ -114,7 +124,12 @@ def __init__(
root: str = default_dataset_path("cifar100"),
train: bool = True,
rand_trans: bool = False,
image_size: None = None,
):
if image_size is not None:
_LOGGER.warning(
"image_size not supported for CIFAR dataset, using default size"
)
normalize = transforms.Normalize(
mean=_CIFAR100_RGB_MEANS, std=_CIFAR100_RGB_STDS
)
Expand Down

0 comments on commit a8991c0

Please sign in to comment.