Skip to content

Commit

Permalink
Fix, better tests, expose more stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Oct 3, 2024
1 parent e100702 commit 1f24d8f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 45 deletions.
84 changes: 41 additions & 43 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6171,50 +6171,48 @@ def test_transform_invalid_quality_error(self, quality):
transforms.JPEG(quality=quality)


class TestQuerySize:
class TestUtils:
# TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes
@pytest.mark.parametrize(
"make_input, input_name",
[
(lambda: torch.rand(3, 32, 64), "pure_tensor"),
(lambda: tv_tensors.Image(torch.rand(3, 32, 64)), "tv_tensor_image"),
(lambda: PIL.Image.new("RGB", (64, 32)), "pil_image"),
(lambda: tv_tensors.Video(torch.rand(1, 3, 32, 64)), "tv_tensor_video"),
(lambda: tv_tensors.Mask(torch.randint(0, 2, (32, 64))), "tv_tensor_mask"),
],
ids=["pure_tensor", "tv_tensor_image", "pil_image", "tv_tensor_video", "tv_tensor_mask"],
"make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
)
@pytest.mark.parametrize(
"make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
)
def test_functional(self, make_input, input_name):
input1 = make_input()
input2 = make_input()
# Both inputs should have the same size (32, 64)
assert transforms.query_size([input1, input2]) == (32, 64)
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
def test_query_size_and_query_chw(self, make_input1, make_input2, query):
size = (32, 64)
input1 = make_input1(size)
input2 = make_input2(size)

if query is transforms.query_chw and not any(
transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
for inpt in (input1, input2)
):
return

expected = size if query is transforms.query_size else ((3,) + size)
assert query([input1, input2]) == expected

@pytest.mark.parametrize(
"make_input, input_name",
[
(lambda: torch.rand(3, 32, 64), "pure_tensor"),
(lambda: tv_tensors.Image(torch.rand(3, 32, 64)), "tv_tensor_image"),
(lambda: PIL.Image.new("RGB", (64, 32)), "pil_image"),
(lambda: tv_tensors.Video(torch.rand(1, 3, 32, 64)), "tv_tensor_video"),
(lambda: tv_tensors.Mask(torch.randint(0, 2, (32, 64))), "tv_tensor_mask"),
],
ids=["pure_tensor", "tv_tensor_image", "pil_image", "tv_tensor_video", "tv_tensor_mask"],
)
def test_functional_mixed_types(self, make_input, input_name):
input1 = make_input()
input2 = make_input()
# Both inputs should have the same size (32, 64)
assert transforms.query_size([input1, input2]) == (32, 64)

def test_different_sizes(self):
img_tensor = torch.rand(3, 32, 64) # (C, H, W)
img_tensor_different_size = torch.rand(3, 48, 96) # (C, H, W)
# Should raise ValueError for different sizes
with pytest.raises(ValueError, match="Found multiple HxW dimensions"):
transforms.query_size([img_tensor, img_tensor_different_size])

def test_no_valid_image(self):
invalid_input = torch.rand(1, 10) # Non-image tensor
# Should raise TypeError for invalid input
with pytest.raises(TypeError, match="No image, video, mask or bounding box was found"):
transforms.query_size([invalid_input])
"make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
)
@pytest.mark.parametrize(
"make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask]
)
@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
def test_different_sizes(self, make_input1, make_input2, query):
input1 = make_input1((10, 10))
input2 = make_input2((20, 20))
if query is transforms.query_chw and not all(
transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
for inpt in (input1, input2)
):
return
with pytest.raises(ValueError, match="Found multiple"):
query([input1, input2])

@pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw])
def test_no_valid_input(self, query):
with pytest.raises(TypeError, match="No image"):
query(["blah"])
1 change: 1 addition & 0 deletions torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size

from ._deprecated import ToTensor # usort: skip
2 changes: 0 additions & 2 deletions torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,3 @@
from ._type_conversion import pil_to_tensor, to_image, to_pil_image

from ._deprecated import get_image_size, to_tensor # usort: skip

from ._utils import query_size

0 comments on commit 1f24d8f

Please sign in to comment.