Skip to content

Commit

Permalink
allow dispatch to PIL image subclasses (#7835)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Aug 16, 2023
1 parent c1592f9 commit e0e6f7e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
33 changes: 23 additions & 10 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import math
import re
from pathlib import Path
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -2126,16 +2127,10 @@ class TestGetKernel:
datapoints.Video: F.resize_video,
}

def test_unsupported_types(self):
class MyTensor(torch.Tensor):
pass

class MyPILImage(PIL.Image.Image):
pass

for input_type in [str, int, object, MyTensor, MyPILImage]:
with pytest.raises(TypeError, match="supports inputs of type"):
_get_kernel(F.resize, input_type)
@pytest.mark.parametrize("input_type", [str, int, object])
def test_unsupported_types(self, input_type):
with pytest.raises(TypeError, match="supports inputs of type"):
_get_kernel(F.resize, input_type)

def test_exact_match(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
Expand Down Expand Up @@ -2197,6 +2192,24 @@ def resize_my_datapoint():

assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint

def test_pil_image_subclass(self):
opened_image = PIL.Image.open(Path(__file__).parent / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")
loaded_image = opened_image.convert("RGB")

# check the assumptions
assert isinstance(opened_image, PIL.Image.Image)
assert type(opened_image) is not PIL.Image.Image

assert type(loaded_image) is PIL.Image.Image

size = [17, 11]
for image in [opened_image, loaded_image]:
kernel = _get_kernel(F.resize, type(image))

output = kernel(image, size=size)

assert F.get_size(output) == size


class TestPermuteChannels:
_DEFAULT_PERMUTATION = [2, 0, 1]
Expand Down
23 changes: 8 additions & 15 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,14 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False):
if not registry:
raise ValueError(f"No kernel registered for functional {functional.__name__}.")

# In case we have an exact type match, we take a shortcut.
if input_type in registry:
return registry[input_type]

# In case of datapoints, we check if we have a kernel for a superclass registered
if issubclass(input_type, datapoints.Datapoint):
# Since we have already checked for an exact match above, we can start the traversal at the superclass.
for cls in input_type.__mro__[1:]:
if cls is datapoints.Datapoint:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway.
break
elif cls in registry:
return registry[cls]
for cls in input_type.__mro__:
if cls in registry:
return registry[cls]
elif cls is datapoints.Datapoint:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway.
break

if allow_passthrough:
return lambda inpt, *args, **kwargs: inpt
Expand Down

0 comments on commit e0e6f7e

Please sign in to comment.