Skip to content

Commit

Permalink
make datapoint methods private
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jul 11, 2023
1 parent 08c9938 commit 58b8eba
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 130 deletions.
10 changes: 5 additions & 5 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,12 +1417,12 @@ def test_antialias_warning():
transforms.RandomResize(10, 20)(tensor_img)

with pytest.warns(UserWarning, match=match):
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20))
datapoints.Image(tensor_img)._resized_crop(0, 0, 10, 10, (20, 20))

with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resize((20, 20))
datapoints.Video(tensor_video)._resize((20, 20))
with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20))
datapoints.Video(tensor_video)._resized_crop(0, 0, 10, 10, (20, 20))

with warnings.catch_warnings():
warnings.simplefilter("error")
Expand All @@ -1436,8 +1436,8 @@ def test_antialias_warning():
transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
transforms.RandomResize(10, 20, antialias=True)(tensor_img)

datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
datapoints.Image(tensor_img)._resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
datapoints.Video(tensor_video)._resized_crop(0, 0, 10, 10, (20, 20), antialias=True)


@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def test_pil_output_type(self, info, args_kwargs):
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
(datapoint, *other_args), kwargs = args_kwargs.load()

method_name = info.id
method_name = f"_{info.id}"
method = getattr(datapoint, method_name)
datapoint_type = type(datapoint)
spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}")
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _check_dispatcher_datapoint_signature_match(dispatcher):
dispatcher_signature = inspect.signature(dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

datapoint_method = getattr(datapoints._datapoint.Datapoint, dispatcher.__name__)
datapoint_method = getattr(datapoints._datapoint.Datapoint, f"_{dispatcher.__name__}")
datapoint_signature = inspect.signature(datapoint_method)
datapoint_params = list(datapoint_signature.parameters.values())[1:]

Expand Down
22 changes: 11 additions & 11 deletions torchvision/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,19 @@ def wrap_like(
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, spatial_size=self.spatial_size)

def horizontal_flip(self) -> BoundingBox:
def _horizontal_flip(self) -> BoundingBox:
output = self._F.horizontal_flip_bounding_box(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
)
return BoundingBox.wrap_like(self, output)

def vertical_flip(self) -> BoundingBox:
def _vertical_flip(self) -> BoundingBox:
output = self._F.vertical_flip_bounding_box(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
)
return BoundingBox.wrap_like(self, output)

def resize( # type: ignore[override]
def _resize(
self,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Expand All @@ -125,19 +125,19 @@ def resize( # type: ignore[override]
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
def _crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
output, spatial_size = self._F.crop_bounding_box(
self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

def center_crop(self, output_size: List[int]) -> BoundingBox:
def _center_crop(self, output_size: List[int]) -> BoundingBox:
output, spatial_size = self._F.center_crop_bounding_box(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

def resized_crop(
def _resized_crop(
self,
top: int,
left: int,
Expand All @@ -152,7 +152,7 @@ def resized_crop(
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

def pad(
def _pad(
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, List[float]]] = None,
Expand All @@ -167,7 +167,7 @@ def pad(
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

def rotate(
def _rotate(
self,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Expand All @@ -185,7 +185,7 @@ def rotate(
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)

def affine(
def _affine(
self,
angle: Union[int, float],
translate: List[float],
Expand All @@ -207,7 +207,7 @@ def affine(
)
return BoundingBox.wrap_like(self, output)

def perspective(
def _perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
Expand All @@ -225,7 +225,7 @@ def perspective(
)
return BoundingBox.wrap_like(self, output)

def elastic(
def _elastic(
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Expand Down
53 changes: 27 additions & 26 deletions torchvision/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,13 @@ def __deepcopy__(self: D, memo: Dict[int, Any]) -> D:
# `BoundingBox.clone()`.
return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value]

def horizontal_flip(self) -> Datapoint:
def _horizontal_flip(self) -> Datapoint:
return self

def vertical_flip(self) -> Datapoint:
def _vertical_flip(self) -> Datapoint:
return self

# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
# https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593
def resize( # type: ignore[override]
def _resize(
self,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
Expand All @@ -159,13 +157,13 @@ def resize( # type: ignore[override]
) -> Datapoint:
return self

def crop(self, top: int, left: int, height: int, width: int) -> Datapoint:
def _crop(self, top: int, left: int, height: int, width: int) -> Datapoint:
return self

def center_crop(self, output_size: List[int]) -> Datapoint:
def _center_crop(self, output_size: List[int]) -> Datapoint:
return self

def resized_crop(
def _resized_crop(
self,
top: int,
left: int,
Expand All @@ -177,15 +175,15 @@ def resized_crop(
) -> Datapoint:
return self

def pad(
def _pad(
self,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Datapoint:
return self

def rotate(
def _rotate(
self,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
Expand All @@ -195,7 +193,7 @@ def rotate(
) -> Datapoint:
return self

def affine(
def _affine(
self,
angle: Union[int, float],
translate: List[float],
Expand All @@ -207,7 +205,7 @@ def affine(
) -> Datapoint:
return self

def perspective(
def _perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
Expand All @@ -217,51 +215,54 @@ def perspective(
) -> Datapoint:
return self

def elastic(
def _elastic(
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
) -> Datapoint:
return self

def rgb_to_grayscale(self, num_output_channels: int = 1) -> Datapoint:
def _rgb_to_grayscale(self, num_output_channels: int = 1) -> Datapoint:
return self

def adjust_brightness(self, brightness_factor: float) -> Datapoint:
def _adjust_brightness(self, brightness_factor: float) -> Datapoint:
return self

def adjust_saturation(self, saturation_factor: float) -> Datapoint:
def _adjust_saturation(self, saturation_factor: float) -> Datapoint:
return self

def adjust_contrast(self, contrast_factor: float) -> Datapoint:
def _adjust_contrast(self, contrast_factor: float) -> Datapoint:
return self

def adjust_sharpness(self, sharpness_factor: float) -> Datapoint:
def _adjust_sharpness(self, sharpness_factor: float) -> Datapoint:
return self

def adjust_hue(self, hue_factor: float) -> Datapoint:
def _adjust_hue(self, hue_factor: float) -> Datapoint:
return self

def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint:
def _adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint:
return self

def posterize(self, bits: int) -> Datapoint:
def _posterize(self, bits: int) -> Datapoint:
return self

def solarize(self, threshold: float) -> Datapoint:
def _solarize(self, threshold: float) -> Datapoint:
return self

def autocontrast(self) -> Datapoint:
def _autocontrast(self) -> Datapoint:
return self

def equalize(self) -> Datapoint:
def _equalize(self) -> Datapoint:
return self

def invert(self) -> Datapoint:
def _invert(self) -> Datapoint:
return self

def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint:
def _gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint:
return self

def _normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Datapoint:
return self


Expand Down
Loading

0 comments on commit 58b8eba

Please sign in to comment.