Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple tensor -> pure tensor #7846

Merged
merged 4 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions test/test_prototype_datasets_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor


def assert_samples_equal(*args, msg=None, **kwargs):
Expand Down Expand Up @@ -140,18 +140,18 @@ def make_msg_and_close(head):
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_unaccompanied_simple_tensors(self, dataset_mock, config):
def test_no_unaccompanied_pure_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))

simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)}
pure_tensors = {key for key, value in sample.items() if is_pure_tensor(value)}

if simple_tensors and not any(
if pure_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
):
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors, "
f"{sequence_to_str(sorted(pure_tensors), separate_last='and ')} contained pure tensors, "
f"but didn't find any (encoded) image or video."
)

Expand Down
6 changes: 3 additions & 3 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.v2.utils import check_type, is_simple_tensor
from torchvision.transforms.v2.utils import check_type, is_pure_tensor

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]

Expand Down Expand Up @@ -296,7 +296,7 @@ def test_call(self, dims, inverse_dims):
value_type = type(value)
transformed_value = transformed_sample[key]

if check_type(value, (Image, is_simple_tensor, Video)):
if check_type(value, (Image, is_pure_tensor, Video)):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
Expand Down Expand Up @@ -341,7 +341,7 @@ def test_call(self, dims):
transformed_value = transformed_sample[key]

transposed_dims = transform.dims.get(value_type)
if check_type(value, (Image, is_simple_tensor, Video)):
if check_type(value, (Image, is_pure_tensor, Video)):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
Expand Down
38 changes: 19 additions & 19 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw
from torchvision.transforms.v2.utils import check_type, is_pure_tensor, query_chw


def make_vanilla_tensor_images(*args, **kwargs):
Expand Down Expand Up @@ -71,7 +71,7 @@ def auto_augment_adapter(transform, input, device):
if isinstance(value, (datapoints.BoundingBoxes, datapoints.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)):
elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
Expand Down Expand Up @@ -101,7 +101,7 @@ def normalize_adapter(transform, input, device):
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)):
elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor)):
# normalize doesn't support integer images
value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value
Expand Down Expand Up @@ -357,19 +357,19 @@ def test_random_resized_crop(self, transform, input):
3,
),
)
def test_simple_tensor_heuristic(flat_inputs):
def split_on_simple_tensor(to_split):
def test_pure_tensor_heuristic(flat_inputs):
def split_on_pure_tensor(to_split):
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
# 1. The first simple tensor. If none is present, this will be `None`
# 2. A list of the remaining simple tensors
# 1. The first pure tensor. If none is present, this will be `None`
# 2. A list of the remaining pure tensors
# 3. A list of all other items
simple_tensors = []
pure_tensors = []
others = []
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
# affect the splitting.
for item, inpt in zip(to_split, flat_inputs):
(simple_tensors if is_simple_tensor(inpt) else others).append(item)
return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others
(pure_tensors if is_pure_tensor(inpt) else others).append(item)
return pure_tensors[0] if pure_tensors else None, pure_tensors[1:], others

class CopyCloneTransform(transforms.Transform):
def _transform(self, inpt, params):
Expand All @@ -385,20 +385,20 @@ def was_applied(output, inpt):
assert_equal(output, inpt)
return True

first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs)
first_pure_tensor_input, other_pure_tensor_inputs, other_inputs = split_on_pure_tensor(flat_inputs)

transform = CopyCloneTransform()
transformed_sample = transform(flat_inputs)

first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample)
first_pure_tensor_output, other_pure_tensor_outputs, other_outputs = split_on_pure_tensor(transformed_sample)

if first_simple_tensor_input is not None:
if first_pure_tensor_input is not None:
if other_inputs:
assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
assert not transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)
else:
assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
assert transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)

for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs):
for output, inpt in zip(other_pure_tensor_outputs, other_pure_tensor_inputs):
assert not transform.was_applied(output, inpt)

for input, output in zip(other_inputs, other_outputs):
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image)
assert is_pure_tensor(image)

label = 1 if label_type is int else torch.tensor([1])

Expand Down Expand Up @@ -1125,7 +1125,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image)
assert is_pure_tensor(image)

label = torch.randint(0, 10, size=(num_boxes,))

Expand All @@ -1146,7 +1146,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
out = t(sample)

