From 7e235234aa4d3e988b39dbb5b09360d875a809af Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 17 Aug 2023 07:23:09 -0700 Subject: [PATCH 1/3] Add generator parameter to random transforms --- test/common_utils.py | 10 +++++++ test/test_transforms_v2_refactored.py | 36 +++++++++++++++++++++++++ torchvision/transforms/v2/_transform.py | 7 ++--- 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 9713901bdcf..48f8801ec6f 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -103,6 +103,16 @@ def freeze_rng_state(): torch.set_rng_state(rng_state) +@contextlib.contextmanager +def assert_default_rng_is_unchanged(): + default_state = torch.random.get_rng_state() + yield + try: + torch.testing.assert_close(default_state, torch.random.get_rng_state(), rtol=0, atol=0) + except AssertionError as e: + raise AssertionError("The default RNG got consumed.") from e + + def cycle_over(objs): for idx, obj1 in enumerate(objs): for obj2 in objs[:idx] + objs[idx + 1 :]: diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index c51b7c7555f..2d687340fde 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -13,6 +13,7 @@ import torch import torchvision.transforms.v2 as transforms from common_utils import ( + assert_default_rng_is_unchanged, assert_equal, assert_no_warnings, cache, @@ -255,6 +256,10 @@ def check_transform(transform_cls, input, *args, **kwargs): _check_transform_v1_compatibility(transform, input) + if "generator" in inspect.signature(transform_cls).parameters: + with assert_default_rng_is_unchanged(): + transform_cls(*args, generator=torch.Generator(), **kwargs)(input) + def transform_cls_to_functional(transform_cls, **transform_specific_kwargs): def wrapper(input, *args, **kwargs): @@ -2374,3 +2379,34 @@ def test_correctness(self): assert isinstance(out_value, torch.Tensor) and not isinstance(out_value, datapoints.Datapoint) else: assert isinstance(out_value, type(input_value)) + + +def test_transforms_rng_with_dataloader(): + # This is more of a sanity test for torch core's handling of Generators within the Dataloader + # But worth having it here as well for security. + class MyTransform(torch.nn.Module): + def __init__(self, generator): + super().__init__() + self.generator = generator + + def forward(self): + return torch.randint(0, 100_000, size=(1,), generator=self.generator).item() + + class Dataset: + def __init__(self, transform): + self.transform = transform + + def __getitem__(self, _): + return self.transform() + + def __len__(self): + return 10 + + rng = torch.Generator().manual_seed(0) + t = MyTransform(rng) + ds = Dataset(t, multiprocessing_context="fork") + + dl = DataLoader(ds, num_workers=2) + all_samples = [x.item() for x in dl] + # If the RNG were the same across workers, we would get duplicated samples here. We assert they're all unique. + assert len(set(all_samples)) == len(all_samples) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index e9af4b426fa..969c6d3cec8 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -122,7 +122,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: return { attr: value for attr, value in self.__dict__.items() - if not attr.startswith("_") and attr not in common_attrs + if not attr.startswith("_") and attr not in common_attrs and attr != "generator" } def __prepare_scriptable__(self) -> nn.Module: @@ -143,11 +143,12 @@ def __prepare_scriptable__(self) -> nn.Module: class _RandomApplyTransform(Transform): - def __init__(self, p: float = 0.5) -> None: + def __init__(self, p: float = 0.5, generator=None) -> None: if not (0.0 <= p <= 1.0): raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") super().__init__() + self.generator = generator self.p = p def forward(self, *inputs: Any) -> Any: @@ -160,7 +161,7 @@ def forward(self, *inputs: Any) -> Any: self._check_inputs(flat_inputs) - if torch.rand(1) >= self.p: + if torch.rand(1, generator=self.generator) >= self.p: return inputs needs_transform_list = self._needs_transform_list(flat_inputs) From 2d83526607e31a504aa2b9c87dc969425e19698b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 21 Aug 2023 10:31:56 +0100 Subject: [PATCH 2/3] fix --- test/test_transforms_v2_refactored.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 5ae3afae4ac..5c24ee4def4 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2404,9 +2404,9 @@ def __len__(self): rng = torch.Generator().manual_seed(0) t = MyTransform(rng) - ds = Dataset(t, multiprocessing_context="fork") + ds = Dataset(t) - dl = DataLoader(ds, num_workers=2) + dl = DataLoader(ds, num_workers=2, multiprocessing_context="fork") all_samples = [x.item() for x in dl] # If the RNG were the same across workers, we would get duplicated samples here. We assert they're all unique. assert len(set(all_samples)) == len(all_samples) From c9a3e9fe5b119c1b0428f5b01b616bc77f64cedb Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 21 Aug 2023 06:44:22 -0700 Subject: [PATCH 3/3] Some more --- test/test_transforms_v2_refactored.py | 2 + torchvision/transforms/v2/_augment.py | 10 +-- torchvision/transforms/v2/_auto_augment.py | 38 ++++++---- torchvision/transforms/v2/_color.py | 50 +++++++++---- torchvision/transforms/v2/_container.py | 10 +-- torchvision/transforms/v2/_geometry.py | 84 ++++++++++++++-------- torchvision/transforms/v2/_misc.py | 8 ++- 7 files changed, 135 insertions(+), 67 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 5c24ee4def4..817e873ca89 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -3,6 +3,7 @@ import inspect import math import re +import sys from pathlib import Path from unittest import mock @@ -2381,6 +2382,7 @@ def test_correctness(self): assert isinstance(out_value, type(input_value)) +@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="Windows doesn't support fork()") def test_transforms_rng_with_dataloader(): # This is more of a sanity test for torch core's handling of Generators within the Dataloader # But worth having it here as well for security. diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index a6af96a5ef6..f5f3c2ab84b 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -63,8 +63,9 @@ def __init__( ratio: Tuple[float, float] = (0.3, 3.3), value: float = 0.0, inplace: bool = False, + generator=None, ): - super().__init__(p=p) + super().__init__(p=p, generator=generator) if not isinstance(value, (numbers.Number, str, tuple, list)): raise TypeError("Argument value should be either a number or str or a sequence") if isinstance(value, str) and value != "random": @@ -111,11 +112,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: log_ratio = self._log_ratio for _ in range(10): - erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=self.generator).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] log_ratio[1], # type: ignore[arg-type] + generator=self.generator, ) ).item() @@ -129,8 +131,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: else: v = torch.tensor(self.value)[:, None, None] - i = torch.randint(0, img_h - h + 1, size=(1,)).item() - j = torch.randint(0, img_w - w + 1, size=(1,)).item() + i = torch.randint(0, img_h - h + 1, size=(1,), generator=self.generator).item() + j = torch.randint(0, img_w - w + 1, size=(1,), generator=self.generator).item() break else: i, j, h, w, v = 0, 0, img_h, img_w, None diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 097e90fc4ab..2aae6965c89 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -25,14 +25,16 @@ def __init__( *, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, + generator=None, ) -> None: super().__init__() self.interpolation = _check_interpolation(interpolation) self.fill = _setup_fill_arg(fill) + self.generator = generator def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]: keys = tuple(dct.keys()) - key = keys[int(torch.randint(len(keys), ()))] + key = keys[int(torch.randint(len(keys), (), generator=self.generator))] return key, dct[key] def _flatten_and_extract_image_or_video( @@ -219,8 +221,9 @@ def __init__( policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, + generator=None, ) -> None: - super().__init__(interpolation=interpolation, fill=fill) + super().__init__(interpolation=interpolation, fill=fill, generator=generator) self.policy = policy self._policies = self._get_policies(policy) @@ -318,10 +321,10 @@ def forward(self, *inputs: Any) -> Any: flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) height, width = get_size(image_or_video) - policy = self._policies[int(torch.randint(len(self._policies), ()))] + policy = self._policies[int(torch.randint(len(self._policies), (), generator=self.generator))] for transform_id, probability, magnitude_idx in policy: - if not torch.rand(()) <= probability: + if not torch.rand((), generator=self.generator) <= probability: continue magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] @@ -329,7 +332,7 @@ def forward(self, *inputs: Any) -> Any: magnitudes = magnitudes_fn(10, height, width) if magnitudes is not None: magnitude = float(magnitudes[magnitude_idx]) - if signed and torch.rand(()) <= 0.5: + if signed and torch.rand((), generator=self.generator) <= 0.5: magnitude *= -1 else: magnitude = 0.0 @@ -399,8 +402,9 @@ def __init__( num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, + generator=None, ) -> None: - super().__init__(interpolation=interpolation, fill=fill) + super().__init__(interpolation=interpolation, fill=fill, generator=generator) self.num_ops = num_ops self.magnitude = magnitude self.num_magnitude_bins = num_magnitude_bins @@ -414,7 +418,7 @@ def forward(self, *inputs: Any) -> Any: magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: magnitude = float(magnitudes[self.magnitude]) - if signed and torch.rand(()) <= 0.5: + if signed and torch.rand((), generator=self.generator) <= 0.5: magnitude *= -1 else: magnitude = 0.0 @@ -472,8 +476,9 @@ def __init__( num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, + generator=None, ): - super().__init__(interpolation=interpolation, fill=fill) + super().__init__(interpolation=interpolation, fill=fill, generator=generator) self.num_magnitude_bins = num_magnitude_bins def forward(self, *inputs: Any) -> Any: @@ -484,8 +489,8 @@ def forward(self, *inputs: Any) -> Any: magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) if magnitudes is not None: - magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) - if signed and torch.rand(()) <= 0.5: + magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, (), generator=self.generator))]) + if signed and torch.rand((), generator=self.generator) <= 0.5: magnitude *= -1 else: magnitude = 0.0 @@ -555,8 +560,9 @@ def __init__( all_ops: bool = True, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, + generator=None, ) -> None: - super().__init__(interpolation=interpolation, fill=fill) + super().__init__(interpolation=interpolation, fill=fill, generator=generator) self._PARAMETER_MAX = 10 if not (1 <= severity <= self._PARAMETER_MAX): raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") @@ -601,14 +607,18 @@ def forward(self, *inputs: Any) -> Any: mix = m[:, 0].reshape(batch_dims) * batch for i in range(self.mixture_width): aug = batch - depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) + depth = ( + self.chain_depth + if self.chain_depth > 0 + else int(torch.randint(low=1, high=4, size=(1,), generator=self.generator).item()) + ) for _ in range(depth): transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width) if magnitudes is not None: - magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) - if signed and torch.rand(()) <= 0.5: + magnitude = float(magnitudes[int(torch.randint(self.severity, (), generator=self.generator))]) + if signed and torch.rand((), generator=self.generator) <= 0.5: magnitude *= -1 else: magnitude = 0.0 diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index a3792797959..1dd390be9a1 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -96,12 +96,14 @@ def __init__( contrast: Optional[Union[float, Sequence[float]]] = None, saturation: Optional[Union[float, Sequence[float]]] = None, hue: Optional[Union[float, Sequence[float]]] = None, + generator=None, ) -> None: super().__init__() self.brightness = self._check_input(brightness, "brightness") self.contrast = self._check_input(contrast, "contrast") self.saturation = self._check_input(saturation, "saturation") self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + self.generator = generator def _check_input( self, @@ -131,16 +133,28 @@ def _check_input( return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) @staticmethod - def _generate_value(left: float, right: float) -> float: - return torch.empty(1).uniform_(left, right).item() + def _generate_value(left: float, right: float, generator=None) -> float: + return torch.empty(1).uniform_(left, right, generator=generator).item() def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - fn_idx = torch.randperm(4) - - b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1]) - c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1]) - s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1]) - h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1]) + fn_idx = torch.randperm(4, generator=self.generator) + + b = ( + None + if self.brightness is None + else self._generate_value(self.brightness[0], self.brightness[1], generator=self.generator) + ) + c = ( + None + if self.contrast is None + else self._generate_value(self.contrast[0], self.contrast[1], generator=self.generator) + ) + s = ( + None + if self.saturation is None + else self._generate_value(self.saturation[0], self.saturation[1], generator=self.generator) + ) + h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1], generator=self.generator) return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h) @@ -168,9 +182,13 @@ class RandomChannelPermutation(Transform): .. v2betastatus:: RandomChannelPermutation transform """ + def __init__(self, generator=None): + super().__init__() + self.generator = generator + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) - return dict(permutation=torch.randperm(num_channels)) + return dict(permutation=torch.randperm(num_channels, generator=self.generator)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.permute_channels, inpt, params["permutation"]) @@ -209,6 +227,7 @@ def __init__( saturation: Tuple[float, float] = (0.5, 1.5), hue: Tuple[float, float] = (-0.05, 0.05), p: float = 0.5, + generator=None, ): super().__init__() self.brightness = brightness @@ -216,11 +235,14 @@ def __init__( self.hue = hue self.saturation = saturation self.p = p + self.generator = generator def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) params: Dict[str, Any] = { - key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None + key: ColorJitter._generate_value(range[0], range[1]) + if torch.rand(1, generator=self.generator) < self.p + else None for key, range in [ ("brightness_factor", self.brightness), ("contrast_factor", self.contrast), @@ -228,8 +250,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: ("hue_factor", self.hue), ] } - params["contrast_before"] = bool(torch.rand(()) < 0.5) - params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None + params["contrast_before"] = bool(torch.rand((), generator=self.generator) < 0.5) + params["channel_permutation"] = ( + torch.randperm(num_channels, generator=self.generator) + if torch.rand(1, generator=self.generator) < self.p + else None + ) return params def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index 8f591c49707..fa8331faad9 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -85,7 +85,7 @@ class RandomApply(Transform): _v1_transform_cls = _transforms.RandomApply - def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None: + def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5, generator=None) -> None: super().__init__() if not isinstance(transforms, (Sequence, nn.ModuleList)): @@ -95,6 +95,7 @@ def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: floa if not (0.0 <= p <= 1.0): raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") self.p = p + self.generator = generator def _extract_params_for_v1_transform(self) -> Dict[str, Any]: return {"transforms": self.transforms, "p": self.p} @@ -102,7 +103,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - if torch.rand(1) >= self.p: + if torch.rand(1, generator=self.generator) >= self.p: return sample for transform in self.transforms: @@ -166,15 +167,16 @@ class RandomOrder(Transform): transforms (sequence or torch.nn.Module): list of transformations """ - def __init__(self, transforms: Sequence[Callable]) -> None: + def __init__(self, transforms: Sequence[Callable], generator=None) -> None: if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") super().__init__() self.transforms = transforms + self.generator = generator def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - for idx in torch.randperm(len(self.transforms)): + for idx in torch.randperm(len(self.transforms), generator=self.generator): transform = self.transforms[idx] sample = transform(sample) return sample diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index a442b2d4be0..daaf398fcbb 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -249,6 +249,7 @@ def __init__( ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", + generator=None, ) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -268,6 +269,7 @@ def __init__( self.antialias = antialias self._log_ratio = torch.log(torch.tensor(self.ratio)) + self.generator = generator def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_size(flat_inputs) @@ -275,11 +277,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: log_ratio = self._log_ratio for _ in range(10): - target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=self.generator).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] log_ratio[1], # type: ignore[arg-type] + generator=self.generator, ) ).item() @@ -287,8 +290,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: - i = torch.randint(0, height - h + 1, size=(1,)).item() - j = torch.randint(0, width - w + 1, size=(1,)).item() + i = torch.randint(0, height - h + 1, size=(1,), generator=self.generator).item() + j = torch.randint(0, width - w + 1, size=(1,), generator=self.generator).item() break else: # Fallback to central crop @@ -532,8 +535,9 @@ def __init__( fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, side_range: Sequence[float] = (1.0, 4.0), p: float = 0.5, + generator=None, ) -> None: - super().__init__(p=p) + super().__init__(p=p, generator=generator) self.fill = fill self._fill = _setup_fill_arg(fill) @@ -547,11 +551,11 @@ def __init__( def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) - r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + r = self.side_range[0] + torch.rand(1, generator=self.generator) * (self.side_range[1] - self.side_range[0]) canvas_width = int(orig_w * r) canvas_height = int(orig_h * r) - r = torch.rand(2) + r = torch.rand(2, generator=self.generator) left = int((canvas_width - orig_w) * r[0]) top = int((canvas_height - orig_h) * r[1]) right = canvas_width - (left + orig_w) @@ -608,6 +612,7 @@ def __init__( expand: bool = False, center: Optional[List[float]] = None, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, + generator=None, ) -> None: super().__init__() self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) @@ -621,9 +626,10 @@ def __init__( _check_sequence_input(center, "center", req_sizes=(2,)) self.center = center + self.generator = generator def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=self.generator).item() return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: @@ -692,6 +698,7 @@ def __init__( interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, center: Optional[List[float]] = None, + generator=None, ) -> None: super().__init__() self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) @@ -721,30 +728,31 @@ def __init__( _check_sequence_input(center, "center", req_sizes=(2,)) self.center = center + self.generator = generator def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_size(flat_inputs) - angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=self.generator).item() if self.translate is not None: max_dx = float(self.translate[0] * width) max_dy = float(self.translate[1] * height) - tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) - ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) + tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx, generator=self.generator).item())) + ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy, generator=self.generator).item())) translate = (tx, ty) else: translate = (0, 0) if self.scale is not None: - scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + scale = torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=self.generator).item() else: scale = 1.0 shear_x = shear_y = 0.0 if self.shear is not None: - shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item() + shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1], generator=self.generator).item() if len(self.shear) == 4: - shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item() + shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3], generator=self.generator).item() shear = (shear_x, shear_y) return dict(angle=angle, translate=translate, scale=scale, shear=shear) @@ -831,6 +839,7 @@ def __init__( pad_if_needed: bool = False, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + generator=None, ) -> None: super().__init__() @@ -846,6 +855,7 @@ def __init__( self.fill = fill self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode + self.generator = None def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: padded_height, padded_width = query_size(flat_inputs) @@ -885,12 +895,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: needs_pad = any(padding) needs_vert_crop, top = ( - (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) + (True, int(torch.randint(0, padded_height - cropped_height + 1, size=(), generator=self.generator))) if padded_height > cropped_height else (False, 0) ) needs_horz_crop, left = ( - (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) + (True, int(torch.randint(0, padded_width - cropped_width + 1, size=(), generator=self.generator))) if padded_width > cropped_width else (False, 0) ) @@ -972,20 +982,20 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: bound_height = int(distortion_scale * half_height) + 1 bound_width = int(distortion_scale * half_width) + 1 topleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), + int(torch.randint(0, bound_width, size=(1,), generator=self.generator)), + int(torch.randint(0, bound_height, size=(1,), generator=self.generator)), ] topright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(0, bound_height, size=(1,))), + int(torch.randint(width - bound_width, width, size=(1,), generator=self.generator)), + int(torch.randint(0, bound_height, size=(1,), generator=self.generator)), ] botright = [ - int(torch.randint(width - bound_width, width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), + int(torch.randint(width - bound_width, width, size=(1,), generator=self.generator)), + int(torch.randint(height - bound_height, height, size=(1,), generator=self.generator)), ] botleft = [ - int(torch.randint(0, bound_width, size=(1,))), - int(torch.randint(height - bound_height, height, size=(1,))), + int(torch.randint(0, bound_width, size=(1,), generator=self.generator)), + int(torch.randint(height - bound_height, height, size=(1,), generator=self.generator)), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] @@ -1054,6 +1064,7 @@ def __init__( sigma: Union[float, Sequence[float]] = 5.0, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, + generator=None, ) -> None: super().__init__() self.alpha = _setup_float_or_seq(alpha, "alpha", 2) @@ -1062,11 +1073,12 @@ def __init__( self.interpolation = _check_interpolation(interpolation) self.fill = fill self._fill = _setup_fill_arg(fill) + self.generator = generator def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: size = list(query_size(flat_inputs)) - dx = torch.rand([1, 1] + size) * 2 - 1 + dx = torch.rand([1, 1] + size, generator=self.generator) * 2 - 1 if self.sigma[0] > 0.0: kx = int(8 * self.sigma[0] + 1) # if kernel size is even we have to make it odd @@ -1075,7 +1087,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) dx = dx * self.alpha[0] / size[0] - dy = torch.rand([1, 1] + size) * 2 - 1 + dy = torch.rand([1, 1] + size, generator=self.generator) * 2 - 1 if self.sigma[1] > 0.0: ky = int(8 * self.sigma[1] + 1) # if kernel size is even we have to make it odd @@ -1134,6 +1146,7 @@ def __init__( max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40, + generator=None, ): super().__init__() # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 @@ -1145,6 +1158,7 @@ def __init__( sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] self.options = sampler_options self.trials = trials + self.generator = generator def _check_inputs(self, flat_inputs: List[Any]) -> None: if not ( @@ -1162,14 +1176,14 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: while True: # sample an option - idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + idx = int(torch.randint(low=0, high=len(self.options), size=(1,), generator=self.generator)) min_jaccard_overlap = self.options[idx] if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option return dict() for _ in range(self.trials): # check the aspect ratio limitations - r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) + r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2, generator=self.generator) new_w = int(orig_w * r[0]) new_h = int(orig_h * r[1]) aspect_ratio = new_w / new_h @@ -1177,7 +1191,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: continue # check for 0 area crops - r = torch.rand(2) + r = torch.rand(2, generator=self.generator) left = int((orig_w - new_w) * r[0]) top = int((orig_h - new_h) * r[1]) right = left + new_w @@ -1271,17 +1285,21 @@ def __init__( scale_range: Tuple[float, float] = (0.1, 2.0), interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", + generator=None, ): super().__init__() self.target_size = target_size self.scale_range = scale_range self.interpolation = _check_interpolation(interpolation) self.antialias = antialias + self.generator = generator def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) - scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + scale = self.scale_range[0] + torch.rand(1, generator=self.generator) * ( + self.scale_range[1] - self.scale_range[0] + ) r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale new_width = int(orig_width * r) new_height = int(orig_height * r) @@ -1338,17 +1356,19 @@ def __init__( max_size: Optional[int] = None, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", + generator=None, ): super().__init__() self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) self.max_size = max_size self.interpolation = _check_interpolation(interpolation) self.antialias = antialias + self.generator = generator def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) - min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] + min_size = self.min_size[int(torch.randint(len(self.min_size), (), generator=self.generator))] r = min_size / min(orig_height, orig_width) if self.max_size is not None: r = min(r, self.max_size / max(orig_height, orig_width)) @@ -1419,15 +1439,17 @@ def __init__( max_size: int, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", + generator=None, ) -> None: super().__init__() self.min_size = min_size self.max_size = max_size self.interpolation = _check_interpolation(interpolation) self.antialias = antialias + self.generator = generator def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - size = int(torch.randint(self.min_size, self.max_size, ())) + size = int(torch.randint(self.min_size, self.max_size, (), generator=self.generator)) return dict(size=[size]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 405fbc6c43a..239ecb4022b 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -191,7 +191,10 @@ class GaussianBlur(Transform): _v1_transform_cls = _transforms.GaussianBlur def __init__( - self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0) + self, + kernel_size: Union[int, Sequence[int]], + sigma: Union[int, float, Sequence[float]] = (0.1, 2.0), + generator=None, ) -> None: super().__init__() self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") @@ -210,9 +213,10 @@ def __init__( raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.") self.sigma = _setup_float_or_seq(sigma, "sigma", 2) + self.generator = generator def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() + sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1], generator=self.generator).item() return dict(sigma=[sigma, sigma]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: