Skip to content

Commit

Permalink
Server accepts file for pre-/post-processing functions (#1033)
Browse files Browse the repository at this point in the history
* Add support to read in custom pre-/post-processing functions

* * Style fixes
* Better Documentation

* Changed print message -> log info message
Updated documentation as requested by @bfineran

* fix small typo
  • Loading branch information
rahul-tuli committed May 25, 2023
1 parent 3bf3e84 commit 6029f89
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 5 deletions.
57 changes: 56 additions & 1 deletion docs/user-guide/deepsparse-server.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,62 @@ Check out the [Use Case](../use-cases) page for detailed documentation on task-s

## Custom Use Cases

Stay tuned for documentation on using a custom DeepSparse Pipeline within the Server!
The endpoints can also take in a custom task, along with custom preprocess and postprocessing functions

```yaml
# custom-processing-config.yaml

endpoints:
- task: custom
model: ~/models/resnet50.onnx
kwargs:
processing_file: ~/processing.py
```

Where `model` must be a valid onnx model that exists on the system, and `processing_file` must be a
valid python file contain pre- and/or post-processing functions, the `preprocess` function must return
a list of `numpy.ndarray`(s) and the `postprocess` function must take in a list of `numpy.ndarray`(s) for example:

(make sure you have torchvision installed for this exact example)

```python
# processing.py

from torchvision import transforms
from PIL import Image
import torch
from typing import List

IMAGENET_RGB_MEANS = [0.485, 0.456, 0.406]
IMAGENET_RGB_STDS = [0.229, 0.224, 0.225]
preprocess_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_RGB_MEANS, std=IMAGENET_RGB_STDS),
])

def preprocess(img_file) -> List["numpy.ndarray"]:
with open(img_file, "rb") as img_file:
img = Image.open(img_file)
img = img.convert("RGB")
img = preprocess_transforms(img)
batch = torch.stack([img])
return [batch.numpy()]

def postprocess(outputs: List["numpy.ndarray"]):
return outputs
```

Spinning up:

```bash
deepsparse.server \
--config-file custom-processing-config.yaml
```

Now the custom preprocess and postprocess functions will be used when
requests are made to this server!

## Multi-Stream

Expand Down
63 changes: 59 additions & 4 deletions src/deepsparse/pipelines/custom_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import numpy
Expand All @@ -21,20 +22,23 @@
from deepsparse.utils.onnx import model_to_path


_LOGGER = logging.getLogger(__name__)


@Pipeline.register(task="custom")
class CustomTaskPipeline(Pipeline):
"""
A utility class provided to make specifying custom pipelines easier.
Instead of creating a subclass of Pipeline, you can instantiate this directly
by passing in functions to call for pre and post processing.
by passing in functions to call for pre- and post-processing.
The easiest way to use this class is to just pass in the model path, which
lets use directly interact with engine inputs/outputs:
```python
pipeline = CustomPipeline(model_path="...")
```
Alternatively, you can pass the pre/post processing functions into
Alternatively, you can pass the pre-/post-processing functions into
the constructor:
```python
def yolo_preprocess(inputs: YOLOInput) -> List[np.ndarray]:
Expand All @@ -52,6 +56,27 @@ def yolo_postprocess(engine_outputs: List[np.ndarray]) -> YOLOOutput:
)
```
Alternatively, you can also pass a processing file in kwargs containing
`preprocess` and `postprocess` functions into the constructor:
`processing.py`
```python
def preprocess(inputs: YOLOInput) -> List[np.ndarray]:
...
def postprocess(engine_outputs: List[np.ndarray]):
...
```
```python
yolo = CustomPipeline(
model_path="...",
input_schema=YOLOInput,
output_schema=YOLOOutput,
kwargs={"processing_file": "processing.py"},
)
```
:param model_path: path on local system or SparseZoo stub to load the model from.
Passed to :class:`Pipeline`.
:param input_schema: Optional pydantic schema that describes the input to
Expand All @@ -63,7 +88,7 @@ def yolo_postprocess(engine_outputs: List[np.ndarray]) -> YOLOOutput:
mapsan `InputSchema` object to a list of numpy arrays that can be directly
passed into the forward pass of the pipeline engine. If `None`, raw data is
passed to the engine.
:param process_outputs_fn: Optional callable (function, method, lambda, etc) that
:param process_outputs_fn: Optional callable (function, method, lambda, etc.) that
maps the list of numpy arrays that are the output of the engine forward pass
into an `OutputSchema` object. If `None`, engine outputs are directly returned.
"""
Expand Down Expand Up @@ -97,6 +122,15 @@ def __init__(
f"output_schema must subclass BaseModel. Found {output_schema}"
)

processing_file = kwargs.pop("processing_file", None)
if processing_file is not None:
(
process_inputs_fn,
process_outputs_fn,
) = self._read_processing_functions_from_file(
processing_file=processing_file
)

if process_inputs_fn is None:
process_inputs_fn = _passthrough

Expand Down Expand Up @@ -148,6 +182,27 @@ def process_engine_outputs(
"""
return self._process_outputs_fn(engine_outputs, **kwargs)

def _read_processing_functions_from_file(self, processing_file: str):
"""
Parses the file containing the `preprocess` and `postprocess` functions
:pre-condition: The file is a valid `.py` file that exists and may
contain a preprocess and a postprocess function
:param processing_file: The path to the file containing the preprocess
and postprocess functions
:return: The preprocess and postprocess functions from the file
"""
_LOGGER.info(
"Overriding preprocess and postprocess "
f"functions using {processing_file}"
)
spec = importlib.util.spec_from_file_location(
"custom_processing_functions", processing_file
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return getattr(module, "preprocess", None), getattr(module, "postprocess", None)


def _passthrough(x, **kwargs):
return x

0 comments on commit 6029f89

Please sign in to comment.