Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEEDBACK] Transforms V2 API #6753

Closed
datumbox opened this issue Oct 12, 2022 · 89 comments · Fixed by #7860
Closed

[FEEDBACK] Transforms V2 API #6753

datumbox opened this issue Oct 12, 2022 · 89 comments · Fixed by #7860

Comments

@datumbox
Copy link
Contributor

datumbox commented Oct 12, 2022

🚀 The feature

This issue is dedicated for collecting community feedback on the Transforms V2 API. Please review the dedicated blogpost where we describe the API in detail and provide an overview of its features.

We would love to get your thoughts, comments and input in order to improve the API and graduate it from prototype on the near future.

Please also check out #7319 where we collect feedback on some specific design decision, and document as well which APIs may change in the future!


Code example using this image:

import PIL
from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F


# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = features.Image(io.read_image(path), color_space=features.ColorSpace.RGB)
# img = PIL.Image.open(path)
bboxes = features.BoundingBox(
    [[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
     [148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
     [422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
     [435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
     [469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
     [452, 39, 463, 63], [424, 38, 429, 50]],
    format=features.BoundingBoxFormat.XYXY,
    spatial_size=F.get_spatial_size(img),
)
labels = features.Label([59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74])


# Defining and applying Transforms V2
trans = T.Compose(
    [
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480),
    ]
)
img, bboxes, labels = trans(img, bboxes, labels)


# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()
@rsokl
Copy link

rsokl commented Oct 18, 2022

Could you post the link to the blog post in this thread when it becomes available? I am not sure where else to look to find it.

@jangop
Copy link

jangop commented Oct 21, 2022

Can you provide a good starting point for getting an overview of the current state of Transforms V2?

Or is reading through https://github.com/pytorch/vision/tree/main/torchvision/prototype/transforms and https://github.com/pytorch/vision/projects/5 the best approach for now?

@datumbox
Copy link
Contributor Author

@rsokl @jangop We got a blogpost in the pipeline that provides an overview. We are waiting for marketing to publish. I'll post the link here once we do. Until then those 2 references are the best places to look. There is no documentation because the API was being modified but all transforms receive the same exact parameters as V1.

@datumbox
Copy link
Contributor Author

datumbox commented Nov 4, 2022

@rsokl @jangop we published the blogpost. Looking forward to your input.

@rsokl
Copy link

rsokl commented Nov 4, 2022

Hi all,

I just read the blog post and have been exploring the prototype's internals for the last few of days: torchvision.transforms v2 looks great! I'd like to share my thoughts with you all.

First, let me summarize some of the primary features that are offered by v2. This will, in part, help me make sure that I have a clear understanding of things. v2 includes:

  • Native support for transforming videos, bounding boxes, segmentation masks, and labels via the _Feature class.
  • The ability to describe metadata for these features. E.g., the image class that has a mechanism for signaling the image's color space (e.g. RGB vs RGBA).
  • A flexible API for mapping transformations over general pytrees of tensors / features, making it easy to incorporate arbitrary numbers/configurations of features in each transformation.
  • A simple API for implementing transforms via the Transform class. Here, an implementation draws random parameters within _get_params, and an individual tensor/feature is transformed by _transform.
  • Continued support for batches of tensors/features as well as cuda support.
  • Improved performance over v1.

I had been doing a review of augmentation libraries, including kornia, albumentations, and augly, when I came across v2 for torchvision. To me, your new API is the simplest, most capable, and the easiest to extend. I particularly like how simple Transform is and that you do not rely on sprawling class hierarchies to dispatch functionality.

So first of all, thank you for all of the hard work that you have been doing on this effort. The PyTorch community is fortunate to benefit from this excellent work ❤️.

There are two features that I would like to propose. I would be happy to assist with these if there is interest.

Enabling local reproducibility by passing torch.Generator to transforms.

In recent years, NumPy has completely revised their PRNG API to avoid global random state (here is a great post on good practices with NumPy's generators). JAX avoids mutable RNG objects altogether. PyTorch provides torch.Generator to users to to make randomness local and "non-spooky", but many libraries prevent users from utilizing this capability.

I am proposing that Transform enable users to optionally pass in a Generator to the forward pass so that torchvision transform pipelines can be made to be isolated from global entropy and thus support more reproducible workflows. This reproducibility is especially useful in the context of performing testing & evaluation – the specific sequence of data transformations performed should be able to be isolated from whether or not a model is using dropout in its forward pass.

The local generator (whose default value is PyTorch's global generator) could be passed to the forward method and piped to the _get_params method. This assumes that all random number generation occurs in self._get_params. The following would be a compatibility-preserving modification to Transform; by default the global seed would still control stochasticity in the transforms.

from torch import Generator, Tensor, default_generator
import torch.nn as nn

class Transform(nn.Module):
    def _get_params(self, flat_inputs: List[Any], *, generator: Generator = default_generator) -> Dict[str, Any]:
        ...

    def forward(self, *inputs: Any, generator: Generator = default_generator) -> Any:
        # the only modification
        params = self._get_params(flat_inputs, generator=generator)

and transforms that implement _get_params would replace calls like

# e.g. replace calls like
angle = float(torch.empty(1).uniform_(0.0, 180.).item())

with

# specifying the device is, unfortunately, necessary: https://github.com/pytorch/pytorch/issues/79018
angle = float(torch.empty(1, device=generator.device).uniform_(0.0, 180., generator=generator).item())

A transform like Compose would have to be modified as well. Currently, it supports a sequence of callables that are assumed to accept a single positional argument. It could be assumed that only instances of Transform involve stochasticity and will be passed the random generator. In this case, Compose would look like:

class Compose(Transform):
    # __init__ is unchanged

    def forward(self, *inputs: Any, generator: Generator = default_generator) -> Any:
        sample = inputs if len(inputs) > 1 else inputs[0]
        for transform in self.transforms:
            sample = transform(sample) if not isinstance(transform, Transform) else transform(sample, generator=generator)
        return sample

It would be straightforward to document this behavior to users – that only instances of Transform are passed the generator – so that they know how to opt-in to having the generator be passed to their custom transforms. And, again, this would be compatible with the old nn.Module transforms.

An example of this in practice would be:

from torch import Generator

rng = Generator.manual_seed(0)

trans = T.Compose(
    [
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480),
    ]
)
img, bboxes, labels = trans(img, bboxes, labels, generator=rng)

Another nice thing about this is that specific fail cases that occur during training/testing can be reproduced in an isolated way; _get_params(dummy_img, generator=rng) can be used to iterate the generator's state to "replay" a sequence of transformations without have to redo all of the compute. Whereas this would not work if the model and the transforms both affect and derive from global state.

Adding Feature to the public API and adding a protocol for dispatching transforms on 3rd party features

This request a bit more ambitious, but I think that it could have a big pay-off.

The crown jewel of torchvision.transforms v2 is its added support for features like bounding boxes and segmentation masks. Presently, the development of new features and transforms is gated on the development efforts of the intrepid torchvision team. It would be great to enable users to implement their own features (e.g. Polygon or ImageTrack) and for torchvision's transforms (or 3rd party transforms) to be able to dispatch on those features. This would reduce the burden on the torchvision team in that it enables users to develop bespoke features that have "drop-in" compatibility with standard torchvision pipelines, without you needing to officially support the feature.

Thus I am proposing to a) make _Feature part of the public API (take away that leading underscore 😄 ) so that 3rd parties can create custom features, and b) provide a dispatch mechanism by which a 3rd party feature can coerce a transform (either 1st party or 3rd party transform) into dispatching to its own implementation of said transform. E.g. I could write a Polygon feature that would be able to implement its own affine transformation that both torchvision.transforms.affine and torchvision.transforms.RandomAffine dispatch to when they operate on a Polygon tensor.

There is already some degree of dispatching going on among some transforms. E.g., for affine and RandomAffine, any _Feature subclass can simply implement a Feature.affine method. The affine function checks isinstance(inp, Feature) and then calls inp.affine(...). And RandomAffine._transform simply calls affine under the hood, which performs the aforementioned dispatching.

That being said, users can't rely on all functions and transformations to be implemented this way. Indeed, Transform._transformed_types will often preclude such synergistic dispatching (RandomAffine._transformed_types works here because it only contains Tensor). It seems to me that, to support this proposal, Transform._transformed_types would need to be replaced by a Feature-centric way of documenting compatibility with transforms.

Given that these torchvision's transforms are all unary functions, I believe that the dispatch mechanism could be kept relatively simple. That being said, I don't have deep insight into how JIT compatibility might impact the feasibility of some of the design ideas that I have been playing around with. I also only have a rudimentary understanding of __torch_function__ and __torch_dispatch__, but I get the sense that torchvision's transforms would not be able to leverage these.

Ultimately, I do not have a concrete design for this dispatch protocol, but would be happy to help develop it with you.

Thank you very much for your consideration. And thank you again for your excellent work on v2. My colleagues and I are excited to start using it.

@pmeier
Copy link
Contributor

pmeier commented Nov 4, 2022

Thanks a ton for the kind words @rsokl! I think your summary is on point. Let me share my thoughts on your proposals.

Allow torch.Generator

+1 from me on basically everything you proposed. I think the design is simple enough that this can be done quickly in case we decide to adopt it. pytorch/pytorch#79018 is unfortunate here, but if we as torchvision need that, we might be able to speed things up on core.

The only thing that I would change about your proposal is the existence of a default value for generator on _get_params. During normal operation, _get_params should not be called directly and thus we can be explicit about it. IIUC, this would also not interfere with your use case of fast-forwarding the RNG without doing the actual computation since there you also don't want to use the default generator.

Regarding this use case, please be aware that it will only work for most transforms. Some more elaborate transforms like the AutoAugment family or SimpleCopyPaste don't fit into the regular Transform scheme that we devised. They overwrite forward directly and perform the random parameter generation there.

Make features._Feature public

Let's go through the proposal point by point

a) make _Feature part of the public API (take away that leading underscore 😄 ) so that 3rd parties can create custom features

Opinions about this changed over the course of the development. We started with it being public, but made it private at some point. Recently, we found a use case for it to be public again (#6663 (comment)). A cautious +1 from my side on this.

b) provide a dispatch mechanism by which a 3rd party feature can coerce a transform (either 1st party or 3rd party transform) into dispatching to its own implementation of said transform.

As you have explained above, this is already happening for some dispatchers like F.affine. However, there are a few caveats that need more careful design:

  1. Not all dispatchers are available as methods on the features. In the beginning our rule of thumb has been we only put something as method if in addition to images, bounding boxes and masks are supported as well. This is not a design limitation though, so we only need to "fix" our rule if we want this.

  2. Not all kernels have dispatchers. For example

    def convert_format_bounding_box(

    because there would only be a single kernel. Even if we added such a dispatcher for consistency we would get into hot waters, since the dispatcher would violate the rule "for BC plain tensors need to be treated as image or video" since there is no image or video kernel to dispatch to. This rule is currently kept by all our other dispatchers which allows us to keep them BC even with regard to JIT. Of course there is no BC concern for new functions, but it is a lot easier to say "the functional API is JIT compatible" rather than "the functional API that was already there for v1 is JIT compatible, but for the remainder it depends".

  3. Internally, a lot of transformations look up the spatial size from the inputs with

    height, width = query_spatial_size(flat_inputs)

    This is already pretty flexible and can pull the information from a lot of types, but there is no "official" protocol yet. For now we are using the spatial_size attribute

    elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)):
    return list(inpt.spatial_size)

    but this only a convention.

  4. Taking 3. one step further, some transformations also query bounding boxes like that:

    def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox:
    bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)]
    if not bounding_boxes:
    raise TypeError("No bounding box was found in the sample")

    Without a more elaborate protocol, any bounding-box-like feature like RotatedBoundingBox proposed in [RFC] Rotated Bounding Boxes #2761 cannot be a free feature, but has to subclass features.BoundingBox. Otherwise it won't be picked up and in turn the transformation will fail.

These are just the limitations from the top of my head. There will likely be more if take a closer look. Thus, the design for this will be more complex.

That being said some of these points mentioned above are already in our backlog, but we focused on performance the past few weeks. I agree allowing users to extend our API like that would be very powerful, but we need to solve all rough edges before we can promote this as official part of the API. Let's hear what @datumbox and @vfdev-5 say about this before I cast my vote.

@jangop
Copy link

jangop commented Nov 5, 2022

I had been doing a review of augmentation libraries, including kornia, albumentations, and augly, when I came across v2 for torchvision. To me, your new API is the simplest, most capable, and the easiest to extend. I particularly like how simple Transform is and that you do not rely on sprawling class hierarchies to dispatch functionality.

@rsokl Would you perhaps share a few words on why / how / in which regard Transforms V2 is “the simplest, most capable, and the easiest to extend”? Reading through the blog post, the one section I missed was related work for context. Personally, I only have experience with albumentations. Coming from that, I am inclined to agree with your assessment.

@pmeier pmeier pinned this issue Nov 7, 2022
@FlorinAndrei
Copy link

FlorinAndrei commented Nov 8, 2022

Is there an estimate in terms of either time, or version number, or both, for when the v2 transforms will be included in pytorch-stable? The blog says "planned for Q1" but I wonder if there's a better estimate than that somewhere.

I was about to implement, on a small scale, functions that basically do what v2 does, but if your prototype works well then I will use it for my project.

Should I expect any issues if I train a model on the nightly builds with the v2 transforms, but then run it for predictions on pytorch-stable 1.13? Other than the v2 transforms, I do not use any unusual features.

@FlorinAndrei
Copy link

test case: image segmentation with SegFormer. My dataset has images and masks organized in tuples. I am following this example:

https://huggingface.co/blog/fine-tune-segformer

I need to make sure that geometric transforms (horizontal flip, rotation, affine, perspective) are applied randomly, but are applied the exact same way to each image and its corresponding mask. I am testing with horizontal flip and the image and the mask are flipped in an uncoordinated fashion.

Code example (admittedly inefficient):

from torchvision.prototype.transforms import (
    Compose,
    RandomApply,
    ColorJitter,
    RandomRotation,
    RandomCrop,
    RandomAffine,
    RandomHorizontalFlip,
    RandomPerspective,
)
from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace

feature_extractor = SegformerFeatureExtractor()

augmentations = Compose(
    [
        # RandomApply([ColorJitter(brightness=0.5, contrast=0.5)], p=0.75),
        RandomHorizontalFlip(p=0.5),
        # RandomApply([RandomRotation(degrees=45)], p=0.75),
        # RandomApply([RandomAffine(degrees=0.0, scale=(0.5, 2.0))], p=0.25),
        # RandomPerspective(distortion_scale=0.5, p=0.25),
    ]
)


def train_transforms(example_batch):
    # original
    images = [augmentations(x) for x in example_batch["pixel_values"]]
    labels = [x for x in example_batch["label"]]
    inputs = feature_extractor(images, labels)
    return inputs


def train_transforms2(example_batch):
    # my version
    batch_items = list(
        zip(
            [x for x in example_batch["pixel_values"]],
            [x for x in example_batch["label"]],
        )
    )
    batch_items_aug = [
        augmentations(
            features.Image(
                np.swapaxes(np.array(x[0]), 0, -1), color_space=ColorSpace.RGB
            ),
            features.Mask(np.swapaxes(np.array(x[1]), 0, -1)),
        )
        for x in batch_items
    ]
    images, labels = map(list, zip(*batch_items_aug))
    inputs = feature_extractor(images, labels)
    return inputs


train_ds.set_transform(train_transforms2)

This is one entry in the train dataset that the transforms are applied to:

{'dataset': 0,
 'pixel_values': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=400x322>,
 'tumor': 1,
 'dataset_tumor': 1,
 'label': <PIL.PngImagePlugin.PngImageFile image mode=L size=400x322>}

If I repeatedly visualize the image and the mask in train_ds[0], they are flipped randomly, but the flip is not applied the same to the image and its mask - they are each flipped as a completely separate random process.

What I expect:

  • images and masks are geometrically transformed in a coordinated fashion (random transforms are applied the same way to an image and to its mask)
  • pixel-value transforms such as brightness are only applied to the image (this seems to work fine)

@pmeier
Copy link
Contributor

pmeier commented Nov 8, 2022

Hey @FlorinAndrei

Is there an estimate in terms of either time, or version number, or both, for when the v2 transforms will be included in pytorch-stable? The blog says "planned for Q1" but I wonder if there's a better estimate than that somewhere.

We are aiming to publish it with the next release, i.e. torchvision==0.15.0. The exact date is not yet available. We operate on a 3-4 months release cycle and the last release just happened on the 28th of October. So you can expect this to be released in February or March 2023.

Should I expect any issues if I train a model on the nightly builds with the v2 transforms, but then run it for predictions on pytorch-stable 1.13? Other than the v2 transforms, I do not use any unusual features.

No. Apart from JIT, the transforms are BC and thus there should be no complications. You can have a look at #6433 where we do just that.

I am testing with horizontal flip and the image and the mask are flipped in an uncoordinated fashion.

You are not flipping the masks at all:

images = [augmentations(x) for x in example_batch["pixel_values"]]
labels = [x for x in example_batch["label"]]

The horizontal flip is happening in augmentations but this is never called on any example_batch["label"]. Even if you were using it in the line above, the augmentation cannot be same for images and masks since you apply them separately.

Transforms v2 can handle arbitrary input structures and so you don't need to handle images and masks separately. You can just pass them into the transforms together like:

augmented_batch = [augmentations(sample) for sample in example_batch]
images, labels = zip(*augmented_batch)
inputs = feature_extractor(images, labels)

The only caveat of this is that example_batch["label"] will be treated as image for BC. Meaning, the color transformations will also be applied to it. Thus, we need to communicate to the transformations that it is a mask.

In transforms v2 this is done by wrapping the data into custom tensor classes located under torchvision.prototype.features:

from torchvision.prototype import features, transforms
from torchvision.prototype.transforms import functional as F

class WrapIntoFeatures(transforms.Transform):
    def forward(self, sample):
        image, label = sample
        label = features.Mask(F.pil_to_tensor(mask))
        return image, label

augmentations = transforms.Compose([WrapIntoFeatures(), ...])

For custom datasets this will have to happen manually. For our builtin ones, we currently explore the options. See #6662 for the current favorite.

I've put together a small notebook for you. I hope it helps. but let us know if you encounter anything.

@maxoppelt
Copy link

I understand the need of wrapping the torch.Tensor class to mimic the behavior of "matching" transformations for a specific input type using _transformed_types class variable. However, as far as I can understand the user directly handles / sets these types. This might lead to inconsistent behavior with functionalities commonly used and provided by inheritance of the th.Tensor class. I have attached a simple multiplication of two image-typed-like-feature classes below to highlight my concerns. I see two directions:

(a) It is fine for "transforms v2", that operation on two feature.Image types as input defaults back to a non features.Image tensor. I think this behavior might feel inconsistent for some users?
(b) The operations (multiplication, addition, substraction, etc.) need to be implemented. However, this opens the hells gate to a lot of manual implementation overhead, e.g. what is the multiplication of a black-and-white image and a RGB image. Or what is the multiplication of an features.Image and a features.Video? --> Implementation Hell. Not to mention what would happen if a feature.Image would be passed to a neural network... feature.Image * th.Tensor..

It might be that my concerns are irrelevant: I have just read through the code briefly. However, I think that this behavior needs to be specified / discussed upfront.

import torch as th
from torchvision.prototype.features import Image

img1 = Image(th.rand(3, 256, 256))
img2 = Image(th.rand(3, 256, 256))

type(img1), type(img2)
>> torchvision.prototype.features._image.Image, torchvision.prototype.features._image.Image

img = img1 * img2
type(img)
>> torch.Tensor

@pmeier
Copy link
Contributor

pmeier commented Nov 9, 2022

@maxoppelt

mimic the behavior of "matching" transformations for a specific input type using _transformed_types class variable.

Not sure I understand. Transform._transformed_types is a way to specify which types you want to transform or on the flip side for which you want no-op behavior. Most transformations use the default

_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)

meaning all tensors and PIL images will be transformed.

This might lead to inconsistent behavior with functionalities commonly used and provided by inheritance of the th.Tensor class.

Inside our transformations, we make sure that there is no inconsistent behavior. If you pass in a features.Image you will also get one back. That is handled on the functional dispatcher level, i.e. F.resize or the like will give you this behavior. There are a few exceptions like

# Image or Video type should not be retained after normalization due to unknown data range
# Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)

I think that this behavior needs to be specified / discussed upfront.

The behavior is specified and encoded here:

def __torch_function__(

The function has extensive comments what it is doing, but let me give you the TL;DR: except for very few operators, namely

_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch.Tensor.requires_grad_: lambda cls, input, output: output,
}

there is no fast and safe way for us to determine if the result of the operation should retain the feature type or not. You already highlighted a few cases above. Thus, any operation except for the ones mentioned above will "unwrap" the result, i.e. give you a plain torch.Tensor back.

Now, there are times where the result should retain the input type. This is happening all over our transforms. In that case you will need to wrap again, like

output = ...
return features.Image(output)

As a shorthand if you don't want to copy the metadata manually, you can use features.Image.wrap_like(input, output), e.g.

def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)

(Note that the .as_subclass(torch.Tensor) call before we call the kernel is a micro-optimization and not required)

@maxoppelt
Copy link

maxoppelt commented Nov 9, 2022

@pmeier: Yes I agree and I think it's great that you can define Transformations with op or no-op depending on the specified feature. Thanks for laying out the implementation again, but my concern is more about the user interface and the way of specifying these types. Let me give you a small example, a user would assume (without reading our last two comments in this feedback ticket) would work:

As addition to your provided example, here are two masks for your image above.

strawberry
tomato

from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F

# Loading data
img = features.Image(io.read_image('COCO_val2014_000000418825.jpg'), color_space=features.ColorSpace.RGB)
smask = features.Mask((io.read_image('strawberry.png') != 0).type(th.uint8))   # has a no-op for GaussianBlur
tmask = features.Mask((io.read_image('tomato.png') != 0).type(th.uint8))       # has a no-op for GaussianBlur
sotmask = smask | tmask  #  is now a simple tensor and _isinstance(..., _is_simple_tensor) --> has op for GuassianBlur

# Defining and applying Transforms V2
trans = T.Compose(
    [
        T.GaussianBlur(kernel_size=7),
    ]
)

np.all((trans(img) == img).numpy())
>> False.   # Expected
np.all((trans(smask) == smask).numpy())
>> True # Expected
np.all((trans(sotmask) == sotmask).numpy())
>> False # This is what I am talking about...

Edit: Fixed, code based on @pmeier's comment below.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 9, 2022

@maxoppelt thanks for your feedback and the discussion ! Concerning your point about

operation on two feature.Image types as input defaults back to a non features.Image tensor. I think this behavior might feel inconsistent for some users?

and your examples:

type(img1), type(img2)
>> torchvision.prototype.features._image.Image, torchvision.prototype.features._image.Image

img = img1 * img2
type(img)
>> torch.Tensor

and

np.all((trans(sotmask) == sotmask).numpy())
>> False # This is what I am talking about...

What do you think would be a consistent/expected behaviour ? Should they raise an error instead ?

@pmeier
Copy link
Contributor

pmeier commented Nov 9, 2022

@maxoppelt

but my concern is more about the user interface and the way of specifying these types. Let me give you a small example, a user would assume (without reading our last two comments in this feedback ticket) would work:

You are right that the communication of this needs to be clear. Right now documentation is scarce, but this will change before the release. Basically the rule is: "if you perform any operation on tensor subclasses outside of our builtin transforms, you need to wrap again". We will have detailed documentation about how to write a custom transform that will cover this.

We thought about this very issue a lot during the design phase. In the end we decided against an automatic detection whether or not the result can retain the type or not due to the issues listed above. And the behavior you mention is just the logical consequence of this.

Unless we missed something obvious, there is no better way to do this if we want to use the type of the input as dispatch mechanism. The solutions of comparable libraries also require you to be explicit either by calling the transform like transform(..., mask=sotmask) or providing this information when the transform is created transform = MyTransform(..., order=["image", "mask"]). Thus, we are not an outlier when it comes to being explicit with transform(..., features.Mask(sot_mask)). Do you have a better idea?

P.S.: features.Mask is the right type for masks. features.Label is for numeric labels like

>>> label = features.Label([0, 1, 1], categories=["strawberry", "tomato"])
>>> label
Label([0, 1, 1])
>>> label.to_categories()
['strawberry', 'tomato', 'tomato']

@maxoppelt
Copy link

maxoppelt commented Nov 9, 2022

@vfdev-5

What do you think would be a consistent/expected behavior ? Should they raise an error instead ?

Depends, I could image something like

type(img)
> torchvision.prototype.features._image.Image
type(mask)
>  torchvision.prototype.features._mask.Mask

result = image * mask
type(result)
> torchvision.prototype.features._image.Image

would be self-explanatory and result in something like this:

masked

Raising an error is probably too much, as it might be a valid operation... It is just not clear that any operation using a feature class will result a non-feature class, but in a th.Tensor class.

However, I kind of agree with @pmeier, as these new feature classes are mainly there to define op vs. no-op for inputs in torchvision transformations v2. My question is only pointing in the direction of unspecified usage: Which is, I might add, probably happening very often in the future. For example: The Dataset class provides e.g. two variables of type torchvision.prototype.features._image.Image and torchvision.prototype.features._label.Label, that should be transformed using the torchvision transformation pipeline v2. However in-between these two steps, it is very likely the user is accessing the data, e.g. for a special normalization or a reshape operation?

