Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation] Multitoken prefill enablement #1130

Merged

Conversation

dbogunowicz
Copy link
Contributor

@dbogunowicz dbogunowicz commented Jul 20, 2023

Enable running the pipeline in the mode, where the prompt is processed (prefill scenario) through multiple, consecutive passes through the multitoken engine. The goal is to achieve optimal inference speed with the deepsparse engine.

Manual Testing

from deepsparse import Pipeline



def _test_pipeline(engine_type, prompt_processing_sequence_length):
    opt = Pipeline.create(task="opt",
                          model_path="/home/ubuntu/damian/sparseml/deployment",
                          engine_type=engine_type,
                          use_deepsparse_cache = False,
                          prompt_processing_sequence_length = prompt_processing_sequence_length,
                          max_generated_tokens=32)
    prompt = "def hello_world():" * 20 # long prompt, so it gets processed by multitoken engine
    out = opt(sequences=prompt, return_logits=True)
    print(out.sequences[0])

for prompt_processing_sequence_length in (8, 16, 55, 128):
    _test_pipeline(engine_type ="onnxruntime", prompt_processing_sequence_length=prompt_processing_sequence_length)

Results:

# Results are identical to the pytorch baseline 

2023-07-20 15:58:14 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 15:58:24 deepsparse.utils.onnx INFO     Overwriting in-place the batch size of the model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 15:58:32 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 15:58:39 deepsparse.utils.onnx INFO     Overwriting in-place the batch size of the model at /home/ubuntu/damian/sparseml/deployment/model.onnx
def hello_world():def hello_world():def hello_world():def hello_world():def hello_world():def hello_world():def hello

2023-07-20 15:58:53 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 15:58:59 deepsparse.utils.onnx INFO     Overwriting in-place the batch size of the model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 15:59:12 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 15:59:20 deepsparse.utils.onnx INFO     Overwriting in-place the batch size of the model at /home/ubuntu/damian/sparseml/deployment/model.onnx
def hello_world():def hello_world():def hello_world():def hello_world():def hello_world():def hello_world():def hello

2023-07-20 15:59:33 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 15:59:40 deepsparse.utils.onnx INFO     Overwriting in-place the batch size of the model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 16:01:12 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 16:01:18 deepsparse.utils.onnx INFO     Overwriting in-place the batch size of the model at /home/ubuntu/damian/sparseml/deployment/model.onnx
def hello_world():def hello_world():def hello_world():def hello_world():def hello_world():def hello_world():def hello

2023-07-20 16:02:38 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 16:02:44 deepsparse.utils.onnx INFO     Overwriting in-place the batch size of the model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 16:02:54 deepsparse.transformers.engines.nl_decoder_engine INFO     Overwriting in-place the input shapes of the transformer model at /home/ubuntu/damian/sparseml/deployment/model.onnx
2023-07-20 16:03:00 deepsparse.utils.onnx INFO     Overwriting in-place the batch size of the model at /home/ubuntu/damian/sparseml/deployment/model.onnx
def hello_world():def hello_world():def hello_world():def hello_world():def hello_world():def hello_world():def hello

Process finished with exit code 0

@dbogunowicz dbogunowicz changed the base branch from main to feature/damian/causal_mask_support July 20, 2023 14:41
@dbogunowicz dbogunowicz marked this pull request as ready for review July 20, 2023 15:46
Copy link
Member

@bfineran bfineran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a deeper look following our offline conversation and understand why you had to go this route - LGTM, but let's update an existing diagram or add a new one to explain the relationship between decoder engine, cache, state, state transfer, and capacity

# self.prompt_processing_sequence_length)
num_non_blank_cache_entries = min(
num_non_blank_cache_entries,
self.sequence_length - self.prompt_processing_sequence_length,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be the total remaining tokens ie something like self.sequence_length - idx * self.prompt_processing_sequence_length or am I missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are essentially talking about the same thing, but my logic was way too overcomplicated. Refactored the function, so now hopefully anyone reading should grasp what's going on.

Base automatically changed from feature/damian/causal_mask_support to feature/damian/causal_mask_fb July 25, 2023 08:22
@dbogunowicz dbogunowicz merged commit e324cdc into feature/damian/causal_mask_fb Jul 25, 2023
@dbogunowicz dbogunowicz deleted the feature/damian/multitoken_prefill branch July 25, 2023 12:25
bfineran pushed a commit that referenced this pull request Jul 27, 2023
* Update helpers.py

* correct implementation of the mapping from inputs to causal mask

* [Text Generation] Causal Mask Support (#1127)

* initial commit

* clean up the PR

* working implementation

* Ben's review comments

* [Text Generation] Multitoken prefill enablement (#1130)

* initial commit

* clean up the PR

* working implementation

* initial implementation, hacky lets clean it up

* ready for review

* few tiny quality improvements

* simplify the logic for computing num of unmasked bits for creating attention_mask for the multitoken prefill

* replace boolean causal mask for int64 causal mask

* fix breaking tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants