Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generator parameter to random transforms #7848

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ def freeze_rng_state():
torch.set_rng_state(rng_state)


@contextlib.contextmanager
def assert_default_rng_is_unchanged():
default_state = torch.random.get_rng_state()
yield
try:
torch.testing.assert_close(default_state, torch.random.get_rng_state(), rtol=0, atol=0)
except AssertionError as e:
raise AssertionError("The default RNG got consumed.") from e


def cycle_over(objs):
for idx, obj1 in enumerate(objs):
for obj2 in objs[:idx] + objs[idx + 1 :]:
Expand Down
38 changes: 38 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import math
import re
import sys
from pathlib import Path
from unittest import mock

Expand All @@ -13,6 +14,7 @@
import torch
import torchvision.transforms.v2 as transforms
from common_utils import (
assert_default_rng_is_unchanged,
assert_equal,
assert_no_warnings,
cache,
Expand Down Expand Up @@ -255,6 +257,10 @@ def check_transform(transform_cls, input, *args, **kwargs):

_check_transform_v1_compatibility(transform, input)

if "generator" in inspect.signature(transform_cls).parameters:
with assert_default_rng_is_unchanged():
transform_cls(*args, generator=torch.Generator(), **kwargs)(input)
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved


def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
def wrapper(input, *args, **kwargs):
Expand Down Expand Up @@ -2374,3 +2380,35 @@ def test_correctness(self):
assert isinstance(out_value, torch.Tensor) and not isinstance(out_value, datapoints.Datapoint)
else:
assert isinstance(out_value, type(input_value))


@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="Windows doesn't support fork()")
def test_transforms_rng_with_dataloader():
# This is more of a sanity test for torch core's handling of Generators within the Dataloader
# But worth having it here as well for security.
class MyTransform(torch.nn.Module):
def __init__(self, generator):
super().__init__()
self.generator = generator

def forward(self):
return torch.randint(0, 100_000, size=(1,), generator=self.generator).item()

class Dataset:
def __init__(self, transform):
self.transform = transform

def __getitem__(self, _):
return self.transform()

def __len__(self):
return 10

rng = torch.Generator().manual_seed(0)
t = MyTransform(rng)
ds = Dataset(t)

dl = DataLoader(ds, num_workers=2, multiprocessing_context="fork")
all_samples = [x.item() for x in dl]
# If the RNG were the same across workers, we would get duplicated samples here. We assert they're all unique.
assert len(set(all_samples)) == len(all_samples)
10 changes: 6 additions & 4 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def __init__(
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0.0,
inplace: bool = False,
generator=None,
):
super().__init__(p=p)
super().__init__(p=p, generator=generator)
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
Expand Down Expand Up @@ -111,11 +112,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

log_ratio = self._log_ratio
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=self.generator).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
generator=self.generator,
)
).item()

Expand All @@ -129,8 +131,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
else:
v = torch.tensor(self.value)[:, None, None]

i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
i = torch.randint(0, img_h - h + 1, size=(1,), generator=self.generator).item()
j = torch.randint(0, img_w - w + 1, size=(1,), generator=self.generator).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, None
Expand Down
38 changes: 24 additions & 14 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ def __init__(
*,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
generator=None,
) -> None:
super().__init__()
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
self.generator = generator

def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
key = keys[int(torch.randint(len(keys), (), generator=self.generator))]
return key, dct[key]

def _flatten_and_extract_image_or_video(
Expand Down Expand Up @@ -219,8 +221,9 @@ def __init__(
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
generator=None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
super().__init__(interpolation=interpolation, fill=fill, generator=generator)
self.policy = policy
self._policies = self._get_policies(policy)

Expand Down Expand Up @@ -318,18 +321,18 @@ def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video)

policy = self._policies[int(torch.randint(len(self._policies), ()))]
policy = self._policies[int(torch.randint(len(self._policies), (), generator=self.generator))]

for transform_id, probability, magnitude_idx in policy:
if not torch.rand(()) <= probability:
if not torch.rand((), generator=self.generator) <= probability:
continue

magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]

magnitudes = magnitudes_fn(10, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
if signed and torch.rand((), generator=self.generator) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
Expand Down Expand Up @@ -399,8 +402,9 @@ def __init__(
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
generator=None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
super().__init__(interpolation=interpolation, fill=fill, generator=generator)
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
Expand All @@ -414,7 +418,7 @@ def forward(self, *inputs: Any) -> Any:
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[self.magnitude])
if signed and torch.rand(()) <= 0.5:
if signed and torch.rand((), generator=self.generator) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
Expand Down Expand Up @@ -472,8 +476,9 @@ def __init__(
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
generator=None,
):
super().__init__(interpolation=interpolation, fill=fill)
super().__init__(interpolation=interpolation, fill=fill, generator=generator)
self.num_magnitude_bins = num_magnitude_bins

def forward(self, *inputs: Any) -> Any:
Expand All @@ -484,8 +489,8 @@ def forward(self, *inputs: Any) -> Any:

magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, (), generator=self.generator))])
if signed and torch.rand((), generator=self.generator) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
Expand Down Expand Up @@ -555,8 +560,9 @@ def __init__(
all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
generator=None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
super().__init__(interpolation=interpolation, fill=fill, generator=generator)
self._PARAMETER_MAX = 10
if not (1 <= severity <= self._PARAMETER_MAX):
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
Expand Down Expand Up @@ -601,14 +607,18 @@ def forward(self, *inputs: Any) -> Any:
mix = m[:, 0].reshape(batch_dims) * batch
for i in range(self.mixture_width):
aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
depth = (
self.chain_depth
if self.chain_depth > 0
else int(torch.randint(low=1, high=4, size=(1,), generator=self.generator).item())
)
for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)

magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude = float(magnitudes[int(torch.randint(self.severity, (), generator=self.generator))])
if signed and torch.rand((), generator=self.generator) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0
Expand Down
50 changes: 38 additions & 12 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,14 @@ def __init__(
contrast: Optional[Union[float, Sequence[float]]] = None,
saturation: Optional[Union[float, Sequence[float]]] = None,
hue: Optional[Union[float, Sequence[float]]] = None,
generator=None,
) -> None:
super().__init__()
self.brightness = self._check_input(brightness, "brightness")
self.contrast = self._check_input(contrast, "contrast")
self.saturation = self._check_input(saturation, "saturation")
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
self.generator = generator

def _check_input(
self,
Expand Down Expand Up @@ -131,16 +133,28 @@ def _check_input(
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))

@staticmethod
def _generate_value(left: float, right: float) -> float:
return torch.empty(1).uniform_(left, right).item()
def _generate_value(left: float, right: float, generator=None) -> float:
return torch.empty(1).uniform_(left, right, generator=generator).item()

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
fn_idx = torch.randperm(4)

b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1])
s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1])
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1])
fn_idx = torch.randperm(4, generator=self.generator)

b = (
None
if self.brightness is None
else self._generate_value(self.brightness[0], self.brightness[1], generator=self.generator)
)
c = (
None
if self.contrast is None
else self._generate_value(self.contrast[0], self.contrast[1], generator=self.generator)
)
s = (
None
if self.saturation is None
else self._generate_value(self.saturation[0], self.saturation[1], generator=self.generator)
)
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1], generator=self.generator)

return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)

Expand Down Expand Up @@ -168,9 +182,13 @@ class RandomChannelPermutation(Transform):
.. v2betastatus:: RandomChannelPermutation transform
"""

def __init__(self, generator=None):
super().__init__()
self.generator = generator

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
return dict(permutation=torch.randperm(num_channels))
return dict(permutation=torch.randperm(num_channels, generator=self.generator))

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.permute_channels, inpt, params["permutation"])
Expand Down Expand Up @@ -209,27 +227,35 @@ def __init__(
saturation: Tuple[float, float] = (0.5, 1.5),
hue: Tuple[float, float] = (-0.05, 0.05),
p: float = 0.5,
generator=None,
):
super().__init__()
self.brightness = brightness
self.contrast = contrast
self.hue = hue
self.saturation = saturation
self.p = p
self.generator = generator

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
params: Dict[str, Any] = {
key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None
key: ColorJitter._generate_value(range[0], range[1])
if torch.rand(1, generator=self.generator) < self.p
else None
for key, range in [
("brightness_factor", self.brightness),
("contrast_factor", self.contrast),
("saturation_factor", self.saturation),
("hue_factor", self.hue),
]
}
params["contrast_before"] = bool(torch.rand(()) < 0.5)
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
params["contrast_before"] = bool(torch.rand((), generator=self.generator) < 0.5)
params["channel_permutation"] = (
torch.randperm(num_channels, generator=self.generator)
if torch.rand(1, generator=self.generator) < self.p
else None
)
return params

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
Expand Down
10 changes: 6 additions & 4 deletions torchvision/transforms/v2/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class RandomApply(Transform):

_v1_transform_cls = _transforms.RandomApply

def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5, generator=None) -> None:
super().__init__()

if not isinstance(transforms, (Sequence, nn.ModuleList)):
Expand All @@ -95,14 +95,15 @@ def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: floa
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
self.p = p
self.generator = generator

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {"transforms": self.transforms, "p": self.p}

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

if torch.rand(1) >= self.p:
if torch.rand(1, generator=self.generator) >= self.p:
return sample

for transform in self.transforms:
Expand Down Expand Up @@ -166,15 +167,16 @@ class RandomOrder(Transform):
transforms (sequence or torch.nn.Module): list of transformations
"""

def __init__(self, transforms: Sequence[Callable]) -> None:
def __init__(self, transforms: Sequence[Callable], generator=None) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
super().__init__()
self.transforms = transforms
self.generator = generator

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for idx in torch.randperm(len(self.transforms)):
for idx in torch.randperm(len(self.transforms), generator=self.generator):
transform = self.transforms[idx]
sample = transform(sample)
return sample
Loading
Loading