Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bloom KV cache usage in ORTForCausalLM #1152

Merged
merged 5 commits into from
Jul 6, 2023

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Jun 30, 2023

Fixes #1150

@michaelbenayoun @echarlaix Basically not all models share the exact same _reorder_cache and prepare_inputs_for_generation, in particular here bloom. What do you think of this solution? @echarlaix I would guess there is the same bug in optimum-intel.

I think my solution is very ugly (now ORTBloomForCausalLM and ORTModelForCausalLM need to be in the same file forever). An other approach is to move all shared methods to ORTModelDecoder (effectively making it a mixin class) and having ORTBloomForCausalLM not inherit from ORTModelForCausalLM, but it does not solve the issue of "all classes in one file", and more importantly I believe that changing the inheritance is too breaking of a change (i.e. isinstance(ort_bloom_model, ORTModelForCausalLM) not working anymore).

Other solution: have a single prepare_inputs_for_generation, _reorder_cache, and dispatch to the relevant function from a dictionary. This adds dynamism, which I think is better to avoid.

Note: should add tests for num_beams>1 in this PR as well

Copy link
Contributor

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This solution looks good enough to me 👍

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good for now. If this conditioning on the model type grows maybe we can find a nicer way of doing it, but right now it seems acceptable to me.

optimum/onnxruntime/modeling_decoder.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_decoder.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_decoder.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix @fxmarty !

fxmarty and others added 4 commits July 6, 2023 18:02
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 6, 2023

The documentation is not available anymore as the PR was closed or merged.

@fxmarty fxmarty merged commit 2eab7ab into huggingface:main Jul 6, 2023
62 of 64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ORTModelForCausalLM - ONNX - Merge model fails to run
5 participants