Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V2 Pipeline] Middleware manager #1471

Merged
merged 21 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/deepsparse/middlewares/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,28 @@
import threading
from typing import Any, Callable, Dict, Iterator, Optional, Protocol, Sequence, Type

from deepsparse.operators import Operator


horheynm marked this conversation as resolved.
Show resolved Hide resolved
class MiddlewareCallable(Protocol):
horheynm marked this conversation as resolved.
Show resolved Hide resolved
"""
Newly created middlewares should inherit this class
"""

def __call__(self, *args, **kwargs):
"""
Pipeline, Operator callable will be overwritten
and will eventually its callable
horheynm marked this conversation as resolved.
Show resolved Hide resolved
"""
...
horheynm marked this conversation as resolved.
Show resolved Hide resolved

def send(self, dct: Dict):
def send(self, reducer: Callable[[Dict], Dict]):
"""
Update middleware Manager state
Logic defined in MiddlewareManager._update_middleware_spec_send

:param reducer: A callable that contains logic to update
the middleware state
"""
...
horheynm marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -108,12 +117,18 @@ def _update_middleware_spec_send(
if middleware is not None:
for next_middleware, init_args in middleware:

# allow the middleware to communivate with the manager
# allow the middleware to communicate with the manager
next_middleware.send = self.recieve

self.middleware.append(MiddlewareSpec(next_middleware, **init_args))

def build_middleware_stack(self, next_call: Callable):
horheynm marked this conversation as resolved.
Show resolved Hide resolved
for middleware, init_args in reversed(self.middleware):
next_call = middleware(next_call, **init_args)
if self.middleware is not None:
for middleware, init_args in reversed(self.middleware):
next_call = middleware(next_call, **init_args)
return next_call

def wrap(self, operator: Operator) -> Callable:
"""Add middleware to the operator"""
wrapped_operator = self.build_middleware_stack(operator)
return wrapped_operator
18 changes: 1 addition & 17 deletions src/deepsparse/operators/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

from pydantic import BaseModel

from deepsparse.middlewares import MiddlewareManager
from deepsparse.operators.registry import OperatorRegistry
from deepsparse.utils import InferenceState


__all__ = ["Operator"]


class BaseOperator(ABC):
class Operator(ABC):
"""
Base operator class - an operator should be defined for each atomic, functional
part of the pipeline.
Expand Down Expand Up @@ -138,18 +137,3 @@ def yaml(self):

def json(self):
pass


class Operator(BaseOperator):
def __init__(
self, middleware_manager: Optional[MiddlewareManager] = None, *args, **kwargs
):
self.middleware_manager = middleware_manager
super().__init__(*args, **kwargs)

def __call__(self, *args, **kwargs):
next_call = super().__call__
if self.middleware_manager is not None:
next_call = self.middleware_manager.build_middleware_stack(next_call)

return next_call(*args, **kwargs)
42 changes: 22 additions & 20 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
]


class BasePipeline(Operator):
class Pipeline(Operator):
"""
Pipeline accepts a series of operators, schedulers, and a router. Calling a pipeline
will use the router to run through all the defined operators. The operators should
Expand All @@ -71,13 +71,15 @@ def __init__(
schedulers: List[OperatorScheduler],
continuous_batching_scheduler: Optional[ContinuousBatchingScheduler] = None,
pipeline_state: Optional[PipelineState] = None,
middleware_manager: Optional[MiddlewareManager] = None,
):

self.ops = ops
self.router = router
self.schedulers = schedulers
self.pipeline_state = pipeline_state
self._continuous_batching_scheduler = continuous_batching_scheduler
self.middleware_manager = middleware_manager
horheynm marked this conversation as resolved.
Show resolved Hide resolved
self.validate()

self._scheduler_group = SchedulerGroup(self.schedulers)
Expand All @@ -94,7 +96,7 @@ def _run_next(
else:
func = self._scheduler_group.submit

return run_func(
return self.run_func_with_middleware(
func=func,
operator=self.ops[next_step],
inp=inp,
Expand Down Expand Up @@ -196,7 +198,7 @@ async def run_async(self, *args, inference_state: InferenceState, **kwargs):
return operator_output

if next_step == self.router.START_ROUTE:
outputs = run_func(
outputs = self.run_func_with_middleware(
*args,
func=self._scheduler_group.submit,
operator=self.ops[next_step],
Expand Down Expand Up @@ -308,7 +310,7 @@ def run(
return operator_output

if next_step == self.router.START_ROUTE:
operator_output = run_func(
operator_output = self.run_func_with_middleware(
horheynm marked this conversation as resolved.
Show resolved Hide resolved
*args,
func=self._scheduler_group.submit,
operator=self.ops[next_step],
Expand Down Expand Up @@ -361,7 +363,11 @@ def __call__(self, *args, **kwargs):

kwargs["inference_state"] = inference_state

return self.run(*args, **kwargs)
next_call = self.run
if self.middleware_manager is not None:
next_call = self.middleware_manager.build_middleware_stack(next_call)
horheynm marked this conversation as resolved.
Show resolved Hide resolved
horheynm marked this conversation as resolved.
Show resolved Hide resolved

return next_call(*args, **kwargs)

def expand_inputs(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -397,6 +403,17 @@ def validate(self):
elif isinstance(router_validation, str):
raise ValueError(f"Invalid Router for operators: {router_validation}")

def run_func_with_middleware(
horheynm marked this conversation as resolved.
Show resolved Hide resolved
self,
*args,
operator: Operator,
**kwargs,
):
wrapped_operator = operator
if self.middleware_manager is not None:
wrapped_operator = self.middleware_manager.wrap(operator)
return run_func(*args, operator=wrapped_operator, **kwargs)


def text_generation_pipeline(*args, **kwargs) -> "Pipeline":
"""
Expand Down Expand Up @@ -503,18 +520,3 @@ def zero_shot_text_classification_pipeline(*args, **kwargs) -> "Pipeline":
is returned depends on the value of the passed model_scheme argument.
"""
return Pipeline.create("zero_shot_text_classification", *args, **kwargs)


class Pipeline(BasePipeline):
def __init__(
self, middleware_manager: Optional[MiddlewareManager] = None, *args, **kwargs
):
self.middleware_manager = middleware_manager
super().__init__(*args, **kwargs)

def __call__(self, *args, **kwargs):
next_call = super().__call__
if self.middleware_manager is not None:
next_call = self.middleware_manager.build_middleware_stack(next_call)

return next_call(*args, **kwargs)
22 changes: 22 additions & 0 deletions tests/deepsparse/middlewares/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# flake8: noqa

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tests.deepsparse.middlewares.utils import (
DummyMiddleware,
PrintingMiddleware,
ReducerMiddleware,
SendStateMiddleware,
)
36 changes: 36 additions & 0 deletions tests/deepsparse/middlewares/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,39 @@ def reducer(self, state: Dict, *args, **kwargs):
state[name] = []
state[name].append(args[0])
return state


class PrintingMiddleware(MiddlewareCallable):
def __init__(
self, call_next: MiddlewareCallable, identifier: str = "PrintingMiddleware"
):
self.identifier: str = identifier
self.call_next: MiddlewareCallable = call_next

def __call__(self, *args, **kwargs) -> Any:
print(f"{self.identifier}: before call_next")
result = self.call_next(*args, **kwargs)
print(f"{self.identifier}: after call_next: {result}")
return result


class SendStateMiddleware(MiddlewareCallable):
def __init__(
self, call_next: MiddlewareCallable, identifier: str = "SendStateMiddleware"
):
self.identifier: str = identifier
self.call_next: MiddlewareCallable = call_next

def __call__(self, *args, **kwargs) -> Any:
self.send(self.reducer, 0)
result = self.call_next(*args, **kwargs)
self.send(self.reducer, 1)

return result

def reducer(self, state: Dict, *args, **kwargs):
name = self.__class__.__name__
if name not in state:
state[name] = []
state[name].append(args[0])
return state
77 changes: 77 additions & 0 deletions tests/deepsparse/pipelines/test_middleware_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Simple example and testing middlewares in Ops and Pipeline
"""

from typing import Dict

from pydantic import BaseModel

from deepsparse import Pipeline
from deepsparse.middlewares import MiddlewareManager, MiddlewareSpec
from deepsparse.operators import Operator
from deepsparse.routers import LinearRouter
from deepsparse.schedulers import OperatorScheduler
from tests.deepsparse.middlewares import PrintingMiddleware, SendStateMiddleware


class IntSchema(BaseModel):
value: int


class AddOneOperator(Operator):
input_schema = IntSchema
output_schema = IntSchema

def run(self, inp: IntSchema, **kwargs) -> Dict:
return {"value": inp.value + 1}


class AddTwoOperator(Operator):
input_schema = IntSchema
output_schema = IntSchema

def run(self, inp: IntSchema, **kwargs) -> Dict:
return {"value": inp.value + 2}


middlewares = [
MiddlewareSpec(PrintingMiddleware),
MiddlewareSpec(SendStateMiddleware),
]

middleware_manager = MiddlewareManager(middlewares)

AddThreePipeline = Pipeline(
ops=[AddOneOperator(), AddTwoOperator()],
router=LinearRouter(end_route=2),
schedulers=[OperatorScheduler()],
middleware_manager=middleware_manager,
)


def test_middleware_execution_in_pipeline_and_operator():
pipeline_input = IntSchema(value=5)
pipeline_output = AddThreePipeline(pipeline_input)

assert pipeline_output.value == 8

# SendStateMiddleware, order of calls:
# Pipeline start, AddOneOperator start, AddOneOperator end
# AddTwoOperator start, AddTwoOperator end, Pipeline_ end
expected_order = [0, 0, 1, 0, 1, 1]
state = AddThreePipeline.middleware_manager.state
assert state["SendStateMiddleware"] == expected_order
Loading