diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 38fc417204c..1117b6c555f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -489,24 +489,42 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" -class RandomTransforms: - """Base class for a list of transformations with randomness +class RandomApply(torch.nn.Module): + """Apply randomly a list of transformations with a given probability. + + .. note:: + In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of + transforms as shown below: + + >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ + >>> transforms.ColorJitter(), + >>> ]), p=0.3) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. Args: - transforms (sequence): list of transformations + transforms (sequence or torch.nn.Module): list of transformations + p (float): probability """ - def __init__(self, transforms): + def __init__(self, transforms, p=0.5): + super().__init__() _log_api_usage_once(self) - if not isinstance(transforms, Sequence): - raise TypeError("Argument transforms should be a sequence") self.transforms = transforms + self.p = p - def __call__(self, *args, **kwargs): - raise NotImplementedError() + def forward(self, img): + if self.p < torch.rand(1): + return img + for t in self.transforms: + img = t(img) + return img def __repr__(self) -> str: format_string = self.__class__.__name__ + "(" + format_string += f"\n p={self.p}" for t in self.transforms: format_string += "\n" format_string += f" {t}" @@ -514,16 +532,16 @@ def __repr__(self) -> str: return format_string -class RandomApply(torch.nn.Module): - """Apply randomly a list of transformations with a given probability. +class RandomOrder(torch.nn.Module): + """Apply a list of transformations in a random order. .. note:: In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of transforms as shown below: - >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ + >>> transforms = transforms.RandomOrder(torch.nn.ModuleList([ >>> transforms.ColorJitter(), - >>> ]), p=0.3) + >>> ])) >>> scripted_transforms = torch.jit.script(transforms) Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require @@ -531,25 +549,21 @@ class RandomApply(torch.nn.Module): Args: transforms (sequence or torch.nn.Module): list of transformations - p (float): probability """ - def __init__(self, transforms, p=0.5): + def __init__(self, transforms): super().__init__() _log_api_usage_once(self) self.transforms = transforms - self.p = p def forward(self, img): - if self.p < torch.rand(1): - return img - for t in self.transforms: - img = t(img) + order = torch.randperm(len(self.transforms)) + for i in order: + img = self.transforms[i.item()](img) return img def __repr__(self) -> str: format_string = self.__class__.__name__ + "(" - format_string += f"\n p={self.p}" for t in self.transforms: format_string += "\n" format_string += f" {t}" @@ -557,32 +571,50 @@ def __repr__(self) -> str: return format_string -class RandomOrder(RandomTransforms): - """Apply a list of transformations in a random order. This transform does not support torchscript.""" +class RandomChoice(torch.nn.Module): + """Apply single transformation randomly picked from a list. - def __call__(self, img): - order = list(range(len(self.transforms))) - random.shuffle(order) - for i in order: - img = self.transforms[i](img) - return img + .. note:: + In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of + transforms as shown below: + >>> transforms = transforms.RandomChoice(torch.nn.ModuleList([ + >>> transforms.ColorJitter(), + >>> ]), p=torch.Tensor([0.3])) + >>> scripted_transforms = torch.jit.script(transforms) -class RandomChoice(RandomTransforms): - """Apply single transformation randomly picked from a list. This transform does not support torchscript.""" + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + Args: + transforms (sequence or torch.nn.Module): list of transformations + p (optional, torch.Tensor): input tensor containing weights. Default: equal weights + """ - def __init__(self, transforms, p=None): - super().__init__(transforms) - if p is not None and not isinstance(p, Sequence): - raise TypeError("Argument p should be a sequence") + def __init__(self, transforms, p: Optional[torch.Tensor] = None): + super().__init__() + _log_api_usage_once(self) + if p is None: + p = torch.ones(len(transforms)) + self.transforms = transforms self.p = p - def __call__(self, *args): - t = random.choices(self.transforms, weights=self.p)[0] - return t(*args) + def forward(self, img): + i = torch.multinomial(self.p, 1) + # self.transforms[i.item()](img) gives Error: Expected integer literal for index, whilw JIT Scripting + # Workaround the ModuleList indexing issue: https://github.com/pytorch/pytorch/issues/16123 + for j,t in enumerate(self.transforms): + if i==j: + return t(img) def __repr__(self) -> str: - return f"{super().__repr__()}(p={self.p})" + format_string = self.__class__.__name__ + "(" + format_string += f"\n p={self.p}" + for t in self.transforms: + format_string += "\n" + format_string += f" {t}" + format_string += "\n)" + return format_string class RandomCrop(torch.nn.Module):