Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KV Cache Injection] Causal Mask implementation for OPT and CodeGen #1677

Merged
merged 12 commits into from
Jul 27, 2023

Conversation

dbogunowicz
Copy link
Contributor

@dbogunowicz dbogunowicz commented Jul 20, 2023

Introducing a gentle refactoring of position_adjustment transformations to additional_transforms transformations. The aim is to make the positions_adjustment transformation more general since this transformation will now also include causal mask injection.

Note: This PR contains the CodeGen causal mask support: #1676 as well as OPT causal mask support: #1688
Note2: Also added more verbose logging to the kv cache injection process:
e.g. for OPT:

2023-07-27 14:06:19 sparseml.exporters.transforms.kv_cache.configs INFO     Loaded config file deployment/config.json for model: opt
2023-07-27 14:06:19 sparseml.exporters.transforms.kv_cache.configs INFO     Properly configured arguments for KV Cache Transformation
2023-07-27 14:06:24 sparseml.exporters.transforms.onnx_transform INFO     [CacheKeysAndValues] Transformed 48 matches
2023-07-27 14:06:32 sparseml.exporters.transforms.kv_cache.transforms_base INFO     Inserted positions input to the ONNX model
2023-07-27 14:06:32 sparseml.exporters.transforms.kv_cache.transforms_base INFO     Inserted causal_mask input to the ONNX model
2023-07-27 14:06:33 sparseml.exporters.transforms.kv_cache.transforms_base INFO     Successfully swapped 1 nodes for input 'positions'
2023-07-27 14:06:33 sparseml.exporters.transforms.kv_cache.transforms_base INFO     Successfully swapped 1 nodes for input 'causal_mask'
2023-07-27 14:06:34 sparseml.exporters.transforms.kv_cache.transforms_opt INFO     Successfully adjusted the causal_mask input
2023-07-27 14:06:34 sparseml.exporters.transforms.onnx_transform INFO     [AdditionalTransformsOPT] Transformed 5 matches

and for CodeGen

2023-07-27 14:25:03 sparseml.exporters.transforms.kv_cache.configs INFO     Loaded config file deployment/config.json for model: codegen
2023-07-27 14:25:03 sparseml.exporters.transforms.kv_cache.configs INFO     Properly configured arguments for KV Cache Transformation
2023-07-27 14:25:08 sparseml.exporters.transforms.onnx_transform INFO     [CacheKeysAndValues] Transformed 40 matches
2023-07-27 14:25:17 sparseml.exporters.transforms.kv_cache.transforms_base INFO     Inserted positions input to the ONNX model
2023-07-27 14:25:17 sparseml.exporters.transforms.kv_cache.transforms_base INFO     Inserted causal_mask input to the ONNX model
2023-07-27 14:25:18 sparseml.exporters.transforms.kv_cache.transforms_base INFO     Successfully swapped 1 nodes for input 'positions'
2023-07-27 14:25:18 sparseml.exporters.transforms.kv_cache.transforms_base INFO     Successfully swapped 20 nodes for input 'causal_mask'
2023-07-27 14:25:19 sparseml.exporters.transforms.kv_cache.transforms_codegen INFO     Successfully adjusted the causal_mask input
2023-07-27 14:25:19 sparseml.exporters.transforms.onnx_transform INFO     [AdditionalTransformsCodeGen] Transformed 22 matches

@dbogunowicz dbogunowicz marked this pull request as ready for review July 20, 2023 06:46
bfineran
bfineran previously approved these changes Jul 25, 2023
* initial implementation; testing now

* fix a small blunder

* cleanup

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
bfineran
bfineran previously approved these changes Jul 25, 2023
* initial implementation; testing now

* fix a small blunder

* cleanup

* initial implementation

* on to testing with deepsparse

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
@dbogunowicz dbogunowicz force-pushed the feature/damian/refactor_injection branch from 7737994 to 224412f Compare July 27, 2023 12:41
bfineran
bfineran previously approved these changes Jul 27, 2023
Copy link
Member

@bfineran bfineran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - make sure to update the PR title since this includes the entire causal mask feature now

@@ -84,7 +81,7 @@ class Config:

OPT_CONFIG = KeyValueCacheConfig(
model_name="opt",
positions_adjustment_transform=PositionsAdjustmentOPT,
additional_transforms=AdditionalTransformsOPT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to support a list here eventually instead in case we don't want to squash many steps into a single transform or enable more code sharing between model specific additional transforms outside of inheritance

@dbogunowicz dbogunowicz changed the title [KV Cache Injection] Refactor before Causal Mask implementation [KV Cache Injection] Causal Mask implementation for OPT and CodeGen Jul 27, 2023
@bfineran bfineran merged commit b1d5ea2 into main Jul 27, 2023
10 checks passed
@bfineran bfineran deleted the feature/damian/refactor_injection branch July 27, 2023 15:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants