Skip to content

Commit

Permalink
[fbsync] Add scale option to ToDtype. Remove ConvertDtype. (#7759)
Browse files Browse the repository at this point in the history
Reviewed By: matteobettini

Differential Revision: D48642282

fbshipit-source-id: 95a2eea16407f17e1ebeb386cd5e2618a105450f

Co-authored-by: vfdev <vfdev.5@gmail.com>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
  • Loading branch information
3 people authored and facebook-github-bot committed Aug 25, 2023
1 parent 7beb6a5 commit b54c9b4
Show file tree
Hide file tree
Showing 15 changed files with 296 additions and 271 deletions.
1 change: 0 additions & 1 deletion docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ Conversion
v2.PILToTensor
v2.ToImageTensor
ConvertImageDtype
v2.ConvertDtype
v2.ConvertImageDtype
v2.ToDtype
v2.ConvertBoundingBoxFormat
Expand Down
2 changes: 1 addition & 1 deletion gallery/plot_transforms_v2_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def show(sample):
image, target = sample
if isinstance(image, PIL.Image.Image):
image = F.to_image_tensor(image)
image = F.convert_dtype(image, torch.uint8)
image = F.to_dtype(image, torch.uint8, scale=True)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)

fig, ax = plt.subplots()
Expand Down
4 changes: 2 additions & 2 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import datapoints, io
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_pil, to_image_tensor
from torchvision.transforms.v2.functional import to_dtype_image_tensor, to_image_pil, to_image_tensor


IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
Expand Down Expand Up @@ -601,7 +601,7 @@ def fn(shape, dtype, device, memory_format):
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
else:
image_tensor = image_tensor.to(device=device)
image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype)
image_tensor = to_dtype_image_tensor(image_tensor, dtype=dtype, scale=True)

return datapoints.Image(image_tensor)

Expand Down
60 changes: 2 additions & 58 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import itertools
import pathlib
import random
import re
import textwrap
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -105,7 +104,7 @@ def normalize_adapter(transform, input, device):
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)):
# normalize doesn't support integer images
value = F.convert_dtype(value, torch.float32)
value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value
return adapted_input

Expand Down Expand Up @@ -146,7 +145,7 @@ class TestSmoke:
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
(transforms.ClampBoundingBox(), None),
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertDtype(), None),
(transforms.ConvertImageDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None),
(
transforms.LinearTransformation(
Expand Down Expand Up @@ -1326,61 +1325,6 @@ def test__transform(self, mocker):
)


class TestToDtype:
@pytest.mark.parametrize(
("dtype", "expected_dtypes"),
[
(
torch.float64,
{
datapoints.Video: torch.float64,
datapoints.Image: torch.float64,
datapoints.BoundingBox: torch.float64,
},
),
(
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
),
],
)
def test_call(self, dtype, expected_dtypes):
sample = dict(
video=make_video(dtype=torch.int64),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
)

transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)

for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]

# make sure the transformation retains the type
assert isinstance(transformed_value, value_type)

if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value

@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((), dtype=torch.float32)
transform = transforms.ToDtype({torch.Tensor: torch.float64})

assert transform(tensor).dtype is torch.float64

@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64})


class TestUniformTemporalSubsample:
@pytest.mark.parametrize(
"inpt",
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
v2_transforms.ConvertDtype,
v2_transforms.ConvertImageDtype,
legacy_transforms.ConvertImageDtype,
[
ArgsKwargs(torch.float16),
Expand Down
23 changes: 2 additions & 21 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,12 @@ def test_float32_vs_uint8(self, test_id, info, args_kwargs):
adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)

actual = info.kernel(
F.convert_dtype_image_tensor(input, dtype=torch.float32),
F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True),
*adapted_other_args,
**adapted_kwargs,
)

expected = F.convert_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32)
expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)

assert_close(
actual,
Expand Down Expand Up @@ -538,7 +538,6 @@ def test_bounding_box_format_consistency(self, info, args_kwargs):
(F.get_image_num_channels, F.get_num_channels),
(F.to_pil_image, F.to_image_pil),
(F.elastic_transform, F.elastic),
(F.convert_image_dtype, F.convert_dtype_image_tensor),
(F.to_grayscale, F.rgb_to_grayscale),
]
],
Expand All @@ -547,24 +546,6 @@ def test_alias(alias, target):
assert alias is target


@pytest.mark.parametrize(
("info", "args_kwargs"),
make_info_args_kwargs_params(
KERNEL_INFOS_MAP[F.convert_dtype_image_tensor],
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
),
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32)

output = info.kernel(input, dtype)

assert output.dtype == dtype
assert output.device == input.device


@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
Expand Down
Loading

0 comments on commit b54c9b4

Please sign in to comment.