Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Squashed changes with YoloNASPose & Loss #1512

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
238e41c
Introduce sample-centric keypoint transforms
BloodAxe Oct 2, 2023
efeb4ef
Cleanup leftovers
BloodAxe Oct 2, 2023
fa1f79d
Fixed way of checking transforms that require additional samples
BloodAxe Oct 3, 2023
88cef73
Docstrings
BloodAxe Oct 3, 2023
a29f58a
:attr -> :param
BloodAxe Oct 3, 2023
c797100
Added docs clarifying behavior of mosaic & mixup
BloodAxe Oct 3, 2023
7b1b9ee
Added docs clarifying behavior of mosaic & mixup
BloodAxe Oct 3, 2023
0d00668
Improved tests
BloodAxe Oct 3, 2023
a445917
Merge branch 'master' into feature/SG-1060-yolo-nas-pose-release
BloodAxe Oct 3, 2023
f388fcc
Additional docstrings & typing annotations
BloodAxe Oct 3, 2023
5b24d06
Merge remote-tracking branch 'origin/feature/SG-1060-yolo-nas-pose-re…
BloodAxe Oct 3, 2023
a7c97ce
Added missing additional_samples_count field
BloodAxe Oct 4, 2023
b84817d
Merge branch 'master' into feature/SG-1060-yolo-nas-pose-release
BloodAxe Oct 4, 2023
4a21438
KeypointsRemoveSmallObjects
BloodAxe Oct 4, 2023
53ecb3e
KeypointsRemoveSmallObjects
BloodAxe Oct 4, 2023
577804c
Merge branch 'master' into feature/SG-1060-yolo-nas-pose-release
BloodAxe Oct 5, 2023
fd41b71
Feature/sg 1060 yolo nas pose release pr to add datasets and metric (…
BloodAxe Oct 9, 2023
81091b7
Merge branch 'master' into feature/SG-1060-yolo-nas-pose-release
BloodAxe Oct 9, 2023
a1284be
Squashed changes with YoloNASPose & Loss
BloodAxe Oct 9, 2023
2d01377
Merge branch 'master' into feature/SG-1060-yolo-nas-pose-release-add-…
BloodAxe Oct 9, 2023
3ed26ff
Remove print statement
BloodAxe Oct 10, 2023
359dcf6
Fixed attribute name that was not renamed
BloodAxe Oct 10, 2023
928da84
Improve docstrings to use 'Num Keypoints' instead of magic number 17
BloodAxe Oct 10, 2023
0e540e5
Fixed PoseNMS export to work with custom number of keypoints
BloodAxe Oct 10, 2023
8dc7647
Added docstrings
BloodAxe Oct 10, 2023
0e23800
Simplify forward/forward_eval
BloodAxe Oct 10, 2023
15cef48
Simplify forward/forward_eval
BloodAxe Oct 10, 2023
295b0c7
_insert_heads_list_params
BloodAxe Oct 10, 2023
0474536
Merge master
shaydeci Oct 11, 2023
94a15fc
Added tests
BloodAxe Oct 11, 2023
74ea314
Refactor the way we generate usage instructions. Should be easier to …
BloodAxe Oct 11, 2023
d43cf2b
Merge branch 'master' into feature/SG-1060-yolo-nas-pose-release-add-…
BloodAxe Oct 12, 2023
408b5d0
Improve docstrings
BloodAxe Oct 12, 2023
e8b3f18
Improved docstrings
BloodAxe Oct 12, 2023
b2ff5e9
Improved docstrings
BloodAxe Oct 12, 2023
90bca43
Improved docstrings
BloodAxe Oct 12, 2023
84aa12b
Improved docstrings
BloodAxe Oct 12, 2023
fe00d56
Rename bboxes -> bboxes_xyxy
BloodAxe Oct 12, 2023
632d62d
Fixed instructions text
BloodAxe Oct 12, 2023
11d3eb2
Improve efficiency of training
BloodAxe Oct 12, 2023
91dd113
Merge branch 'master' into feature/SG-1060-yolo-nas-pose-release-add-…
BloodAxe Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/super_gradients/common/object_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Losses:
DICE_CE_EDGE_LOSS = "DiceCEEdgeLoss"
DEKR_LOSS = "DEKRLoss"
RESCORING_LOSS = "RescoringLoss"
YOLONAS_POSE_LOSS = "YoloNASPoseLoss"


class Metrics:
Expand Down Expand Up @@ -314,6 +315,11 @@ class Models:
POSE_RESCORING = "pose_rescoring_custom"
POSE_RESCORING_COCO = "pose_rescoring_coco"

YOLO_NAS_POSE_N = "yolo_nas_pose_n"
YOLO_NAS_POSE_S = "yolo_nas_pose_s"
YOLO_NAS_POSE_M = "yolo_nas_pose_m"
YOLO_NAS_POSE_L = "yolo_nas_pose_l"


class ConcatenatedTensorFormats:
XYXY_LABEL = "XYXY_LABEL"
Expand Down
380 changes: 380 additions & 0 deletions src/super_gradients/conversion/onnx/pose_nms.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion src/super_gradients/module_interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule, ModelHasNoPreprocessingParamsException
from .exceptions import ModelHasNoPreprocessingParamsException
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule
from .exportable_pose_estimation import ExportablePoseEstimationModel, PoseEstimationModelExportResult, AbstractPoseEstimationDecodingModule
from .pose_estimation_post_prediction_callback import AbstractPoseEstimationPostPredictionCallback, PoseEstimationPredictions

__all__ = [
Expand All @@ -11,4 +13,7 @@
"ModelHasNoPreprocessingParamsException",
"AbstractPoseEstimationPostPredictionCallback",
"PoseEstimationPredictions",
"ExportablePoseEstimationModel",
"PoseEstimationModelExportResult",
"AbstractPoseEstimationDecodingModule",
]
6 changes: 6 additions & 0 deletions src/super_gradients/module_interfaces/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class ModelHasNoPreprocessingParamsException(Exception):
"""
Exception that is raised when model does not have preprocessing parameters.
"""

pass
552 changes: 552 additions & 0 deletions src/super_gradients/module_interfaces/exportable_pose_estimation.py

Large diffs are not rendered by default.

155 changes: 155 additions & 0 deletions src/super_gradients/module_interfaces/usage_instructions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import numpy as np
from torch import nn

from super_gradients.conversion.conversion_utils import torch_dtype_to_numpy_dtype
from super_gradients.conversion.conversion_enums import ExportTargetBackend, ExportQuantizationMode, DetectionOutputFormatMode


def build_preprocessing_hint_text(preprocessing_module: nn.Module) -> str:
module_repr = repr(preprocessing_module)
return f"""
Exported model already contains preprocessing (normalization) step, so you don't need to do it manually.
Preprocessing steps to be applied to input image are:
{module_repr}
"""


def build_postprocessing_hint_text(num_pre_nms_predictions, max_predictions_per_image, nms_threshold, confidence_threshold, output_predictions_format) -> str:
return f"""
Exported model contains postprocessing (NMS) step with the following parameters:
num_pre_nms_predictions={num_pre_nms_predictions}
max_predictions_per_image={max_predictions_per_image}
nms_threshold={nms_threshold}
confidence_threshold={confidence_threshold}
output_predictions_format={output_predictions_format}

"""


def build_usage_instructions_for_pose_estimation(
*,
output,
batch_size,
input_image_channels,
input_image_shape,
input_image_dtype,
preprocessing,
preprocessing_module,
postprocessing,
postprocessing_module,
num_pre_nms_predictions,
max_predictions_per_image,
confidence_threshold,
nms_threshold,
output_predictions_format,
engine: ExportTargetBackend,
quantization_mode: ExportQuantizationMode,
) -> str:
# Add usage instructions
usage_instructions = f"""
Model exported successfully to {output}
Model expects input image of shape [{batch_size}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]}]
Input image dtype is {input_image_dtype}"""

if preprocessing:
preprocessing_hint_text = build_preprocessing_hint_text(preprocessing_module)
usage_instructions += f"\n{preprocessing_hint_text}"

if postprocessing:
postprocessing_hint_text = build_postprocessing_hint_text(
num_pre_nms_predictions, max_predictions_per_image, nms_threshold, confidence_threshold, output_predictions_format
)
usage_instructions += f"\n{postprocessing_hint_text}"

if engine in (ExportTargetBackend.ONNXRUNTIME, ExportTargetBackend.TENSORRT):
dtype_name = np.dtype(torch_dtype_to_numpy_dtype(input_image_dtype)).name

usage_instructions += f"""
Exported model is in ONNX format and can be used with ONNXRuntime
To run inference with ONNXRuntime, please use the following code snippet:

import onnxruntime
import numpy as np
session = onnxruntime.InferenceSession("{output}", providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
inputs = [o.name for o in session.get_inputs()]
outputs = [o.name for o in session.get_outputs()]

example_input_image = np.zeros(({batch_size}, {input_image_channels}, {input_image_shape[0]}, {input_image_shape[1]})).astype(np.{dtype_name})
predictions = session.run(outputs, {{inputs[0]: example_input_image}})

Exported model can also be used with TensorRT
To run inference with TensorRT, please see TensorRT deployment documentation
You can benchmark the model using the following code snippet:

trtexec --onnx={output} {'--int8' if quantization_mode == ExportQuantizationMode.INT8 else '--fp16'} --avgRuns=100 --duration=15

"""

if postprocessing is True:
if output_predictions_format == DetectionOutputFormatMode.FLAT_FORMAT:
usage_instructions += f"""
Exported model has predictions in {output_predictions_format} format:

# flat_predictions is a 2D array of [N,K] shape
# Each row represents (image_index, x_min, y_min, x_max, y_max, confidence, joints...)
# Please note all values are floats, so you have to convert them to integers if needed

[flat_predictions] = predictions"""

if batch_size == 1:
usage_instructions += """
pred_bboxes = flat_predictions[:, 1:5]
pred_scores = flat_predictions[:, 5]
pred_joints = flat_predictions[:, 6:].reshape((len(pred_bboxes), -1, 3))
for i in range(len(pred_bboxes)):
confidence = pred_scores[i]
x_min, y_min, x_max, y_max = pred_bboxes[i]
print(f"Detected pose with confidence={{confidence}}, x_min={{x_min}}, y_min={{y_min}}, x_max={{x_max}}, y_max={{y_max}}")
for joint_index, (x, y, confidence) in enumerate(pred_joints[i]):")
print(f"Joint {{joint_index}} has coordinates x={{x}}, y={{y}}, confidence={{confidence}}")

"""

else:
usage_instructions += f"""
for current_sample in range({batch_size}):
predictions_for_current_sample = predictions[predictions[0] == current_sample]
print("Predictions for sample " + str(current_sample))
pred_bboxes = predictions_for_current_sample[:, 1:5]
pred_scores = predictions_for_current_sample[:, 5]
pred_joints = predictions_for_current_sample[:, 6:].reshape((len(pred_bboxes), -1, 3))
for i in range(len(pred_bboxes)):
confidence = pred_scores[i]
x_min, y_min, x_max, y_max = pred_bboxes[i]
print(f"Detected pose with confidence={{confidence}}, x_min={{x_min}}, y_min={{y_min}}, x_max={{x_max}}, y_max={{y_max}}")
for joint_index, (x, y, confidence) in enumerate(pred_joints[i]):
print(f"Joint {{joint_index}} has coordinates x={{x}}, y={{y}}, confidence={{confidence}}")

"""

elif output_predictions_format == DetectionOutputFormatMode.BATCH_FORMAT:
# fmt: off
usage_instructions += f"""Exported model has predictions in {output_predictions_format} format:

num_detections, pred_boxes, pred_scores, pred_joints = predictions
for image_index in range(num_detections.shape[0]):
for i in range(num_detections[image_index,0]):
confidence = pred_scores[image_index, i]
x_min, y_min, x_max, y_max = pred_boxes[image_index, i]
pred_joints = pred_joints[image_index, i]
print(f"Detected pose with confidence={{confidence}}, x_min={{x_min}}, y_min={{y_min}}, x_max={{x_max}}, y_max={{y_max}}")
for joint_index, (x, y, confidence) in enumerate(pred_joints[i]):
print(f"Joint {{joint_index}} has coordinates x={{x}}, y={{y}}, confidence={{confidence}}")

"""
elif postprocessing is False:
usage_instructions += """Model exported with postprocessing=False
No decoding or NMS is added to the model, so you will have to decode predictions manually.
Please refer to the documentation for the model you exported"""
elif isinstance(postprocessing_module, nn.Module):
usage_instructions += f""""Exported model contains a custom postprocessing step.
We are unable to provide usage instructions to user-provided postprocessing module
But here is the human-friendly representation of the postprocessing module:
{repr(postprocessing_module)}"""

return usage_instructions
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
in_channels: 3

backbone:
NStageBackbone:

stem:
YoloNASStem:
out_channels: 48

stages:
- YoloNASStage:
out_channels: 96
num_blocks: 2
activation_type: relu
hidden_channels: 96
concat_intermediates: True

- YoloNASStage:
out_channels: 192
num_blocks: 3
activation_type: relu
hidden_channels: 128
concat_intermediates: True

- YoloNASStage:
out_channels: 384
num_blocks: 5
activation_type: relu
hidden_channels: 256
concat_intermediates: True

- YoloNASStage:
out_channels: 768
num_blocks: 2
activation_type: relu
hidden_channels: 512
concat_intermediates: True


context_module:
SPP:
output_channels: 768
activation_type: relu
k: [5,9,13]

out_layers: [stage1, stage2, stage3, context_module]

neck:
YoloNASPANNeckWithC2:

neck1:
YoloNASUpStage:
out_channels: 192
num_blocks: 4
hidden_channels: 128
width_mult: 1
depth_mult: 1
activation_type: relu
reduce_channels: True

neck2:
YoloNASUpStage:
out_channels: 96
num_blocks: 4
hidden_channels: 128
width_mult: 1
depth_mult: 1
activation_type: relu
reduce_channels: True

neck3:
YoloNASDownStage:
out_channels: 192
num_blocks: 4
hidden_channels: 128
activation_type: relu
width_mult: 1
depth_mult: 1

neck4:
YoloNASDownStage:
out_channels: 384
num_blocks: 4
hidden_channels: 256
activation_type: relu
width_mult: 1
depth_mult: 1

heads:
YoloNASPoseNDFLHeads:
num_classes: 17
reg_max: 16
heads_list:
- YoloNASPoseDFLHead:
bbox_inter_channels: 128
pose_inter_channels: 128
pose_regression_blocks: 2
shared_stem: False
width_mult: 1
pose_conf_in_class_head: True
pose_block_use_repvgg: False
first_conv_group_size: 0
num_classes:
stride: 8
reg_max: 16
cls_dropout_rate: 0.0
reg_dropout_rate: 0.0

- YoloNASPoseDFLHead:
bbox_inter_channels: 256
pose_inter_channels: 512
pose_regression_blocks: 2
shared_stem: False
width_mult: 1
pose_conf_in_class_head: True
pose_block_use_repvgg: False
first_conv_group_size: 0
num_classes:
stride: 16
reg_max: 16
cls_dropout_rate: 0.0
reg_dropout_rate: 0.0

- YoloNASPoseDFLHead:
bbox_inter_channels: 512
pose_inter_channels: 512
pose_regression_blocks: 3
shared_stem: False
width_mult: 1
pose_conf_in_class_head: True
pose_block_use_repvgg: False
first_conv_group_size: 0
num_classes:
stride: 32
reg_max: 16
cls_dropout_rate: 0.0
reg_dropout_rate: 0.0

bn_eps: 1e-6
bn_momentum: 0.03
inplace_act: True

_convert_: all
Loading