Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KV cache] Base, generic KV Cache injection support #1559

Merged
merged 26 commits into from
Jun 12, 2023

Conversation

bfineran
Copy link
Member

@bfineran bfineran commented May 11, 2023

Feature Preview:

import onnx
from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector

model = onnx.load("deployment/model.onnx")
model = KeyValueCacheInjector(model_type="opt").apply(model)
onnx.save(model, "deployment/model_kvcache.onnx")

The implementation above guarantees full "onnx checker safety".
However, to allow the user to run the transform faster / in insolation, the following
path is also enabled:

import onnx
from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector
model = onnx.load("deployment/model.onnx", load_external_data=False) # we operate only on the model graph
model = KeyValueCacheInjector(model_type="opt").apply(model)
onnx.save(model, "deployment/model_kvcache.onnx")

Note: this will raise multiple warnings, making the user conscious of the fact, that the models with load_external_data=True, cannot be properly validated.

Additional changes:

  • makes MatMulAddToMatMulIntegerAddCastMul a bit more generic by making the Add portion optional - should look into renaming this transform...

Testing

This functionality has been tested multiple times for OPT models dense/sparse/quantize:

model = onnx.load("/network/damian/sparsegpt_webinar_fp32/sparsegpt_1.3b/model.onnx")
model = KeyValueCacheInjector(model_type = "opt").apply(model)
onnx.save(model, "temp.onnx")
onnx.checker.check_model("temp.onnx")
2023-05-29 11:26:27 sparseml.exporters.transforms.onnx_transform INFO     [CacheKeysAndValues] Transformed 48 matches
2023-05-29 11:26:29 sparseml.exporters.transforms.onnx_transform INFO     [PositionEmbeddingsAdjustment] Transformed 5 matches

With quantized models, we have observed that the kv cache export leads to the presence of the "dead MatMuls":

image

However, the current head of the branch does not produce "dead MatMuls" anymore. Does that mean that the problem has been solved along the way?
image
image

Also, the quantized models run fine in the pipeline (tested with ORT)

@bfineran bfineran requested a review from dbogunowicz May 11, 2023 22:57
@bfineran bfineran self-assigned this May 11, 2023
@bfineran
Copy link
Member Author

status so far - confirmed the 'generic' pattern matching works for OPT, next step is adding the injection

@bfineran
Copy link
Member Author

@dbogunowicz initial implementation of KV cache concats + OPT Cache length adjustment completed

for some reason the exporter I wrote was cleared in my environment, will get that up at a later date
Sample code to try:

import onnx
from sparseml.exporters.transforms.kv_cache import *
model = onnx.load("/home/benjamin/tmp-models/small_decoder_opt.onnx", load_external_data=False)
model = CacheKeysAndValues().transform(model)
model = OPTCacheLengthAdjustment().transform(model)

Example Cached MatMul:
Screenshot 2023-05-12 at 12 20 48 PM

Example Cache Length adjustment:
Screenshot 2023-05-12 at 12 20 07 PM

Copy link
Member Author

@bfineran bfineran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbogunowicz quick comments, pushing up adjustment to reshape

# no great way to generically infer this from the graph since transposes can
# be used to place it on either side of the matmul
# hardcoding for now, will update to have a hardcoded value for each model type
_KEY_NODE_INPUT_IDX = 0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did we take this out?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now it can be hard-coded as an argument to the function - no changes between CodeGen and OPT

@dbogunowicz dbogunowicz changed the title [WIP] 'Generic' KV cache injection support 'Generic' KV cache injection support May 22, 2023
@dbogunowicz dbogunowicz marked this pull request as ready for review May 24, 2023 12:56
KSGulin
KSGulin previously approved these changes May 24, 2023
Copy link
Contributor

@KSGulin KSGulin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great. Building KV cache injection in ONNX feels like the ML equivalent of Chris Sawyer creating RollerCoaster Tycoon in assembly, but this was clean and easy to follow. Left a couple non-blocking comments

"""
graph = ONNXGraph(model)

if node.op_type == "QLinearMatMul" and cache_input_idx == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to check if cache_input_idx == 1 here instead of setting to 3 regardless?

@dbogunowicz dbogunowicz changed the title 'Generic' KV cache injection support [KV cache] Base, generic KV Cache injection support May 29, 2023
@dbogunowicz dbogunowicz changed the base branch from main to feature/damian/fb_kv_cache May 29, 2023 10:05
@dbogunowicz dbogunowicz merged commit d4bd539 into feature/damian/fb_kv_cache Jun 12, 2023
@dbogunowicz dbogunowicz deleted the kv-cache-injection branch June 12, 2023 14:54
dbogunowicz added a commit that referenced this pull request Jun 16, 2023
* Update __init__.py

* [KV cache] Base, generic KV Cache injection support (#1559)

* [WIP] Inject core cache ops - pattern matching + base export

* complete initial implementation of CacheKeysAndValues

* documentation + suggestions from Damian

* Cache length adjustment - ABC + OPT impl

* typo

* little cleanup, more importantly, started testing

* stuck on testing

* move KV cache concat to before transpose where applicable

* working for dynamic seq len

* add support for slicing cache by actual length

* add position_embeddings_adjustment

* quantized model support

* Support Q/DQ folding of Parameterized matmuls w/o bias add

* delete Exporter for KV cache for now - since checker doesn't pass, will do transforms ad-hoc

* quality

* refactor

* fix docstrings

* hardening the validation

* validator not needed

---------

Co-authored-by: Damian <damian@neuralmagic.com>
Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>

* [KV cache] Properly set the static dimensions of the kv cache inputs/outputs (#1573)

* [WIP] Inject core cache ops - pattern matching + base export

* complete initial implementation of CacheKeysAndValues

* documentation + suggestions from Damian

* Cache length adjustment - ABC + OPT impl

* typo

* little cleanup, more importantly, started testing

* stuck on testing

* move KV cache concat to before transpose where applicable

* working for dynamic seq len

* add support for slicing cache by actual length

* add position_embeddings_adjustment

* quantized model support

* Support Q/DQ folding of Parameterized matmuls w/o bias add

* delete Exporter for KV cache for now - since checker doesn't pass, will do transforms ad-hoc

* quality

* refactor

* initial commit

* fix docstrings

* hardening the validation

* validator not needed

* adressing PR comments

---------

Co-authored-by: Benjamin <ben@neuralmagic.com>
Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>

* [KV cache] Input/output KV cache to include `batch` dimension. (#1589)

* [WIP] Inject core cache ops - pattern matching + base export

* complete initial implementation of CacheKeysAndValues

* documentation + suggestions from Damian

* Cache length adjustment - ABC + OPT impl

* typo

* little cleanup, more importantly, started testing

* stuck on testing

* move KV cache concat to before transpose where applicable

* working for dynamic seq len

* add support for slicing cache by actual length

* add position_embeddings_adjustment

* quantized model support

* Support Q/DQ folding of Parameterized matmuls w/o bias add

* delete Exporter for KV cache for now - since checker doesn't pass, will do transforms ad-hoc

* quality

* refactor

* initial commit

* fix docstrings

* initial commit

* hardening the validation

* validator not needed

* tested with deepsparse

* ready for reviews

* adressing PR comments

* addressing PR comments

* ready to land

---------

Co-authored-by: Benjamin <ben@neuralmagic.com>
Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>

* remove  changes

---------

Co-authored-by: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com>
Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
Co-authored-by: Benjamin <ben@neuralmagic.com>
Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants