Skip to content

Commit

Permalink
[OpenPifPaf] ONNX Evaluation Pipeline (#915)
Browse files Browse the repository at this point in the history
* something odd about how images are being passed

* getting the same mAP!

* second round of refactoring

* ready to break it down into smaller PRs

* initial commit

* ready for testing

* remove torch model from the ported code

* solving some issues with logging

* ready for review

* Update src/deepsparse/open_pif_paf/utils/validation/cli.py

* prohibit openpifpaf fields in the server

---------

Co-authored-by: Konstantin <konstantin@neuralmagic.com>
  • Loading branch information
dbogunowicz and KSGulin committed Mar 2, 2023
1 parent 89abc5d commit ffbd351
Show file tree
Hide file tree
Showing 11 changed files with 615 additions and 4 deletions.
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ def _parse_requirements_file(file_path):
"opencv-python<=4.6.0.66",
]
_openpifpaf_integration_deps = [
"openpifpaf==0.13.6",
"openpifpaf==0.13.11",
"opencv-python<=4.6.0.66",
"pycocotools >=2.0.6",
"scipy==1.10.1",
]
_yolov8_integration_deps = _yolo_integration_deps + ["ultralytics==8.0.30"]

Expand Down Expand Up @@ -277,6 +279,7 @@ def _setup_entry_points() -> Dict:
"deepsparse.yolov8.annotate=deepsparse.yolov8.annotate:main",
"deepsparse.yolov8.eval=deepsparse.yolov8.validation:main",
"deepsparse.pose_estimation.annotate=deepsparse.open_pif_paf.annotate:main",
"deepsparse.pose_estimation.eval=deepsparse.open_pif_paf.validation:main",
"deepsparse.image_classification.annotate=deepsparse.image_classification.annotate:main", # noqa E501
"deepsparse.instance_segmentation.annotate=deepsparse.yolact.annotate:main",
f"deepsparse.image_classification.eval={ic_eval}",
Expand Down
62 changes: 59 additions & 3 deletions src/deepsparse/open_pif_paf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,70 @@ DeepSparse pipeline for OpenPifPaf
```python
from deepsparse import Pipeline

model_path: str = ... # path to open_pif_paf model
pipeline = Pipeline.create(task="open_pif_paf", batch_size=1, model_path=model_path)
model_path: str = ... # path to open_pif_paf model (SparseZoo stub or onnx model)
pipeline = Pipeline.create(task="open_pif_paf", model_path=model_path)
predictions = pipeline(images=['dancers.jpg'])
# predictions have attributes `data', 'keypoints', 'scores', 'skeletons'
predictions[0].scores
>> scores=[0.8542259724243828, 0.7930507659912109]
```
# predictions have attributes `data', 'keypoints', 'scores', 'skeletons'
### Output CifCaf fields
Alternatively, instead of returning the detected poses, it is possible to return the intermediate output - the CifCaf fields.
This is the representation returned directly by the neural network, but not yet processed by the matching algorithm

```python
...
pipeline = Pipeline.create(task="open_pif_paf", model_path=model_path, return_cifcaf_fields=True)
predictions = pipeline(images=['dancers.jpg'])
predictions.fields
```

## Validation script:
This paragraph describes how to run validation of the ONNX model/SparseZoo stub

