From 8b3d7de2c7dc8ab83fbb61813fab1f6f4c043756 Mon Sep 17 00:00:00 2001 From: Louis Dupont Date: Thu, 10 Aug 2023 20:44:01 +0300 Subject: [PATCH] fix --- .../training/pipelines/pipelines.py | 2 +- .../utils/predict/prediction_results.py | 4 +- .../training/utils/predict/predictions.py | 21 ++--- .../utils/visualization/classification.py | 89 ++++++++++++------- 4 files changed, 66 insertions(+), 50 deletions(-) diff --git a/src/super_gradients/training/pipelines/pipelines.py b/src/super_gradients/training/pipelines/pipelines.py index b7e58f7c3e..49d4a27759 100644 --- a/src/super_gradients/training/pipelines/pipelines.py +++ b/src/super_gradients/training/pipelines/pipelines.py @@ -421,7 +421,7 @@ def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], m predictions = list() for prediction, confidence, image_input in zip(classifier_predictions, confidence_predictions, model_input): - predictions.append(ClassificationPrediction(confidence=float(confidence), labels=int(prediction), image_shape=image_input.shape)) + predictions.append(ClassificationPrediction(confidence=float(confidence), label=int(prediction), image_shape=image_input.shape)) return predictions def _instantiate_image_prediction(self, image: np.ndarray, prediction: ClassificationPrediction) -> ImagePrediction: diff --git a/src/super_gradients/training/utils/predict/prediction_results.py b/src/super_gradients/training/utils/predict/prediction_results.py index f22295660f..6b324f0d17 100644 --- a/src/super_gradients/training/utils/predict/prediction_results.py +++ b/src/super_gradients/training/utils/predict/prediction_results.py @@ -64,9 +64,7 @@ def draw(self, show_confidence: bool = True) -> np.ndarray: """ image = self.image.copy() - return draw_label( - image=image, label=self.class_names[self.prediction.labels], confidence=str(self.prediction.confidence), image_shape=self.prediction.image_shape[1:] - ) + return draw_label(image=image, label=self.class_names[self.prediction.label], confidence=self.prediction.confidence) def show(self, show_confidence: bool = True) -> None: """Display the image with predicted label. diff --git a/src/super_gradients/training/utils/predict/predictions.py b/src/super_gradients/training/utils/predict/predictions.py index 56a75bf975..70962c48d0 100644 --- a/src/super_gradients/training/utils/predict/predictions.py +++ b/src/super_gradients/training/utils/predict/predictions.py @@ -113,27 +113,24 @@ class ClassificationPrediction(Prediction): """Represents a Classification prediction""" confidence: float - labels: int + label: int image_shape: Tuple[int, int] - def __init__(self, confidence: float, labels: int, image_shape: Optional[Tuple[int, int]]): + def __init__(self, confidence: float, label: int, image_shape: Optional[Tuple[int, int]]): """ :param confidence: Confidence scores for each bounding box - :param labels: Labels for each bounding box. + :param label: Labels for each bounding box. :param image_shape: Shape of the image the prediction is made on, (H, W). """ - self._validate_input(confidence, labels) + self._validate_input(confidence, label) self.confidence = confidence - self.labels = labels + self.label = label self.image_shape = image_shape - def _validate_input(self, confidence: np.ndarray, labels: np.ndarray) -> None: + def _validate_input(self, confidence: float, label: int) -> None: if not isinstance(confidence, float): - raise ValueError(f"Argument confidence must be a numpy array, not {type(confidence)}") - if not isinstance(labels, int): - raise ValueError(f"Argument labels must be a numpy array, not {type(labels)}") - - def __len__(self): - return len(self.labels) + raise ValueError(f"Argument confidence must be a float, not {type(confidence)}") + if not isinstance(label, int): + raise ValueError(f"Argument labels must be an integer, not {type(label)}") diff --git a/src/super_gradients/training/utils/visualization/classification.py b/src/super_gradients/training/utils/visualization/classification.py index ddbdc691ed..0205e1ac49 100644 --- a/src/super_gradients/training/utils/visualization/classification.py +++ b/src/super_gradients/training/utils/visualization/classification.py @@ -1,43 +1,64 @@ -from typing import Tuple - import cv2 import numpy as np -def draw_label(image: np.ndarray, label: str, confidence: str, image_shape: Tuple) -> np.ndarray: +def draw_label(image: np.ndarray, label: str, confidence: float) -> np.ndarray: """Draw a label and confidence on an image. - - :param image: Image on which to draw the bounding box. - :param label: Label to display on an image. - :param confidence: Confidence of the predicted label to display on an image - :param image_shape: Image shape of the image + :param image: The image on which to draw the label and confidence, in RGB format, and Channel Last (H, W, C) + :param label: The label to draw. + :param confidence: The confidence of the label. """ - # Determine the size of the label text - (label_width, label_height), _ = cv2.getTextSize(text=label, fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, thickness=1) - - # Calculate the position to draw the label - image_width, image_height = image_shape - start_point = ((image_width - label_width) // 2, (image_height - label_height) // 4) - - # Draw a filled rectangle as the background for the label - label_color = (0, 0, 0) - bg_position = (start_point[0], start_point[1] - label_height) - bg_size = (label_width, label_height * 2) # Double the height to accommodate two lines - cv2.rectangle(image, bg_position, (bg_position[0] + bg_size[0], bg_position[1] + bg_size[1]), label_color, thickness=-1) - - text_org = [(start_point[0], start_point[1]), (start_point[0], start_point[1] + label_height)] - for text, org in zip([label, confidence], text_org): - - cv2.putText( - img=image, - text=text, - org=org, - fontFace=cv2.FONT_HERSHEY_SIMPLEX, - fontScale=0.5, - color=(255, 255, 255), - thickness=1, - lineType=cv2.LINE_AA, - ) + # Format confidence as a percentage + confidence_str = f"{confidence * 100:.3f}%" + + # Use a slightly smaller font scale and a moderate thickness + fontScale = 0.8 + thickness = 1 + + # Determine the size of the label and confidence text + label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale, thickness)[0] + confidence_size = cv2.getTextSize(confidence_str, cv2.FONT_HERSHEY_SIMPLEX, fontScale, thickness)[0] + + # Determine the size of the bounding rectangle + text_width = max(label_size[0], confidence_size[0]) + text_height = label_size[1] + confidence_size[1] + thickness * 3 + + # Calculate the position to draw the label, centered horizontally and at the top + start_x = (image.shape[1] - text_width) // 2 + start_y = 5 + + # Draw a filled rectangle with transparency as the background for the label + overlay = image.copy() + bg_color = (255, 255, 255) # White + bg_start = (start_x, start_y) + bg_end = (start_x + text_width, start_y + text_height) + cv2.rectangle(overlay, bg_start, bg_end, bg_color, thickness=-1) + + alpha = 0.6 + cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) + + # Center the label and confidence text within the bounding rectangle + text_color = (0, 0, 0) # Black + cv2.putText( + image, + label, + (start_x + (text_width - label_size[0]) // 2, start_y + label_size[1]), + cv2.FONT_HERSHEY_SIMPLEX, + fontScale, + text_color, + thickness, + lineType=cv2.LINE_AA, + ) + cv2.putText( + image, + confidence_str, + (start_x + (text_width - confidence_size[0]) // 2, start_y + label_size[1] + confidence_size[1] + thickness), + cv2.FONT_HERSHEY_SIMPLEX, + fontScale, + text_color, + thickness, + lineType=cv2.LINE_AA, + ) return image