Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow catch-all 'others' key in fill dicts. Avoid need for defaultdict. #7779

Merged
merged 6 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions gallery/plot_transforms_v2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

import pathlib
from collections import defaultdict

import PIL.Image

Expand Down Expand Up @@ -99,9 +98,7 @@ def load_example_coco_detection_dataset(**kwargs):
transform = transforms.Compose(
[
transforms.RandomPhotometricDistort(),
transforms.RandomZoomOut(
fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)})
),
transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}),
transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(),
transforms.ToImageTensor(),
Expand Down
4 changes: 1 addition & 3 deletions references/segmentation/presets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from collections import defaultdict

import torch


Expand Down Expand Up @@ -48,7 +46,7 @@ def __init__(
if use_v2:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))]
transforms += [v2_extras.PadIfSmaller(crop_size, fill={datapoints.Mask: 255, "others": 0})]

transforms += [T.RandomCrop(crop_size)]

Expand Down
4 changes: 2 additions & 2 deletions references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class PadIfSmaller(v2.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = v2._geometry._setup_fill_arg(fill)
self.fill = v2._utils._setup_fill_arg(fill)

def _get_params(self, sample):
_, height, width = v2.utils.query_chw(sample)
Expand All @@ -20,7 +20,7 @@ def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

fill = self.fill[type(inpt)]
fill = v2._utils._get_fill(self.fill, type(inpt))
fill = v2._utils._convert_fill_arg(fill)

return v2.functional.pad(inpt, padding=params["padding"], fill=fill)
Expand Down
3 changes: 1 addition & 2 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import random
import textwrap
import warnings
from collections import defaultdict

import numpy as np

Expand Down Expand Up @@ -1475,7 +1474,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
elif data_augmentation == "ssd":
t = [
transforms.RandomPhotometricDistort(p=1),
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}), p=1),
transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), datapoints.Mask: 0}, p=1),
transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1),
to_tensor,
Expand Down
6 changes: 3 additions & 3 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import inspect
import random
import re
from collections import defaultdict
from pathlib import Path

import numpy as np
Expand All @@ -30,6 +29,7 @@

from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2._utils import _get_fill
from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import query_size

Expand Down Expand Up @@ -1181,7 +1181,7 @@ def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

fill = self.fill[type(inpt)]
fill = _get_fill(self.fill, type(inpt))
return prototype_F.pad(inpt, padding=params["padding"], fill=fill)


Expand Down Expand Up @@ -1243,7 +1243,7 @@ def check(self, t, t_ref, data_kwargs=None):
seg_transforms.RandomCrop(size=480),
v2_transforms.Compose(
[
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})),
PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}),
v2_transforms.RandomCrop(size=480),
]
),
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size
from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size


class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
Expand Down Expand Up @@ -119,7 +119,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
)

if params["needs_pad"]:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)

return inpt
18 changes: 16 additions & 2 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
import functools
import warnings
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
from collections import defaultdict
from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union

import torch

from torchvision import datapoints
from torchvision.transforms.v2 import Transform

from torchvision.transforms.v2._utils import _get_defaultdict
from torchvision.transforms.v2.utils import is_simple_tensor


T = TypeVar("T")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are prototype transforms that rely on _get_defaultdict() for another parameter. I couldn't be bothered to update them so I just ported the code here.



def _default_arg(value: T) -> T:
return value


def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))


class PermuteDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)

Expand Down
16 changes: 8 additions & 8 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_size

from ._utils import _setup_fill_arg
from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_simple_tensor


Expand All @@ -20,7 +20,7 @@ def __init__(
self,
*,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__()
self.interpolation = _check_interpolation(interpolation)
Expand Down Expand Up @@ -80,9 +80,9 @@ def _apply_image_or_video_transform(
transform_id: str,
magnitude: float,
interpolation: Union[InterpolationMode, int],
fill: Dict[Type, datapoints._FillTypeJIT],
fill: Dict[Union[Type, str], datapoints._FillTypeJIT],
) -> Union[datapoints._ImageType, datapoints._VideoType]:
fill_ = fill[type(image)]
fill_ = _get_fill(fill, type(image))

if transform_id == "Identity":
return image
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
Expand Down Expand Up @@ -394,7 +394,7 @@ def __init__(
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
Expand Down Expand Up @@ -550,7 +550,7 @@ def __init__(
alpha: float = 1.0,
all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
Expand Down
29 changes: 15 additions & 14 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_check_padding_arg,
_check_padding_mode_arg,
_check_sequence_input,
_get_fill,
_setup_angle,
_setup_fill_arg,
_setup_float_or_seq,
Expand Down Expand Up @@ -487,7 +488,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand All @@ -504,7 +505,7 @@ def __init__(
self.padding_mode = padding_mode

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]


Expand Down Expand Up @@ -542,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform):

def __init__(
self,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
Expand Down Expand Up @@ -574,7 +575,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(padding=padding)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, **params, fill=fill)


Expand Down Expand Up @@ -620,7 +621,7 @@ def __init__(
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
Expand All @@ -640,7 +641,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(angle=angle)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.rotate(
inpt,
**params,
Expand Down Expand Up @@ -702,7 +703,7 @@ def __init__(
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -762,7 +763,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(angle=angle, translate=translate, scale=scale, shear=shear)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.affine(
inpt,
**params,
Expand Down Expand Up @@ -840,7 +841,7 @@ def __init__(
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand Down Expand Up @@ -918,7 +919,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)

if params["needs_crop"]:
Expand Down Expand Up @@ -959,7 +960,7 @@ def __init__(
distortion_scale: float = 0.5,
p: float = 0.5,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__(p=p)

Expand Down Expand Up @@ -1002,7 +1003,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(coefficients=perspective_coeffs)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.perspective(
inpt,
None,
Expand Down Expand Up @@ -1061,7 +1062,7 @@ def __init__(
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
Expand Down Expand Up @@ -1095,7 +1096,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(displacement=displacement)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.elastic(
inpt,
**params,
Expand Down
Loading
Loading