diff --git a/gallery/plot_transforms_v2_e2e.py b/gallery/plot_transforms_v2_e2e.py index 53c461959a8..981b1e58832 100644 --- a/gallery/plot_transforms_v2_e2e.py +++ b/gallery/plot_transforms_v2_e2e.py @@ -10,7 +10,6 @@ """ import pathlib -from collections import defaultdict import PIL.Image @@ -99,9 +98,7 @@ def load_example_coco_detection_dataset(**kwargs): transform = transforms.Compose( [ transforms.RandomPhotometricDistort(), - transforms.RandomZoomOut( - fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)}) - ), + transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}), transforms.RandomIoUCrop(), transforms.RandomHorizontalFlip(), transforms.ToImageTensor(), diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index abb70d8d0db..e62fd5ae301 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,5 +1,3 @@ -from collections import defaultdict - import torch @@ -48,7 +46,7 @@ def __init__( if use_v2: # We need a custom pad transform here, since the padding we want to perform here is fundamentally # different from the padding in `RandomCrop` if `pad_if_needed=True`. - transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))] + transforms += [v2_extras.PadIfSmaller(crop_size, fill={datapoints.Mask: 255, "others": 0})] transforms += [T.RandomCrop(crop_size)] diff --git a/references/segmentation/v2_extras.py b/references/segmentation/v2_extras.py index c69827c22e7..f21799e86c8 100644 --- a/references/segmentation/v2_extras.py +++ b/references/segmentation/v2_extras.py @@ -8,7 +8,7 @@ class PadIfSmaller(v2.Transform): def __init__(self, size, fill=0): super().__init__() self.size = size - self.fill = v2._geometry._setup_fill_arg(fill) + self.fill = v2._utils._setup_fill_arg(fill) def _get_params(self, sample): _, height, width = v2.utils.query_chw(sample) @@ -20,7 +20,7 @@ def _transform(self, inpt, params): if not params["needs_padding"]: return inpt - fill = self.fill[type(inpt)] + fill = v2._utils._get_fill(self.fill, type(inpt)) fill = v2._utils._convert_fill_arg(fill) return v2.functional.pad(inpt, padding=params["padding"], fill=fill) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 4c1815fddea..d5f448b09aa 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3,7 +3,6 @@ import random import textwrap import warnings -from collections import defaultdict import numpy as np @@ -1475,7 +1474,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): elif data_augmentation == "ssd": t = [ transforms.RandomPhotometricDistort(p=1), - transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}), p=1), + transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), datapoints.Mask: 0}, p=1), transforms.RandomIoUCrop(), transforms.RandomHorizontalFlip(p=1), to_tensor, diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 47a0b05b511..f5ea69279a1 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -4,7 +4,6 @@ import inspect import random import re -from collections import defaultdict from pathlib import Path import numpy as np @@ -30,6 +29,7 @@ from torchvision.transforms import functional as legacy_F from torchvision.transforms.v2 import functional as prototype_F +from torchvision.transforms.v2._utils import _get_fill from torchvision.transforms.v2.functional import to_image_pil from torchvision.transforms.v2.utils import query_size @@ -1181,7 +1181,7 @@ def _transform(self, inpt, params): if not params["needs_padding"]: return inpt - fill = self.fill[type(inpt)] + fill = _get_fill(self.fill, type(inpt)) return prototype_F.pad(inpt, padding=params["padding"], fill=fill) @@ -1243,7 +1243,7 @@ def check(self, t, t_ref, data_kwargs=None): seg_transforms.RandomCrop(size=480), v2_transforms.Compose( [ - PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})), + PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}), v2_transforms.RandomCrop(size=480), ] ), diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 28aff8416d2..a4023ca2108 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -6,7 +6,7 @@ from torchvision import datapoints from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size +from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size @@ -14,7 +14,7 @@ class FixedSizeCrop(Transform): def __init__( self, size: Union[int, Sequence[int]], - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, padding_mode: str = "constant", ) -> None: super().__init__() @@ -119,7 +119,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) if params["needs_pad"]: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) return inpt diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 3a4e6e956f3..51a2ea9074a 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,15 +1,29 @@ +import functools import warnings -from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union +from collections import defaultdict +from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union import torch from torchvision import datapoints from torchvision.transforms.v2 import Transform -from torchvision.transforms.v2._utils import _get_defaultdict from torchvision.transforms.v2.utils import is_simple_tensor +T = TypeVar("T") + + +def _default_arg(value: T) -> T: + return value + + +def _get_defaultdict(default: T) -> Dict[Any, T]: + # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. + # If it were possible, we could replace this with `defaultdict(lambda: default)` + return defaultdict(functools.partial(_default_arg, default)) + + class PermuteDimensions(Transform): _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 2921903da8f..146c8c236ef 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._meta import get_size -from ._utils import _setup_fill_arg +from ._utils import _get_fill, _setup_fill_arg from .utils import check_type, is_simple_tensor @@ -20,7 +20,7 @@ def __init__( self, *, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__() self.interpolation = _check_interpolation(interpolation) @@ -80,9 +80,9 @@ def _apply_image_or_video_transform( transform_id: str, magnitude: float, interpolation: Union[InterpolationMode, int], - fill: Dict[Type, datapoints._FillTypeJIT], + fill: Dict[Union[Type, str], datapoints._FillTypeJIT], ) -> Union[datapoints._ImageType, datapoints._VideoType]: - fill_ = fill[type(image)] + fill_ = _get_fill(fill, type(image)) if transform_id == "Identity": return image @@ -214,7 +214,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy @@ -394,7 +394,7 @@ def __init__( magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops @@ -467,7 +467,7 @@ def __init__( self, num_magnitude_bins: int = 31, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins @@ -550,7 +550,7 @@ def __init__( alpha: float = 1.0, all_ops: bool = True, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 9e7ca64d41c..c7a1e39286f 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -17,6 +17,7 @@ _check_padding_arg, _check_padding_mode_arg, _check_sequence_input, + _get_fill, _setup_angle, _setup_fill_arg, _setup_float_or_seq, @@ -487,7 +488,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: def __init__( self, padding: Union[int, Sequence[int]], - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -504,7 +505,7 @@ def __init__( self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] @@ -542,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform): def __init__( self, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, side_range: Sequence[float] = (1.0, 4.0), p: float = 0.5, ) -> None: @@ -574,7 +575,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(padding=padding) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.pad(inpt, **params, fill=fill) @@ -620,7 +621,7 @@ def __init__( interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, center: Optional[List[float]] = None, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, ) -> None: super().__init__() self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) @@ -640,7 +641,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.rotate( inpt, **params, @@ -702,7 +703,7 @@ def __init__( scale: Optional[Sequence[float]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -762,7 +763,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(angle=angle, translate=translate, scale=scale, shear=shear) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.affine( inpt, **params, @@ -840,7 +841,7 @@ def __init__( size: Union[int, Sequence[int]], padding: Optional[Union[int, Sequence[int]]] = None, pad_if_needed: bool = False, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -918,7 +919,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]: @@ -959,7 +960,7 @@ def __init__( distortion_scale: float = 0.5, p: float = 0.5, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, ) -> None: super().__init__(p=p) @@ -1002,7 +1003,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(coefficients=perspective_coeffs) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.perspective( inpt, None, @@ -1061,7 +1062,7 @@ def __init__( alpha: Union[float, Sequence[float]] = 50.0, sigma: Union[float, Sequence[float]] = 5.0, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, + fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0, ) -> None: super().__init__() self.alpha = _setup_float_or_seq(alpha, "alpha", 2) @@ -1095,7 +1096,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self._fill[type(inpt)] + fill = _get_fill(self._fill, type(inpt)) return F.elastic( inpt, **params, diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 2915910ea05..a7826a6645f 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -1,9 +1,7 @@ import collections.abc -import functools import numbers -from collections import defaultdict from contextlib import suppress -from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, TypeVar, Union +from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union import torch @@ -29,32 +27,15 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: return arg -def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None: +def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> None: if isinstance(fill, dict): - for key, value in fill.items(): - # Check key for type + for value in fill.values(): _check_fill_arg(value) - if isinstance(fill, defaultdict) and callable(fill.default_factory): - default_value = fill.default_factory() - _check_fill_arg(default_value) else: if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") -T = TypeVar("T") - - -def _default_arg(value: T) -> T: - return value - - -def _get_defaultdict(default: T) -> Dict[Any, T]: - # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. - # If it were possible, we could replace this with `defaultdict(lambda: default)` - return defaultdict(functools.partial(_default_arg, default)) - - def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT: # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we can't reassign fill to 0 @@ -68,19 +49,24 @@ def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT: return fill # type: ignore[return-value] -def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]: +def _setup_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> Dict[Union[Type, str], _FillTypeJIT]: _check_fill_arg(fill) if isinstance(fill, dict): for k, v in fill.items(): fill[k] = _convert_fill_arg(v) - if isinstance(fill, defaultdict) and callable(fill.default_factory): - default_value = fill.default_factory() - sanitized_default = _convert_fill_arg(default_value) - fill.default_factory = functools.partial(_default_arg, sanitized_default) return fill # type: ignore[return-value] + else: + return {"others": _convert_fill_arg(fill)} - return _get_defaultdict(_convert_fill_arg(fill)) + +def _get_fill(fill_dict, inpt_type): + if inpt_type in fill_dict: + return fill_dict[inpt_type] + elif "others" in fill_dict: + return fill_dict["others"] + else: + RuntimeError("This should never happen, please open an issue on the torchvision repo if you hit this.") def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: