Skip to content

Commit

Permalink
[Text Generation][Enhancement] If prompt_processing_sequence_length =…
Browse files Browse the repository at this point in the history
…= 1, do not inititalize multitoken_engine(#1214)
  • Loading branch information
dbogunowicz committed Aug 30, 2023
1 parent 4aeb848 commit d8b63da
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,14 @@ def __init__(
_delay_engine_initialize=True,
_delay_overwriting_inputs=True,
)
self.enable_multitoken_prefill = self.causal_mask_input_present(
model_path=self.onnx_file_path
# enable multitoken prefill if
# - the model graph is supporting it (causal_mask input is present)
# - prompt_processing_sequence_length != 1 (identical to single-token prefill)
self.enable_multitoken_prefill = (
self.causal_mask_input_present(model_path=self.onnx_file_path)
and prompt_processing_sequence_length > 1
)

self.cache_support_enabled = self.is_cache_support_enabled()

if self.engine_type == DEEPSPARSE_ENGINE:
Expand Down Expand Up @@ -680,20 +685,14 @@ def engine_inputs_for_prefill(
# delay creation of the causal mask
continue
elif name == "positions":
if self.prompt_processing_sequence_length == 1:
# we need to treat `positions` as if we were in
# the autoregressive mode
engine_input = numpy.array([[idx]], dtype=numpy.int64)
else:
engine_input = (
numpy.arange(
num_cached_entries,
num_cached_entries
+ self.prompt_processing_sequence_length,
)
.reshape(1, -1)
.astype(numpy.int64)
engine_input = (
numpy.arange(
num_cached_entries,
num_cached_entries + self.prompt_processing_sequence_length,
)
.reshape(1, -1)
.astype(numpy.int64)
)

engine_inputs.append(engine_input)

Expand Down

0 comments on commit d8b63da

Please sign in to comment.