Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLAMA] KV Cache Injection #1709

Merged
merged 7 commits into from
Aug 29, 2023
Merged

[LLAMA] KV Cache Injection #1709

merged 7 commits into from
Aug 29, 2023

Conversation

dsikka
Copy link
Contributor

@dsikka dsikka commented Aug 21, 2023

  • LLAMA2 in general requires the most up-to-date nm-transformers

Summary

Add support for LLAMA2 KV Cache Injection:

  • Add a new config for the LLM in configs.py
  • Transforms added for positions and causal mask use the same set of transforms required/used by the other LLMs we currently support
  • One additional transform is added to update the slice nodes in the attention heads such that the ends attribute is updated. This is required for the positions injection to work properly

Testing:

import onnx
from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector
from pathlib import Path
import numpy as np
import onnxruntime
import numpy 

INJECT_NEW = True

MODEL_NAME = "deployment_sparseml/model.onnx"
root = Path("/home/dsikka/llama_run/deployment_sparseml")

if INJECT_NEW:
    model = onnx.load(str(root / "model_og.onnx"), load_external_data=True)
    model = KeyValueCacheInjector(str(root)).apply(model)
    onnx.save(model, MODEL_NAME, all_tensors_to_one_file=True, save_as_external_data=True)

    try:
        onnx.checker.check_model(MODEL_NAME)
    except onnx.checker.ValidationError as e:
        print(e)
    else:
        print("Valid")

With this injected model, we can currently run the model in the pipeline using ORT:

from deepsparse import Pipeline

llama = Pipeline.create(
   task="text-generation",
   model_path="/home/dsikka/llama_run/deployment_sparseml",
   engine_type="onnxruntime"
)

inference = llama(sequences="Who is your favourite Toronto Raptor?")
print(inference)

Output:

sequences=["\n\nI'm a big fan of Kyle Lowry, he's a great point guard and leader on the team. But I also really enjoy watching Pascal Siakam, he's a talented young player with a lot of potential. And of course, I can't forget about Kawhi Leonard, he's a incredible player and a key part of the team's success."] logits=None session_id=None

@dbogunowicz
Copy link
Contributor

This looks great @dsikka . We need to figure out the fix for the positions and then we are good to go!

Copy link
Member

@bfineran bfineran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great - see comment

dbogunowicz
dbogunowicz previously approved these changes Aug 28, 2023
@dsikka dsikka merged commit df570d1 into main Aug 29, 2023
10 checks passed
@dsikka dsikka deleted the llama_update branch August 29, 2023 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants