Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenPifPaf] ONNX Evaluation Pipeline #915

Merged
merged 18 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _parse_requirements_file(file_path):
"protobuf>=3.12.2,<=3.20.1",
"click>=7.1.2,!=8.0.0", # latest version < 8.0 + blocked version with reported bug
]
_nm_deps = [f"{'sparsezoo' if is_release else 'sparsezoo-nightly'}~={version_base}"]
_nm_deps = []#[f"{'sparsezoo' if is_release else 'sparsezoo-nightly'}~={version_base}"]
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
_dev_deps = [
"beautifulsoup4>=4.9.3",
"black==22.12.0",
Expand Down Expand Up @@ -133,7 +133,7 @@ def _parse_requirements_file(file_path):
"opencv-python<=4.6.0.66",
]
_openpifpaf_integration_deps = [
"openpifpaf==0.13.6",
"openpifpaf==0.13.11",
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
"opencv-python<=4.6.0.66",
]
# haystack dependencies are installed from a requirements file to avoid
Expand Down
10 changes: 10 additions & 0 deletions src/deepsparse/open_pif_paf/README_temp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
For training and evaluation, you need to download the dataset.

mkdir data-crowdpose
cd data-crowdpose
# download links here: https://github.com/Jeff-sjtu/CrowdPose
unzip annotations.zip
unzip images.zip
Now you can use the standard openpifpaf.train and openpifpaf.eval commands as documented in Training with --dataset=crowdpose.

install pycocotools
1 change: 1 addition & 0 deletions src/deepsparse/open_pif_paf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils import *
29 changes: 13 additions & 16 deletions src/deepsparse/open_pif_paf/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import cv2
import torch
from deepsparse.open_pif_paf.schemas import OpenPifPafInput, OpenPifPafOutput
from deepsparse.open_pif_paf.schemas import OpenPifPafInput, OpenPifPafOutput, OpenPifPafFields
from deepsparse.pipeline import Pipeline
from deepsparse.yolact.utils import preprocess_array
from openpifpaf import decoder, network
Expand Down Expand Up @@ -57,12 +57,10 @@ class OpenPifPafPipeline(Pipeline):
"""

def __init__(
self, *, image_size: Union[int, Tuple[int, int]] = (384, 384), **kwargs
self, *, output_fields = False, **kwargs
):
super().__init__(**kwargs)
self._image_size = (
image_size if isinstance(image_size, Tuple) else (image_size, image_size)
)
self.output_fields = output_fields
# necessary openpifpaf dependencies for now
model_cpu, _ = network.Factory().factory(head_metas=None)
self.processor = decoder.factory(model_cpu.head_metas)
Expand All @@ -79,7 +77,7 @@ def output_schema(self) -> Type[OpenPifPafOutput]:
"""
:return: pydantic model class that outputs to this pipeline must comply to
"""
return OpenPifPafOutput
return OpenPifPafOutput if not self.output_fields else OpenPifPafFields

def setup_onnx_file_path(self) -> str:
"""
Expand All @@ -93,16 +91,13 @@ class properties into an inference ready onnx file to be compiled by the

def process_inputs(self, inputs: OpenPifPafInput) -> List[numpy.ndarray]:

images = inputs.images

if not isinstance(images, list):
images = [images]

image_batch = list(self.executor.map(self._preprocess_image, images))

image_batch = numpy.concatenate(image_batch, axis=0)
image = inputs.images
image = image.astype(numpy.float32)
#image = image.transpose(0, 2, 3, 1)
#image /= 255
image = numpy.ascontiguousarray(image)

return [image_batch]
return [image]

def process_engine_outputs(
self, fields: List[numpy.ndarray], **kwargs
Expand All @@ -114,6 +109,8 @@ def process_engine_outputs(
:return: Outputs of engine post-processed into an object in the `output_schema`
format of this pipeline
"""
if self.output_fields:
return OpenPifPafFields(fields=fields)

data_batch, skeletons_batch, scores_batch, keypoints_batch = [], [], [], []

Expand All @@ -137,4 +134,4 @@ def _preprocess_image(self, image) -> numpy.ndarray:
if isinstance(image, str):
image = cv2.imread(image)

return preprocess_array(image, input_image_size=self._image_size)
return preprocess_array(image)
13 changes: 12 additions & 1 deletion src/deepsparse/open_pif_paf/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from typing import List, Tuple

import numpy
from pydantic import BaseModel, Field

from deepsparse.pipelines.computer_vision import ComputerVisionSchema
Expand All @@ -22,6 +22,7 @@
__all__ = [
"OpenPifPafInput",
"OpenPifPafOutput",
"OpenPifPafFields",
]


Expand All @@ -32,6 +33,16 @@ class OpenPifPafInput(ComputerVisionSchema):

pass

class OpenPifPafFields(BaseModel):
"""
# TODO
"""

fields: List[numpy.ndarray] = Field(description="")

class Config:
arbitrary_types_allowed = True


class OpenPifPafOutput(BaseModel):
"""
Expand Down
1 change: 1 addition & 0 deletions src/deepsparse/open_pif_paf/utils/validation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .deepsparse_evaluator import *
47 changes: 47 additions & 0 deletions src/deepsparse/open_pif_paf/utils/validation/deepsparse_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import time
import torch
import logging
from openpifpaf.decoder import CifCaf
from openpifpaf.decoder.decoder import DummyPool
from typing import List
from deepsparse import Pipeline

from deepsparse.open_pif_paf.utils.validation.helpers import deepsparse_fields_to_torch

LOG = logging.getLogger(__name__)

class DeepSparseCifCaf(CifCaf):
def __init__(self, pipeline: Pipeline, head_metas: List[None]):
self.pipeline = pipeline
cif_metas, caf_metas = head_metas
super().__init__([cif_metas], [caf_metas])

# adapted from OPENPIFPAF GITHUB:
# https://github.com/openpifpaf/openpifpaf/blob/main/src/openpifpaf/decoder/decoder.py
# the appropriate edits are marked with # deepsparse edit: <edit comment>
def batch(self, model, image_batch, *, device=None, gt_anns_batch=None):
"""From image batch straight to annotations batch."""
start_nn = time.perf_counter()
fields_batch = self.fields_batch(model, image_batch, device=device)
fields_batch = deepsparse_fields_to_torch(self.pipeline(images=image_batch.numpy()))
self.last_nn_time = time.perf_counter() - start_nn

if gt_anns_batch is None:
gt_anns_batch = [None for _ in fields_batch]

if not isinstance(self.worker_pool, DummyPool):
# remove debug_images to save time during pickle
image_batch = [None for _ in fields_batch]
gt_anns_batch = [None for _ in fields_batch]

LOG.debug('parallel execution with worker %s', self.worker_pool)
start_decoder = time.perf_counter()
result = self.worker_pool.starmap(
self._mappable_annotations, zip(fields_batch, image_batch, gt_anns_batch))
self.last_decoder_time = time.perf_counter() - start_decoder

LOG.debug('time: nn = %.1fms, dec = %.1fms',
self.last_nn_time * 1000.0,
self.last_decoder_time * 1000.0)
return result

100 changes: 100 additions & 0 deletions src/deepsparse/open_pif_paf/utils/validation/deepsparse_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from collections import defaultdict
import json
import logging
import os
import sys
import typing as t
from openpifpaf.eval import Evaluator, count_ops
from openpifpaf import network, __version__

from deepsparse.open_pif_paf.utils.validation.helpers import apply_deepsparse_preprocessing
from deepsparse.open_pif_paf.utils.validation.deepsparse_predictor import DeepSparsePredictor


from deepsparse import Pipeline


LOG = logging.getLogger(__name__)

__all__ = ["DeepSparseEvaluator"]

# adapted from OPENPIFPAF GITHUB:
# https://github.com/openpifpaf/openpifpaf/blob/main/src/openpifpaf/eval.py
# the appropriate edits are marked with # deepsparse edit: <edit comment>
class DeepSparseEvaluator(Evaluator):
# deepsparse edit: allow for passing in a pipeline
def __init__(self, pipeline: Pipeline, img_size: int, **kwargs):
self.pipeline = pipeline
super().__init__(**kwargs)
# deepsparse edit: required to enforce square images
apply_deepsparse_preprocessing(self.data_loader, img_size)

def evaluate(self, output: t.Optional[str]):
# generate a default output filename
if output is None:
assert self.args is not None
output = self.default_output_name(self.args)

# skip existing?
if self.skip_epoch0:
assert network.Factory.checkpoint is not None
if network.Factory.checkpoint.endswith('.epoch000'):
print('Not evaluating epoch 0.')
return
if self.skip_existing:
stats_file = output + '.stats.json'
if os.path.exists(stats_file):
print('Output file {} exists already. Exiting.'.format(stats_file))
return
print('{} not found. Processing: {}'.format(stats_file, network.Factory.checkpoint))

# deepsparse edit: allow for passing in a pipeline
predictor = DeepSparsePredictor(pipeline = self.pipeline, head_metas=self.datamodule.head_metas)
metrics = self.datamodule.metrics()

total_time = self.accumulate(predictor, metrics)

# model stats
# deepsparse edit: compute stats for ONNX file not torch model
counted_ops = list(count_ops(predictor.model_cpu))
file_size = -1.0 # TODO get file size

# write
additional_data = {
'args': sys.argv,
'version': __version__,
'dataset': self.dataset_name,
'total_time': total_time,
'checkpoint': network.Factory.checkpoint,
'count_ops': counted_ops,
'file_size': file_size,
'n_images': predictor.total_images,
'decoder_time': predictor.total_decoder_time,
'nn_time': predictor.total_nn_time,
}

metric_stats = defaultdict(list)
for metric in metrics:
if self.write_predictions:
metric.write_predictions(output, additional_data=additional_data)

this_metric_stats = metric.stats()
assert (len(this_metric_stats.get('text_labels', []))
== len(this_metric_stats.get('stats', [])))

for k, v in this_metric_stats.items():
metric_stats[k] = metric_stats[k] + v

stats = dict(**metric_stats, **additional_data)

# write stats file
with open(output + '.stats.json', 'w') as f:
json.dump(stats, f)

LOG.info('stats:\n%s', json.dumps(stats, indent=4))
LOG.info(
'time per image: decoder = %.0fms, nn = %.0fms, total = %.0fms',
1000 * stats['decoder_time'] / stats['n_images'],
1000 * stats['nn_time'] / stats['n_images'],
1000 * stats['total_time'] / stats['n_images'],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import logging
from openpifpaf import Predictor
from deepsparse import Pipeline
from deepsparse.open_pif_paf.utils.validation.deepsparse_decoder import DeepSparseCifCaf

LOG = logging.getLogger(__name__)

# adapted from OPENPIFPAF GITHUB:
# https://github.com/openpifpaf/openpifpaf/blob/main/src/openpifpaf/predictor.py
# the appropriate edits are marked with # deepsparse edit: <edit comment>
class DeepSparsePredictor(Predictor):
"""Convenience class to predict from various inputs with a common configuration."""
# deepsparse edit: allow for passing in a pipeline
def __init__(self, pipeline: Pipeline, **kwargs):
super().__init__(**kwargs)
# deepsparse edit: allow for passing in a pipeline and fix the processor
# to CifCaf processor
self.processor = DeepSparseCifCaf(pipeline = pipeline, head_metas = self.model_cpu.head_metas)


15 changes: 15 additions & 0 deletions src/deepsparse/open_pif_paf/utils/validation/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
from openpifpaf import transforms
import numpy

__all__ = ["apply_deepsparse_preprocessing", "deepsparse_fields_to_torch"]

def deepsparse_fields_to_torch(fields_batch, device='cpu'):
result = []
fields = fields_batch.fields
for idx, (cif, caf) in enumerate(zip(*fields)):
result.append([torch.from_numpy(cif).to(device), torch.from_numpy(caf).to(device)])
return

def apply_deepsparse_preprocessing(data_loader: torch.utils.data.DataLoader, img_size: int) -> torch.utils.data.DataLoader:
data_loader.dataset.preprocess.preprocess_list[2] = transforms.CenterPad(img_size)
21 changes: 21 additions & 0 deletions src/deepsparse/open_pif_paf/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Evaluation on COCO data."""
from openpifpaf.eval import cli
from deepsparse.open_pif_paf.utils.validation.deepsparse_evaluator import DeepSparseEvaluator
from deepsparse import Pipeline
def main():
args = cli()
args.dataset = "cocokp"
args.output = "funny"
args.decoder = ["cifcaf"]
args.checkpoint = "shufflenetv2k16"
pipeline = Pipeline.create(task="open_pif_paf", model_path="openpifpaf-resnet50.onnx", output_fields=True)
evaluator = DeepSparseEvaluator(pipeline = pipeline, dataset_name=args.dataset, skip_epoch0=False, img_size= args.coco_eval_long_edge)
if args.watch:
assert args.output is None
evaluator.watch(args.checkpoint, args.watch)
else:
evaluator.evaluate(args.output)


if __name__ == '__main__':
main()
8 changes: 0 additions & 8 deletions src/deepsparse/yolact/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,6 @@ def preprocess_array(
"""
image = image.astype(numpy.float32)
image = _assert_channels_last(image)
if image.ndim == 4 and image.shape[:2] != input_image_size:
image = numpy.stack([cv2.resize(img, input_image_size) for img in image])

else:
if image.shape[:2] != input_image_size:
image = cv2.resize(image, input_image_size)
image = numpy.expand_dims(image, 0)

image = image.transpose(0, 3, 1, 2)
image /= 255
image = numpy.ascontiguousarray(image)
Expand Down