-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add RandomMixupCutmix. * Add test with real data. * Use dataloader and collate in the test. * Making RandomMixupCutmix JIT scriptable. * Move out label_smoothing and try roll instead of flip * Adding mixup/cutmix in references script. * Handle one-hot encoded target in accuracy. * Add support of devices on tests. * Separate Mixup from Cutmix. * Add check for floats. * Adding device on expect value. * Remove hardcoded weights. * One-hot only when necessary. * Fix linter. * Moving mixup and cutmix to references. * Final code clean up.
- Loading branch information
Showing
5 changed files
with
210 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import math | ||
import torch | ||
|
||
from typing import Tuple | ||
from torch import Tensor | ||
from torchvision.transforms import functional as F | ||
|
||
|
||
class RandomMixup(torch.nn.Module): | ||
"""Randomly apply Mixup to the provided batch and targets. | ||
The class implements the data augmentations as described in the paper | ||
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_. | ||
Args: | ||
num_classes (int): number of classes used for one-hot encoding. | ||
p (float): probability of the batch being transformed. Default value is 0.5. | ||
alpha (float): hyperparameter of the Beta distribution used for mixup. | ||
Default value is 1.0. | ||
inplace (bool): boolean to make this transform inplace. Default set to False. | ||
""" | ||
|
||
def __init__(self, num_classes: int, | ||
p: float = 0.5, alpha: float = 1.0, | ||
inplace: bool = False) -> None: | ||
super().__init__() | ||
assert num_classes > 0, "Please provide a valid positive value for the num_classes." | ||
assert alpha > 0, "Alpha param can't be zero." | ||
|
||
self.num_classes = num_classes | ||
self.p = p | ||
self.alpha = alpha | ||
self.inplace = inplace | ||
|
||
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | ||
""" | ||
Args: | ||
batch (Tensor): Float tensor of size (B, C, H, W) | ||
target (Tensor): Integer tensor of size (B, ) | ||
Returns: | ||
Tensor: Randomly transformed batch. | ||
""" | ||
if batch.ndim != 4: | ||
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) | ||
elif target.ndim != 1: | ||
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) | ||
elif not batch.is_floating_point(): | ||
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) | ||
elif target.dtype != torch.int64: | ||
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) | ||
|
||
if not self.inplace: | ||
batch = batch.clone() | ||
target = target.clone() | ||
|
||
if target.ndim == 1: | ||
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) | ||
|
||
if torch.rand(1).item() >= self.p: | ||
return batch, target | ||
|
||
# It's faster to roll the batch by one instead of shuffling it to create image pairs | ||
batch_rolled = batch.roll(1, 0) | ||
target_rolled = target.roll(1, 0) | ||
|
||
# Implemented as on mixup paper, page 3. | ||
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) | ||
batch_rolled.mul_(1.0 - lambda_param) | ||
batch.mul_(lambda_param).add_(batch_rolled) | ||
|
||
target_rolled.mul_(1.0 - lambda_param) | ||
target.mul_(lambda_param).add_(target_rolled) | ||
|
||
return batch, target | ||
|
||
def __repr__(self) -> str: | ||
s = self.__class__.__name__ + '(' | ||
s += 'num_classes={num_classes}' | ||
s += ', p={p}' | ||
s += ', alpha={alpha}' | ||
s += ', inplace={inplace}' | ||
s += ')' | ||
return s.format(**self.__dict__) | ||
|
||
|
||
class RandomCutmix(torch.nn.Module): | ||
"""Randomly apply Cutmix to the provided batch and targets. | ||
The class implements the data augmentations as described in the paper | ||
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" | ||
<https://arxiv.org/abs/1905.04899>`_. | ||
Args: | ||
num_classes (int): number of classes used for one-hot encoding. | ||
p (float): probability of the batch being transformed. Default value is 0.5. | ||
alpha (float): hyperparameter of the Beta distribution used for cutmix. | ||
Default value is 1.0. | ||
inplace (bool): boolean to make this transform inplace. Default set to False. | ||
""" | ||
|
||
def __init__(self, num_classes: int, | ||
p: float = 0.5, alpha: float = 1.0, | ||
inplace: bool = False) -> None: | ||
super().__init__() | ||
assert num_classes > 0, "Please provide a valid positive value for the num_classes." | ||
assert alpha > 0, "Alpha param can't be zero." | ||
|
||
self.num_classes = num_classes | ||
self.p = p | ||
self.alpha = alpha | ||
self.inplace = inplace | ||
|
||
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | ||
""" | ||
Args: | ||
batch (Tensor): Float tensor of size (B, C, H, W) | ||
target (Tensor): Integer tensor of size (B, ) | ||
Returns: | ||
Tensor: Randomly transformed batch. | ||
""" | ||
if batch.ndim != 4: | ||
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) | ||
elif target.ndim != 1: | ||
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) | ||
elif not batch.is_floating_point(): | ||
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) | ||
elif target.dtype != torch.int64: | ||
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) | ||
|
||
if not self.inplace: | ||
batch = batch.clone() | ||
target = target.clone() | ||
|
||
if target.ndim == 1: | ||
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) | ||
|
||
if torch.rand(1).item() >= self.p: | ||
return batch, target | ||
|
||
# It's faster to roll the batch by one instead of shuffling it to create image pairs | ||
batch_rolled = batch.roll(1, 0) | ||
target_rolled = target.roll(1, 0) | ||
|
||
# Implemented as on cutmix paper, page 12 (with minor corrections on typos). | ||
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) | ||
W, H = F.get_image_size(batch) | ||
|
||
r_x = torch.randint(W, (1,)) | ||
r_y = torch.randint(H, (1,)) | ||
|
||
r = 0.5 * math.sqrt(1.0 - lambda_param) | ||
r_w_half = int(r * W) | ||
r_h_half = int(r * H) | ||
|
||
x1 = int(torch.clamp(r_x - r_w_half, min=0)) | ||
y1 = int(torch.clamp(r_y - r_h_half, min=0)) | ||
x2 = int(torch.clamp(r_x + r_w_half, max=W)) | ||
y2 = int(torch.clamp(r_y + r_h_half, max=H)) | ||
|
||
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] | ||
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) | ||
|
||
target_rolled.mul_(1.0 - lambda_param) | ||
target.mul_(lambda_param).add_(target_rolled) | ||
|
||
return batch, target | ||
|
||
def __repr__(self) -> str: | ||
s = self.__class__.__name__ + '(' | ||
s += 'num_classes={num_classes}' | ||
s += ', p={p}' | ||
s += ', alpha={alpha}' | ||
s += ', inplace={inplace}' | ||
s += ')' | ||
return s.format(**self.__dict__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters