Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast image processor #28847

Merged
merged 40 commits into from
Jun 11, 2024
Merged

Conversation

amyeroberts
Copy link
Collaborator

@amyeroberts amyeroberts commented Feb 2, 2024

What does this PR do?

Adds a ViTImageProcessorFast class, which uses a torchvision backend to speed up image processor

import requests
from PIL import Image
from transformers import AutoImageProcessor

# Load the image
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

# Set use_fast=True to select the fast implementation
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k", use_fast=True)

outputs = image_processor(images=image, return_tensors="pt")

Benchmark comparing the two image processors:
2024_06_06_benchmark_fast_image_processor

List of PIL image input
Fast mean time: 0.005302833318710327 std: 0.0007948700795509047
Slow mean time: 0.010124003410339356 std: 0.0002542689804230619
1.9541 times faster

Torch tensor batch input
Fast mean time: 0.0016665704250335694 std: 5.2973509597780026e-05
Slow mean time: 0.05901132440567017 std: 0.0005533502207234909
35.4395 times faster. 35.4395 times faster

List of PIL image input
Fast mean time: 0.0017093017101287842 std: 4.127424976607243e-05
Slow mean time: 0.005274987697601318 std: 6.580574283983749e-05
3.0876 times faster

Torch tensor batch input
Fast mean time: 0.0009776098728179932 std: 1.986684460031034e-05
Slow mean time: 0.005883036375045776 std: 5.80707152345902e-05
6.0196 times faster. 6.0196 times faster
Script for replicating benchmark
import time

import matplotlib.pyplot as plt
import numpy as np
import requests
import torch
from PIL import Image

from transformers import ViTImageProcessorFast, ViTImageProcessor

# Load the image
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

image_processor_fast = ViTImageProcessorFast.from_pretrained("google/vit-base-patch16-224-in21k")
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

processed_image_fast = image_processor_fast(images=image, do_rescale=True, do_normalize=True, return_tensors="pt")
processed_image = image_processor(images=image, do_rescale=True,  do_normalize=True, return_tensors="pt")

print(processed_image['pixel_values'].dtype)  # torch.float32
print(processed_image_fast['pixel_values'].dtype)  # torch.float32
print((processed_image_fast['pixel_values'] - processed_image['pixel_values']).abs().max())  # tensor(0.)


N_ITERATIONS = 1000
DO_RESCALE = True

# # Benchmarking
fast_times = []
for _ in range(N_ITERATIONS):
    start = time.time()
    image_processor_fast(images=[image, image, image], return_tensors="pt")
    end = time.time()
    fast_times.append(end - start)

slow_times = []
for _ in range(N_ITERATIONS):
    start = time.time()
    image_processor(images=[image, image, image], return_tensors="pt")
    end = time.time()
    slow_times.append(end - start)

fast_times = np.array(fast_times)
slow_times = np.array(slow_times)
print("List of PIL image input")
print("Fast mean time:", np.mean(fast_times), "std:", np.std(fast_times))
print("Slow mean time:", np.mean(slow_times), "std:", np.std(slow_times))
times_faster = np.mean(slow_times/fast_times)
print(f"{times_faster:.4f} times faster")

fig, ax = plt.subplots(2, 2, figsize=(10, 5), sharex=True)