An alternative approach could look like this:

type(img1), type(img2), type(mask)
> th.Tensor, th.Tensor, th.Tensor

trans = T.Compose(
    [
        T.GaussianBlur(kernel_size=7),
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480)
    ],
input_types = [Image, Image, Mask]
)

trans(img1, img2, mask)

instead of this:

type(img1), type(img2), type(mask)
> torchvision.prototype.features._image.Image, torchvision.prototype.features._image.Image, torchvision.prototype.features._mask.Mask

trans = T.Compose(
    [
        T.GaussianBlur(kernel_size=7),
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480)
    ],
)

trans (img1, img2, mask)

However, this approach might have some disadvantages, too: Type and value are separated and why not use https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor (like currently done...). At the end it is probably an issue that could be resolved/prevented in the future by writing a good documentation like @pmeier pointed out. Maybe one could use python warnings to inform the user...

Btw. why is the new torchvision type system called features? In most cases torchvision transforms will work on the input and not latent representation (usually called features).

@FlorinAndrei
Copy link

@pmeier In the example above, I actually use train_transforms2() which is applied to both images and masks. However, that is not why my code was failing to apply the transforms in a coordinated fashion. I've tried Albumentations instead, and I got the same bad result.

The reason is that I was invoking the dataset twice: once to get the image, and again to get the mask. Those were two separate applications of the augmentation function, and of course they were not coordinated in terms of random geometric transforms.

All that was fixed when I extracted that single item from the dataset once, and then extracted the image and the mask from the item.

Apologies for the non-issue.

@datumbox
Copy link
Contributor Author

@FlorinAndrei Thanks for taking the time to provide in depth feedback. We really appreciate it. Please keep it coming; perhaps the confusion caused indicates there are still rough edges on the API that we need to fix. Or perhaps we need to document the gotcha's better. If you have other ideas on the overall API (both public and developer) or on the naming conventions please let us know. :)

@datumbox
Copy link
Contributor Author

@maxoppelt

Btw. why is the new torchvision type system called features? In most cases torchvision transforms will work on the input and not latent representation (usually called features).

You are absolutely right. This is very confusing. This is a placeholder name until we find something better. We need a concept for the base tensor class that can be reused for Images, Videos, BoundingBoxes, Labels, Masks etc. Naming is NP-hard, any help on that front would be highly appreciated...

@rsokl
Copy link

rsokl commented Nov 11, 2022

We need a concept for the base tensor class that can be reused for Images, Videos, BoundingBoxes, Labels, Masks etc.

Perhaps DataTensor? It risks being a bit vague, but it does convey "these are things are all forms of training/testing data".

@pmeier
Copy link
Contributor

pmeier commented Nov 11, 2022

The naming issue is not new. See for example the thread in #5045. We used "feature" in the beginning since that is what tensorflow-datasets used. X-posting some of the comments / suggestions from the other thread here

@CedricPicron
Copy link

CedricPicron commented Nov 15, 2022

Hi. As a PhD student working on object detection and instance segmentation, I'm very happy to see this added to torchvision. I'm eager to use this new transforms API in my own code base.

Possible inconsistencies between boxes and masks

When doing (box-based) instance segmentation, both target boxes and target masks need to be provided to the model. As written in the blog post, one would use

img, bboxes, masks, labels = trans(img, bboxes, masks, labels)

to apply the transforms to the various data structures.

By applying each of the transforms on the boxes and masks individually (as I believe is done in Transforms V2), one might expect the transformed bounding boxes to tightly fit the transformed object masks, as was the case before the applying the transforms.

However, some operations like the crop operation might result in inconsistencies between the cropped boxes and masks, with the cropped boxes no longer tightly fitting the cropped masks. See https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/dataset_mapper.py#L135 for a toy example. I believe the same might happen to other transforms, such as e.g. rotation-based transforms.

I'm not sure how much these inconsistencies affect the performance (I expect not much), but I think this issue needs to be addressed. If not, some users might think there's some kind of bug as they would expect the transformed bounding boxes to always tightly fit the transformed masks.

Possible solutions

Regardless of the precise solution to this problem, I think a get_bounding_boxes(bbox_format) method should be added to the Mask structure.

Solution 1: Change nothing to the API. It's up to the users to decide on how they want to transform the boxes. If the user wants the transformed boxes to always tightly fit the transformed masks, then the user can proceed as follows:

img, masks, labels = trans(img, masks, labels)
bboxes = masks.get_bounding_boxes(bboxes.format)

Solution 2: Change the API. I'm not sure what the cleanest solution would be. Maybe something like

img, bboxes, masks, labels = trans(img, bboxes, masks, labels, get_bboxes_from_masks=True)

could be considered where internally something like

if get_bboxes_from_masks:
    # TODO: Remove bboxes from pytree and apply transforms as usual
    bboxes = masks.get_bounding_boxes(bboxes.format)
    # TODO: Add bboxes to output
else:
    # The usual

would be used.

@pmeier
Copy link
Contributor

pmeier commented Nov 16, 2022

Hey @CedricPicron and thanks for your feedback. I agree that we need to touch on this in the documentation to explain what is happening.

Regardless of the precise solution to this problem, I think a get_bounding_boxes(bbox_format) method should be added to the Mask structure.

There is already an operator for it: torchvision.ops.masks_to_boxes. We haven't wrapped it inside a transform or a method on Mask yet. This is still up for discussion.

As for the proposed solutions, with the op from above, you can easily add a custom transform that does what you want:

from typing import *

from torchvision.prototype import transforms, features
from torchvision.prototype.transforms import functional as F
from torchvision.ops import masks_to_boxes

# We are currently debating whether we should make this public
from torchvision.prototype.transforms._utils import has_all


# This is modeled after `query_bounding_box` located in `torchvision.prototype.transforms._utils`
def query_mask(flat_inputs: List[Any]) -> features.Mask:
    masks = [inpt for inpt in flat_inputs if isinstance(inpt, features.Mask)]
    if not masks:
        raise TypeError("No mask was found in the sample")
    elif len(masks) > 1:
        raise ValueError("Found multiple masks in the sample")
    return masks.pop()


