Skip to content

Commit

Permalink
[Text Generation][MPT] Filter out the appropriate engine_inputs for a…
Browse files Browse the repository at this point in the history
…utoregressive_inference (#1151)
  • Loading branch information
dbogunowicz committed Aug 1, 2023
1 parent 758fc05 commit 3254ca8
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def autoregressive_inference(
:return: The new, generated token and the logits for the new token
(with dimensions ['batch_size', 'num_tokens', 'vocab_size'])
"""

new_token = tokens[-1]
# padding is added to left, so attention mask is 1s from the
# right up to the number of total tokens (prompt + generated)
Expand All @@ -444,7 +445,17 @@ def autoregressive_inference(
positions -= 1
input_ids = numpy.array([[new_token]])
causal_mask = create_causal_mask(input_ids, attention_mask)
engine_inputs = [input_ids, attention_mask, positions, causal_mask]

# filter out the inputs that are not needed by the engine
engine_inputs_map = dict(
input_ids=input_ids,
attention_mask=attention_mask,
causal_mask=causal_mask,
positions=positions,
)
engine_inputs = [
engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache
]

generated_token, generated_logits = self.engine(engine_inputs)

Expand Down

0 comments on commit 3254ca8

Please sign in to comment.