# Plotting
ax[0][0].hist(fast_times, bins=max(N_ITERATIONS // 10, 10), label=f"fast {np.mean(fast_times):.4f} ± {np.std(fast_times):.4f}")
ax[0][0].hist(slow_times, bins=max(N_ITERATIONS // 10, 10), label=f"slow {np.mean(slow_times):.4f} ± {np.std(slow_times):.4f}")
ax[0][0].set_xlabel("Time (s)")
ax[0][0].set_ylabel("Frequency")
ax[0][0].legend()
ax[0][0].set_title(f"List of PIL image input.\n{times_faster:.4f} times faster")

# Benchmarking - torch
image_pt = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float()
# FIXME - how to handle the rescaling?
if not DO_RESCALE:
    image_pt /= 255
images = torch.vstack([image_pt, image_pt, image_pt])

fast_times = []
for _ in range(N_ITERATIONS):
    start = time.time()
    image_processor_fast(images=images, return_tensors="pt", do_rescale=DO_RESCALE)
    end = time.time()
    fast_times.append(end - start)

slow_times = []
for _ in range(N_ITERATIONS):
    start = time.time()
    image_processor(images=images, return_tensors="pt", do_rescale=DO_RESCALE)
    end = time.time()
    slow_times.append(end - start)

fast_times = np.array(fast_times)
slow_times = np.array(slow_times)
print("Torch tensor batch input")
print("Fast mean time:", np.mean(fast_times), "std:", np.std(fast_times))
print("Slow mean time:", np.mean(slow_times), "std:", np.std(slow_times))
times_faster = np.mean(slow_times/fast_times)
print(f"{times_faster:.4f} times faster.")

# Plotting
ax[0][1].hist(fast_times, bins=max(N_ITERATIONS // 10, 10), label=f"fast {np.mean(fast_times):.4f} ± {np.std(fast_times):.4f}")
ax[0][1].hist(slow_times, bins=max(N_ITERATIONS // 10, 10), label=f"slow {np.mean(slow_times):.4f} ± {np.std(slow_times):.4f}")
ax[0][1].set_xlabel("Time (s)")
ax[0][1].set_ylabel("Frequency")
ax[0][1].legend()
ax[0][1].set_title(f"Torch tensor batch input.\n{times_faster:.4f} times faster")

# # Benchmarking PIL images, different sizes
image_small = image.resize((64, 48), Image.Resampling.NEAREST)
image_medium = image.resize((128, 96), Image.Resampling.NEAREST)
image_large = image.resize((256, 192), Image.Resampling.NEAREST)

images = [image_small, image_medium, image_large]

fast_times = []
for _ in range(N_ITERATIONS):
    start = time.time()
    image_processor_fast(images=images, return_tensors="pt")
    end = time.time()
    fast_times.append(end - start)

slow_times = []
for _ in range(N_ITERATIONS):
    start = time.time()
    image_processor(images=images, return_tensors="pt")
    end = time.time()
    slow_times.append(end - start)

fast_times = np.array(fast_times)
slow_times = np.array(slow_times)
print("List of PIL image input")
print("Fast mean time:", np.mean(fast_times), "std:", np.std(fast_times))
print("Slow mean time:", np.mean(slow_times), "std:", np.std(slow_times))
times_faster = np.mean(slow_times/fast_times)
print(f"{times_faster:.4f} times faster")

ax[1][0].hist(fast_times, bins=max(N_ITERATIONS // 10, 10), label=f"fast {np.mean(fast_times):.4f} ± {np.std(fast_times):.4f}")
ax[1][0].hist(slow_times, bins=max(N_ITERATIONS // 10, 10), label=f"slow {np.mean(slow_times):.4f} ± {np.std(slow_times):.4f}")
ax[1][0].set_xlabel("Time (s)")
ax[1][0].set_ylabel("Frequency")
ax[1][0].legend()
ax[1][0].set_title(f"List of PIL different sized images input.\n{times_faster:.4f} times faster")

# Benchmarking - torch different sizes
images = [
    torch.tensor(np.array(image_small)).permute(2, 0, 1).float(),
    torch.tensor(np.array(image_medium)).permute(2, 0, 1).float(),
    torch.tensor(np.array(image_large)).permute(2, 0, 1).float(),
]

fast_times = []
for _ in range(N_ITERATIONS):
    start = time.time()
    image_processor_fast(images=images, return_tensors="pt", do_rescale=DO_RESCALE)
    end = time.time()
    fast_times.append(end - start)

slow_times = []
for _ in range(N_ITERATIONS):
    start = time.time()
    image_processor(images=images, return_tensors="pt", do_rescale=DO_RESCALE)
    end = time.time()
    slow_times.append(end - start)

fast_times = np.array(fast_times)
slow_times = np.array(slow_times)
print("Torch tensor batch input")
print("Fast mean time:", np.mean(fast_times), "std:", np.std(fast_times))
print("Slow mean time:", np.mean(slow_times), "std:", np.std(slow_times))
times_faster = np.mean(slow_times/fast_times)
print(f"{times_faster:.4f} times faster. {times_faster:.4f} times faster")

# Plotting
ax[1][1].hist(fast_times, bins=max(N_ITERATIONS // 10, 10), label=f"fast {np.mean(fast_times):.4f} ± {np.std(fast_times):.4f}")
ax[1][1].hist(slow_times, bins=max(N_ITERATIONS // 10, 10), label=f"slow {np.mean(slow_times):.4f} ± {np.std(slow_times):.4f}")
ax[1][1].set_xlabel("Time (s)")
ax[1][1].set_ylabel("Frequency")
ax[1][1].legend()
ax[1][1].set_title(f"Torch tensor list of different sized images input.\n{times_faster:.4f} times faster")

fig.tight_layout()
fig.savefig("benchmark_fast_image_processor.png")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts amyeroberts added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Feb 12, 2024
@amyeroberts amyeroberts force-pushed the fast-image-processor branch 2 times, most recently from 9ceab5d to ffdadd2 Compare May 16, 2024 16:08
@amyeroberts amyeroberts marked this pull request as ready for review May 16, 2024 20:08
@@ -0,0 +1,542 @@
# coding=utf-8
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just moving objects from image_processing_utils.py to this new file, which will share the common objects between the base fast and slow image processors in image_processing_utils.py and image_processing_utils_fast.py.

This is to match the structure of the tokenizers files

@@ -313,6 +322,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
use_fast (`bool`, *optional*, defaults to `False`):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaulting to false for now, so users have to actively opt-in, to avoid any surprises

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add a warning for supported models that the user is using a slow version (+ default to None, if None set to False and warn for model in the list of supported models)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Added in f6a0847

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very thorough. I just left minor feedback.

Would be interesting to benchmark how does the speedup hold across higher resolutions and batch sizes (i.e., when we increase both batch size and resolution).

src/transformers/models/vit/image_processing_vit_fast.py Outdated Show resolved Hide resolved
src/transformers/models/vit/image_processing_vit_fast.py Outdated Show resolved Hide resolved
Comment on lines 75 to 77
PIL = "pillow"
TORCH = "torch"
NUMPY = "numpy"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used to support TensorFlow, too, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the slow image processors we support tensorflow and jax as well. The usage of these frameworks compared to pytorch is tiny. As this uses a torchvision backend, I'm not sure how much sense it is to convert to torch tensors, run in pytorch and the convert back, as this would require TF users to install torchvision.

I've added an exception if jax or TF arrays are passed in.

We can add support in the future if we end up getting lots of requests for it.

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, with the same interface but faster, and the code looks much easier!
I added some comments

src/transformers/models/vit/image_processing_vit_fast.py Outdated Show resolved Hide resolved
Comment on lines 184 to 188
# Regardless of whether we rescale, all PIL and numpy values need to be converted to a torch tensor
# to keep cross compatibility with slow image processors
convert_to_tensor = image_type in (ImageType.PIL, ImageType.NUMPY)
if convert_to_tensor:
transforms.append(ToTensor())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth writing a custom ToTensor without rescaling logic to avoid the complex rules below? We can just copy-modify the original one

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found PILToTensor which thankfully does the conversion without scaling.

I added three custom transforms regarding this comment and your other comment about fusing the operations:

  • NumpyToTensor: numpy equivalent for PILToTensor
  • Rescale: class which will rescale pixel values by rescale_factor
  • FusedRescaleNormalize: which will do as you suggested and combine the rescale and normalize operations. Only tricky thig here is we need to make sure to convert to a floating tensor first

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing with and without the fused normalize and rescale, the speed ups are between 2 - 13 %

With separate normaliza and rescale

List of PIL image input
Fast mean time: 0.006879540920257568 std: 0.00035099185559956325
Slow mean time: 0.010017627239227295 std: 0.00016269613600058844
1.4601 times faster
Torch tensor batch input
Fast mean time: 0.0017743961811065674 std: 6.134420692941824e-05
Slow mean time: 0.05170198488235474 std: 0.0006948744101280261
29.1670 times faster
List of PIL image input
Fast mean time: 0.002653493165969849 std: 5.281411681712124e-05
Slow mean time: 0.004650842666625976 std: 2.603507052302019e-05
1.7534 times faster
Torch tensor batch input
Fast mean time: 0.001086679458618164 std: 2.1535627834848004e-05
Slow mean time: 0.006056704044342041 std: 4.2741895905449867e-05
5.5755 times faster

With fused normalize and rescale:

List of PIL image input
Fast mean time: 0.006737879753112793 std: 0.0004794828832439867
Slow mean time: 0.009744790077209473 std: 0.0001803401154248809
1.4535 times faster
Torch tensor batch input
Fast mean time: 0.0017259061336517335 std: 0.0001853211607801645
Slow mean time: 0.050941661834716795 std: 0.0006355194041148941
29.7736 times faster
List of PIL image input
Fast mean time: 0.0025912017822265626 std: 4.1097075065458606e-05
Slow mean time: 0.005103119373321533 std: 2.3521387703663177e-05
1.9698 times faster
Torch tensor batch input
Fast mean time: 0.0009581971168518066 std: 2.3126728878460678e-05
Slow mean time: 0.006041062116622925 std: 3.5341232719903447e-05
6.3075 times faster

Speed ups:

List of PIL image input: 1.0210245911674707 times faster
Torch tensor batch input: 1.0280954140606922 times faster
List of PIL image input: 1.024039572745956 times faster
Torch tensor batch input: 1.134087589606292 times faster

src/transformers/models/vit/image_processing_vit_fast.py Outdated Show resolved Hide resolved
Comment on lines 329 to 340
self._maybe_update_transforms(
do_resize=do_resize,
do_rescale=do_rescale,
do_normalize=do_normalize,
size=size,
resample=resample,
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
image_type=image_type,
)
transformed_images = [self._transforms(image) for image in images]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks a bit tricky changing _transforms, because stored _transforms and class parameters became inconsistent. I don't see any side effects here, but probably more safe will be to return new transforms instead of updating

        transforms = self._maybe_update_transforms(
            do_resize=do_resize,
            do_rescale=do_rescale,
            do_normalize=do_normalize,
            size=size,
            resample=resample,
            rescale_factor=rescale_factor,
            image_mean=image_mean,
            image_std=image_std,
            image_type=image_type,
        )
        transformed_images = [transforms(image) for image in images]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because stored _transforms and class parameters became inconsistent

Which class params do you mean here - _transform_params?

I've changed it so we return transforms which we use. I've kept it as storing though - this is part of the trick that helps make this class fast when it's called multiple times: we don't have to recompose the transforms.

I might not be understanding the concern about class parameters and _transforms becoming inconsistent. Let me know if there's a particularly risky code path I should try and protect against here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably I missed some design nuances, now I see that _transform_settings are updated accordingly 🙂

The general idea was to use the method without saving the state in an instance. Can we use lru_cache for _build_transform function without storing _transforms and _transforms_settings? or is there any reason to save them?

Here is how I see it, in case lru_cache works fine - transform is not going to be recomposed for the same parameters. Does it make sense?

class BaseImageProcessorFast(BaseImageProcessor):
    _transform_params = None

    def _build_transforms(self, **kwargs) -> "Compose":
        """
        Given the input settings e.g. do_resize, build the image transforms.
        """
        raise NotImplementedError

    def _validate_params(self, **kwargs) -> None:
        for k, v in kwargs.items():
            if k not in self._transform_params:
                raise ValueError(f"Invalid transform parameter {k}={v}.")

    @functools.lru_cache(maxsize=1)
    def get_transforms(self, **kwargs) -> "Compose":
        self._validate_params(**kwargs)
        return self._build_transforms(**kwargs)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Yes, absolutely we can do that!

My only thought is that have no transforms stored makes it harder to inspect if debugging, and to obtain you have to pass in all the do_resize, size etc. arguments, which might be annoying. As I'll probably be the person maintain for the foreseeable, I'll push this change and add back some explicit setting if it ends up being a pain :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very compelling argument! I just checked and there is no direct way to inspect the functools cache, so that's might be a pain for debugging :)

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments about numpy images

src/transformers/models/vit/image_processing_vit_fast.py Outdated Show resolved Hide resolved
Comment on lines 147 to 142
# Do we want to permute the channels here?
transforms.append(NumpyToTensor())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect NumpyToTensor() returns the same CHW format as PILToTensor(), and some torchvision transforms expect CHW format too, for example normalize

Copy link
Collaborator Author

@amyeroberts amyeroberts May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, in terms of what the transforms accept, I think we're best working in and assuming channels first. The reason I'm not sure about permuting is because of ToTensor's behaviour, which when given a numpy array of 3 dimensions assumes it's in (H, W, C) format i.e. should we make the same assumption?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it! I expect that numpy images come in HWC format because any library, as far as I know, that reads the image and returns numpy does so in HWC format. The same applies to third-party libraries for image processing/augmentation; if they work with numpy images, it is assumed to have the HWC format.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated NumpyToTensor to transpose to CHW, assuming a HWC input: 6fb7df2

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥 Amazing work and results, can't wait for this to be democratized to other models

src/transformers/image_processing_utils_fast.py Outdated Show resolved Hide resolved
src/transformers/image_transforms.py Show resolved Hide resolved
src/transformers/models/auto/image_processing_auto.py Outdated Show resolved Hide resolved
@@ -313,6 +322,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
use_fast (`bool`, *optional*, defaults to `False`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add a warning for supported models that the user is using a slow version (+ default to None, if None set to False and warn for model in the list of supported models)

src/transformers/models/auto/image_processing_auto.py Outdated Show resolved Hide resolved
tests/test_image_processing_common.py Show resolved Hide resolved
tests/test_image_processing_common.py Show resolved Hide resolved
@amyeroberts amyeroberts force-pushed the fast-image-processor branch 2 times, most recently from b426fff to 4b6cad3 Compare June 5, 2024 18:13
@amyeroberts
Copy link
Collaborator Author

Thanks everyone for your extensive reviews! I think I've addressed everything and have updated the graph on the description to reflect the current benchmarks.

@qubvel Would be great to have a final review before merging

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉 awesome!

I have just two questions/suggestions regarding testing slow vs fast:

  1. Do we claim that fast and slow image processor's results are somehow equivalent? Probably, they will not match exactly due to some difference in resizing algorithm, but we can test mean absolute error.
  2. Worth it to add a slow test to benchmark slow vs fast image processors, just to be sure that fast is really faster and we will not break anything in the future?

src/transformers/models/vit/image_processing_vit_fast.py Outdated Show resolved Hide resolved
@@ -29,3 +29,4 @@ timm
albumentations >= 1.4.5
torchmetrics
pycocotools
Pillow>=10.0.1,<=15.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why 15.0 here? :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a security thing :) It matches the pin we have in setup.py which was set in #27409

@amyeroberts
Copy link
Collaborator Author

@qubvel Thanks for your review!

Regarding your questions:

Do we claim that fast and slow image processor's results are somehow equivalent? Probably, they will not match exactly due to some difference in resizing algorithm, but we can test mean absolute error.

I don't think we can claim they'll be exactly the same, but they should be similar. I can add an equivalence test that has a certain tolerance.

Worth it to add a slow test to benchmark slow vs fast image processors, just to be sure that fast is really faster and we will not break anything in the future?

Sure - good idea. I can add a test to make sure a single pass (or maybe series of passes as the fast benefits from the cache) are faster for the fast class.

@amyeroberts amyeroberts merged commit f53fe35 into huggingface:main Jun 11, 2024
23 checks passed
@amyeroberts amyeroberts deleted the fast-image-processor branch June 11, 2024 14:47
Comment on lines +103 to +104
("mobilevit", ("MobileViTImageProcessor",)),
("mobilevit", ("MobileViTImageProcessor",)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts Is this typo or is there any specific reason for duplication?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have found this when I try to resolve conflict from local branch

Copy link
Collaborator Author

@amyeroberts amyeroberts Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a typo - we can remove one of the lines. Thanks for catching!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it must have been added when mobilevit was first added, as it was present in the previous automap list. I've opened a PR here to remove it: #31383

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jun 14, 2024
* Draft fast image processors

* Draft working fast version

* py3.8 compatible cache

* Enable loading fast image processors through auto

* Tidy up; rescale behaviour based on input type

* Enable tests for fast image processors

* Smarter rescaling

* Don't default to Fast

* Safer imports

* Add necessary Pillow requirement

* Woops

* Add AutoImageProcessor test

* Fix up

* Fix test for imagegpt

* Fix test

* Review comments

* Add warning for TF and JAX input types

* Rearrange

* Return transforms

* NumpyToTensor transformation

* Rebase - include changes from upstream in ImageProcessingMixin

* Safe typing

* Fix up

* convert mean/std to tesnor to rescale

* Don't store transforms in state

* Fix up

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Warn if fast image processor available

* Update src/transformers/models/vit/image_processing_vit_fast.py

* Transpose incoming numpy images to be in CHW format

* Update mapping names based on packages, auto set fast to None

* Fix up

* Fix

* Add AutoImageProcessor.from_pretrained(checkpoint, use_fast=True) test

* Update src/transformers/models/vit/image_processing_vit_fast.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Add equivalence and speed tests

* Fix up

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
itazap pushed a commit that referenced this pull request Jun 17, 2024
* Draft fast image processors

* Draft working fast version

* py3.8 compatible cache

* Enable loading fast image processors through auto

* Tidy up; rescale behaviour based on input type

* Enable tests for fast image processors

* Smarter rescaling

* Don't default to Fast

* Safer imports

* Add necessary Pillow requirement

* Woops

* Add AutoImageProcessor test

* Fix up

* Fix test for imagegpt

* Fix test

* Review comments

* Add warning for TF and JAX input types

* Rearrange

* Return transforms

* NumpyToTensor transformation

* Rebase - include changes from upstream in ImageProcessingMixin

* Safe typing

* Fix up

* convert mean/std to tesnor to rescale

* Don't store transforms in state

* Fix up

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Warn if fast image processor available

* Update src/transformers/models/vit/image_processing_vit_fast.py

* Transpose incoming numpy images to be in CHW format

* Update mapping names based on packages, auto set fast to None

* Fix up

* Fix

* Add AutoImageProcessor.from_pretrained(checkpoint, use_fast=True) test

* Update src/transformers/models/vit/image_processing_vit_fast.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Add equivalence and speed tests

* Fix up

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
itazap pushed a commit that referenced this pull request Jun 17, 2024
* Draft fast image processors

* Draft working fast version

* py3.8 compatible cache

* Enable loading fast image processors through auto

* Tidy up; rescale behaviour based on input type

* Enable tests for fast image processors

* Smarter rescaling

* Don't default to Fast

* Safer imports

* Add necessary Pillow requirement

* Woops

* Add AutoImageProcessor test

* Fix up

* Fix test for imagegpt

* Fix test

* Review comments

* Add warning for TF and JAX input types

* Rearrange

* Return transforms

* NumpyToTensor transformation

* Rebase - include changes from upstream in ImageProcessingMixin

* Safe typing

* Fix up

* convert mean/std to tesnor to rescale

* Don't store transforms in state

* Fix up

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Warn if fast image processor available

* Update src/transformers/models/vit/image_processing_vit_fast.py

* Transpose incoming numpy images to be in CHW format

* Update mapping names based on packages, auto set fast to None

* Fix up

* Fix

* Add AutoImageProcessor.from_pretrained(checkpoint, use_fast=True) test

* Update src/transformers/models/vit/image_processing_vit_fast.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Add equivalence and speed tests

* Fix up

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
itazap pushed a commit that referenced this pull request Jun 17, 2024
* Draft fast image processors

* Draft working fast version

* py3.8 compatible cache

* Enable loading fast image processors through auto

* Tidy up; rescale behaviour based on input type

* Enable tests for fast image processors

* Smarter rescaling

* Don't default to Fast

* Safer imports

* Add necessary Pillow requirement

* Woops

* Add AutoImageProcessor test

* Fix up

* Fix test for imagegpt

* Fix test

* Review comments

* Add warning for TF and JAX input types

* Rearrange

* Return transforms

* NumpyToTensor transformation

* Rebase - include changes from upstream in ImageProcessingMixin

* Safe typing

* Fix up

* convert mean/std to tesnor to rescale

* Don't store transforms in state

* Fix up

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Warn if fast image processor available

* Update src/transformers/models/vit/image_processing_vit_fast.py

* Transpose incoming numpy images to be in CHW format

* Update mapping names based on packages, auto set fast to None

* Fix up

* Fix

* Add AutoImageProcessor.from_pretrained(checkpoint, use_fast=True) test

* Update src/transformers/models/vit/image_processing_vit_fast.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Add equivalence and speed tests

* Fix up

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
itazap pushed a commit that referenced this pull request Jun 18, 2024
* Draft fast image processors

* Draft working fast version

* py3.8 compatible cache

* Enable loading fast image processors through auto

* Tidy up; rescale behaviour based on input type

* Enable tests for fast image processors

* Smarter rescaling

* Don't default to Fast

* Safer imports

* Add necessary Pillow requirement

* Woops

* Add AutoImageProcessor test

* Fix up

* Fix test for imagegpt

* Fix test

* Review comments

* Add warning for TF and JAX input types

* Rearrange

* Return transforms

* NumpyToTensor transformation

* Rebase - include changes from upstream in ImageProcessingMixin

* Safe typing

* Fix up

* convert mean/std to tesnor to rescale

* Don't store transforms in state

* Fix up

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Warn if fast image processor available

* Update src/transformers/models/vit/image_processing_vit_fast.py

* Transpose incoming numpy images to be in CHW format

* Update mapping names based on packages, auto set fast to None

* Fix up

* Fix

* Add AutoImageProcessor.from_pretrained(checkpoint, use_fast=True) test

* Update src/transformers/models/vit/image_processing_vit_fast.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Add equivalence and speed tests

* Fix up

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
itazap pushed a commit that referenced this pull request Jun 20, 2024
* Draft fast image processors

* Draft working fast version

* py3.8 compatible cache

* Enable loading fast image processors through auto

* Tidy up; rescale behaviour based on input type

* Enable tests for fast image processors

* Smarter rescaling

* Don't default to Fast

* Safer imports

* Add necessary Pillow requirement

* Woops

* Add AutoImageProcessor test

* Fix up

* Fix test for imagegpt

* Fix test

* Review comments

* Add warning for TF and JAX input types

* Rearrange

* Return transforms

* NumpyToTensor transformation

* Rebase - include changes from upstream in ImageProcessingMixin

* Safe typing

* Fix up

* convert mean/std to tesnor to rescale

* Don't store transforms in state

* Fix up

* Update src/transformers/image_processing_utils_fast.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/auto/image_processing_auto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Warn if fast image processor available

* Update src/transformers/models/vit/image_processing_vit_fast.py

* Transpose incoming numpy images to be in CHW format

* Update mapping names based on packages, auto set fast to None

* Fix up

* Fix

* Add AutoImageProcessor.from_pretrained(checkpoint, use_fast=True) test

* Update src/transformers/models/vit/image_processing_vit_fast.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Add equivalence and speed tests

* Fix up

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-slow WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants