-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[KV Cache Injection] Causal Mask implementation for OPT and CodeGen #1677
Conversation
* initial implementation; testing now * fix a small blunder * cleanup --------- Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
* initial implementation; testing now * fix a small blunder * cleanup * initial implementation * on to testing with deepsparse --------- Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
7737994
to
224412f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - make sure to update the PR title since this includes the entire causal mask feature now
@@ -84,7 +81,7 @@ class Config: | |||
|
|||
OPT_CONFIG = KeyValueCacheConfig( | |||
model_name="opt", | |||
positions_adjustment_transform=PositionsAdjustmentOPT, | |||
additional_transforms=AdditionalTransformsOPT, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be great to support a list here eventually instead in case we don't want to squash many steps into a single transform or enable more code sharing between model specific additional transforms outside of inheritance
Introducing a gentle refactoring of
position_adjustment
transformations toadditional_transforms
transformations. The aim is to make thepositions_adjustment
transformation more general since this transformation will now also include causal mask injection.Note: This PR contains the CodeGen causal mask support: #1676 as well as OPT causal mask support: #1688
Note2: Also added more verbose logging to the kv cache injection process:
e.g. for OPT:
and for CodeGen