Skip to content

Commit

Permalink
Merge branch 'main' into SamuelGabriel_trivialaugment_implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Sep 2, 2021
1 parent fa8a6d5 commit 496de93
Show file tree
Hide file tree
Showing 12 changed files with 176 additions and 34 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
sphinx==2.4.4
sphinx==3.5.4
sphinx-gallery>=0.9.0
sphinx-copybutton>=0.3.1
matplotlib
Expand Down
13 changes: 11 additions & 2 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ Generic Transforms
:members:


AutoAugment Transforms
----------------------
Automatic Augmentation Transforms
---------------------------------

`AutoAugment <https://arxiv.org/pdf/1805.09501.pdf>`_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models.
Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that
Expand All @@ -229,6 +229,15 @@ The new transform can be used standalone or mixed-and-matched with existing tran
.. autoclass:: AutoAugment
:members:

`RandAugment <https://arxiv.org/abs/1909.13719>`_ is a simple high-performing Data Augmentation technique which improves the accuracy of Image Classification models.

.. autoclass:: RandAugment
:members:

`TrivialAugmentWide <https://arxiv.org/abs/2103.10158>`_ is a dataset-independent data-augmentation technique which improves the accuracy of Image Classification models.

.. autoclass:: TrivialAugmentWide
:members:

.. _functional_transforms:

Expand Down
16 changes: 16 additions & 0 deletions gallery/plot_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,22 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)

####################################
# RandAugment
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandAugment` transform automatically augments the data.
augmenter = T.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

####################################
# TrivialAugmentWide
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data.
augmenter = T.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

####################################
# Randomly-applied transforms
# ---------------------------
Expand Down
2 changes: 1 addition & 1 deletion gallery/plot_visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def show(imgs):
print(dog1_output['scores'])

#####################################
# Clearly the model is less confident about the dog detection than it is about
# Clearly the model is more confident about the dog detection than it is about
# the people detections. That's good news. When plotting the masks, we can ask
# for only those that have a good score. Let's use a score threshold of .75
# here, and also plot the masks of the second dog.
Expand Down
4 changes: 3 additions & 1 deletion references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ta_wide":
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment())
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide())
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
Expand Down
5 changes: 4 additions & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def main(args):
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

criterion = nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

opt_name = args.opt.lower()
if opt_name == 'sgd':
Expand Down Expand Up @@ -256,6 +256,9 @@ def get_args_parser(add_help=True):
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--label-smoothing', default=0.0, type=float,
help='label smoothing (default: 0.0)',
dest='label_smoothing')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
Expand Down
12 changes: 12 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,18 @@ def test_autoaugment(policy, fill):
transform.__repr__()


@pytest.mark.parametrize('num_ops', [1, 2, 3])
@pytest.mark.parametrize('magnitude', [7, 9, 11])
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
def test_randaugment(num_ops, magnitude, fill):
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
for _ in range(100):
img = transform(img)
transform.__repr__()


@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
def test_trivialaugmentwide(fill, num_magnitude_bins):
Expand Down
21 changes: 15 additions & 6 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,18 +525,26 @@ def test_autoaugment(device, policy, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

s_transform = None
transform = T.AutoAugment(policy=policy, fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


def test_autoaugment_save(tmpdir):
transform = T.AutoAugment()
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('num_ops', [1, 2, 3])
@pytest.mark.parametrize('magnitude', [7, 9, 11])
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
def test_randaugment(device, num_ops, magnitude, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
s_transform = torch.jit.script(transform)
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


@pytest.mark.parametrize('device', cpu_and_gpu())
Expand All @@ -552,8 +560,9 @@ def test_trivialaugmentwide(device, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


def test_trivialaugmentwide_save(tmpdir):
transform = T.TrivialAugmentWide()
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
def test_autoaugment_save(augmentation, tmpdir):
transform = augmentation()
s_transform = torch.jit.script(transform)
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))

Expand Down
6 changes: 3 additions & 3 deletions torchvision/csrc/io/image/cpu/encode_jpeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {

#else
// For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is
// defined as unsigned long, where as in later version, it is defined as size_t.
// defined as unsigned long, whereas in later version, it is defined as size_t.
// For windows backward compatibility, we define JpegSizeType as different types
// according to the libjpeg version used, in order to prevent compilcation
// according to the libjpeg version used, in order to prevent compilation
// errors.
#if defined(_WIN32) || !defined(JPEG_LIB_VERSION_MAJOR) || \
(JPEG_LIB_VERSION_MAJOR < 9) || \
JPEG_LIB_VERSION_MAJOR < 9 || \
(JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2)
using JpegSizeType = unsigned long;
#else
Expand Down
7 changes: 4 additions & 3 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ class Caltech101(VisionDataset):
root (string): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True.
target_type (string or list, optional): Type of target to use, ``category`` or
``annotation``. Can also be a list to output a tuple with all specified target types.
``category`` represents the target class, and ``annotation`` is a list of points
from a hand-generated outline. Defaults to ``category``.
``annotation``. Can also be a list to output a tuple with all specified
target types. ``category`` represents the target class, and
``annotation`` is a list of points from a hand-generated outline.
Defaults to ``category``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
Expand Down
12 changes: 8 additions & 4 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,14 @@ def decode_single(self, rel_codes, boxes):
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]

pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
# Distance from center to box's corner.
c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w

pred_boxes1 = pred_ctr_x - c_to_c_w
pred_boxes2 = pred_ctr_y - c_to_c_h
pred_boxes3 = pred_ctr_x + c_to_c_w
pred_boxes4 = pred_ctr_y + c_to_c_h
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
return pred_boxes

Expand Down
110 changes: 98 additions & 12 deletions torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import functional as F, InterpolationMode

__all__ = ["AutoAugmentPolicy", "AutoAugment", "TrivialAugmentWide"]
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]


def _apply_op(img: Tensor, op_name: str, magnitude: float,
Expand Down Expand Up @@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum):
SVHN = "svhn"


# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
Expand Down Expand Up @@ -85,9 +86,9 @@ def __init__(
self.policy = policy
self.interpolation = interpolation
self.fill = fill
self.transforms = self._get_transforms(policy)
self.policies = self._get_policies(policy)

def _get_transforms(
def _get_policies(
self,
policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
Expand Down Expand Up @@ -178,9 +179,9 @@ def _get_transforms(
else:
raise ValueError("The provided policy {} is not recognized.".format(policy))

def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# name: (magnitudes, signed)
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
Expand Down Expand Up @@ -224,11 +225,11 @@ def forward(self, img: Tensor) -> Tensor:
elif fill is not None:
fill = [float(f) for f in fill]

transform_id, probs, signs = self.get_params(len(self.transforms))
transform_id, probs, signs = self.get_params(len(self.policies))

for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]):
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
if probs[i] <= p:
op_meta = self._get_magnitudes(10, F.get_image_size(img))
op_meta = self._augmentation_space(10, F.get_image_size(img))
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
if signed and signs[i] == 0:
Expand All @@ -241,6 +242,91 @@ def __repr__(self) -> str:
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)


class RandAugment(torch.nn.Module):
r"""RandAugment data augmentation method based on
`"RandAugment: Practical automated data augmentation with a reduced search space"
<https://arxiv.org/abs/1909.13719>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_ops (int): Number of augmentation transformations to apply sequentially.
magnitude (int): Magnitude for all the transformations.
num_magnitude_bins (int): The number of different magnitude values.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""

def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None) -> None:
super().__init__()
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill

def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
"Invert": (torch.tensor(0.0), False),
}

def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]

for _ in range(self.num_ops):
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)

return img

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_ops={num_ops}'
s += ', magnitude={magnitude}'
s += ', num_magnitude_bins={num_magnitude_bins}'
s += ', interpolation={interpolation}'
s += ', fill={fill}'
s += ')'
return s.format(**self.__dict__)


class TrivialAugmentWide(torch.nn.Module):
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
Expand All @@ -264,9 +350,9 @@ def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMod
self.interpolation = interpolation
self.fill = fill

def _get_magnitudes(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
return {
# name: (magnitudes, signed)
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
Expand All @@ -283,7 +369,7 @@ def _get_magnitudes(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
"Invert": (torch.tensor(0.0), False),
}

def forward(self, img: Tensor):
def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Expand All @@ -297,7 +383,7 @@ def forward(self, img: Tensor):
elif fill is not None:
fill = [float(f) for f in fill]

op_meta = self._get_magnitudes(self.num_magnitude_bins)
op_meta = self._augmentation_space(self.num_magnitude_bins)
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
Expand Down

0 comments on commit 496de93

Please sign in to comment.