Skip to content

Commit

Permalink
Add batched_run to Engine (#1100)
Browse files Browse the repository at this point in the history
* Add batched_run to Engine

* Refactor and add a test

* bad return

* Refactor the Pipeline split/join into utils/data.py and write tests

* Format

* Fix yolact

* rebase and quality

* Reset some state for test_onnx

* .

* Failing test shouldn't fail anymore!

* Try slicing outputs differently

* Format
  • Loading branch information
mgoin authored Jul 11, 2023
1 parent 286574f commit 3f9ff51
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 140 deletions.
22 changes: 22 additions & 0 deletions src/deepsparse/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
from deepsparse.utils import (
generate_random_inputs,
get_output_names,
join_engine_outputs,
model_to_path,
override_onnx_input_shapes,
split_engine_inputs,
)


Expand Down Expand Up @@ -440,6 +442,26 @@ def generate_random_inputs(self) -> List[numpy.ndarray]:
"""
return generate_random_inputs(self.model_path, self.batch_size)

def batched_run(
self,
inp: List[numpy.ndarray],
) -> List[numpy.ndarray]:

if self.batch_size == 1:
_LOGGER.warn(
"Using batched_run with an Engine of batch_size=1 isn't recommended "
"for optimal performance."
)

# Split inputs into batches of size `self.batch_size`
batch_inputs, orig_batch_size = split_engine_inputs(inp, self.batch_size)

# Submit split batches to engine threadpool
batch_outputs = list(map(self.run, batch_inputs))

# Join together the batches of size `self.batch_size`
return join_engine_outputs(batch_outputs, orig_batch_size)

def run(
self,
inp: List[numpy.ndarray],
Expand Down
79 changes: 11 additions & 68 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@
validate_identifier,
)
from deepsparse.tasks import SupportedTasks, dynamic_import_task
from deepsparse.utils import InferenceStages, StagedTimer, TimerManager
from deepsparse.utils import (
InferenceStages,
StagedTimer,
TimerManager,
join_engine_outputs,
split_engine_inputs,
)


__all__ = [
Expand Down Expand Up @@ -256,13 +262,15 @@ def __call__(self, *args, **kwargs) -> BaseModel:
# ------ INFERENCE ------
# split inputs into batches of size `self._batch_size`
timer.start(InferenceStages.ENGINE_FORWARD)
batches = self.split_engine_inputs(engine_inputs, self._batch_size)
batches, orig_batch_size = split_engine_inputs(
engine_inputs, self._batch_size
)

# submit split batches to engine threadpool
batch_outputs = list(self.executor.map(self.engine_forward, batches))

# join together the batches of size `self._batch_size`
engine_outputs = self.join_engine_outputs(batch_outputs)
engine_outputs = join_engine_outputs(batch_outputs, orig_batch_size)
timer.stop(InferenceStages.ENGINE_FORWARD)

self.log(
Expand Down Expand Up @@ -302,71 +310,6 @@ def __call__(self, *args, **kwargs) -> BaseModel:

return pipeline_outputs

@staticmethod
def split_engine_inputs(
items: List[numpy.ndarray], batch_size: int
) -> List[List[numpy.ndarray]]:
"""
Splits each item into numpy arrays with the first dimension == `batch_size`.
For example, if `items` has three numpy arrays with the following
shapes: `[(4, 32, 32), (4, 64, 64), (4, 128, 128)]`
Then with `batch_size==4` the output would be:
```
[[(4, 32, 32), (4, 64, 64), (4, 128, 128)]]
```
Then with `batch_size==2` the output would be:
```
[
[(2, 32, 32), (2, 64, 64), (2, 128, 128)],
[(2, 32, 32), (2, 64, 64), (2, 128, 128)],
]
```
Then with `batch_size==1` the output would be:
```
[
[(1, 32, 32), (1, 64, 64), (1, 128, 128)],
[(1, 32, 32), (1, 64, 64), (1, 128, 128)],
[(1, 32, 32), (1, 64, 64), (1, 128, 128)],
[(1, 32, 32), (1, 64, 64), (1, 128, 128)],
]
```
"""
# if not all items here are numpy arrays, there's an internal
# but in the processing code
assert all(isinstance(item, numpy.ndarray) for item in items)

# if not all items have the same batch size, there's an
# internal bug in the processing code
total_batch_size = items[0].shape[0]
assert all(item.shape[0] == total_batch_size for item in items)

if total_batch_size % batch_size != 0:
raise RuntimeError(
f"batch size of {total_batch_size} passed into pipeline "
f"is not divisible by model batch size of {batch_size}"
)

batches = []
for i_batch in range(total_batch_size // batch_size):
start = i_batch * batch_size
batches.append([item[start : start + batch_size] for item in items])
return batches

@staticmethod
def join_engine_outputs(
batch_outputs: List[List[numpy.ndarray]],
) -> List[numpy.ndarray]:
"""
Joins list of engine outputs together into one list using `numpy.concatenate`.
This is the opposite of `Pipeline.split_engine_inputs`.
"""
return list(map(numpy.concatenate, zip(*batch_outputs)))

@staticmethod
def _get_task_constructor(task: str) -> Type["Pipeline"]:
"""
Expand Down
93 changes: 92 additions & 1 deletion src/deepsparse/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
import re
from typing import List
from typing import List, Tuple

import numpy

Expand All @@ -25,6 +25,8 @@
"verify_outputs",
"parse_input_shapes",
"numpy_softmax",
"split_engine_inputs",
"join_engine_outputs",
]


Expand Down Expand Up @@ -162,3 +164,92 @@ def numpy_softmax(x: numpy.ndarray, axis: int = 0):
e_x_sum = numpy.sum(e_x, axis=axis, keepdims=True)
softmax_x = e_x / e_x_sum
return softmax_x


def split_engine_inputs(
items: List[numpy.ndarray], batch_size: int
) -> Tuple[List[List[numpy.ndarray]], int]:
"""
Splits each item into numpy arrays with the first dimension == `batch_size`.
For example, if `items` has three numpy arrays with the following
shapes: `[(4, 32, 32), (4, 64, 64), (4, 128, 128)]`
Then with `batch_size==4` the output would be:
```
[[(4, 32, 32), (4, 64, 64), (4, 128, 128)]]
```
Then with `batch_size==2` the output would be:
```
[
[(2, 32, 32), (2, 64, 64), (2, 128, 128)],
[(2, 32, 32), (2, 64, 64), (2, 128, 128)],
]
```
Then with `batch_size==1` the output would be:
```
[
[(1, 32, 32), (1, 64, 64), (1, 128, 128)],
[(1, 32, 32), (1, 64, 64), (1, 128, 128)],
[(1, 32, 32), (1, 64, 64), (1, 128, 128)],
[(1, 32, 32), (1, 64, 64), (1, 128, 128)],
]
```
In the case where the total input batch size isn't divisble by `batch_size`, it
will pad the last mini batch. Look at `padding_is_needed`
"""
# The engine expects to recieve data in numpy format, so at this point it should be
assert all(isinstance(item, numpy.ndarray) for item in items)

# Check that all inputs have the same batch size
total_batch_size = items[0].shape[0]
if not all(arr.shape[0] == total_batch_size for arr in items):
raise ValueError("Not all inputs have matching batch size")

batches = []
for section_idx in range(0, total_batch_size, batch_size):
padding_is_needed = section_idx + batch_size > total_batch_size
if padding_is_needed:
# If we can't evenly divide with batch size, pad the last batch
input_sections = []
for arr in items:
pads = ((0, section_idx + batch_size - total_batch_size),) + (
(0, 0),
) * (arr.ndim - 1)
section = numpy.pad(
arr[section_idx : section_idx + batch_size], pads, mode="edge"
)
input_sections.append(section)
batches.append(input_sections)
else:
# Otherwise we just take our slice as the batch
batches.append(
[arr[section_idx : section_idx + batch_size] for arr in items]
)

return batches, total_batch_size


def join_engine_outputs(
batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
) -> List[numpy.ndarray]:
"""
Joins list of engine outputs together into one list using `numpy.stack`.
If the batch size doesn't evenly divide into the available batches, it will cut off
the remainder as padding.
This is the opposite of `split_engine_inputs` and is meant to be used in tandem.
"""
assert all(isinstance(item, List) for item in batch_outputs)

candidate_output = list(map(numpy.concatenate, zip(*batch_outputs)))

# If we can't evenly divide with batch size, remove the remainder as padding
if candidate_output[0].shape[0] > orig_batch_size:
for i in range(len(candidate_output)):
candidate_output[i] = candidate_output[i][:orig_batch_size]

return candidate_output
27 changes: 9 additions & 18 deletions src/deepsparse/yolact/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,24 +166,6 @@ def process_inputs(
)
return [image_batch], postprocessing_kwargs

@staticmethod
def join_engine_outputs(
batch_outputs: List[List[numpy.ndarray]],
) -> List[numpy.ndarray]:
boxes, confidence, masks, priors, protos = Pipeline.join_engine_outputs(
batch_outputs
)

# priors never has a batch dimension
# so the above step doesn't concat along a batch dimension
# reshape into a batch dimension
num_priors = boxes.shape[1]
batch_priors = numpy.reshape(priors, (-1, num_priors, 4))

# all the priors should be equal, so only use the first one
assert (batch_priors == batch_priors[0]).all()
return [boxes, confidence, masks, batch_priors[0], protos]

def _preprocess_image(self, image) -> numpy.ndarray:
if isinstance(image, str):
image = cv2.imread(image)
Expand All @@ -209,6 +191,15 @@ def process_engine_outputs(
priors = torch.from_numpy(priors).cpu()
protos = torch.from_numpy(protos).cpu()

# priors never has a batch dimension
# so the above step doesn't concat along a batch dimension
# reshape into a batch dimension
# all the priors should be equal, so only use the first one
num_priors = boxes.shape[1]
batch_priors = numpy.reshape(priors, (-1, num_priors, 4))
assert (batch_priors == batch_priors[0]).all()
priors = batch_priors[0]

batch_size, num_priors, _ = boxes.size()

# Preprocess every image in the batch individually
Expand Down
47 changes: 5 additions & 42 deletions tests/deepsparse/pipelines/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,10 @@

import numpy

import pytest
from deepsparse.pipeline import Pipeline, _initialize_executor_and_workers
from tests.utils import mock_engine


def test_split_engine_inputs():
inp = [numpy.zeros((4, 28)) for _ in range(3)]

out = Pipeline.split_engine_inputs(inp, batch_size=4)
assert numpy.array(out).shape == (1, 3, 4, 28)

out = Pipeline.split_engine_inputs(inp, batch_size=2)
assert numpy.array(out).shape == (2, 3, 2, 28)

out = Pipeline.split_engine_inputs(inp, batch_size=1)
assert numpy.array(out).shape == (4, 3, 1, 28)


def test_join_opposite_of_split():
inp = [numpy.random.rand(4, 28) for _ in range(3)]

out = Pipeline.split_engine_inputs(inp, batch_size=2)
assert numpy.array(out).shape == (2, 3, 2, 28)

joined = Pipeline.join_engine_outputs(out)
assert numpy.array(joined).shape == (3, 4, 28)

for i, j in zip(inp, joined):
assert (i == j).all()


def test_split_engine_inputs_uneven_raises_error():
with pytest.raises(
RuntimeError,
match=(
"batch size of 3 passed into pipeline "
"is not divisible by model batch size of 2"
),
):
Pipeline.split_engine_inputs([numpy.zeros((3, 28))], batch_size=2)


@mock_engine(rng_seed=0)
def test_split_interaction_with_forward_batch_size_1(engine_mock):
pipeline = Pipeline.create("token_classification", batch_size=1)
Expand All @@ -82,14 +44,15 @@ def test_split_interaction_with_forward_batch_size_2(engine_forward):
with mock.patch.object(
Pipeline, "engine_forward", wraps=pipeline.engine_forward
) as engine_forward:
with pytest.raises(RuntimeError, match="is not divisible"):
pipeline("word")
# this is okay because we can pad batches
pipeline("word")
assert engine_forward.call_count == 1

pipeline("two words".split())
assert engine_forward.call_count == 1
assert engine_forward.call_count == 2

pipeline("two words for me".split())
assert engine_forward.call_count == 3
assert engine_forward.call_count == 4


def test_pipeline_executor_num_workers():
Expand Down
Loading

0 comments on commit 3f9ff51

Please sign in to comment.