Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feature/damian/benchm…
Browse files Browse the repository at this point in the history
…ark_llm
  • Loading branch information
dbogunowicz committed Jul 25, 2023
2 parents 0fe9f7e + 2be7087 commit f5a06ba
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 26 deletions.
31 changes: 29 additions & 2 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,15 @@ def __call__(self, *args, **kwargs) -> BaseModel:
# ------ INFERENCE ------
# split inputs into batches of size `self._batch_size`
timer.start(InferenceStages.ENGINE_FORWARD)
batches, orig_batch_size = split_engine_inputs(
batches, orig_batch_size = self.split_engine_inputs(
engine_inputs, self._batch_size
)

# submit split batches to engine threadpool
batch_outputs = list(self.executor.map(self.engine_forward, batches))

# join together the batches of size `self._batch_size`
engine_outputs = join_engine_outputs(batch_outputs, orig_batch_size)
engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size)
timer.stop(InferenceStages.ENGINE_FORWARD)

self.log(
Expand Down Expand Up @@ -458,6 +458,33 @@ def to_config(self) -> "PipelineConfig":
kwargs=kwargs,
)

def join_engine_outputs(
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
) -> List[numpy.ndarray]:
"""
Joins list of engine outputs together into one list.
This is the opposite of `split_engine_inputs` and is meant to be used in tandem.
:param batch_outputs: list of engine outputs
:param orig_batch_size: original batch size of the inputs
:return: list of engine outputs joined together
"""
return join_engine_outputs(batch_outputs, orig_batch_size)

def split_engine_inputs(
self, items: List[numpy.ndarray], batch_size: int
) -> List[List[numpy.ndarray]]:
"""
Splits each item into numpy arrays with the first dimension == `batch_size`.
This is the opposite of `join_engine_outputs` and is meant to be used in tandem.
:param items: size of each batch to split into
:param batch_size: size of each batch to enforce
:return: list of batches, where each batch is a list of numpy arrays
"""
return split_engine_inputs(items, batch_size)

def engine_forward(self, engine_inputs: List[numpy.ndarray]) -> List[numpy.ndarray]:
"""
:param engine_inputs: list of numpy inputs to Pipeline engine forward
Expand Down
5 changes: 3 additions & 2 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def __call__(
else:
logits = out[0]

token = self.generate_token(logits=logits[:, -1, :])
# select batch idx 0, batch is always 1
token = self.generate_token(logits=logits[0, -1, :])

return token, logits

Expand Down Expand Up @@ -193,7 +194,7 @@ def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:

logits /= self.sampling_temperature

probs = softmax(logits)
probs = numpy_softmax(logits)

return numpy.random.choice(len(probs), p=probs)

Expand Down
7 changes: 4 additions & 3 deletions src/deepsparse/transformers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""


from itertools import compress
from typing import Any, Dict, List, Optional

import numpy
Expand Down Expand Up @@ -89,10 +90,10 @@ def add_batch(self, predictions: List[str]):
# with <PAD> tokens from the left side. We need to remove
# them and zero-pad from the right side up to the length
# of the longest sequence in the batch
encoded_batch = numpy.array(encoded_batch) * numpy.array(attention_mask)

encoded_batch = [
list(filter(lambda num: num != 0, sequence))
for sequence in encoded_batch
list(compress(sequence, attn_mask))
for (sequence, attn_mask) in zip(encoded_batch, attention_mask)
]
max_sequence_len = max([len(sequence) for sequence in encoded_batch])

Expand Down
3 changes: 2 additions & 1 deletion src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def has_cache(self) -> bool:

@staticmethod
def join_engine_outputs(
batch_outputs: List[List[numpy.ndarray]],
batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
) -> List[numpy.ndarray]:
"""
Takes a list of outputs (batches) from the engine
Expand All @@ -479,6 +479,7 @@ def join_engine_outputs(
they can be concatenated.
:param batch_outputs: A list of outputs from the engine
:param orig_batch_size: The original batch size
:return: A list of joined outputs
"""
tokens, logits = zip(*batch_outputs)
Expand Down
17 changes: 0 additions & 17 deletions src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"overwrite_onnx_model_inputs_for_kv_cache_models",
"generate_session_id",
"pad_to_fixed_length",
"softmax",
]

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -107,22 +106,6 @@ def generate_session_id() -> str:
return session_id


def softmax(x: numpy.ndarray) -> numpy.ndarray:
"""
Compute softmax values for x. This function is
against overflow/underflow by using the
trick of shifting the input vector by subtracting
the maximum element in it from all elements
:param x: input array
:return: softmax values
"""
z = x - max(x)
numerator = numpy.exp(z)
denominator = numpy.sum(numerator)
return numerator / denominator


def pad_to_fixed_length(
array: numpy.ndarray, max_len: int, axis: int = 0, value: int = 0
) -> numpy.ndarray:
Expand Down
11 changes: 10 additions & 1 deletion src/deepsparse/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def numpy_softmax(x: numpy.ndarray, axis: int = 0):

def split_engine_inputs(
items: List[numpy.ndarray], batch_size: int
) -> Tuple[List[List[numpy.ndarray]], int]:
) -> List[List[numpy.ndarray]]:
"""
Splits each item into numpy arrays with the first dimension == `batch_size`.
Expand Down Expand Up @@ -200,6 +200,11 @@ def split_engine_inputs(
In the case where the total input batch size isn't divisble by `batch_size`, it
will pad the last mini batch. Look at `padding_is_needed`
:param items: list of numpy arrays to split
:param batch_size: size of each batch to split into
:return: list of batches, where each batch is a list of numpy arrays
"""
# The engine expects to recieve data in numpy format, so at this point it should be
assert all(isinstance(item, numpy.ndarray) for item in items)
Expand Down Expand Up @@ -242,6 +247,10 @@ def join_engine_outputs(
the remainder as padding.
This is the opposite of `split_engine_inputs` and is meant to be used in tandem.
:param batch_outputs: List of engine outputs
:param orig_batch_size: The original batch size of the inputs
:return: List of engine outputs joined together
"""
assert all(isinstance(item, (List, Tuple)) for item in batch_outputs)

Expand Down

0 comments on commit f5a06ba

Please sign in to comment.