Skip to content

Commit

Permalink
fix (#1367)
Browse files Browse the repository at this point in the history
* fix

* add spacing
  • Loading branch information
Louis-Dupont committed Aug 15, 2023
1 parent ee24164 commit fed8756
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/super_gradients/training/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], m

predictions = list()
for prediction, confidence, image_input in zip(classifier_predictions, confidence_predictions, model_input):
predictions.append(ClassificationPrediction(confidence=float(confidence), labels=int(prediction), image_shape=image_input.shape))
predictions.append(ClassificationPrediction(confidence=float(confidence), label=int(prediction), image_shape=image_input.shape))
return predictions

def _instantiate_image_prediction(self, image: np.ndarray, prediction: ClassificationPrediction) -> ImagePrediction:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def draw(self, show_confidence: bool = True) -> np.ndarray:
"""

image = self.image.copy()
return draw_label(
image=image, label=self.class_names[self.prediction.labels], confidence=str(self.prediction.confidence), image_shape=self.prediction.image_shape[1:]
)
return draw_label(image=image, label=self.class_names[self.prediction.label], confidence=self.prediction.confidence)

def show(self, show_confidence: bool = True) -> None:
"""Display the image with predicted label.
Expand Down
21 changes: 9 additions & 12 deletions src/super_gradients/training/utils/predict/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,24 @@ class ClassificationPrediction(Prediction):
"""Represents a Classification prediction"""

confidence: float
labels: int
label: int
image_shape: Tuple[int, int]

def __init__(self, confidence: float, labels: int, image_shape: Optional[Tuple[int, int]]):
def __init__(self, confidence: float, label: int, image_shape: Optional[Tuple[int, int]]):
"""
:param confidence: Confidence scores for each bounding box
:param labels: Labels for each bounding box.
:param label: Labels for each bounding box.
:param image_shape: Shape of the image the prediction is made on, (H, W).
"""
self._validate_input(confidence, labels)
self._validate_input(confidence, label)

self.confidence = confidence
self.labels = labels
self.label = label
self.image_shape = image_shape

def _validate_input(self, confidence: np.ndarray, labels: np.ndarray) -> None:
def _validate_input(self, confidence: float, label: int) -> None:
if not isinstance(confidence, float):
raise ValueError(f"Argument confidence must be a numpy array, not {type(confidence)}")
if not isinstance(labels, int):
raise ValueError(f"Argument labels must be a numpy array, not {type(labels)}")

def __len__(self):
return len(self.labels)
raise ValueError(f"Argument confidence must be a float, not {type(confidence)}")
if not isinstance(label, int):
raise ValueError(f"Argument labels must be an integer, not {type(label)}")
92 changes: 58 additions & 34 deletions src/super_gradients/training/utils/visualization/classification.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,67 @@
from typing import Tuple

import cv2
import numpy as np


def draw_label(image: np.ndarray, label: str, confidence: str, image_shape: Tuple) -> np.ndarray:
def draw_label(image: np.ndarray, label: str, confidence: float) -> np.ndarray:
"""Draw a label and confidence on an image.
:param image: Image on which to draw the bounding box.
:param label: Label to display on an image.
:param confidence: Confidence of the predicted label to display on an image
:param image_shape: Image shape of the image
:param image: The image on which to draw the label and confidence, in RGB format, and Channel Last (H, W, C)
:param label: The label to draw.
:param confidence: The confidence of the label.
"""

# Determine the size of the label text
(label_width, label_height), _ = cv2.getTextSize(text=label, fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, thickness=1)

# Calculate the position to draw the label
image_width, image_height = image_shape
start_point = ((image_width - label_width) // 2, (image_height - label_height) // 4)

# Draw a filled rectangle as the background for the label
label_color = (0, 0, 0)
bg_position = (start_point[0], start_point[1] - label_height)
bg_size = (label_width, label_height * 2) # Double the height to accommodate two lines
cv2.rectangle(image, bg_position, (bg_position[0] + bg_size[0], bg_position[1] + bg_size[1]), label_color, thickness=-1)

text_org = [(start_point[0], start_point[1]), (start_point[0], start_point[1] + label_height)]
for text, org in zip([label, confidence], text_org):

cv2.putText(
img=image,
text=text,
org=org,
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.5,
color=(255, 255, 255),
thickness=1,
lineType=cv2.LINE_AA,
)
# Format confidence as a percentage
confidence_str = f"{confidence * 100:.3f}%"

# Use a slightly smaller font scale and a moderate thickness
fontScale = 0.8
thickness = 1

# Define additional spacing between the two lines
line_spacing = 5

# Determine the size of the label and confidence text
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale, thickness)[0]
confidence_size = cv2.getTextSize(confidence_str, cv2.FONT_HERSHEY_SIMPLEX, fontScale, thickness)[0]

# Determine the size of the bounding rectangle
text_width = max(label_size[0], confidence_size[0])
text_height = label_size[1] + confidence_size[1] + thickness * 3 + line_spacing

# Calculate the position to draw the label, centered horizontally and at the top
start_x = (image.shape[1] - text_width) // 2
start_y = 5

# Draw a filled rectangle with transparency as the background for the label
overlay = image.copy()
bg_color = (255, 255, 255) # White
bg_start = (start_x, start_y)
bg_end = (start_x + text_width, start_y + text_height)
cv2.rectangle(overlay, bg_start, bg_end, bg_color, thickness=-1)

alpha = 0.6
cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)

# Center the label and confidence text within the bounding rectangle, with additional spacing
text_color = (0, 0, 0) # Black
cv2.putText(
image,
label,
(start_x + (text_width - label_size[0]) // 2, start_y + label_size[1]),
cv2.FONT_HERSHEY_SIMPLEX,
fontScale,
text_color,
thickness,
lineType=cv2.LINE_AA,
)
cv2.putText(
image,
confidence_str,
(start_x + (text_width - confidence_size[0]) // 2, start_y + label_size[1] + confidence_size[1] + thickness + line_spacing),
cv2.FONT_HERSHEY_SIMPLEX,
fontScale,
text_color,
thickness,
lineType=cv2.LINE_AA,
)

return image

0 comments on commit fed8756

Please sign in to comment.