Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration] Split up prep_for_generation operator, handle edge cases, handle kv_cache full during prefill #1562

Merged
merged 11 commits into from
Jan 26, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def can_operate(self, inp: Any) -> bool:
if inp.get("in_generation"):
return True

if kv_cache.total_num_processed_tokens >= kv_cache.capacity:
raise RuntimeError(
"Not enough kv_cache capacity to run generation. Please use a larger "
"sequence_length or a shorter prompt"
)

remaining_tokens = len(tokens) - kv_cache.total_num_processed_tokens
can_process = (
remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pydantic import BaseModel, Field

from deepsparse.operators import Operator
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
from deepsparse.utils import InferenceState


Expand All @@ -43,9 +42,6 @@ def run(self, inference_state: InferenceState, **kwargs):
generated_logits = inference_state.current_state.get("generated_logits")
finished_reason = inference_state.current_state.get("finished_reason")

if len(finished_reason) == 0:
finished_reason.append(FinishReason.LENGTH)

generated_tokens = numpy.array([generated_tokens])
generated_logits = numpy.concatenate(generated_logits, axis=1)
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
from deepsparse.transformers.pipelines.text_generation.nl_engine_operator import (
NLEngineOutputs,
)
from deepsparse.transformers.schemas.text_generation_schemas import (
FinishReason,
PromptLogitsNoKVCacheInference,
)
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
from deepsparse.utils import InferenceState


Expand All @@ -36,14 +33,16 @@ def __init__(
self.force_max_tokens = force_max_tokens
self.tokenizer = tokenizer

def can_operate(self, inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs]):
def can_operate(
self, inp: Union[NLEngineOutputs, "PrepareForGenerationOutput"] # noqa: F821
):
if inp.in_generation:
return True
return False

def run(
self,
inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs],
inp: Union[NLEngineOutputs, "PrepareForGenerationOutput"], # noqa: F821
inference_state: InferenceState,
**kwargs,
):
Expand All @@ -52,21 +51,26 @@ def run(
if isinstance(inp, NLEngineOutputs)
else inp.prompt_logits
)
kv_cache = inp.kv_cache if isinstance(inp, NLEngineOutputs) else None
kv_cache = inp.kv_cache

max_tokens = inference_state.current_state.get("max_tokens")
length_finish_reason = inference_state.current_state.get("length_finish_reason")
generated_tokens = inference_state.current_state.get("generated_tokens")
num_generated_tokens = len(generated_tokens)

token_generator = inference_state.current_state.get("token_generator")
token = token_generator.generate(logits=logits[0, -1, :])
finish_reason = None

callback = inference_state.current_state.get("callback")
stop = inference_state.current_state.get("stop")

if (
kv_cache is not None
and kv_cache.total_num_processed_tokens >= kv_cache.capacity
):
finish_reason = FinishReason.CAPACITY

callback = inference_state.current_state.get("callback")
stop = inference_state.current_state.get("stop")

if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
finish_reason = FinishReason.STOP

Expand All @@ -84,9 +88,8 @@ def run(
)
finish_reason = FinishReason.CALLBACK

max_tokens = inference_state.current_state.get("max_tokens")
if len(inference_state.current_state.get("generated_tokens")) + 1 >= max_tokens:
finish_reason = inference_state.current_state.get("length_finish_reason")
if num_generated_tokens + 1 == max_tokens:
dsikka marked this conversation as resolved.
Show resolved Hide resolved
finish_reason = length_finish_reason

state_update = {
"token_generator": token_generator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def can_operate(self, inp: Any):
kv_cache = inp.get("kv_cache")
tokens = inp.get("tokens")

if kv_cache.total_num_processed_tokens >= kv_cache.capacity:
raise RuntimeError(
"Not enough kv_cache capacity to run generation. Please use a larger "
"sequence_length or a shorter prompt"
)

if len(tokens) < self.prompt_sequence_length:
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def __init__(
sequence_length=sequence_length,
prompt_sequence_length=prompt_sequence_length,
token_generator=token_generator,
process_output_operator=process_output,
)

# TODO: do we want to support lists for different engines?
Expand Down Expand Up @@ -286,7 +285,7 @@ def __init__(
"compile_logits",
"generate_new_token",
],
"prep_for_generation": "autoregressive_preprocess",
"prep_for_generation": "generate_new_token",
"generate_new_token": "compile_generated_tokens",
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from deepsparse.routers import GraphRouter
from deepsparse.schedulers import OperatorScheduler
from deepsparse.transformers.pipelines.text_generation import (
CompileGeneratedTokens,
CompileGenerations,
GenerateNewTokenOperator,
JoinOutput,
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
tokenizer=self.tokenizer, force_max_tokens=True
)
compile_generations = CompileGenerations()
compile_generated_tokens = CompileGeneratedTokens()
join_output = JoinOutput(tokenizer=self.tokenizer)
process_outputs = ProcessOutputs(tokenizer=self.tokenizer)

Expand All @@ -82,6 +84,7 @@ def __init__(
"engine_operator": engine_operator,
"prepare_generation": prepare_generation,
"generate_new_token": generate_new_token,
"compile_generated_tokens": compile_generated_tokens,
"compile_generations": compile_generations,
"join_output": join_output,
"process_outputs": process_outputs,
Expand All @@ -92,7 +95,8 @@ def __init__(
"SPLIT": "engine_operator",
"engine_operator": "prepare_generation",
"prepare_generation": "generate_new_token",
"generate_new_token": "compile_generations",
"generate_new_token": "compile_generated_tokens",
"compile_generated_tokens": "compile_generations",
"compile_generations": "JOIN",
"JOIN": "join_output",
"join_output": "process_outputs",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,38 @@
from typing import Any, Optional

import numpy
from pydantic import BaseModel, Field

from deepsparse.operators import Operator
from deepsparse.subgraph_execute import StreamingOutput
from deepsparse.transformers.pipelines.text_generation import TokenGeneratorOperator
from deepsparse.transformers.schemas.text_generation_schemas import (
FinishReason,
PromptLogitsNoKVCacheInference,
)
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
from deepsparse.transformers.utils.helpers import set_generated_length
from deepsparse.utils import InferenceState


__all__ = ["PrepareGeneration"]
__all__ = ["PrepareGeneration", "PrepareForGenerationOutput"]


class PrepareForGenerationOutput(BaseModel):
prompt_logits: Any = Field(
description="A set of prompt logits generated during prefill"
)
kv_cache: Optional[Any] = Field(description="kv cache")
in_generation: Optional[bool] = Field(description="in_generation flag")


class PrepareGeneration(Operator):
output_schema = PrepareForGenerationOutput

def __init__(
self,
token_generator: TokenGeneratorOperator,
prompt_sequence_length: int,
sequence_length: int,
process_output_operator: Optional[Operator] = None,
):
self.sequence_length = sequence_length
self.token_generator_creator = token_generator
self.prompt_sequence_length = prompt_sequence_length
# Needed for streaming as currently both setting up generation and generating
# Will split this up soon
self.process_output_operator = process_output_operator

def can_operate(self, inp: Any):
kv_cache = inp.get("kv_cache")
Expand Down Expand Up @@ -79,7 +82,6 @@ def run(
**inference_state.current_state,
)
token_generator = token_generator_creator_output.get("token_generator")
token_generator.generate(prompt_logits[0, -1, :])

max_tokens, length_finish_reason = set_generated_length(
max_length=generation_config.max_length,
Expand All @@ -93,43 +95,21 @@ def run(
state_update = {
"max_tokens": max_tokens,
"length_finish_reason": length_finish_reason,
"generated_tokens": [token_generator.tokens[-1]],
"generated_logits": [prompt_logits]
"generated_tokens": [],
"generated_logits": [prompt_logits[:, 0:-1, :]]
if include_prompt_logits
else [numpy.expand_dims(prompt_logits[:, -1, :], 0)],
else [],
"finished_reason": [],
"token_generator": token_generator,
}

if kv_cache is None:
output = PromptLogitsNoKVCacheInference(prompt_logits=prompt_logits)
output = {"prompt_logits": numpy.expand_dims(prompt_logits[:, -1, :], 0)}
else:
output = {
"tokens": token_generator.tokens,
"kv_cache": kv_cache,
"in_generation": True,
"prompt_logits": numpy.expand_dims(prompt_logits[:, -1, :], 0),
}
# TODO: maybe break this operator up since it is both generating and setting
# up values needed for generation? Holding off on this as this will change
# routes slighty and want to confirm wont break anything for non-kv cache
if inference_state.current_state.get("streaming") and max_tokens >= 1:
finished_reason = [length_finish_reason] if max_tokens == 1 else [None]

if self.process_output_operator is None:
raise ValueError(
"An operator must be provided to process outputs"
"while streaming."
)
data_to_yield = self.process_output_operator.run(
generated_tokens=state_update.get("generated_tokens"),
finished_reason=finished_reason,
inference_state=inference_state,
generated_logits=prompt_logits[0, -1, :],
)
output = StreamingOutput(
data_to_yield=self.process_output_operator.output_schema(
**data_to_yield
),
data_to_return=output,
)

return output, state_update
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,3 @@ class TextGenerationOutput(BaseModel):
class Config:
arbitrary_types_allowed = True
extra = "allow"


class PromptLogitsNoKVCacheInference(BaseModel):
prompt_logits: Any = Field(
description="A set of prompt logits generated "
"during the inference pass with a "
"non-kv cache model"
)
4 changes: 3 additions & 1 deletion src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ def set_generated_length(
:param max_new_tokens: the max_new_tokens attribute, which may be provided
as part of the input during inference
"""
if max_length:
if max_length is not None:
# if max_length provided, use that to cap total tokens generated
if max_length == 0:
raise ValueError("max_length must be greater than 0")
max_tokens = max_length
finish_reason = finish_reason_choices.LENGTH
else:
Expand Down
46 changes: 45 additions & 1 deletion tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def pipeline():
return Pipeline.create(
task="text_generation",
model_path="hf:mgoin/TinyStories-1M-deepsparse",
engine_type="onnxruntime",
dsikka marked this conversation as resolved.
Show resolved Hide resolved
engine_type="deepsparse",
)


Expand Down Expand Up @@ -143,6 +143,7 @@ def test_stop_inference_kv_cache_full(prompt):
expected_generated_tokens_length=max_new_tokens_plus_one,
expected_finished_reason="capacity",
)

"""
Check the following structure ok the kv cache:
minus_one | full | plus_one | plus_two
Expand All @@ -152,6 +153,7 @@ def test_stop_inference_kv_cache_full(prompt):
[row B] | [row C] | [row D] | [row D]
... | ... | ... | ...
"""

# check for the "free" space in the kv cache
assert kv_cache_state_full_minus_one["past_key_values.0.key"][:, :, 0, :].sum() == 0
# check for the row A
Expand Down Expand Up @@ -282,3 +284,45 @@ def test_streaming_non_streaming_generate_same_tokens(pipeline, prompt):
tokens.append(g.generations[0].text)
output_2 = "".join(tokens)
assert output_1 == output_2


def test_edge_cases(pipeline, prompt):
output = pipeline(prompt=prompt, max_length=1, output_scores=True)
dsikka marked this conversation as resolved.
Show resolved Hide resolved
assert len(output.generations[0].score) == 1

output = pipeline(
prompt=prompt, max_length=1, output_scores=True, include_prompt_logits=True
)
assert len(output.generations[0].score) == 11

output = pipeline(prompt=prompt, max_new_tokens=0, output_scores=True)
assert len(output.generations[0].score) == 1

output = pipeline(
prompt=prompt, max_new_tokens=0, output_scores=True, include_prompt_logits=True
)
assert len(output.generations[0].score) == 11

output = pipeline(prompt=prompt, max_new_tokens=1, output_scores=True)
assert len(output.generations[0].score) == 2

output = pipeline(
prompt=prompt, max_new_tokens=1, output_scores=True, include_prompt_logits=True
)
assert len(output.generations[0].score) == 12

with pytest.raises(ValueError):
pipeline(prompt=prompt, max_length=0)


def test_kv_cache_too_small_for_prefill(prompt):
for i in range(10):
prompt += prompt

pipeline = Pipeline.create(
task="text_generation",
model_path="hf:mgoin/TinyStories-1M-deepsparse",
sequence_length=25,
)
with pytest.raises(RuntimeError):
pipeline(prompt=prompt)
Loading
Loading