Skip to content

Commit

Permalink
postprocessing_kwargs -> context
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Aug 2, 2023
1 parent 9b916f3 commit 72234a5
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,9 @@ def __call__(self, *args, **kwargs) -> BaseModel:
# batch size of the inputs may be `> self._batch_size` at this point
engine_inputs: List[numpy.ndarray] = self.process_inputs(pipeline_inputs)
if isinstance(engine_inputs, tuple):
engine_inputs, postprocess_kwargs = engine_inputs
engine_inputs, context = engine_inputs
else:
postprocess_kwargs = {}
context = {}

timer.stop(InferenceStages.PRE_PROCESS)
self.log(
Expand All @@ -248,9 +248,7 @@ def __call__(self, *args, **kwargs) -> BaseModel:
)

# submit split batches to engine threadpool
engine_forward_with_context = partial(
self.engine_forward, context=postprocess_kwargs
)
engine_forward_with_context = partial(self.engine_forward, context=context)
batch_outputs = list(
self.executor.map(engine_forward_with_context, batches)
)
Expand All @@ -276,9 +274,7 @@ def __call__(self, *args, **kwargs) -> BaseModel:

# ------ POSTPROCESSING ------
timer.start(InferenceStages.POST_PROCESS)
pipeline_outputs = self.process_engine_outputs(
engine_outputs, **postprocess_kwargs
)
pipeline_outputs = self.process_engine_outputs(engine_outputs, **context)
if not isinstance(pipeline_outputs, self.output_schema):
raise ValueError(
f"Outputs of {self.__class__} must be instances of "
Expand Down

0 comments on commit 72234a5

Please sign in to comment.