Skip to content

Commit

Permalink
Add generator parameter to random transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Aug 17, 2023
1 parent a7b52a6 commit 7e23523
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
10 changes: 10 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def freeze_rng_state():
torch.set_rng_state(rng_state)


@contextlib.contextmanager
def assert_default_rng_is_unchanged():
default_state = torch.random.get_rng_state()
yield
try:
torch.testing.assert_close(default_state, torch.random.get_rng_state(), rtol=0, atol=0)
except AssertionError as e:
raise AssertionError("The default RNG got consumed.") from e


def cycle_over(objs):
for idx, obj1 in enumerate(objs):
for obj2 in objs[:idx] + objs[idx + 1 :]:
Expand Down
36 changes: 36 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
import torchvision.transforms.v2 as transforms
from common_utils import (
assert_default_rng_is_unchanged,
assert_equal,
assert_no_warnings,
cache,
Expand Down Expand Up @@ -255,6 +256,10 @@ def check_transform(transform_cls, input, *args, **kwargs):

_check_transform_v1_compatibility(transform, input)

if "generator" in inspect.signature(transform_cls).parameters:
with assert_default_rng_is_unchanged():
transform_cls(*args, generator=torch.Generator(), **kwargs)(input)


def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
def wrapper(input, *args, **kwargs):
Expand Down Expand Up @@ -2374,3 +2379,34 @@ def test_correctness(self):
assert isinstance(out_value, torch.Tensor) and not isinstance(out_value, datapoints.Datapoint)
else:
assert isinstance(out_value, type(input_value))


def test_transforms_rng_with_dataloader():
# This is more of a sanity test for torch core's handling of Generators within the Dataloader
# But worth having it here as well for security.
class MyTransform(torch.nn.Module):
def __init__(self, generator):
super().__init__()
self.generator = generator

def forward(self):
return torch.randint(0, 100_000, size=(1,), generator=self.generator).item()

class Dataset:
def __init__(self, transform):
self.transform = transform

def __getitem__(self, _):
return self.transform()

def __len__(self):
return 10

rng = torch.Generator().manual_seed(0)
t = MyTransform(rng)
ds = Dataset(t, multiprocessing_context="fork")

dl = DataLoader(ds, num_workers=2)
all_samples = [x.item() for x in dl]
# If the RNG were the same across workers, we would get duplicated samples here. We assert they're all unique.
assert len(set(all_samples)) == len(all_samples)
7 changes: 4 additions & 3 deletions torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {
attr: value
for attr, value in self.__dict__.items()
if not attr.startswith("_") and attr not in common_attrs
if not attr.startswith("_") and attr not in common_attrs and attr != "generator"
}

def __prepare_scriptable__(self) -> nn.Module:
Expand All @@ -143,11 +143,12 @@ def __prepare_scriptable__(self) -> nn.Module:


class _RandomApplyTransform(Transform):
def __init__(self, p: float = 0.5) -> None:
def __init__(self, p: float = 0.5, generator=None) -> None:
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")

super().__init__()
self.generator = generator
self.p = p

def forward(self, *inputs: Any) -> Any:
Expand All @@ -160,7 +161,7 @@ def forward(self, *inputs: Any) -> Any:

self._check_inputs(flat_inputs)

if torch.rand(1) >= self.p:
if torch.rand(1, generator=self.generator) >= self.p:
return inputs

needs_transform_list = self._needs_transform_list(flat_inputs)
Expand Down

0 comments on commit 7e23523

Please sign in to comment.