diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index ceff2b73e1..828e2ba21b 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -4,7 +4,7 @@ import signal import time from abc import ABC, abstractmethod -from typing import List, Union, Optional, Sequence, Mapping +from typing import List, Union, Optional, Sequence, Mapping, Tuple import csv import cv2 @@ -27,7 +27,7 @@ from super_gradients.common.sg_loggers.time_units import GlobalBatchStepNumber, EpochNumber from super_gradients.training.utils import get_param from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase, Callback -from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback +from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback, cxcywh2xyxy, xyxy2cxcywh from super_gradients.training.utils.distributed_training_utils import maybe_all_reduce_tensor_average, maybe_all_gather_np_images from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path @@ -1112,6 +1112,167 @@ def _is_more_extreme(self, score: float) -> bool: return self.extreme_score > score +@register_callback("ExtremeBatchDetectionVisualizationCallback") +class ExtremeBatchDetectionVisualizationCallback(ExtremeBatchCaseVisualizationCallback): + """ + ExtremeBatchSegVisualizationCallback + + Visualizes worst/best batch in an epoch for Object detection. + For clarity, the batch is saved twice in the SG Logger, once with the model's predictions and once with + ground truth targets. + + Assumptions on bbox dormats: + - After applying post_prediction_callback on context.preds, the predictions are a list/Tensor s.t: + predictions[i] is a tensor of shape nx6 - (x1, y1, x2, y2, confidence, class) where x and y are in pixel units. + + - context.targets is a tensor of shape (total_num_targets, 6), in LABEL_CXCYWH format: (index, label, cx, cy, w, h). + + + + Example usage in Yaml config: + + training_hyperparams: + phase_callbacks: + - ExtremeBatchDetectionVisualizationCallback: + metric: + DetectionMetrics_050: + score_thres: 0.1 + top_k_predictions: 300 + num_cls: ${num_classes} + normalize_targets: True + post_prediction_callback: + _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback + score_threshold: 0.01 + nms_top_k: 1000 + max_predictions: 300 + nms_threshold: 0.7 + metric_component_name: 'mAP@0.50' + post_prediction_callback: + _target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback + score_threshold: 0.25 + nms_top_k: 1000 + max_predictions: 300 + nms_threshold: 0.7 + normalize_targets: True + + :param metric: Metric, will be the metric which is monitored. + + :param metric_component_name: In case metric returns multiple values (as Mapping), + the value at metric.compute()[metric_component_name] will be the one monitored. + + :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). + Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: + + if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: + "/". + + If a single item is returned rather then a tuple: + . + + When there is no such attributes and criterion.forward(..) returns a tuple: + "/"Loss_" + + :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or + the minimum (default=False). + + :param freq: int, epoch frequency to perform all of the above (default=1). + + :param classes: List[str], a list of class names corresponding to the class indices for display. + When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does + not exist an error will be raised (default=None). + + :param normalize_targets: bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader + are in pixel values range, this needs to be set to True (default=False) + + """ + + def __init__( + self, + post_prediction_callback: DetectionPostPredictionCallback, + metric: Optional[Metric] = None, + metric_component_name: Optional[str] = None, + loss_to_monitor: Optional[str] = None, + max: bool = False, + freq: int = 1, + classes: Optional[List[str]] = None, + normalize_targets: bool = False, + ): + super(ExtremeBatchDetectionVisualizationCallback, self).__init__( + metric=metric, metric_component_name=metric_component_name, loss_to_monitor=loss_to_monitor, max=max, freq=freq + ) + self.post_prediction_callback = post_prediction_callback + if classes is None: + logger.info( + "No classes have been passed to ExtremeBatchDetectionVisualizationCallback. " + "Will try to fetch them through context.valid_loader.dataset classes attribute if it exists." + ) + self.classes = classes + self.normalize_targets = normalize_targets + + @staticmethod + def universal_undo_preprocessing_fn(inputs: torch.Tensor) -> np.ndarray: + """ + A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg. + :param inputs: + :return: + """ + inputs -= inputs.min() + inputs /= inputs.max() + inputs *= 255 + inputs = inputs.to(torch.uint8) + inputs = inputs.cpu().numpy() + inputs = inputs[:, ::-1, :, :].transpose(0, 2, 3, 1) + inputs = np.ascontiguousarray(inputs, dtype=np.uint8) + return inputs + + def process_extreme_batch(self) -> Tuple[np.ndarray, np.ndarray]: + """ + Processes the extreme batch, and returns 2 image batches for visualization - one with predictions and one with GT boxes. + :return:Tuple[np.ndarray, np.ndarray], the predictions batch, the GT batch + """ + inputs = self.extreme_batch + preds = self.post_prediction_callback(self.extreme_preds, self.extreme_batch.device) + targets = self.extreme_targets.clone() + if self.normalize_targets: + target_bboxes = targets[:, 2:] + target_bboxes = cxcywh2xyxy(target_bboxes) + _, _, height, width = inputs.shape + target_bboxes[:, [0, 2]] /= width + target_bboxes[:, [1, 3]] /= height + target_bboxes = xyxy2cxcywh(target_bboxes) + targets[:, 2:] = target_bboxes + + images_to_save_preds = DetectionVisualization.visualize_batch( + inputs, preds, targets, "extreme_batch_preds", self.classes, gt_alpha=0.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn + ) + images_to_save_preds = np.stack(images_to_save_preds) + + images_to_save_gt = DetectionVisualization.visualize_batch( + inputs, None, targets, "extreme_batch_gt", self.classes, gt_alpha=1.0, undo_preprocessing_func=self.universal_undo_preprocessing_fn + ) + images_to_save_gt = np.stack(images_to_save_gt) + + return images_to_save_preds, images_to_save_gt + + def on_validation_loader_end(self, context: PhaseContext) -> None: + if self.classes is None: + if hasattr(context.valid_loader.dataset, "classes"): + self.classes = context.valid_loader.dataset.classes + + else: + raise RuntimeError("Couldn't fetch classes from valid_loader, please pass classes explicitly") + if context.epoch % self.freq == 0: + images_to_save_preds, images_to_save_gt = self.process_extreme_batch() + images_to_save_preds = maybe_all_gather_np_images(images_to_save_preds) + images_to_save_gt = maybe_all_gather_np_images(images_to_save_gt) + + if not context.ddp_silent_mode: + context.sg_logger.add_images(tag=f"{self._tag}_preds", images=images_to_save_preds, global_step=context.epoch, data_format="NHWC") + context.sg_logger.add_images(tag=f"{self._tag}_GT", images=images_to_save_gt, global_step=context.epoch, data_format="NHWC") + + self._reset() + + @register_callback("ExtremeBatchSegVisualizationCallback") class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback): """ diff --git a/src/super_gradients/training/utils/detection_utils.py b/src/super_gradients/training/utils/detection_utils.py index 1158b19c8b..53450f9024 100755 --- a/src/super_gradients/training/utils/detection_utils.py +++ b/src/super_gradients/training/utils/detection_utils.py @@ -501,7 +501,7 @@ def visualize_batch( :param image_tensor: rgb images, (B, H, W, 3) :param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6), values on dim 1 are: x1, y1, x2, y2, confidence, class - :param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h + :param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, cx cy w h (coordinates scaled to [0, 1]) :param batch_name: id of the current batch to use for image naming @@ -518,6 +518,8 @@ def visualize_batch( """ image_np = undo_preprocessing_func(image_tensor.detach()) targets = DetectionVisualization._scaled_ccwh_to_xyxy(target_boxes.detach().cpu().numpy(), *image_np.shape[1:3], image_scale) + if pred_boxes is None: + pred_boxes = [None for _ in range(image_np.shape[0])] out_images = [] for i in range(image_np.shape[0]): diff --git a/tests/unit_tests/extreme_batch_cb_test.py b/tests/unit_tests/extreme_batch_cb_test.py index 902571eb57..b4692b6274 100644 --- a/tests/unit_tests/extreme_batch_cb_test.py +++ b/tests/unit_tests/extreme_batch_cb_test.py @@ -2,19 +2,27 @@ from super_gradients import Trainer from super_gradients.common.object_names import Models from super_gradients.training import models -from super_gradients.training.dataloaders.dataloaders import segmentation_test_dataloader +from super_gradients.training.dataloaders.dataloaders import segmentation_test_dataloader, detection_test_dataloader +from super_gradients.training.losses import PPYoloELoss from super_gradients.training.losses.ddrnet_loss import DDRNetLoss -from super_gradients.training.metrics import IoU -from super_gradients.training.utils.callbacks.callbacks import ExtremeBatchSegVisualizationCallback +from super_gradients.training.metrics import IoU, DetectionMetrics_050 +from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback +from super_gradients.training.utils.callbacks.callbacks import ExtremeBatchSegVisualizationCallback, ExtremeBatchDetectionVisualizationCallback # Helper method to set up Trainer and model with common parameters -def setup_trainer_and_model(experiment_name: str): +def setup_trainer_and_model_seg(experiment_name: str): trainer = Trainer(experiment_name) model = models.get(Models.DDRNET_23, arch_params={"use_aux_heads": True}, pretrained_weights="cityscapes") return trainer, model +def setup_trainer_and_model_detection(experiment_name: str): + trainer = Trainer(experiment_name) + model = models.get(Models.YOLO_NAS_S, num_classes=1) + return trainer, model + + class DummyIOU(IoU): """ Metric for testing the segmentation callback works with compound metrics @@ -28,13 +36,12 @@ def compute(self): class ExtremeBatchSanityTest(unittest.TestCase): @classmethod def setUpClass(cls): - cls.training_params = { + cls.seg_training_params = { "max_epochs": 3, "initial_lr": 1e-2, "loss": DDRNetLoss(), "lr_mode": "poly", "ema": True, - "average_best_models": True, "optimizer": "SGD", "mixed_precision": False, "optimizer_params": {"weight_decay": 5e-4, "momentum": 0.9}, @@ -45,25 +52,74 @@ def setUpClass(cls): "greater_metric_to_watch_is_better": True, } + cls.od_training_params = { + "max_epochs": 3, + "initial_lr": 1e-2, + "loss": PPYoloELoss(num_classes=1, use_static_assigner=False, reg_max=16), + "lr_mode": "poly", + "ema": True, + "optimizer": "SGD", + "mixed_precision": False, + "optimizer_params": {"weight_decay": 5e-4, "momentum": 0.9}, + "load_opt_params": False, + "valid_metrics_list": [ + DetectionMetrics_050( + normalize_targets=True, + post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.03, nms_top_k=1000, max_predictions=300, nms_threshold=0.65), + num_cls=1, + ) + ], + "train_metrics_list": [], + "metric_to_watch": "mAP@0.50", + "greater_metric_to_watch_is_better": True, + } + + def test_detection_extreme_batch_with_metric_sanity(self): + trainer, model = setup_trainer_and_model_detection("test_detection_extreme_batch_with_metric_sanity") + self.od_training_params["phase_callbacks"] = [ + ExtremeBatchDetectionVisualizationCallback( + classes=["1"], + metric=DetectionMetrics_050( + normalize_targets=True, + post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.03, nms_top_k=1000, max_predictions=300, nms_threshold=0.65), + num_cls=1, + ), + metric_component_name="mAP@0.50", + post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.03, nms_top_k=1000, max_predictions=300, nms_threshold=0.65), + ) + ] + trainer.train(model=model, training_params=self.od_training_params, train_loader=detection_test_dataloader(), valid_loader=detection_test_dataloader()) + + def test_detection_extreme_batch_with_loss_sanity(self): + trainer, model = setup_trainer_and_model_detection("test_detection_extreme_batch_with_loss_sanity") + self.od_training_params["phase_callbacks"] = [ + ExtremeBatchDetectionVisualizationCallback( + classes=["1"], + loss_to_monitor="PPYoloELoss/loss_cls", + post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.03, nms_top_k=1000, max_predictions=300, nms_threshold=0.65), + ) + ] + trainer.train(model=model, training_params=self.od_training_params, train_loader=detection_test_dataloader(), valid_loader=detection_test_dataloader()) + def test_segmentation_extreme_batch_with_metric_sanity(self): - trainer, model = setup_trainer_and_model("test_segmentation_extreme_batch_with_metric_sanity") - self.training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(IoU(5))] + trainer, model = setup_trainer_and_model_seg("test_segmentation_extreme_batch_with_metric_sanity") + self.seg_training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(IoU(5))] trainer.train( - model=model, training_params=self.training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() + model=model, training_params=self.seg_training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() ) def test_segmentation_extreme_batch_with_compound_metric_sanity(self): - trainer, model = setup_trainer_and_model("test_segmentation_extreme_batch_with_compound_metric_sanity") - self.training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(DummyIOU(5), metric_component_name="diou_minus")] + trainer, model = setup_trainer_and_model_seg("test_segmentation_extreme_batch_with_compound_metric_sanity") + self.seg_training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(DummyIOU(5), metric_component_name="diou_minus")] trainer.train( - model=model, training_params=self.training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() + model=model, training_params=self.seg_training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() ) def test_segmentation_extreme_batch_with_loss_sanity(self): - trainer, model = setup_trainer_and_model("test_segmentation_extreme_batch_with_loss_sanity") - self.training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(loss_to_monitor="DDRNetLoss/aux_loss1")] + trainer, model = setup_trainer_and_model_seg("test_segmentation_extreme_batch_with_loss_sanity") + self.seg_training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(loss_to_monitor="DDRNetLoss/aux_loss1")] trainer.train( - model=model, training_params=self.training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() + model=model, training_params=self.seg_training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() )