Skip to content

Commit

Permalink
Add support for CIFAR10 Dataset in the DCGAN Module
Browse files Browse the repository at this point in the history
  • Loading branch information
ishandutta0098 committed Oct 5, 2023
1 parent 70cc5f5 commit e127400
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/pl_bolts/models/gans/dcgan/dcgan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import LSUN, MNIST
from torchvision.datasets import LSUN, MNIST, CIFAR10
else: # pragma: no cover
warn_missing_pkg("torchvision")

Expand All @@ -35,8 +35,7 @@ class DCGAN(LightningModule):
python dcgan_module.py --gpus 1
# cifar10
python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3
python dcgan_module.py --gpus 1 --dataset cifar10
"""

def __init__(
Expand Down Expand Up @@ -174,7 +173,7 @@ def cli_main(args=None):

parser = ArgumentParser()
parser.add_argument("--batch_size", default=64, type=int)
parser.add_argument("--dataset", default="mnist", type=str, choices=["lsun", "mnist"])
parser.add_argument("--dataset", default="mnist", type=str, choices=["lsun", "mnist", "cifar10"])
parser.add_argument("--data_dir", default="./", type=str)
parser.add_argument("--image_size", default=64, type=int)
parser.add_argument("--num_workers", default=8, type=int)
Expand Down Expand Up @@ -202,6 +201,16 @@ def cli_main(args=None):
)
dataset = MNIST(root=script_args.data_dir, download=True, transform=transforms)
image_channels = 1
elif script_args.dataset == "cifar10":
transforms = transform_lib.Compose(
[
transform_lib.Resize(script_args.image_size),
transform_lib.ToTensor(),
transform_lib.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
]
)
dataset = CIFAR10(root=script_args.data_dir, download=True, transform=transforms)
image_channels = 3

dataloader = DataLoader(
dataset, batch_size=script_args.batch_size, shuffle=True, num_workers=script_args.num_workers
Expand Down

0 comments on commit e127400

Please sign in to comment.