Skip to content

Commit

Permalink
Add --use-v2 support to classification references (#7724)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Jul 7, 2023
1 parent 23b0938 commit 08c9938
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 26 deletions.
68 changes: 42 additions & 26 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import torch
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode


def get_module(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.transforms.v2

return torchvision.transforms.v2
else:
import torchvision.transforms

return torchvision.transforms


class ClassificationPresetTrain:
def __init__(
self,
Expand All @@ -17,41 +28,44 @@ def __init__(
augmix_severity=3,
random_erase_prob=0.0,
backend="pil",
use_v2=False,
):
trans = []
module = get_module(use_v2)

transforms = []
backend = backend.lower()
if backend == "tensor":
trans.append(transforms.PILToTensor())
transforms.append(module.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
transforms.append(module.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
transforms.append(module.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity))
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
aa_policy = module.AutoAugmentPolicy(auto_augment_policy)
transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation))

if backend == "pil":
trans.append(transforms.PILToTensor())
transforms.append(module.PILToTensor())

trans.extend(
transforms.extend(
[
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0:
trans.append(transforms.RandomErasing(p=random_erase_prob))
transforms.append(module.RandomErasing(p=random_erase_prob))

self.transforms = transforms.Compose(trans)
self.transforms = module.Compose(transforms)

def __call__(self, img):
return self.transforms(img)
Expand All @@ -67,28 +81,30 @@ def __init__(
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
backend="pil",
use_v2=False,
):
trans = []
module = get_module(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
trans.append(transforms.PILToTensor())
transforms.append(module.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

trans += [
transforms.Resize(resize_size, interpolation=interpolation, antialias=True),
transforms.CenterCrop(crop_size),
transforms += [
module.Resize(resize_size, interpolation=interpolation, antialias=True),
module.CenterCrop(crop_size),
]

if backend == "pil":
trans.append(transforms.PILToTensor())
transforms.append(module.PILToTensor())

trans += [
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
transforms += [
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
]

self.transforms = transforms.Compose(trans)
self.transforms = module.Compose(transforms)

def __call__(self, img):
return self.transforms(img)
3 changes: 3 additions & 0 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def load_data(traindir, valdir, args):
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
backend=args.backend,
use_v2=args.use_v2,
),
)
if args.cache_dataset:
Expand Down Expand Up @@ -172,6 +173,7 @@ def load_data(traindir, valdir, args):
resize_size=val_resize_size,
interpolation=interpolation,
backend=args.backend,
use_v2=args.use_v2,
)

dataset_test = torchvision.datasets.ImageFolder(
Expand Down Expand Up @@ -516,6 +518,7 @@ def get_args_parser(add_help=True):
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
return parser


Expand Down

0 comments on commit 08c9938

Please sign in to comment.