class TightenBoundingBoxes(transforms.Transform):
    _transformed_types = (features.BoundingBox,)

    def _check_inputs(self, flat_inputs: List[Any]) -> None:
        # Of course, we could also make this a no-op in case we don't find both
        if not has_all(flat_inputs, features.Mask, features.BoundingBox):
            raise TypeError("TightenBoundingBoxes needs masks and bounding boxes")

    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        mask = query_mask(flat_inputs)
        tight_bounding_box = masks_to_boxes(mask)
        return dict(tight_bounding_box=tight_bounding_box)

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        return features.BoundingBox.wrap_like(
            inpt,
            F.convert_format_bounding_box(
                params["tight_bounding_box"].to(inpt),
                old_format=features.BoundingBoxFormat.XYXY,
                new_format=inpt.format,
                inplace=True,
            ),
        )

Stealing the toy example from above:

import torch

mask1 = features.Mask(torch.triu(torch.ones(4, 4), diagonal=1).unsqueeze(0))
box1 = features.BoundingBox(
    [[1, 0, 3, 2]], format=features.BoundingBoxFormat.XYXY, spatial_size=mask1.shape[-2:]
)
print(mask1)
print(box1)

# Emulating a fixed crop transform
mask2 = F.crop(mask1, top=0, left=0, height=4, width=3)
box2 = F.crop(box1, top=0, left=0, height=4, width=3)
print(mask2)
print(box2)

transform = TightenBoundingBoxes()
mask3, box3 = transform(mask2, box2)
print(mask3)
print(box3)
mask1 Mask([[[0., 1., 1., 1.],
       [0., 0., 1., 1.],
       [0., 0., 0., 1.],
       [0., 0., 0., 0.]]])
box1 BoundingBox([[1, 0, 3, 2]], format=BoundingBoxFormat.XYXY, spatial_size=torch.Size([4, 4]))
mask2 Mask([[[0., 1., 1.],
       [0., 0., 1.],
       [0., 0., 0.],
       [0., 0., 0.]]])
box2 BoundingBox([[1, 0, 3, 2]], format=BoundingBoxFormat.XYXY, spatial_size=(4, 3))
mask3 Mask([[[0., 1., 1.],
       [0., 0., 1.],
       [0., 0., 0.],
       [0., 0., 0.]]])
box3 BoundingBox([[1, 0, 2, 1]], format=BoundingBoxFormat.XYXY, spatial_size=(4, 3))

Now you can drop this in wherever you need it in your pipeline. Given that this has a performance implication and we currently don't have evidence that this actually impacts the performance of the model, IMO it is best to keep this behavior manual. Plus, it is not as easy as the check you proposed in solution 2 above since these transforms need to work regardless of the task, but object detection may have no masks available.

If you have another look at the output from the example above, there is another case where some manual action is required by the user if we didn't have TightenBoundingBoxes: after the crop, box2 has the right spatial_size, but the second coordinate now lies outside of the image. This is intentional since we don't know if the user wants to clamp the bounding boxes or not. Plus, always doing so again has some performance implications.

We already have two builtin transforms that function similar to the one I proposed above:

  • ClampBoundingBoxes: clamp bounding boxes to spatial size. This does not remove bounding boxes that are empty after the operation.
  • RemoveSmallBoundingBoxes: remove small (default empty) bounding boxes and do the same to corresponding masks and labels.

@CedricPicron
Copy link

Hi @pmeier. Thanks a lot for your detailed response!

I like the proposed solution based on the TightenBoundingBoxes transform. I think it's a flexible solution, similar to the existing ClampBoundingBoxes and RemoveSmallBoundingBoxes transforms. It would be nice to see TightenBoundingBoxes added to Transforms V2.

Some additional comments:

  1. I think it's important that users can easily find these transforms. Currently, these are found/implemented in _meta.py and _misc.py. I was wondering if some of these transforms could not be grouped in a different way, such that they can more easily be found. Maybe the transforms ClampBoundingBoxes and TightenBoundingBoxes can be grouped under bbox transforms, and RemoveSmallBoundingBoxes under filter transforms.
  2. When using TightenBoundingBoxes, the input bounding boxes do not matter and only the bbox format is used. Transforms prior to the TightenBoundingBoxes transform hence do not need to be applied on the boxes input. However, currently they are. Avoiding this unnecessary computation could make the pipeline with TightenBoundingBoxes possibly even faster than without. If desired, this could be mitigated by making _transformed_types an instance attribute (instead of a class attribute) and allow the user to set a custom _transformed_types attribute different from the default one.

@pmeier
Copy link
Contributor

pmeier commented Nov 16, 2022

  1. I think it's important that users can easily find these transforms. Currently, these are found/implemented in _meta.py and _misc.py. I was wondering if some of these transforms could not be grouped in a different way, such that they can more easily be found. Maybe the transforms ClampBoundingBoxes and TightenBoundingBoxes can be grouped under bbox transforms, and RemoveSmallBoundingBoxes under filter transforms.

I agree the naming of the modules is not perfect here. On the flip side, I'd wager a guess regardless of what scheme you choose, you will always have outliers that don't fit anywhere.

Plus, you are looking at internal / private namespaces here. I'm fully aware that this is the only thing you can do right now due to the non-existent other documentation, but this will change before release. Meaning, you can discover everything there instead of going through the source. Since this bounding box behavior might indeed be unexpected, I think it would be good to add a small gallery to show the effects.

If you do want to look at the source, I suggest to have a look at the __init__ files for a better overview:

  1. When using TightenBoundingBoxes, the input bounding boxes do not matter and only the bbox format is used. Transforms prior to the TightenBoundingBoxes transform hence do not need to be applied on the boxes input. However, currently they are. Avoiding this unnecessary computation could make the pipeline with TightenBoundingBoxes possibly even faster than without. If desired, this could be mitigated by making _transformed_types an instance attribute (instead of a class attribute) and allow the user to set a custom _transformed_types attribute different from the default one.

I agree with the statement, but I think there is a far easier solution: don't pass the bounding box. Unless you have a transformation in your pipeline that requires a bounding box to be present, e.g. RandomIoUCrop, you can simply not pass it and thus save all the computation. If you have transforms that need bounding boxes, I think you can write a thin transforms.Compose wrapper to achieve this locally instead for the whole pipeline:

from torch.utils._pytree import tree_flatten, tree_unflatten