### Dataset
For evaluation, you need to download the dataset. The [Open Pif Paf documentation](https://openpifpaf.github.io/) describes
thoroughly how to prepare different datasets for validation. This is the example for `crowdpose` dataset:

```bash
mkdir data-crowdpose
cd data-crowdpose
# download links here: https://github.com/Jeff-sjtu/CrowdPose
unzip annotations.zip
unzip images.zip
# Now you can use the standard openpifpaf.train and openpifpaf.eval
# commands as documented in Training with --dataset=crowdpose.
```
### Create an ONNX model:

```bash
python3 -m openpifpaf.export_onnx --input-width 641 --input-height 641
```

### Validation command
Once the dataset has been downloaded, run the command:
```bash
deepsparse.pose_estimation.eval --model-path openpifpaf-resnet50.onnx --dataset cocokp --image_size 641
```

This should result in the evaluation output similar to this:
```bash
...
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.502
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.732
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.523
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.429
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.605
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.534
Average Recall (AR) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.744
Average Recall (AR) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.554
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.457
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.643
...
````


### Expected output:

## The necessity of external OpenPifPaf helper function
<img width="678" alt="image" src="https://user-images.githubusercontent.com/97082108/203295520-42fa325f-8a94-4241-af6f-75938ef26b14.png">
Expand Down
2 changes: 2 additions & 0 deletions src/deepsparse/open_pif_paf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
from .utils import *
1 change: 1 addition & 0 deletions src/deepsparse/open_pif_paf/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.
# flake8: noqa
from .annotate import *
from .validation import *
16 changes: 16 additions & 0 deletions src/deepsparse/open_pif_paf/utils/validation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
from .cli import *
from .deepsparse_evaluator import *
94 changes: 94 additions & 0 deletions src/deepsparse/open_pif_paf/utils/validation/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging

import torch
from deepsparse.open_pif_paf.utils.validation.deepsparse_evaluator import (
DeepSparseEvaluator,
)
from deepsparse.open_pif_paf.utils.validation.deepsparse_predictor import (
DeepSparsePredictor,
)
from openpifpaf import __version__, datasets, decoder, logger, network, show, visualizer
from openpifpaf.eval import CustomFormatter


LOG = logging.getLogger(__name__)

__all__ = ["cli"]


# adapted from OPENPIFPAF GITHUB:
# https://github.com/openpifpaf/openpifpaf/blob/main/src/openpifpaf/eval.py
# the appropriate edits are marked with # deepsparse edit: <edit comment>
def cli():
parser = argparse.ArgumentParser(
prog="python3 -m openpifpaf.eval",
usage="%(prog)s [options]",
description=__doc__,
formatter_class=CustomFormatter,
)
parser.add_argument(
"--version",
action="version",
version="OpenPifPaf {version}".format(version=__version__),
)

datasets.cli(parser)
decoder.cli(parser)
logger.cli(parser)
network.Factory.cli(parser)
DeepSparsePredictor.cli(parser, skip_batch_size=True, skip_loader_workers=True)
show.cli(parser)
visualizer.cli(parser)
DeepSparseEvaluator.cli(parser)

parser.add_argument("--disable-cuda", action="store_true", help="disable CUDA")
parser.add_argument(
"--output", default=None, help="output filename without file extension"
)
parser.add_argument(
"--watch",
default=False,
const=60,
nargs="?",
type=int,
help=(
"Watch a directory for new checkpoint files. "
"Optionally specify the number of seconds between checks."
),
)

# deepsparse edit: replace the parse_args call with parse_known_args
args, _ = parser.parse_known_args()

# add args.device
args.device = torch.device("cpu")
args.pin_memory = False
if not args.disable_cuda and torch.cuda.is_available():
args.device = torch.device("cuda")
args.pin_memory = True
LOG.debug("neural network device: %s", args.device)

datasets.configure(args)
decoder.configure(args)
network.Factory.configure(args)
DeepSparsePredictor.configure(args)
show.configure(args)
visualizer.configure(args)
DeepSparseEvaluator.configure(args)

return args
75 changes: 75 additions & 0 deletions src/deepsparse/open_pif_paf/utils/validation/deepsparse_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time
from typing import List, Union

from deepsparse import Pipeline
from deepsparse.open_pif_paf.utils.validation.helpers import deepsparse_fields_to_torch
from openpifpaf.decoder import CifCaf
from openpifpaf.decoder.decoder import DummyPool


LOG = logging.getLogger(__name__)

__all__ = ["DeepSparseCifCaf"]


class DeepSparseCifCaf(CifCaf):
def __init__(
self,
head_metas: List[Union["Cif", "Caf"]], # noqa: F821
pipeline: Pipeline,
):
self.pipeline = pipeline
cif_metas, caf_metas = head_metas
super().__init__([cif_metas], [caf_metas])

# adapted from OPENPIFPAF GITHUB:
# https://github.com/openpifpaf/openpifpaf/blob/main/src/openpifpaf/decoder/decoder.py
# the appropriate edits are marked with # deepsparse edit: <edit comment>

# deepsparse edit: removed model argument (not needed, substituted with '_')
def batch(self, _, image_batch, *, device=None, gt_anns_batch=None):
"""From image batch straight to annotations batch."""
start_nn = time.perf_counter()
# deepsparse edit: inference using deepsparse pipeline
# instead of torch model
fields_batch = deepsparse_fields_to_torch(
self.pipeline(images=image_batch.numpy())
)
self.last_nn_time = time.perf_counter() - start_nn

if gt_anns_batch is None:
gt_anns_batch = [None for _ in fields_batch]

if not isinstance(self.worker_pool, DummyPool):
# remove debug_images to save time during pickle
image_batch = [None for _ in fields_batch]
gt_anns_batch = [None for _ in fields_batch]

LOG.debug("parallel execution with worker %s", self.worker_pool)
start_decoder = time.perf_counter()
result = self.worker_pool.starmap(
self._mappable_annotations, zip(fields_batch, image_batch, gt_anns_batch)
)
self.last_decoder_time = time.perf_counter() - start_decoder

LOG.debug(
"time: nn = %.1fms, dec = %.1fms",
self.last_nn_time * 1000.0,
self.last_decoder_time * 1000.0,
)
return result
Loading

0 comments on commit ffbd351

Please sign in to comment.