Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KV Cache Injection] Causal Mask for OPT #1688

Merged

Conversation

dbogunowicz
Copy link
Contributor

@dbogunowicz dbogunowicz commented Jul 25, 2023

Add causal mask support for OPT models to enable multitoken prefill in the Deepsparse pipeline.

Manual Testing

  1. Export the OPT model
python kv_cache_injector.py --input-file deployment/model.onnx --output-file deployment/model_kvcache.onnx
  1. Inject the KV Cache
python kv_cache_injector.py --input-file deployment/model.onnx --output-file deployment/model_kvcache.onnx
2023-07-25 13:18:44 sparseml.exporters.transforms.kv_cache.configs INFO     Loaded config file deployment/config.json for model: opt
2023-07-25 13:18:44 sparseml.exporters.transforms.kv_cache.configs INFO     Properly configured arguments for KV Cache Transformation
2023-07-25 13:18:46 sparseml.exporters.transforms.onnx_transform INFO     [CacheKeysAndValues] Transformed 48 matches
2023-07-25 13:18:49 sparseml.exporters.transforms.onnx_transform INFO     [AdditionalTransformsOPT] Transformed 2 matches
  1. Run inference (using this branch: [Text Generation] Causal Mask Feature Branch deepsparse#1126)
from deepsparse import Pipeline

def _test_pipeline(engine_type):
    opt = Pipeline.create(task="opt",
                          model_path="/home/ubuntu/damian/sparseml/deployment",
                          engine_type=engine_type,
                          prompt_processing_sequence_length=64,
                          use_deepsparse_cache = False,
                          max_generated_tokens=32)
    print('----------')
    prompt = "Who is the president of the United States?" # the prompt is short, will not be processed by self.multitoken_engine
    out = opt(sequences=prompt, return_logits = True)
    print(out.sequences[0])
    print('---------')
    prompt = "Who is the president of the United States?" * 20 # the prompt is long, will be processed by self.multitoken_engine
    out = opt(sequences=prompt, return_logits=True)
    print(out.sequences[0])

_test_pipeline(engine_type ="onnxruntime")
_test_pipeline(engine_type ="deepsparse")
2023-07-25 13:23:35 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-25 13:23:41 deepsparse.utils.onnx INFO     Overwriting in-place the batch size of the model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-25 13:23:51 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-25 13:23:57 deepsparse.utils.onnx INFO     Overwriting in-place the batch size of the model at /home/ubuntu/damian/sparseml/deployment/model.onnx
----------


The president of the United States is the head of the executive branch of government. The president is the head of the executive branch of government, and the
---------
Who is the president of the United States?Who is the president of the United States?Who is the president of the United States?Who is the president of
/home/ubuntu/damian/deepsparse/src/deepsparse/transformers/pipelines/text_generation.py:137: UserWarning: AVX512 support not detected, disabling internal management of KV cache which may affect performance. To enable full performance, deploy on an AVX512-compatible system.
  warnings.warn(
2023-07-25 13:24:16 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230725 COMMUNITY | (f26e1c2e) (release) (optimized) (system=avx2, binary=avx2)
2023-07-25 13:24:31 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
----------


The president of the United States is the head of the executive branch of government. The president is the head of the executive branch of government, and the
---------
Who is the president of the United States?Who is the president of the United States?Who is the president of the United States?Who is the president of

@dbogunowicz dbogunowicz changed the base branch from main to feature/damian/causal_mask_codegen July 25, 2023 08:42
@dbogunowicz dbogunowicz marked this pull request as ready for review July 25, 2023 13:17
Base automatically changed from feature/damian/causal_mask_codegen to feature/damian/refactor_injection July 25, 2023 14:35
Copy link
Member

@bfineran bfineran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending rebase and comment regarding strong preference on using cast over where


@classmethod
def add_positions_input(cls, model: ModelProto) -> ModelProto:
def add_causal_mask_input(self, model: ModelProto) -> ModelProto:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this needs rebase?

@@ -12,78 +12,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from onnx import ModelProto, NodeProto
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rebase?

```
| causal_mask
| |
| Where
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we using a where instead of a cast to bool? can you check with runtime for what would work better if either are fine? additionally, seems like a cast would read better in onnx vs where which involves a condition...

@dbogunowicz dbogunowicz merged commit db62ca0 into feature/damian/refactor_injection Jul 26, 2023
@dbogunowicz dbogunowicz deleted the feature/damian/causal_mask_opt branch July 26, 2023 06:31
bfineran pushed a commit that referenced this pull request Jul 27, 2023
…1677)

* initial commit

* [KV Cache Injection] Causal Mask for CodeGen (#1676)

* initial implementation; testing now

* fix a small blunder

* cleanup

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>

* [KV Cache Injection] Causal Mask for OPT (#1688)

* initial implementation; testing now

* fix a small blunder

* cleanup

* initial implementation

* on to testing with deepsparse

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>

* replace boolean causal mask for int64 causal mask

* better logging info

* allow transformations to be also a list

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants