Skip to content

Commit

Permalink
second round of refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin committed Feb 20, 2023
1 parent 908f200 commit c175205
Show file tree
Hide file tree
Showing 11 changed files with 377 additions and 90 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def _parse_requirements_file(file_path):
"protobuf>=3.12.2,<=3.20.1",
"click>=7.1.2,!=8.0.0", # latest version < 8.0 + blocked version with reported bug
]
_nm_deps = []#[f"{'sparsezoo' if is_release else 'sparsezoo-nightly'}~={version_base}"]
_nm_deps = (
[]
) # [f"{'sparsezoo' if is_release else 'sparsezoo-nightly'}~={version_base}"]
_dev_deps = [
"beautifulsoup4>=4.9.3",
"black==22.12.0",
Expand Down
2 changes: 1 addition & 1 deletion src/deepsparse/open_pif_paf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils import *
from .utils import *
43 changes: 30 additions & 13 deletions src/deepsparse/open_pif_paf/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

import cv2
import torch
from deepsparse.open_pif_paf.schemas import OpenPifPafInput, OpenPifPafOutput, OpenPifPafFields
from deepsparse.open_pif_paf.schemas import (
OpenPifPafFields,
OpenPifPafInput,
OpenPifPafOutput,
)
from deepsparse.pipeline import Pipeline
from deepsparse.yolact.utils import preprocess_array
from openpifpaf import decoder, network
Expand Down Expand Up @@ -57,10 +61,19 @@ class OpenPifPafPipeline(Pipeline):
"""

def __init__(
self, *, output_fields = False, **kwargs
self,
*,
image_size: Union[int, Tuple[int, int]] = (384, 384),
return_cifcaf_fields: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.output_fields = output_fields
self._image_size = (
image_size if isinstance(image_size, Tuple) else (image_size, image_size)
)
# whether to return the cif and caf fields or the
# complete decoded output
self.return_cifcaf_fields = return_cifcaf_fields
# necessary openpifpaf dependencies for now
model_cpu, _ = network.Factory().factory(head_metas=None)
self.processor = decoder.factory(model_cpu.head_metas)
Expand All @@ -77,7 +90,7 @@ def output_schema(self) -> Type[OpenPifPafOutput]:
"""
:return: pydantic model class that outputs to this pipeline must comply to
"""
return OpenPifPafOutput if not self.output_fields else OpenPifPafFields
return OpenPifPafOutput if not self.return_cifcaf_fields else OpenPifPafFields

def setup_onnx_file_path(self) -> str:
"""
Expand All @@ -91,13 +104,14 @@ class properties into an inference ready onnx file to be compiled by the

def process_inputs(self, inputs: OpenPifPafInput) -> List[numpy.ndarray]:

image = inputs.images
image = image.astype(numpy.float32)
#image = image.transpose(0, 2, 3, 1)
#image /= 255
image = numpy.ascontiguousarray(image)
images = inputs.images
if not isinstance(images, list):
images = [images]

image_batch = list(self.executor.map(self._preprocess_image, images))
image_batch = numpy.concatenate(image_batch, axis=0)

return [image]
return [image_batch]

def process_engine_outputs(
self, fields: List[numpy.ndarray], **kwargs
Expand All @@ -109,8 +123,11 @@ def process_engine_outputs(
:return: Outputs of engine post-processed into an object in the `output_schema`
format of this pipeline
"""
if self.output_fields:
return OpenPifPafFields(fields=fields)
if self.return_cifcaf_fields:
batch_fields = []
for cif, caf in zip(*fields):
batch_fields.append([cif, caf])
return OpenPifPafFields(fields=batch_fields)

data_batch, skeletons_batch, scores_batch, keypoints_batch = [], [], [], []

Expand All @@ -134,4 +151,4 @@ def _preprocess_image(self, image) -> numpy.ndarray:
if isinstance(image, str):
image = cv2.imread(image)

return preprocess_array(image)
return preprocess_array(image, input_image_size=self._image_size)
15 changes: 13 additions & 2 deletions src/deepsparse/open_pif_paf/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import List, Tuple

import numpy
from pydantic import BaseModel, Field

Expand All @@ -33,12 +34,22 @@ class OpenPifPafInput(ComputerVisionSchema):

pass


class OpenPifPafFields(BaseModel):
"""
# TODO
Open Pif Paf is composed of two stages:
- Computing Cif/Caf fields using a parametrized model
- Applying a matching algorithm to obtain the final pose
predictions
In some cases (e.g. for validation), it may be useful to
obtain the Cif/Caf fields as output.
"""

fields: List[numpy.ndarray] = Field(description="")
fields: List[List[numpy.ndarray]] = Field(
description="Cif/Caf fields returned by the network. "
"The outer list is the batch dimension, while the second "
"list contains two numpy arrays: Cif and Caf field values"
)

class Config:
arbitrary_types_allowed = True
Expand Down
14 changes: 14 additions & 0 deletions src/deepsparse/open_pif_paf/utils/validation/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .deepsparse_evaluator import *
49 changes: 36 additions & 13 deletions src/deepsparse/open_pif_paf/utils/validation/deepsparse_decoder.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
import time
import torch
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time
from typing import List, Union

from deepsparse import Pipeline
from deepsparse.open_pif_paf.utils.validation.helpers import deepsparse_fields_to_torch
from openpifpaf.decoder import CifCaf
from openpifpaf.decoder.decoder import DummyPool
from typing import List
from deepsparse import Pipeline

from deepsparse.open_pif_paf.utils.validation.helpers import deepsparse_fields_to_torch

LOG = logging.getLogger(__name__)


class DeepSparseCifCaf(CifCaf):
def __init__(self, pipeline: Pipeline, head_metas: List[None]):
def __init__(
self,
head_metas: List[Union["Cif", "Caf"]], # noqa: F821
pipeline: Pipeline,
):
self.pipeline = pipeline
cif_metas, caf_metas = head_metas
super().__init__([cif_metas], [caf_metas])
Expand All @@ -23,7 +42,9 @@ def batch(self, model, image_batch, *, device=None, gt_anns_batch=None):
"""From image batch straight to annotations batch."""
start_nn = time.perf_counter()
fields_batch = self.fields_batch(model, image_batch, device=device)
fields_batch = deepsparse_fields_to_torch(self.pipeline(images=image_batch.numpy()))
fields_batch = deepsparse_fields_to_torch(
self.pipeline(images=image_batch.numpy())
)
self.last_nn_time = time.perf_counter() - start_nn

if gt_anns_batch is None:
Expand All @@ -34,14 +55,16 @@ def batch(self, model, image_batch, *, device=None, gt_anns_batch=None):
image_batch = [None for _ in fields_batch]
gt_anns_batch = [None for _ in fields_batch]

LOG.debug('parallel execution with worker %s', self.worker_pool)
LOG.debug("parallel execution with worker %s", self.worker_pool)
start_decoder = time.perf_counter()
result = self.worker_pool.starmap(
self._mappable_annotations, zip(fields_batch, image_batch, gt_anns_batch))
self._mappable_annotations, zip(fields_batch, image_batch, gt_anns_batch)
)
self.last_decoder_time = time.perf_counter() - start_decoder

LOG.debug('time: nn = %.1fms, dec = %.1fms',
self.last_nn_time * 1000.0,
self.last_decoder_time * 1000.0)
LOG.debug(
"time: nn = %.1fms, dec = %.1fms",
self.last_nn_time * 1000.0,
self.last_decoder_time * 1000.0,
)
return result

Original file line number Diff line number Diff line change
@@ -1,23 +1,40 @@
from collections import defaultdict
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
import sys
import typing as t
from openpifpaf.eval import Evaluator, count_ops
from openpifpaf import network, __version__

from deepsparse.open_pif_paf.utils.validation.helpers import apply_deepsparse_preprocessing
from deepsparse.open_pif_paf.utils.validation.deepsparse_predictor import DeepSparsePredictor

from collections import defaultdict

from deepsparse import Pipeline
from deepsparse.open_pif_paf.utils.validation.deepsparse_predictor import (
DeepSparsePredictor,
)
from deepsparse.open_pif_paf.utils.validation.helpers import (
apply_deepsparse_preprocessing,
)
from openpifpaf import __version__, network
from openpifpaf.eval import Evaluator


LOG = logging.getLogger(__name__)

__all__ = ["DeepSparseEvaluator"]


# adapted from OPENPIFPAF GITHUB:
# https://github.com/openpifpaf/openpifpaf/blob/main/src/openpifpaf/eval.py
# the appropriate edits are marked with # deepsparse edit: <edit comment>
Expand All @@ -38,39 +55,41 @@ def evaluate(self, output: t.Optional[str]):
# skip existing?
if self.skip_epoch0:
assert network.Factory.checkpoint is not None
if network.Factory.checkpoint.endswith('.epoch000'):
print('Not evaluating epoch 0.')
if network.Factory.checkpoint.endswith(".epoch000"):
print("Not evaluating epoch 0.")
return
if self.skip_existing:
stats_file = output + '.stats.json'
stats_file = output + ".stats.json"
if os.path.exists(stats_file):
print('Output file {} exists already. Exiting.'.format(stats_file))
print("Output file {} exists already. Exiting.".format(stats_file))
return
print('{} not found. Processing: {}'.format(stats_file, network.Factory.checkpoint))
print(
"{} not found. Processing: {}".format(
stats_file, network.Factory.checkpoint
)
)

# deepsparse edit: allow for passing in a pipeline
predictor = DeepSparsePredictor(pipeline = self.pipeline, head_metas=self.datamodule.head_metas)
predictor = DeepSparsePredictor(
pipeline=self.pipeline, head_metas=self.datamodule.head_metas
)
metrics = self.datamodule.metrics()

total_time = self.accumulate(predictor, metrics)

# model stats
# deepsparse edit: compute stats for ONNX file not torch model
counted_ops = list(count_ops(predictor.model_cpu))
file_size = -1.0 # TODO get file size
# deepsparse edit: removed model stats that are
# only applicable to torch models

# write
additional_data = {
'args': sys.argv,
'version': __version__,
'dataset': self.dataset_name,
'total_time': total_time,
'checkpoint': network.Factory.checkpoint,
'count_ops': counted_ops,
'file_size': file_size,
'n_images': predictor.total_images,
'decoder_time': predictor.total_decoder_time,
'nn_time': predictor.total_nn_time,
"args": sys.argv,
"version": __version__,
"dataset": self.dataset_name,
"total_time": total_time,
"n_images": predictor.total_images,
"decoder_time": predictor.total_decoder_time,
"nn_time": predictor.total_nn_time,
}

metric_stats = defaultdict(list)
Expand All @@ -79,22 +98,23 @@ def evaluate(self, output: t.Optional[str]):
metric.write_predictions(output, additional_data=additional_data)

this_metric_stats = metric.stats()
assert (len(this_metric_stats.get('text_labels', []))
== len(this_metric_stats.get('stats', [])))
assert len(this_metric_stats.get("text_labels", [])) == len(
this_metric_stats.get("stats", [])
)

for k, v in this_metric_stats.items():
metric_stats[k] = metric_stats[k] + v

stats = dict(**metric_stats, **additional_data)

# write stats file
with open(output + '.stats.json', 'w') as f:
with open(output + ".stats.json", "w") as f:
json.dump(stats, f)

LOG.info('stats:\n%s', json.dumps(stats, indent=4))
LOG.info("stats:\n%s", json.dumps(stats, indent=4))
LOG.info(
'time per image: decoder = %.0fms, nn = %.0fms, total = %.0fms',
1000 * stats['decoder_time'] / stats['n_images'],
1000 * stats['nn_time'] / stats['n_images'],
1000 * stats['total_time'] / stats['n_images'],
"time per image: decoder = %.0fms, nn = %.0fms, total = %.0fms",
1000 * stats["decoder_time"] / stats["n_images"],
1000 * stats["nn_time"] / stats["n_images"],
1000 * stats["total_time"] / stats["n_images"],
)
Loading

0 comments on commit c175205

Please sign in to comment.