Skip to content

Commit

Permalink
fix breaking tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Jul 27, 2023
1 parent 3987e1b commit e599583
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
17 changes: 11 additions & 6 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def prompt_inference(
if len(tokens) > self.prompt_processing_sequence_length:
for engine_inputs in self.engine_inputs_for_prefill(tokens):
new_token, new_logits = self.multitoken_engine(engine_inputs)
num_tokens_processed = self.prompt_processing_sequence_length
num_tokens_processed += self.prompt_processing_sequence_length
prompt_logits.append(new_logits)

if num_tokens_processed:
Expand Down Expand Up @@ -523,11 +523,16 @@ def engine_inputs_for_prefill(
# delay creation of the causal mask
continue
elif name == "positions":
engine_input = (
numpy.arange(self.prompt_processing_sequence_length)
.reshape(1, -1)
.astype(numpy.int64)
)
if self.prompt_processing_sequence_length == 1:
# we need to treat `positions` as if we were in
# the autoregressive mode
engine_input = numpy.array([[idx]], dtype=numpy.int64)
else:
engine_input = (
numpy.arange(self.prompt_processing_sequence_length)
.reshape(1, -1)
.astype(numpy.int64)
)

engine_inputs.append(engine_input)

Expand Down
2 changes: 1 addition & 1 deletion src/deepsparse/transformers/utils/decoder_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def set_capacity(self, capacity: int):
state = self._add_entries(state, indices=indices)

else:
pass
return

self._state = state

Expand Down

0 comments on commit e599583

Please sign in to comment.