Skip to content

Commit

Permalink
[Add] sequence length to text gen pipelines (#1518)
Browse files Browse the repository at this point in the history
* Add sequence length to text gen pipelines

* Revert try except block as `sequence_length` is now expoosed by the pipeline
  • Loading branch information
rahul-tuli committed Jan 9, 2024
1 parent e1bae5c commit 5004edd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,7 @@ def __init__(
self.tokenizer = tokenizer if tokenizer else self.model.tokenizer

self._batch_size = batch_size
try:
self._max_length = pipeline.sequence_length
except Exception:
# workaround until the DeepSparse pipeline exposes the sequence_length
self._max_length = pipeline.ops["single_engine"].sequence_length

self._max_length = pipeline.sequence_length
self._max_gen_toks = max_gen_toks or 256

self.vocab_size = self.tokenizer.vocab_size
Expand Down
10 changes: 10 additions & 0 deletions src/deepsparse/transformers/pipelines/text_generation/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,16 @@ def expand_inputs(self, items, batch_size):
def condense_inputs(self, *args, **kwargs):
return args[0], kwargs

@property
def sequence_length(self) -> int:
"""
Property to return the sequence length for the pipeline.
(relies on the single engine operator)
:return: the sequence length for the pipeline
"""
return self.ops["single_engine"].sequence_length

def _get_continuous_batching_scheduler(
self, batch_sizes: List[int], engines: List[EngineOperator]
) -> ContinuousBatchingScheduler:
Expand Down

0 comments on commit 5004edd

Please sign in to comment.