Skip to content

Commit

Permalink
Feature/sg 000 fix predict in pose estimation (#1358)
Browse files Browse the repository at this point in the history
* Update readme

* Fix small bug in __repr__ implementation of KeypointsImageToTensor

* Test

* Test

* Test

* Test

* Test

* Test

* Make graphsurgeon an optional

* Make graphsurgeon an optional

* Properly handle imports of optional packages

* Added empty __init__.py files

* Do imports of gs inside the export call

* Do imports of gs inside the export call

* Fix DEKR's missing HasPredict interface

* Update notebook & example doc to reflect changes in imports & function names

* Update readme

* Put back images
  • Loading branch information
BloodAxe committed Aug 10, 2023
1 parent 24a708f commit b9d3e75
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from super_gradients.common.registry.registry import register_model
from super_gradients.common.object_names import Models
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.module_interfaces import HasPredict
from super_gradients.training.utils.predict import ImagesPoseEstimationPrediction
from super_gradients.training.models.sg_module import SgModule
from super_gradients.training.models.arch_params_factory import get_arch_params
Expand Down Expand Up @@ -294,7 +295,7 @@ def forward(self, x):


@register_model(Models.DEKR_CUSTOM)
class DEKRPoseEstimationModel(SgModule):
class DEKRPoseEstimationModel(SgModule, HasPredict):
"""
Implementation of HRNet model from DEKR paper (https://arxiv.org/abs/2104.02300).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_equivalent_preprocessing(self) -> List:
]

def __repr__(self):
return self.__class__.__name__ + f"(permutation={self.permutation})"
return self.__class__.__name__ + "()"


@register_transform(Transforms.KeypointsImageStandardize)
Expand Down
16 changes: 16 additions & 0 deletions tests/unit_tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,22 @@ def test_classification_models(self):
predictions.show()
predictions.save(output_folder=tmp_dirname)

def test_pose_estimation_models(self):
model = models.get(Models.DEKR_W32_NO_DC, pretrained_weights="coco_pose")

with tempfile.TemporaryDirectory() as tmp_dirname:
predictions = model.predict(self.images)
predictions.show()
predictions.save(output_folder=tmp_dirname)

def test_detection_models(self):
model = models.get(Models.YOLO_NAS_S, pretrained_weights="coco")

with tempfile.TemporaryDirectory() as tmp_dirname:
predictions = model.predict(self.images)
predictions.show()
predictions.save(output_folder=tmp_dirname)


if __name__ == "__main__":
unittest.main()

0 comments on commit b9d3e75

Please sign in to comment.