Skip to content

Commit

Permalink
Merge branch 'main' into ljanflajnfljanfe
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Aug 1, 2023
2 parents 1e7f83a + edde825 commit b8d0030
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 69 deletions.
5 changes: 1 addition & 4 deletions gallery/plot_transforms_v2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

import pathlib
from collections import defaultdict

import PIL.Image

Expand Down Expand Up @@ -99,9 +98,7 @@ def load_example_coco_detection_dataset(**kwargs):
transform = transforms.Compose(
[
transforms.RandomPhotometricDistort(),
transforms.RandomZoomOut(
fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)})
),
transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}),
transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(),
transforms.ToImageTensor(),
Expand Down
4 changes: 1 addition & 3 deletions references/segmentation/presets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from collections import defaultdict

import torch


Expand Down Expand Up @@ -48,7 +46,7 @@ def __init__(
if use_v2:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))]
transforms += [v2_extras.PadIfSmaller(crop_size, fill={datapoints.Mask: 255, "others": 0})]

transforms += [T.RandomCrop(crop_size)]

Expand Down
4 changes: 2 additions & 2 deletions references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class PadIfSmaller(v2.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = v2._geometry._setup_fill_arg(fill)
self.fill = v2._utils._setup_fill_arg(fill)

def _get_params(self, sample):
_, height, width = v2.utils.query_chw(sample)
Expand All @@ -20,7 +20,7 @@ def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

fill = self.fill[type(inpt)]
fill = v2._utils._get_fill(self.fill, type(inpt))
fill = v2._utils._convert_fill_arg(fill)

return v2.functional.pad(inpt, padding=params["padding"], fill=fill)
Expand Down
3 changes: 1 addition & 2 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import random
import textwrap
import warnings
from collections import defaultdict

import numpy as np

Expand Down Expand Up @@ -1475,7 +1474,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
elif data_augmentation == "ssd":
t = [
transforms.RandomPhotometricDistort(p=1),
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}), p=1),
transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), datapoints.Mask: 0}, p=1),
transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1),
to_tensor,
Expand Down
6 changes: 3 additions & 3 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import inspect
import random
import re
from collections import defaultdict
from pathlib import Path

import numpy as np
Expand All @@ -30,6 +29,7 @@

from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2._utils import _get_fill
from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import query_size

Expand Down Expand Up @@ -1181,7 +1181,7 @@ def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

fill = self.fill[type(inpt)]
fill = _get_fill(self.fill, type(inpt))
return prototype_F.pad(inpt, padding=params["padding"], fill=fill)


Expand Down Expand Up @@ -1243,7 +1243,7 @@ def check(self, t, t_ref, data_kwargs=None):
seg_transforms.RandomCrop(size=480),
v2_transforms.Compose(
[
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})),
PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}),
v2_transforms.RandomCrop(size=480),
]
),
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size
from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size


class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
Expand Down Expand Up @@ -119,7 +119,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
)

if params["needs_pad"]:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)

return inpt
18 changes: 16 additions & 2 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
import functools
import warnings
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
from collections import defaultdict
from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union

import torch

from torchvision import datapoints
from torchvision.transforms.v2 import Transform

from torchvision.transforms.v2._utils import _get_defaultdict
from torchvision.transforms.v2.utils import is_simple_tensor


T = TypeVar("T")


def _default_arg(value: T) -> T:
return value


def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))


class PermuteDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)

Expand Down
16 changes: 8 additions & 8 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_size

from ._utils import _setup_fill_arg
from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_simple_tensor


Expand All @@ -20,7 +20,7 @@ def __init__(
self,
*,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__()
self.interpolation = _check_interpolation(interpolation)
Expand Down Expand Up @@ -80,9 +80,9 @@ def _apply_image_or_video_transform(
transform_id: str,
magnitude: float,
interpolation: Union[InterpolationMode, int],
fill: Dict[Type, datapoints._FillTypeJIT],
fill: Dict[Union[Type, str], datapoints._FillTypeJIT],
) -> Union[datapoints._ImageType, datapoints._VideoType]:
fill_ = fill[type(image)]
fill_ = _get_fill(fill, type(image))

if transform_id == "Identity":
return image
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
Expand Down Expand Up @@ -394,7 +394,7 @@ def __init__(
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
Expand Down Expand Up @@ -550,7 +550,7 @@ def __init__(
alpha: float = 1.0,
all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
Expand Down
29 changes: 15 additions & 14 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_check_padding_arg,
_check_padding_mode_arg,
_check_sequence_input,
_get_fill,
_setup_angle,
_setup_fill_arg,
_setup_float_or_seq,
Expand Down Expand Up @@ -487,7 +488,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand All @@ -504,7 +505,7 @@ def __init__(
self.padding_mode = padding_mode

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]


Expand Down Expand Up @@ -542,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform):

def __init__(
self,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
Expand Down Expand Up @@ -574,7 +575,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(padding=padding)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, **params, fill=fill)


Expand Down Expand Up @@ -620,7 +621,7 @@ def __init__(
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
Expand All @@ -640,7 +641,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(angle=angle)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.rotate(
inpt,
**params,
Expand Down Expand Up @@ -702,7 +703,7 @@ def __init__(
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -762,7 +763,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(angle=angle, translate=translate, scale=scale, shear=shear)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.affine(
inpt,
**params,
Expand Down Expand Up @@ -840,7 +841,7 @@ def __init__(
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand Down Expand Up @@ -918,7 +919,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)

if params["needs_crop"]:
Expand Down Expand Up @@ -959,7 +960,7 @@ def __init__(
distortion_scale: float = 0.5,
p: float = 0.5,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__(p=p)

Expand Down Expand Up @@ -1002,7 +1003,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(coefficients=perspective_coeffs)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.perspective(
inpt,
None,
Expand Down Expand Up @@ -1061,7 +1062,7 @@ def __init__(
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
Expand Down Expand Up @@ -1095,7 +1096,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(displacement=displacement)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.elastic(
inpt,
**params,
Expand Down
Loading

0 comments on commit b8d0030

Please sign in to comment.