class NoBoundingBoxesContainer(transforms.Compose):
    def forward(self, *inputs):
        flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])

        indexed_mask = None
        indexed_bounding_box = None
        everything_else = []
        for idx, inpt in enumerate(flat_inputs):
            if isinstance(inpt, features.BoundingBox):
                indexed_bounding_box = (idx, inpt)
            elif isinstance(inpt, features.Mask):
                indexed_mask = (idx, inpt)
            else:
                everything_else.append((idx, inpt))
        # do a proper error checking here if a mask and box is available

        transformed_indexed_mask, transformed_everything_else = super().forward(
            indexed_mask, everything_else
        )

        transformed_indexed_bounding_box = (
            indexed_bounding_box[0],
            features.BoundingBox.wrap_like(
                indexed_bounding_box[1],
                F.convert_format_bounding_box(
                    masks_to_boxes(transformed_indexed_mask[1]).to(indexed_bounding_box[1]),
                    old_format=features.BoundingBoxFormat.XYXY,
                    new_format=indexed_bounding_box[1].format,
                    inplace=True,
                ),
            ),
        )

        flat_outputs = list(
            zip(
                *sorted(
                    [
                        transformed_indexed_bounding_box,
                        transformed_indexed_mask,
                        *transformed_everything_else,
                    ],
                    key=lambda indexed_output: indexed_output[0],
                )
            )
        )[1]

        return tree_unflatten(flat_outputs, spec)

Although this looks quite daunting, it actually doesn't do anything complicated. Basically we fish out the mask and bounding box from the input, transform the mask as well as everything else, create a new bounding box from the transformed mask, and assemble everything back together. With this you can do

pipeline = transforms.Compose(
    [
        NoBoundingBoxesContainer(
            [
                transforms.RandomRotation(...),
                transforms.RandomCrop(...),
            ]
        ),
        transforms.RandomIoUCrop(...),
        NoBoundingBoxesContainer(...),
    ]
)

Still, IMO we are going deep into specialized transforms here. Unless there is significant demand for something like this in the library, I think you are better off defining such a transform yourself.

@CedricPicron
Copy link

Yes, using NoBoundingBoxesContainer is a nice solution.

I guess the key for users will be to have good documentation and examples regarding the implementation of custom container-like and transform-like transforms, but I'm sure you guys already have things planned.

Thanks @pmeier for the quick responses and good luck finalizing this project. I hope the feedback was (somewhat) useful.

@pmeier
Copy link
Contributor

pmeier commented Aug 17, 2023

@csmotion The code you posted looks correct. I suspect what is missing is that you didn't convert the items inside target to datapoints? That is what is needed for the transform to handle them correctly. Have a look at the datapoints FAQ and an end-to-end example using transforms v2.

@csmotion
Copy link

@pmeier You are 100% correct, I missed it in the blog post (RIP). Much appreciated!

@pmeier
Copy link
Contributor

pmeier commented Aug 24, 2023

@eirikeve With the upcoming release the v2 wrapper is now also pickleable and thus will work with a multiprocessing spawning context as is the default on macOS. The patch should be available as nightly release in a few hours.

@fvgt
Copy link

fvgt commented Aug 25, 2023

I am trying to use the CutMix augmentation following the guide on the web page: https://pytorch.org/vision/main/auto_examples/v2_transforms/plot_cutmix_mixup.html#sphx-glr-auto-examples-v2-transforms-plot-cutmix-mixup-py
However, I get the error:
'module 'torchvision.transforms.v2' has no attribute 'CutMix''

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Aug 25, 2023

@fvgt you may need to install torchvision from source: https://github.com/pytorch/vision/blob/main/CONTRIBUTING.md#development-installation

@pmeier
Copy link
Contributor

pmeier commented Aug 25, 2023

@fvgt you are looking at the documentation for the main branch, but you are likely using a stable release. CutMix and MixUp will only become available in the next release. If you are not restricted by the version you use, you can install a nightly release that already has this implemented.

@fvgt
Copy link

fvgt commented Aug 25, 2023

That was my first intuition as well and I tried using the nightly version, using the following command:

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118

Unfortunately, I still got the same error.

Edit: I will follow the guide posted by @vfdev-5 and will check again. Thank you very much for the quick replies!

@pmeier
Copy link
Contributor

pmeier commented Aug 25, 2023

@fvgt This is most likely an environment issue. Of course it should work as well, but there is no need to build from source. The nightly build is sufficient. Please open a separate issue and post the results of the environment collection script.

@tianlianghai
Copy link

the cutmix only support for classification task for now, hope v2.cutmix support segmentation task! thanks.

@orena1
Copy link

orena1 commented Oct 15, 2023

If I have a batch of videos and I want to run v2.RandomPerspective on each video differently.
Currently if I use this format:

        self.train_video_pipeline = torch.nn.Sequential(
            v2.RandomPerspective(0.5,1),
            torchvision.transforms.Normalize(0.15228, 0.0794))

All videos will have the same transformation.

train_video_pipeline = torch.nn.Sequential(
            v2.RandomPerspective(0.2,0.6),
torchvision.transforms.Normalize(0.15228, 0.0794))

out = train_video_pipeline(1+torch.zeros((10,123,1,100,100))) # e.g. 10 videos each with 123 time-steps, one channel
import matplotlib.pyplot as plt
axs = plt.subplots(1,3,figsize=(13,4))[1]

axs[0].imshow(out[0,7,0].detach().numpy())
axs[1].imshow(out[1,7,0].detach().numpy())
axs[2].imshow(out[2,7,0].detach().numpy())

Is there anyway to have different transformation for each video?

image

@NicolasHug
Copy link
Member

