Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow FP16 or other precision inference for Pipelines #31342

Merged
merged 19 commits into from
Jul 5, 2024

Conversation

aliencaocao
Copy link
Contributor

@aliencaocao aliencaocao commented Jun 10, 2024

What does this PR do?

Currently, if you pass torch_dtype=torch.float16 or set model=AutoModel.from_pretrained(..., torch_dtype=torch.float16) in hope to use FP16 for inference in a Pipeline, it will fail because although the model is casted to FP16, the inputs like image features stays in fp32 as the default torch dtype.

This PR converts them accordingly. It only convert those that comes out of a image_processor and with type torch.float32 so to not accidently touch things like token ids or boxes which may be in torch.int by intention.

I have not checked pipelines involving audio inputs but I would imagine some of them also having the same issue.

I originally found this issue when using ZeroShotImageClassificationPipeline like this:

ZeroShotImageClassificationPipeline(model=AutoModelForZeroShotImageClassification.from_pretrained(clip_path, torch_dtype=torch.float16), tokenizer=AutoTokenizer.from_pretrained(clip_path), image_processor=AutoImageProcessor.from_pretrained(clip_path), device='cuda')

Note that I have yet to write tests for it as I want to make sure this is a valid issue and I am not just using Pipelines wrongly.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@Narsil

@aliencaocao
Copy link
Contributor Author

aliencaocao commented Jun 10, 2024

Anyone know how do I fix the code quality errors from here instead of running ruff locally? I don't have one setup now...
but strange thing I followed all the existing imports formats

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this feature!

You should be able to directly use .to on the image processor outputs.

All of the pipelines should have tests added to check they can accept and run with fp16 inputs

src/transformers/pipelines/depth_estimation.py Outdated Show resolved Hide resolved
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@aliencaocao
Copy link
Contributor Author

And for tests, do you think I can just add a dtype=torch.float16 in the pipeline init method in existing tests, or must I keep the fp32 and do a new fp16 run? I feel the latter is unnecessary

@amyeroberts
Copy link
Collaborator

And for tests, do you think I can just add a dtype=torch.float16 in the pipeline init method in existing tests

Certainly not 😄 fp32 is the default for pipelines and so this should be tested by default. We'll need to add tests which set torch_dtype=torch.float16

@aliencaocao
Copy link
Contributor Author

Sure, i'll add that

@aliencaocao
Copy link
Contributor Author

aliencaocao commented Jun 25, 2024

@amyeroberts do I need to test for numerical similarity in fp16 or just make sure the inference runs?

Testing for numerical is quite some work and slow to run as I have to download each model in each pipeline (i changed 14 of them) then get the expected logits then check allclose. However this can break for indiv models depending on their size and sensitively to numerical precision, and the threshold for allclose will vary between task and models too. Ultimately, even if a model does not work well when using FP16 VS FP32, there is really nothing we can do here and I think it should be up to the users to evaluate themselves.

@aliencaocao aliencaocao changed the title Allow FP16 or other precision inference for Pipelines involving image features Allow FP16 or other precision inference for Pipelines Jun 25, 2024
@aliencaocao
Copy link
Contributor Author

I have pushed the tests for inference and not check for numerical stability by using the existing code in pipeline test mixin, except for a few models where no mixin common methods (get_test_pipeline, run_pipeline_test) were declared so I added them in the respective test scripts on the test_small_model_pt methods.

image to image test has @slow, would need the run slow tag to run that.

@amyeroberts
Copy link
Collaborator

@aliencaocao Great!

At the moment, there's a few pipeline tests failing which will need to be resolved. To run slow tests locally, you can set the RUN_SLOW flag with RUN_SLOW=1 pytest ...

@aliencaocao
Copy link
Contributor Author

aliencaocao commented Jun 25, 2024

there's a few pipeline tests failing which will need to be resolved

yes i will be resolving them but i see 1 error with owlvit (zero-shot-object-detection): value cannot be converted to type at::Half without overflow. This indicates that the weights exceed range of fp16. Normally we dont see that in models. I think I have to skip this test unless we can switch to another model like OWLv2-base.

@amyeroberts
Copy link
Collaborator

@aliencaocao For this test, as it's testing the pipeline rather than the model itself, we can change the checkpoint/architecture used. It will likely need other values to be updated alongside.

@aliencaocao
Copy link
Contributor Author

aliencaocao commented Jun 25, 2024

The image to image slow test fails because swin2sr impl has a issue where it does not cast an intermediate tensor to the type of other model parameters.
Specifically, I have to modify

relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view(
to add

relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)

As relative_coords_table is being passed into this MLP on this line.

Should I make a new PR for this change or add it here?

@amyeroberts
Copy link
Collaborator

Could you do this in a separate PR please? It'll be easier to track this way

@aliencaocao
Copy link
Contributor Author

PR made #31589

@aliencaocao
Copy link
Contributor Author

4 other failed tests can be fixed by #31590 - a small QoL improvement

@aliencaocao
Copy link
Contributor Author

aliencaocao commented Jun 26, 2024

For the failing owlvit test, it is not because weight overflow, but implementation:

pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)

-1e6 is simply out of range for fp16, min is -65504.

Should we correct the impl? I can do a check on other models for similar issues.

The fix should be quite simple, just torch.finfo(pred_logits.dtype).min. Only thing is it does affect all existing model's outputs - but they are masked and meant to be ignored anyways

@amyeroberts
Copy link
Collaborator

@aliencaocao Yes, let's use torch.finfo(pred_logits.dtype).min 👍

Only thing is it does affect all existing model's outputs

Here does "all models" refer to all owlvit checkpoints? If that's the case, then that's OK!

@aliencaocao
Copy link
Contributor Author

Here does "all models" refer to all owlvit checkpoints?

What I meant was all torch impl of models in HF transformers as more may be using a hard-coded out of range value for masking logits like owlvit. So for those that end up changing, then outputs may be affected.

Existing tests already check for logits and we will know if anything gets affected. Ideally, none of them should if the masking is working as intended.

Do you want a new PR to update the code for owlvit and potentially other models that use the same, or do I change it here?

@amyeroberts
Copy link
Collaborator

Do you want a new PR to update the code for owlvit and potentially other models that use the same, or do I change it here?

Up to you. Having it correct across all models if of course the dream, but it can be a bit laborious making sure this is correct everywhere + tests and might not be worth it for low-use models. I'm happy to just have the change made for owlvit here, and then we can think about other models' compatibility with fp16 if users raise it

# Conflicts:
#	tests/pipelines/test_pipelines_feature_extraction.py
#	tests/pipelines/test_pipelines_zero_shot_audio_classification.py
@aliencaocao
Copy link
Contributor Author

@amyeroberts the failing tf and onnx tests are due to some keras changes in https://github.com/keras-team/keras/releases/tag/v3.4.1

The failing torch pipeline test is due to network timeout

@amyeroberts
Copy link
Collaborator

@aliencaocao Yes, unfortunately the keras update has broken everything 😭

We're working on a fix. I'll ping once resolved and hopefully then we can successfully re-run the CI for this PR

@amyeroberts
Copy link
Collaborator

@aliencaocao There's been fixes for keras and some of the timeout errors. Could you rebase to include these - should then make all the CIs green

@aliencaocao
Copy link
Contributor Author

@amyeroberts CI all green now

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work - thanks for adding!

@amyeroberts amyeroberts merged commit ac26260 into huggingface:main Jul 5, 2024
18 checks passed
@aliencaocao aliencaocao deleted the fix-pipeline-dtype branch July 5, 2024 16:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants