From 496de9339f23686bd29019fe0e2f4969c4faf024 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 2 Sep 2021 16:33:12 +0100 Subject: [PATCH] Merge branch 'main' into SamuelGabriel_trivialaugment_implementation --- docs/requirements.txt | 2 +- docs/source/transforms.rst | 13 ++- gallery/plot_transforms.py | 16 +++ gallery/plot_visualization_utils.py | 2 +- references/classification/presets.py | 4 +- references/classification/train.py | 5 +- test/test_transforms.py | 12 ++ test/test_transforms_tensor.py | 21 +++- torchvision/csrc/io/image/cpu/encode_jpeg.cpp | 6 +- torchvision/datasets/caltech.py | 7 +- torchvision/models/detection/_utils.py | 12 +- torchvision/transforms/autoaugment.py | 110 ++++++++++++++++-- 12 files changed, 176 insertions(+), 34 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 68efe2cb639..44132ef3375 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,4 @@ -sphinx==2.4.4 +sphinx==3.5.4 sphinx-gallery>=0.9.0 sphinx-copybutton>=0.3.1 matplotlib diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 854c869a0de..f184d5da30e 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -214,8 +214,8 @@ Generic Transforms :members: -AutoAugment Transforms ----------------------- +Automatic Augmentation Transforms +--------------------------------- `AutoAugment `_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models. Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that @@ -229,6 +229,15 @@ The new transform can be used standalone or mixed-and-matched with existing tran .. autoclass:: AutoAugment :members: +`RandAugment `_ is a simple high-performing Data Augmentation technique which improves the accuracy of Image Classification models. + +.. autoclass:: RandAugment + :members: + +`TrivialAugmentWide `_ is a dataset-independent data-augmentation technique which improves the accuracy of Image Classification models. + +.. autoclass:: TrivialAugmentWide + :members: .. _functional_transforms: diff --git a/gallery/plot_transforms.py b/gallery/plot_transforms.py index 032dd584c26..68ffae16a0f 100644 --- a/gallery/plot_transforms.py +++ b/gallery/plot_transforms.py @@ -245,6 +245,22 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): row_title = [str(policy).split('.')[-1] for policy in policies] plot(imgs, row_title=row_title) +#################################### +# RandAugment +# ~~~~~~~~~~~ +# The :class:`~torchvision.transforms.RandAugment` transform automatically augments the data. +augmenter = T.RandAugment() +imgs = [augmenter(orig_img) for _ in range(4)] +plot(imgs) + +#################################### +# TrivialAugmentWide +# ~~~~~~~~~~~ +# The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data. +augmenter = T.TrivialAugmentWide() +imgs = [augmenter(orig_img) for _ in range(4)] +plot(imgs) + #################################### # Randomly-applied transforms # --------------------------- diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index feedee4e3cf..e4219f7325d 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -343,7 +343,7 @@ def show(imgs): print(dog1_output['scores']) ##################################### -# Clearly the model is less confident about the dog detection than it is about +# Clearly the model is more confident about the dog detection than it is about # the people detections. That's good news. When plotting the masks, we can ask # for only those that have a good score. Let's use a score threshold of .75 # here, and also plot the masks of the second dog. diff --git a/references/classification/presets.py b/references/classification/presets.py index 0ccf835e7c3..981dbd6ed9e 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -9,7 +9,9 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2 if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: - if auto_augment_policy == "ta_wide": + if auto_augment_policy == "ra": + trans.append(autoaugment.RandAugment()) + elif auto_augment_policy == "ta_wide": trans.append(autoaugment.TrivialAugmentWide()) else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) diff --git a/references/classification/train.py b/references/classification/train.py index 9ba99b3dc54..79b99156a05 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -175,7 +175,7 @@ def main(args): if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - criterion = nn.CrossEntropyLoss() + criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) opt_name = args.opt.lower() if opt_name == 'sgd': @@ -256,6 +256,9 @@ def get_args_parser(add_help=True): parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') + parser.add_argument('--label-smoothing', default=0.0, type=float, + help='label smoothing (default: 0.0)', + dest='label_smoothing') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') diff --git a/test/test_transforms.py b/test/test_transforms.py index 2b15c6afdd0..675e79ac3ba 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1490,6 +1490,18 @@ def test_autoaugment(policy, fill): transform.__repr__() +@pytest.mark.parametrize('num_ops', [1, 2, 3]) +@pytest.mark.parametrize('magnitude', [7, 9, 11]) +@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) +def test_randaugment(num_ops, magnitude, fill): + random.seed(42) + img = Image.open(GRACE_HOPPER) + transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill) + for _ in range(100): + img = transform(img) + transform.__repr__() + + @pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) @pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30]) def test_trivialaugmentwide(fill, num_magnitude_bins): diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index a057e193d8a..aaf7880f124 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -525,7 +525,6 @@ def test_autoaugment(device, policy, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) - s_transform = None transform = T.AutoAugment(policy=policy, fill=fill) s_transform = torch.jit.script(transform) for _ in range(25): @@ -533,10 +532,19 @@ def test_autoaugment(device, policy, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -def test_autoaugment_save(tmpdir): - transform = T.AutoAugment() +@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize('num_ops', [1, 2, 3]) +@pytest.mark.parametrize('magnitude', [7, 9, 11]) +@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +def test_randaugment(device, num_ops, magnitude, fill): + tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) + batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) + + transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill) s_transform = torch.jit.script(transform) - s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) + for _ in range(25): + _test_transform_vs_scripted(transform, s_transform, tensor) + _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) @pytest.mark.parametrize('device', cpu_and_gpu()) @@ -552,8 +560,9 @@ def test_trivialaugmentwide(device, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -def test_trivialaugmentwide_save(tmpdir): - transform = T.TrivialAugmentWide() +@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide]) +def test_autoaugment_save(augmentation, tmpdir): + transform = augmentation() s_transform = torch.jit.script(transform) s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) diff --git a/torchvision/csrc/io/image/cpu/encode_jpeg.cpp b/torchvision/csrc/io/image/cpu/encode_jpeg.cpp index 89d05ea4079..3b669b9906c 100644 --- a/torchvision/csrc/io/image/cpu/encode_jpeg.cpp +++ b/torchvision/csrc/io/image/cpu/encode_jpeg.cpp @@ -14,12 +14,12 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) { #else // For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is -// defined as unsigned long, where as in later version, it is defined as size_t. +// defined as unsigned long, whereas in later version, it is defined as size_t. // For windows backward compatibility, we define JpegSizeType as different types -// according to the libjpeg version used, in order to prevent compilcation +// according to the libjpeg version used, in order to prevent compilation // errors. #if defined(_WIN32) || !defined(JPEG_LIB_VERSION_MAJOR) || \ - (JPEG_LIB_VERSION_MAJOR < 9) || \ + JPEG_LIB_VERSION_MAJOR < 9 || \ (JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2) using JpegSizeType = unsigned long; #else diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index 1a254edb430..a99e6fde948 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -18,9 +18,10 @@ class Caltech101(VisionDataset): root (string): Root directory of dataset where directory ``caltech101`` exists or will be saved to if download is set to True. target_type (string or list, optional): Type of target to use, ``category`` or - ``annotation``. Can also be a list to output a tuple with all specified target types. - ``category`` represents the target class, and ``annotation`` is a list of points - from a hand-generated outline. Defaults to ``category``. + ``annotation``. Can also be a list to output a tuple with all specified + target types. ``category`` represents the target class, and + ``annotation`` is a list of points from a hand-generated outline. + Defaults to ``category``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 40281b39b6b..1d3bcdba7fe 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -216,10 +216,14 @@ def decode_single(self, rel_codes, boxes): pred_w = torch.exp(dw) * widths[:, None] pred_h = torch.exp(dh) * heights[:, None] - pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w - pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h - pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w - pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h + # Distance from center to box's corner. + c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h + c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w + + pred_boxes1 = pred_ctr_x - c_to_c_w + pred_boxes2 = pred_ctr_y - c_to_c_h + pred_boxes3 = pred_ctr_x + c_to_c_w + pred_boxes4 = pred_ctr_y + c_to_c_h pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1) return pred_boxes diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 8013d637ae1..117030d3a50 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -7,7 +7,7 @@ from . import functional as F, InterpolationMode -__all__ = ["AutoAugmentPolicy", "AutoAugment", "TrivialAugmentWide"] +__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] def _apply_op(img: Tensor, op_name: str, magnitude: float, @@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum): SVHN = "svhn" +# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class class AutoAugment(torch.nn.Module): r"""AutoAugment data augmentation method based on `"AutoAugment: Learning Augmentation Strategies from Data" `_. @@ -85,9 +86,9 @@ def __init__( self.policy = policy self.interpolation = interpolation self.fill = fill - self.transforms = self._get_transforms(policy) + self.policies = self._get_policies(policy) - def _get_transforms( + def _get_policies( self, policy: AutoAugmentPolicy ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: @@ -178,9 +179,9 @@ def _get_transforms( else: raise ValueError("The provided policy {} is not recognized.".format(policy)) - def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: return { - # name: (magnitudes, signed) + # op_name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), @@ -224,11 +225,11 @@ def forward(self, img: Tensor) -> Tensor: elif fill is not None: fill = [float(f) for f in fill] - transform_id, probs, signs = self.get_params(len(self.transforms)) + transform_id, probs, signs = self.get_params(len(self.policies)) - for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]): + for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]): if probs[i] <= p: - op_meta = self._get_magnitudes(10, F.get_image_size(img)) + op_meta = self._augmentation_space(10, F.get_image_size(img)) magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 if signed and signs[i] == 0: @@ -241,6 +242,91 @@ def __repr__(self) -> str: return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) +class RandAugment(torch.nn.Module): + r"""RandAugment data augmentation method based on + `"RandAugment: Practical automated data augmentation with a reduced search space" + `_. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_ops (int): Number of augmentation transformations to apply sequentially. + magnitude (int): Magnitude for all the transformations. + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None) -> None: + super().__init__() + self.num_ops = num_ops + self.magnitude = magnitude + self.num_magnitude_bins = num_magnitude_bins + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + "Invert": (torch.tensor(0.0), False), + } + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + for _ in range(self.num_ops): + op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0 + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + return img + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_ops={num_ops}' + s += ', magnitude={magnitude}' + s += ', num_magnitude_bins={num_magnitude_bins}' + s += ', interpolation={interpolation}' + s += ', fill={fill}' + s += ')' + return s.format(**self.__dict__) + + class TrivialAugmentWide(torch.nn.Module): r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" `. @@ -264,9 +350,9 @@ def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMod self.interpolation = interpolation self.fill = fill - def _get_magnitudes(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: + def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: return { - # name: (magnitudes, signed) + # op_name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), @@ -283,7 +369,7 @@ def _get_magnitudes(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: "Invert": (torch.tensor(0.0), False), } - def forward(self, img: Tensor): + def forward(self, img: Tensor) -> Tensor: """ img (PIL Image or Tensor): Image to be transformed. @@ -297,7 +383,7 @@ def forward(self, img: Tensor): elif fill is not None: fill = [float(f) for f in fill] - op_meta = self._get_magnitudes(self.num_magnitude_bins) + op_meta = self._augmentation_space(self.num_magnitude_bins) op_index = int(torch.randint(len(op_meta), (1,)).item()) op_name = list(op_meta.keys())[op_index] magnitudes, signed = op_meta[op_name]