Hi @orena1 , the main way to do that is to unbatch, call the random transforms individually on all samples (or use `.get_params + functional API)), and then re-batch the samples.

This is something we'd like to support more transparently, perhaps at least by providing some kind of UnBatchThenCallThenRebatch transform helper (name TBD). But because of the way random parameters are sampled, and because each randomization leads to different parametrizations, there is often no way to process an entire batch efficiently.

@Axel-Jacobsen
Copy link

Axel-Jacobsen commented Oct 25, 2023

Howdy!

The TVTensors + V2 transforms are a pretty cool addition. I'm finding it easy to integrate into one of my current projects, which is great.

I found and am using v2.ConvertBoundingBoxFormat, but haven't found anything that would e.g. normalize bounding box coordinates to the size of the image. E.g., if the image is (100px, 100px), and the center of the image is at (50px, 50px) with (w, h) = (25px, 25px), the normalized coordinates would be (xc, yc, w, h) = (0.5, 0.5, 0.25, 0.25). Normalization of bbox coordinates is frequent in object detection, e.g. the YOLO family of networks.

  • Is this already implemented somewhere?
  • If not, do you think that there any appetite for it in transforms v2? Perhaps there is a good reason why it isn't in there, but if not, I'd happily do the work.

@pmeier
Copy link
Contributor

pmeier commented Oct 26, 2023

@Axel-Jacobsen

Is this already implemented somewhere?

No.

Perhaps there is a good reason why it isn't in there

Yup. Right now we hard-assume that bounding boxes are in absolute coordinates. This makes it easier to implement the corresponding kernels:

  1. We don't need an extra flag on the kernel and subsequently on the bounding box instance that indicates whether the coordinates are absolute or relative.
  2. We don't need extra branching logic inside the kernel to account for both use cases.

From your comment I get that normalized bounding boxes are only required for the model. If that is true, I suggest you implement a custom NormalizeBoundingBoxes transform and just put it at the end of your pipeline. Something along the lines of

import torch
from torchvision import tv_tensors


def normalize_bounding_boxes(bounding_boxes: tv_tensors.BoundingBoxes, dtype=torch.float32) -> torch.Tensor:
    canvas_height, canvas_width = bounding_boxes.canvas_size
    # The .as_subclass(torch.Tensor) is not required, but only a performance improvement
    # See https://pytorch.org/vision/stable/auto_examples/transforms/plot_tv_tensors.html#why-is-this-happening
    return (
        bounding_boxes.as_subclass(torch.Tensor)
        .to(dtype)
        .div_(
            torch.tensor(
                [canvas_width, canvas_height, canvas_width, canvas_height],
                dtype=dtype,
                device=bounding_boxes.device,
            )
        )
    )


class NormalizeBoundingBoxes(torch.nn.Module):
    def forward(self, image, target):
        target["boxes"] = normalize_bounding_boxes(target["boxes"])
        return image, target


image = tv_tensors.Image(torch.rand(3, 100, 100))
bounding_boxes = tv_tensors.BoundingBoxes(
    [[50, 50, 25, 25]], 
    format=tv_tensors.BoundingBoxFormat.CXCYWH, 
    canvas_size=(100, 100),
)
target = {"boxes": bounding_boxes}

transform = NormalizeBoundingBoxes()
transformed_sample = transform(image, target)

torch.testing.assert_close(
    transformed_sample[1]["boxes"],
    torch.tensor([[0.5, 0.5, 0.25, 0.25]]),
)

This requires you to hardcode the schema of the samples that you want to pass. If you need a version of the transform that works for arbitrary sample schemas, as is the default for all builtin v2 transforms, you can do:

from torchvision.transforms import v2 as transforms

class NormalizeBoundingBoxes(transforms.Transform):
    _transformed_types = (tv_tensors.BoundingBoxes,)

    def _transform(self, input, params):
        return normalize_bounding_boxes(input)

But be aware that we are using private parts of the API here and there no BC guarantees for them.

@Axel-Jacobsen
Copy link

Axel-Jacobsen commented Oct 26, 2023

@pmeier sounds good! I appreciate the quick and thorough reply. I'll give this a go in my project.

@EricThomson
Copy link

EricThomson commented Jan 10, 2024

Thanks for making this new API for transformations it's great!

I was sent here from a link on the ToDtype page, as I'm trying to figure out the intent and consequences of the scale param.

My understanding was that (for instance for a float32 tv_tensor) it was supposed to scale values to [0,1]. This is based partly on the page's description of the scale arg, where it links to section in docs on Dtype and expected value range, which is 0-1 for float32. But when I feed it a tensor with values outside of that range, it returns values in the same deviant range.

I dug around in the implementation a bit, and while there is some checking to see if the data types support scaling, I'm not seeing any actual computational consequences of the scale param:

class ToDtype(Transform):

But there is a good chance I'm just missing something too 😄 . At any rate, I can implement it myself easily enough, but I was confused by scale and what it is doing.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jan 10, 2024

@EricThomson ToDtype calls functional implementation: F.to_dtype here: https://github.com/pytorch/vision/blob/d23430765b5df76cd1267f438f129f51b7d6e3e1/torchvision/transforms/v2/_misc.py#L275C40-L275C40 where scale arg is used and finally, this code is run for images:

def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:

You can see there a quick return when scale is False:
elif not scale:
return image.to(dtype)

@NicolasHug
Copy link
Member

But when I feed it a tensor with values outside of that range, it returns values in the same deviant range

@EricThomson if you pass a torch.float32 tensor to ToDtype(torch.float32), then scaling won't happen regardless of whether you set scale=True. The conversion (and scaling) will only happen if the input tensor is not of float dtype.

To convert a float tensor form an arbitrary scale to another, you could use Normalize instead.

@EricThomson
Copy link

EricThomson commented Jan 10, 2024

Thanks @vfdev-5 for pointing out in more detail how kernel dispatching works (I'm embarrassed I didn't go deeply enough 😳 ). The logic becomes clear in to_dtype_image()

@NicolasHug thanks for explaining in more detail and the suggestion. I'm not sure if Normalize is what I want, as that would push to a certain std/mn, while what I really want is scaling to [0,1].

Clearly I was trying to get ToDtype to do something outside its current use: I can create a transform to scale my data for floats. That said, not sure if folks would be against adding scaling for floats in to_dtype_image() when scale is set to True at some point in the future?

@NicolasHug
Copy link
Member

NicolasHug commented Jan 10, 2024

what I really want is scaling to [0,1].

Normalize just returns (x - mean) / std so you can use it to linearly map any interval [a, b] into [c, d]. But I acknowledge that it can potentially be counter-intuitive to use.

adding scaling for floats in to_dtype_image() when scale is set to True at some point in the future?

To clarify the feature request: you mean converting from an arbitrary scale into [0, 1], where the arbitrary scale of the input x is determined by x.min() and x.max()?

@EricThomson
Copy link

@NicolasHug nice! I can piggyback on Normalize , with mean = min(), and std= max()-min(). Thanks for the suggestion -- I'll just put a comment in my code about what I'm doing so people aren't confused. 😄

In terms of the feature request, yes that is what I was suggesting.

@NicolasHug
Copy link
Member

Thank you so much everyone for your input and feedback. The V2 transforms are now stable and part of the latest torchvision release https://github.com/pytorch/vision/releases/tag/v0.17.0.

I'll close this issue as it's getting quite big and somewhat outdated now, but we'd still love to hear from you! Please feel free to open new issues with any feedback or feature requests you may have!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.