if isinstance(to_tensor, transforms.ToTensor) and image_type is not datapoints.Image:
assert is_simple_tensor(out["image"])
assert is_pure_tensor(out["image"])
else:
assert isinstance(out["image"], datapoints.Image)
assert isinstance(out["label"], type(sample["label"]))
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def check_call_consistency(
raise AssertionError(
f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`is_simple_tensor` path in `_transform`."
f"`is_pure_tensor` path in `_transform`."
) from exc

assert_close(
Expand Down
26 changes: 13 additions & 13 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS

Expand Down Expand Up @@ -168,7 +168,7 @@ def _unbatch(self, batch, *, data_dims):
def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device)

datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
datapoint_type = datapoints.Image if is_pure_tensor(batched_input) else type(batched_input)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
data_dims = {
Expand Down Expand Up @@ -333,9 +333,9 @@ def test_scripted_smoke(self, info, args_kwargs, device):
dispatcher = script(info.dispatcher)

(image_datapoint, *other_args), kwargs = args_kwargs.load(device)
image_simple_tensor = torch.Tensor(image_datapoint)
image_pure_tensor = torch.Tensor(image_datapoint)

dispatcher(image_simple_tensor, *other_args, **kwargs)
dispatcher(image_pure_tensor, *other_args, **kwargs)

# TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
# replaces this test for them.
Expand All @@ -358,11 +358,11 @@ def test_scriptable(self, dispatcher):
script(dispatcher)

@image_sample_inputs
def test_simple_tensor_output_type(self, info, args_kwargs):
def test_pure_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)
image_pure_tensor = image_datapoint.as_subclass(torch.Tensor)

output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)
output = info.dispatcher(image_pure_tensor, *other_args, **kwargs)

# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor
Expand Down Expand Up @@ -505,11 +505,11 @@ class TestClampBoundingBoxes:
dict(canvas_size=(1, 1)),
],
)
def test_simple_tensor_insufficient_metadata(self, metadata):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
def test_pure_tensor_insufficient_metadata(self, metadata):
pure_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
F.clamp_bounding_boxes(simple_tensor, **metadata)
F.clamp_bounding_boxes(pure_tensor, **metadata)

@pytest.mark.parametrize(
"metadata",
Expand Down Expand Up @@ -538,11 +538,11 @@ def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_format_bounding_boxes(inpt, old_format)

def test_simple_tensor_insufficient_metadata(self):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
def test_pure_tensor_insufficient_metadata(self):
pure_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
F.convert_format_bounding_boxes(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)

def test_datapoint_explicit_metadata(self):
datapoint = next(make_bounding_boxes())
Expand Down
6 changes: 3 additions & 3 deletions test/test_transforms_v2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), True),
(
(torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
True,
),
(
(to_pil_image(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
True,
),
],
Expand Down
2 changes: 1 addition & 1 deletion test/transforms_v2_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
("TestDispatchers", test_name),
pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
)
for test_name in ["test_simple_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
]
multi_crop_skips.append(skip_dispatch_datapoint)

Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform

from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor


class SimpleCopyPaste(Transform):
Expand Down Expand Up @@ -109,7 +109,7 @@ def _extract_image_targets(
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, datapoints.Image) or is_simple_tensor(obj):
if isinstance(obj, datapoints.Image) or is_pure_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image(obj))
Expand Down Expand Up @@ -146,7 +146,7 @@ def _insert_outputs(
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_pil_image(output_images[c0])
c0 += 1
elif is_simple_tensor(obj):
elif is_pure_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, datapoints.BoundingBoxes):
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_pure_tensor, query_size


class FixedSizeCrop(Transform):
Expand All @@ -32,7 +32,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:
flat_inputs,
PIL.Image.Image,
datapoints.Image,
is_simple_tensor,
is_pure_tensor,
datapoints.Video,
):
raise TypeError(
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchvision import datapoints
from torchvision.transforms.v2 import Transform

from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor


T = TypeVar("T")
Expand All @@ -25,7 +25,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:


class PermuteDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)

def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__()
Expand All @@ -47,7 +47,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:


class TransposeDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)

def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__()
Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter
from .utils import has_any, is_simple_tensor, query_chw, query_size
from .utils import has_any, is_pure_tensor, query_chw, query_size


class RandomErasing(_RandomApplyTransform):
Expand Down Expand Up @@ -243,7 +243,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])

output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
Expand Down Expand Up @@ -310,7 +310,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])

x1, y1, x2, y2 = params["box"]
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT

from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_simple_tensor
from .utils import check_type, is_pure_tensor


ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]
Expand Down Expand Up @@ -50,7 +50,7 @@ def _flatten_and_extract_image_or_video(
(
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
is_pure_tensor,
datapoints.Video,
),
):
Expand Down
Loading
Loading