Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce heuristic for simple tensor handling of transforms v2 #7170

Merged
merged 13 commits into from
Feb 8, 2023

Conversation

pmeier
Copy link
Contributor

@pmeier pmeier commented Feb 3, 2023

Addresses the thread in #6663 (comment).

cc @vfdev-5 @bjuncek

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, gave it a quick look

torchvision/prototype/transforms/_transform.py Outdated Show resolved Hide resolved
torchvision/prototype/transforms/_transform.py Outdated Show resolved Hide resolved
torchvision/prototype/transforms/_transform.py Outdated Show resolved Hide resolved
torchvision/prototype/transforms/_transform.py Outdated Show resolved Hide resolved
test/test_prototype_transforms.py Outdated Show resolved Hide resolved
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, some minor comments but it looks great

test/test_prototype_transforms.py Outdated Show resolved Hide resolved
torchvision/prototype/transforms/_transform.py Outdated Show resolved Hide resolved
test/test_prototype_transforms.py Outdated Show resolved Hide resolved
test/test_prototype_transforms.py Outdated Show resolved Hide resolved
@pmeier
Copy link
Contributor Author

pmeier commented Feb 6, 2023

There are three transforms for which this heuristic is somewhat awkward and all for the same reason:

  • ToDtype
  • PermuteDimensions
  • TransposeDimensions

They all take an input argument that can be a dictionary and the transform selects the appropriate value based on the input type. For example:

import torch
from torchvision.prototype import datapoints, transforms

sample = dict(
    image=datapoints.Image(torch.randint(0, 256, (3, 32, 32), dtype=torch.uint8)),
    boxes=datapoints.BoundingBox(torch.randint(0, 32, (5,)), format="xyxy", spatial_size=(32, 32)),
)

dtype = {
    datapoints.Image: torch.float32,
    datapoints.BoundingBox: torch.float64,
}
transform = transforms.ToDtype(dtype)

transformed_sample = transform(sample)

assert transformed_sample["image"].dtype is torch.float32
assert transformed_sample["boxes"].dtype is torch.float64

sample["tensor"] = torch.rand((3, 16, 16), dtype=torch.float16)
dtype[torch.Tensor] = torch.int32

transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)

assert transformed_sample["image"].dtype is torch.float32
assert transformed_sample["boxes"].dtype is torch.float64
assert transformed_sample["tensor"].dtype is torch.int32  # boom

As shown above, the transform is not applied to the plain tensor according to the heuristic above. That is somewhat awkward since we specified it explicitly in parameter. The example above works on main.

I guess one way to fix this is to disallow torch.Tensor in the parameter dictionary. Better yet, only allow datapoints. Meaning, if someone wants to use this fine-grained control, they'll have to wrap their inputs.

Thoughts?

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, NIT but LGTM regardless

torchvision/prototype/transforms/_transform.py Outdated Show resolved Hide resolved
@NicolasHug
Copy link
Member

Sorry I had missed your comment above before I approved.

This is unrelated to this issue, but I'm tempted to keep PermuteDimensions and TransposeDimensions in the prototype area for now, because they break a lot of assumptions; so I'll just focus on ToDtype here.

I guess one way to fix this is to disallow torch.Tensor in the parameter dictionary. Better yet, only allow datapoints. Meaning, if someone wants to use this fine-grained control, they'll have to wrap their inputs.

Can we just raise a warning specific to ToDtype() if the Tensor key is specified along Image and Video saying

Hey, you passed Tensors and Images (or Videos), but we won't be transforming the tensors

and still support the Tensor key if neither Image or Video are specified?

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still LGTM!

test/test_prototype_transforms.py Outdated Show resolved Hide resolved
@pmeier pmeier merged commit 1120aa9 into pytorch:main Feb 8, 2023
@pmeier pmeier deleted the tensor-fallback-heuristic branch February 8, 2023 19:02
facebook-github-bot pushed a commit that referenced this pull request Mar 28, 2023
… v2 (#7170)

Reviewed By: vmoens

Differential Revision: D44416271

fbshipit-source-id: 20c92067665ea106550bc29947f2596a36000025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants