Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TextGeneration Pipeline] argument renaming #1194

Merged
merged 7 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/openai-server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Set up the server:
```
python examples/openai-server/server.py --model zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
2023-08-07 17:18:32 __main__ INFO args: Namespace(model='zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none', max_model_len=512, prompt_processing_sequence_length=1, use_deepsparse_cache=False, host='localhost', port=8000, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], served_model_name=None)
2023-08-07 17:18:32 __main__ INFO args: Namespace(model='zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none', max_model_len=512, prompt_sequence_length=1, internal_kv_cache=False, host='localhost', port=8000, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], served_model_name=None)
2023-08-07 17:18:32 deepsparse.transformers WARNING The neuralmagic fork of transformers may not be installed. It can be installed via `pip install nm_transformers`
Using pad_token, but it is not set yet.
2023-08-07 17:18:34 deepsparse.transformers.engines.nl_decoder_engine INFO Overwriting in-place the input shapes of the transformer model at /home/mgoin/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/model.onnx
Expand Down
12 changes: 6 additions & 6 deletions examples/openai-server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ def __init__(
self,
model: str,
sequence_length: int = 512,
prompt_processing_sequence_length: int = 64,
use_deepsparse_cache: bool = False,
prompt_sequence_length: int = 64,
internal_kv_cache: bool = False,
):
self.engine = deepsparse.Pipeline.create(
task="text-generation",
model_path=model,
sequence_length=sequence_length,
prompt_processing_sequence_length=prompt_processing_sequence_length,
use_deepsparse_cache=use_deepsparse_cache,
prompt_sequence_length=prompt_sequence_length,
internal_kv_cache=internal_kv_cache,
)

def tokenize(self, text: str) -> List[int]:
Expand Down Expand Up @@ -751,8 +751,8 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
engine = DeepSparseOpenAIEngine(
model=args.model,
sequence_length=max_model_len,
prompt_processing_sequence_length=args.prompt_processing_sequence_length,
use_deepsparse_cache=args.use_deepsparse_cache,
prompt_sequence_length=args.prompt_sequence_length,
internal_kv_cache=args.internal_kv_cache,
)
tokenizer = engine.engine.tokenizer

Expand Down
10 changes: 4 additions & 6 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class NLDecoderEngine:
:param deterministic: Whether to use deterministic sampling
:param tokenizer: The tokenizer to used for engine inputs
:param engine_context: The context to run the engine in
:param use_deepsparse_cache: Whether to use the deepsparse
:param internal_kv_cache: Whether to use the deepsparse
kv cache in the DecoderKVCache object or not
"""

Expand All @@ -64,7 +64,7 @@ def __init__(
sampling_temperature: float = 1.0,
deterministic: bool = True,
engine_context: Optional[Context] = None,
use_deepsparse_cache: bool = False,
internal_kv_cache=False,
):
# flag to indicate if the model is quantized or not
self.kv_cache_data_type = None
Expand All @@ -83,7 +83,7 @@ def __init__(
if sum(output_indices_to_be_cached):
kv_cache_enabled = True
self.kv_cache_data_type = kv_cache_data_type
if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE:
if internal_kv_cache and engine_type == DEEPSPARSE_ENGINE:
# inform the engine, that are using the kv cache
engine_args["cached_outputs"] = output_indices_to_be_cached

Expand All @@ -100,9 +100,7 @@ def __init__(
self.input_ids_length = input_ids_length
self.cache_length = sequence_length - input_ids_length
self.kv_cache_enabled = kv_cache_enabled
self.kv_cache = (
DecoderKVCache(use_deepsparse_cache) if kv_cache_enabled else None
)
self.kv_cache = DecoderKVCache(internal_kv_cache) if kv_cache_enabled else None
self._freeze_first_position = self._should_freeze_first_position(tokenizer)
self._session_id = generate_session_id()
self._engine_type = engine_type
Expand Down
55 changes: 25 additions & 30 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,12 @@ class TextGenerationPipeline(TransformersPipeline):
sequence is reached.
:param sequence_length: sequence length to compile model and tokenizer for.
This controls the maximum context length of the pipeline. Default is 512
:param prompt_processing_sequence_length: For large prompts, the prompt is
:param prompt_sequence_length: For large prompts, the prompt is
processed in chunks of this length. This is to maximize the inference
speed. By default, this is set to 64.
:param force_max_tokens: if True, the pipeline will generate the maximum number
of tokens supplied even if the stop token is reached.
:param use_deepsparse_cache: if True, the pipeline will use the deepsparse kv cache
:param internal_kv_cache: if True, the pipeline will use the deepsparse kv cache
for caching the model outputs.
:param kwargs: kwargs to pass to the TransformersPipeline
"""
Expand All @@ -173,23 +173,23 @@ def __init__(
deterministic: bool = True,
sampling_temperature: float = 1.0,
max_generated_tokens: Optional[int] = 1024,
prompt_sequence_length: int = 64,
sequence_length: int = 512,
prompt_processing_sequence_length: int = 64,
force_max_tokens: bool = False,
use_deepsparse_cache: bool = True,
internal_kv_cache: bool = True,
**kwargs,
):
kwargs_engine_type = kwargs.get("engine_type", DEEPSPARSE_ENGINE)

if use_deepsparse_cache:
if internal_kv_cache:
if kwargs_engine_type != DEEPSPARSE_ENGINE:
_LOGGER.warning(
"`use_deepsparse_cache` is set to True "
"`internal_kv_cache` is set to True "
"but the chosen `engine_type` "
f"is {kwargs_engine_type}. "
f"The optimized kv cache management is disabled."
)
use_deepsparse_cache = False
internal_kv_cache = False

super().__init__(
**kwargs,
Expand All @@ -199,10 +199,10 @@ def __init__(
)
# enable multitoken prefill if
# - the model graph is supporting it (causal_mask input is present)
# - prompt_processing_sequence_length != 1 (identical to single-token prefill)
# - prompt_sequence_length != 1 (identical to single-token prefill)
self.enable_multitoken_prefill = (
self.causal_mask_input_present(model_path=self.onnx_file_path)
and prompt_processing_sequence_length > 1
and prompt_sequence_length > 1
)

self.cache_support_enabled = self.is_cache_support_enabled()
Expand All @@ -221,9 +221,9 @@ def __init__(
self.deterministic = deterministic
self.sampling_temperature = sampling_temperature
self.max_generated_tokens = max_generated_tokens
self.prompt_processing_sequence_length = prompt_processing_sequence_length
self.prompt_sequence_length = prompt_sequence_length
self.force_max_tokens = force_max_tokens
self.use_deepsparse_cache = use_deepsparse_cache
self.internal_kv_cache = internal_kv_cache

# override tokenizer to pad to left
self.tokenizer.padding_side = "left"
Expand Down Expand Up @@ -259,15 +259,15 @@ def initialize_engines(
if self.cache_support_enabled:
if (
self.engine_type == DEEPSPARSE_ENGINE
and self.sequence_length <= self.prompt_processing_sequence_length
and self.sequence_length <= self.prompt_sequence_length
and self.enable_multitoken_prefill
):
raise ValueError(
"Attempting to initialize auxiliary DeepSparse engine to "
"process a prompt with a larger processing length. "
"However, it is assumed that `prompt_processing_sequence_length` "
"However, it is assumed that `prompt_sequence_length` "
"is smaller than the `sequence_length`. "
"Adjust the `prompt_processing_sequence_length` "
"Adjust the `prompt_sequence_length` "
"argument accordingly."
)

Expand Down Expand Up @@ -299,9 +299,9 @@ def initialize_engines(
sampling_temperature=self.sampling_temperature,
deterministic=self.deterministic,
sequence_length=self.sequence_length,
input_ids_length=self.prompt_processing_sequence_length,
input_ids_length=self.prompt_sequence_length,
tokenizer=self.tokenizer,
use_deepsparse_cache=self.use_deepsparse_cache,
internal_kv_cache=self.internal_kv_cache,
)

if self.cache_support_enabled:
Expand All @@ -315,7 +315,7 @@ def initialize_engines(
sequence_length=self.sequence_length,
input_ids_length=1,
tokenizer=self.tokenizer,
use_deepsparse_cache=self.use_deepsparse_cache,
internal_kv_cache=self.internal_kv_cache,
)

assert (engine is not None) or (
Expand Down Expand Up @@ -542,14 +542,11 @@ def prompt_inference(
new_token = None
num_tokens_processed = 0

if (
len(tokens) > self.prompt_processing_sequence_length
and self.enable_multitoken_prefill
):
if len(tokens) > self.prompt_sequence_length and self.enable_multitoken_prefill:
self.multitoken_engine.reset_kv_cache()
for engine_inputs in self.engine_inputs_for_prefill(tokens):
new_token, new_logits = self.multitoken_engine(engine_inputs)
num_tokens_processed += self.prompt_processing_sequence_length
num_tokens_processed += self.prompt_sequence_length
prompt_logits.append(new_logits)

self.engine.reset_kv_cache()
Expand Down Expand Up @@ -620,7 +617,7 @@ def engine_inputs_for_prefill(
of engine_inputs for the multitoken engine.

1. The input tokens first get batched into chunks of
size self.prompt_processing_sequence_length. This is to
size self.prompt_sequence_length. This is to
ensure that they match the expected input size by the
multitoken engine. Any remaining tokens are discarded.

Expand All @@ -632,7 +629,7 @@ def engine_inputs_for_prefill(
that will have the amount of unmasked entries equal to
the sum of:
a) the number of tokens in the batch
(self.prompt_processing_sequence_length)
(self.prompt_sequence_length)
b) the number of non-blank cache entries
(num_non_blank_cache_entries)
so that the attention_mask properly attends to the
Expand All @@ -647,13 +644,11 @@ def engine_inputs_for_prefill(
:return: a generator of engine inputs
"""

num_batches = len(tokens) // self.prompt_processing_sequence_length
num_batches = len(tokens) // self.prompt_sequence_length

token_batches = [
tokens[
i
* self.prompt_processing_sequence_length : (i + 1)
* self.prompt_processing_sequence_length
i * self.prompt_sequence_length : (i + 1) * self.prompt_sequence_length
]
for i in range(0, num_batches)
]
Expand All @@ -676,7 +671,7 @@ def engine_inputs_for_prefill(
:,
-(
# ...the number of current input tokens...
self.prompt_processing_sequence_length
self.prompt_sequence_length
# ...and the number of the previous cache entries
+ num_cached_entries
) :,
Expand All @@ -688,7 +683,7 @@ def engine_inputs_for_prefill(
engine_input = (
numpy.arange(
num_cached_entries,
num_cached_entries + self.prompt_processing_sequence_length,
num_cached_entries + self.prompt_sequence_length,
)
.reshape(1, -1)
.astype(numpy.int64)
Expand Down
11 changes: 5 additions & 6 deletions src/deepsparse/transformers/utils/decoder_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@


class DecoderKVCache:
def __init__(self, use_deepsparse_cache: bool = False):
def __init__(self, internal_kv_cache: bool = False):
"""
The goal this object is to handle the manipulation
of the key value cache.

:param use_deepsparse_cache: If set to True, the
`kv_cache` object from the deepsparse.LIB will
be loaded as an engine_internal_cache attribute.
:param internal_kv_cache: If set to True, the `kv_cache` object
from the deepsparse.LIB will be loaded as an attribute.
This object is used to handle the manipulation of the
key/value buffers on the DeepSparse engine side.
"""
Expand All @@ -42,7 +41,7 @@ def __init__(self, use_deepsparse_cache: bool = False):
# assuming that kv cache arrays are of shape
# [batch_size, num_heads, sequence_length, hidden_size]
self._sequence_len_axis = SEQUENCE_LENGTH_AXIS
self._use_deepsparse_cache = use_deepsparse_cache
self._internal_kv_cache = internal_kv_cache
self._session_id = None
self._freeze_first_position = None
self._state = None
Expand Down Expand Up @@ -80,7 +79,7 @@ def setup(
self._freeze_first_position = freeze_first_position
self.total_num_processed_tokens = num_processed_tokens

if self._use_deepsparse_cache:
if self._internal_kv_cache:
prev_num_tokens = self.total_num_processed_tokens
num_frozen_tokens = int(self._freeze_first_position)
self.engine_internal_cache = LIB.kv_cache(
Expand Down
16 changes: 8 additions & 8 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _initialize_kv_cache_state(model, length=0):


@pytest.mark.parametrize(
"use_deepsparse_cache",
"internal_kv_cache",
[True, False],
)
@pytest.mark.parametrize(
Expand All @@ -82,19 +82,19 @@ def _initialize_kv_cache_state(model, length=0):
)
class TestTextGenerationPipeline:
@pytest.fixture
def setup(self, model_stub, model_name, uses_bos_token, use_deepsparse_cache):
def setup(self, model_stub, model_name, uses_bos_token, internal_kv_cache):

self.max_generated_tokens = 16
self.model = Model(model_stub)
self.use_deepsparse_cache = use_deepsparse_cache
self.internal_kv_cache = internal_kv_cache

pipeline = Pipeline.create(
task="text_generation",
model_path=model_stub,
sequence_length=32,
prompt_processing_sequence_length=4,
prompt_sequence_length=4,
max_generated_tokens=self.max_generated_tokens,
use_deepsparse_cache=self.use_deepsparse_cache,
internal_kv_cache=self.internal_kv_cache,
)
short_prompt = "this"
long_prompt = "this is a sample prompt that we will use to test the pipeline"
Expand All @@ -104,14 +104,14 @@ def setup(self, model_stub, model_name, uses_bos_token, use_deepsparse_cache):
# (DISABLED FOR NOW UNTIL WE HAVE ZOO CAUSAL MASK SUPPORT)
# assert (
# len(pipeline.tokenizer.tokenize(short_prompt)) + int(uses_bos_token)
# < pipeline.prompt_processing_sequence_length
# < pipeline.prompt_sequence_length
# )
# make sure that the long prompt will be processed by
# single token and multiple token engines
# (DISABLED FOR NOW UNTIL WE HAVE ZOO CAUSAL MASK SUPPORT)
# assert (
# len(pipeline.tokenizer.tokenize(long_prompt)) + int(uses_bos_token)
# > pipeline.prompt_processing_sequence_length * 3
# > pipeline.prompt_sequence_length * 3
# )

yield pipeline, model_name, uses_bos_token, short_prompt, long_prompt
Expand All @@ -137,7 +137,7 @@ def test_model_output_sequences(self, setup):

def test_model_output_cache(self, setup):
pipeline, model_name, _, short_prompt, long_prompt = setup
if self.use_deepsparse_cache:
if self.internal_kv_cache:
pytest.skip(
"Running pipeline with internal "
"deepsparse cache will not result "
Expand Down
Loading