Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve extreme batch visualization callbacks #1488

Merged
merged 7 commits into from
Oct 3, 2023
237 changes: 154 additions & 83 deletions src/super_gradients/training/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import copy
import csv
import math
import os
import signal
import time
from abc import ABC, abstractmethod
from typing import List, Union, Optional, Sequence, Mapping, Tuple
from typing import List, Union, Optional, Sequence, Mapping

import csv
import cv2
import numpy as np
import onnx
Expand All @@ -18,22 +18,21 @@

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.deprecate import deprecated
from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.environment.device_utils import device_config
from super_gradients.common.factories.metrics_factory import MetricsFactory
from super_gradients.common.object_names import LRSchedulers, LRWarmups, Callbacks
from super_gradients.common.plugins.deci_client import DeciClient
from super_gradients.common.registry.registry import register_lr_scheduler, register_lr_warmup, register_callback, LR_SCHEDULERS_CLS_DICT, TORCH_LR_SCHEDULERS
from super_gradients.common.object_names import LRSchedulers, LRWarmups, Callbacks
from super_gradients.common.sg_loggers.time_units import GlobalBatchStepNumber, EpochNumber
from super_gradients.training.utils import get_param
from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase, Callback
from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback, cxcywh2xyxy, xyxy2cxcywh
from super_gradients.training.utils.distributed_training_utils import maybe_all_reduce_tensor_average, maybe_all_gather_np_images
from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization
from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path
from super_gradients.training.utils.utils import unwrap_model
from super_gradients.common.deprecate import deprecated

from super_gradients.training.utils.utils import unwrap_model, infer_model_device, tensor_container_to_device

logger = get_logger(__name__)

Expand Down Expand Up @@ -1070,7 +1069,7 @@ class ExtremeBatchCaseVisualizationCallback(Callback, ABC):

:param freq: int, epoch frequency to perform all of the above (default=1).

Inheritors should implement process_extreme_batch which returns an image, as an np.array (uint8) with shape BCHW.
Inheritors should implement process_extreme_batch which returns an image, as np.array (uint8) with shape BCHW.
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
"""

@resolve_param("metric", MetricsFactory())
Expand All @@ -1081,7 +1080,21 @@ def __init__(
loss_to_monitor: Optional[str] = None,
max: bool = False,
freq: int = 1,
enable_on_train_loader: bool = False,
enable_on_valid_loader: bool = True,
max_images: int = -1,
):
"""

:param metric:
:param metric_component_name:
:param loss_to_monitor:
:param max:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
:param freq: Frequency (in epochs) of performing this callback. 1 means every epoch. 2 means every other epoch. Default is 1.
:param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
:param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
:param max_images: Maximum images to save. If -1, save all images.
"""
super(ExtremeBatchCaseVisualizationCallback, self).__init__()

if (metric and loss_to_monitor) or (metric is None and loss_to_monitor is None):
Expand All @@ -1098,15 +1111,19 @@ def __init__(
self.loss_to_monitor = loss_to_monitor
self.max = max
self.freq = freq
self.extreme_score = -1 * np.inf if max else np.inf

self.extreme_score = None
self.extreme_batch = None
self.extreme_preds = None
self.extreme_targets = None

self._first_call = True
self._idx_loss_tuple = None

self.enable_on_train_loader = enable_on_train_loader
self.enable_on_valid_loader = enable_on_valid_loader
self.max_images = max_images

def _set_tag_attr(self, loss_to_monitor, max, metric, metric_component_name):
if metric_component_name:
monitored_val_name = metric_component_name
Expand All @@ -1126,68 +1143,101 @@ def process_extreme_batch(self) -> np.ndarray:
"""
raise NotImplementedError

def on_train_loader_start(self, context: PhaseContext) -> None:
self._reset()

def on_train_batch_end(self, context: PhaseContext) -> None:
if self.enable_on_train_loader and context.epoch % self.freq == 0:
self._on_batch_end(context)

def on_train_loader_end(self, context: PhaseContext) -> None:
if self.enable_on_train_loader and context.epoch % self.freq == 0:
self._gather_extreme_batch_images_and_log(context, "train")
self._reset()

def on_validation_loader_start(self, context: PhaseContext) -> None:
self._reset()

def on_validation_batch_end(self, context: PhaseContext) -> None:
if context.epoch % self.freq == 0:
# FOR METRIC OBJECTS, RESET THEM AND COMPUTE SCORE ONLY ON BATCH.
if self.metric is not None:
self.metric.update(**context.__dict__)
score = self.metric.compute()
if self.metric_component_name is not None:
if not isinstance(score, Mapping) or (isinstance(score, Mapping) and self.metric_component_name not in score.keys()):
raise RuntimeError(
f"metric_component_name: {self.metric_component_name} is not a component "
f"of the monitored metric: {self.metric.__class__.__name__}"
)
score = score[self.metric_component_name]
elif len(score) > 1:
raise RuntimeError(f"returned multiple values from {self.metric} but no metric_component_name has been passed to __init__.")
else:
score = score.pop(list(score.keys())[0])
self.metric.reset()
if self.enable_on_valid_loader and context.epoch % self.freq == 0:
self._on_batch_end(context)

def on_validation_loader_end(self, context: PhaseContext) -> None:
if self.enable_on_valid_loader and context.epoch % self.freq == 0:
self._gather_extreme_batch_images_and_log(context, "valid")
self._reset()

def _gather_extreme_batch_images_and_log(self, context, loader_name: str):
images_to_save = self.process_extreme_batch()
images_to_save = maybe_all_gather_np_images(images_to_save)
if self.max_images > 0:
images_to_save = images_to_save[: self.max_images]
if not context.ddp_silent_mode:
context.sg_logger.add_images(tag=f"{loader_name}/{self._tag}", images=images_to_save, global_step=context.epoch, data_format="NHWC")

def _on_batch_end(self, context: PhaseContext) -> None:
if self.metric is not None:
self.metric.update(**context.__dict__)
score = self.metric.compute()
if self.metric_component_name is not None:
if not isinstance(score, Mapping) or (isinstance(score, Mapping) and self.metric_component_name not in score.keys()):
raise RuntimeError(
f"metric_component_name: {self.metric_component_name} is not a component " f"of the monitored metric: {self.metric.__class__.__name__}"
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
)
score = score[self.metric_component_name]
elif len(score) > 1:
raise RuntimeError(f"returned multiple values from {self.metric} but no metric_component_name has been passed to __init__.")
else:
score = score.pop(list(score.keys())[0])
self.metric.reset()

else:

# FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERIVE IT ON THE FIRST PASS
loss_tuple = context.loss_log_items
if self._first_call:
self._init_loss_attributes(context)
score = loss_tuple[self._idx_loss_tuple]
# FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERIVE IT ON THE FIRST PASS
loss_tuple = context.loss_log_items
if self._first_call:
self._init_loss_attributes(context)
score = loss_tuple[self._idx_loss_tuple].detach().cpu().item()

# IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP
device = next(context.net.parameters()).device
score.to(device)
score = maybe_all_reduce_tensor_average(score)
# IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP
device = infer_model_device(context.net)
score = torch.tensor(score, device=device)
score = maybe_all_reduce_tensor_average(score)

if self._is_more_extreme(score):
self.extreme_score = score
self.extreme_batch = context.inputs
self.extreme_preds = context.preds
self.extreme_targets = context.target
if self._is_more_extreme(score):
self.extreme_score = tensor_container_to_device(score, device="cpu", detach=True, non_blocking=False)
self.extreme_batch = tensor_container_to_device(context.inputs, device="cpu", detach=True, non_blocking=False)
self.extreme_preds = tensor_container_to_device(context.preds, device="cpu", detach=True, non_blocking=False)
self.extreme_targets = tensor_container_to_device(context.target, device="cpu", detach=True, non_blocking=False)

def _init_loss_attributes(self, context: PhaseContext):
if self.loss_to_monitor not in context.loss_logging_items_names:
raise ValueError(f"{self.loss_to_monitor} not a loss or loss component.")
self._idx_loss_tuple = context.loss_logging_items_names.index(self.loss_to_monitor)
self._first_call = False

def on_validation_loader_end(self, context: PhaseContext) -> None:
if context.epoch % self.freq == 0:
images_to_save = self.process_extreme_batch()
images_to_save = maybe_all_gather_np_images(images_to_save)
if not context.ddp_silent_mode:
context.sg_logger.add_images(tag=self._tag, images=images_to_save, global_step=context.epoch)

self._reset()

def _reset(self):
self.extreme_score = -1 * np.inf if self.max else np.inf
self.extreme_score = None
self.extreme_batch = None
self.extreme_preds = None
self.extreme_targets = None
if self.metric is not None:
self.metric.reset()

def _is_more_extreme(self, score: float) -> bool:
def _is_more_extreme(self, score: Union[float, torch.Tensor]) -> bool:
"""
Checks whether computed score is the more extreme than the current extreme score.
If the current score is None (first call), returns True.
:param score: A newly computed score.
:return: True if score is more extreme than the current extreme score, False otherwise.
"""
# A score can be Nan/Inf (rare but possible event when training diverges).
# In such case the both < and > operators would return False according to IEEE 754.
# As a consequence, self.extreme_inputs / self.extreme_outputs would not be updated
# and that would crash at the attempt to visualize batch.
if self.extreme_score is None:
return True

if self.max:
return self.extreme_score < score
else:
Expand Down Expand Up @@ -1254,18 +1304,21 @@ class ExtremeBatchDetectionVisualizationCallback(ExtremeBatchCaseVisualizationCa
When there is no such attributes and criterion.forward(..) returns a tuple:
<LOSS_CLASS.__name__>"/"Loss_"<IDX>

:param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
the minimum (default=False).
:param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
the minimum (default=False).

:param freq: int, epoch frequency to perform all of the above (default=1).
:param freq: int, epoch frequency to perform all of the above (default=1).

:param classes: List[str], a list of class names corresponding to the class indices for display.
When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does
not exist an error will be raised (default=None).
:param classes: List[str], a list of class names corresponding to the class indices for display.
When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does
not exist an error will be raised (default=None).

:param normalize_targets: bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader
:param normalize_targets: bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader
are in pixel values range, this needs to be set to True (default=False)

:param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
:param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
:param max_images: Maximum images to save. If -1, save all images.
"""

def __init__(
Expand All @@ -1278,9 +1331,19 @@ def __init__(
freq: int = 1,
classes: Optional[List[str]] = None,
normalize_targets: bool = False,
enable_on_train_loader: bool = False,
enable_on_valid_loader: bool = True,
max_images: int = -1,
):
super(ExtremeBatchDetectionVisualizationCallback, self).__init__(
metric=metric, metric_component_name=metric_component_name, loss_to_monitor=loss_to_monitor, max=max, freq=freq
metric=metric,
metric_component_name=metric_component_name,
loss_to_monitor=loss_to_monitor,
max=max,
freq=freq,
enable_on_valid_loader=enable_on_valid_loader,
enable_on_train_loader=enable_on_train_loader,
max_images=max_images,
)
self.post_prediction_callback = post_prediction_callback
if classes is None:
Expand All @@ -1307,10 +1370,12 @@ def universal_undo_preprocessing_fn(inputs: torch.Tensor) -> np.ndarray:
inputs = np.ascontiguousarray(inputs, dtype=np.uint8)
return inputs

def process_extreme_batch(self) -> Tuple[np.ndarray, np.ndarray]:
def process_extreme_batch(self) -> np.ndarray:
"""
Processes the extreme batch, and returns 2 image batches for visualization - one with predictions and one with GT boxes.
:return:Tuple[np.ndarray, np.ndarray], the predictions batch, the GT batch
Processes the extreme batch, and returns list of images for visualization.
Default implementations stacks GT and prediction overlays horisontally.

:return: np.ndarray A 4D tensor of [BHWC] shape with visualizations of the extreme batch.
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
"""
inputs = self.extreme_batch
preds = self.post_prediction_callback(self.extreme_preds, self.extreme_batch.device)
Expand All @@ -1334,25 +1399,16 @@ def process_extreme_batch(self) -> Tuple[np.ndarray, np.ndarray]:
)
images_to_save_gt = np.stack(images_to_save_gt)

return images_to_save_preds, images_to_save_gt
# Stack the predictions and GT images together
return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2)

def on_validation_loader_end(self, context: PhaseContext) -> None:
def on_validation_loader_start(self, context: PhaseContext) -> None:
if self.classes is None:
if hasattr(context.valid_loader.dataset, "classes"):
self.classes = context.valid_loader.dataset.classes

else:
raise RuntimeError("Couldn't fetch classes from valid_loader, please pass classes explicitly")
if context.epoch % self.freq == 0:
images_to_save_preds, images_to_save_gt = self.process_extreme_batch()
images_to_save_preds = maybe_all_gather_np_images(images_to_save_preds)
images_to_save_gt = maybe_all_gather_np_images(images_to_save_gt)

if not context.ddp_silent_mode:
context.sg_logger.add_images(tag=f"{self._tag}_preds", images=images_to_save_preds, global_step=context.epoch, data_format="NHWC")
context.sg_logger.add_images(tag=f"{self._tag}_GT", images=images_to_save_gt, global_step=context.epoch, data_format="NHWC")

self._reset()
super().on_validation_loader_start(context)


@register_callback("ExtremeBatchSegVisualizationCallback")
Expand Down Expand Up @@ -1405,12 +1461,14 @@ class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback
When there is no such attributesand criterion.forward(..) returns a tuple:
<LOSS_CLASS.__name__>"/"Loss_"<IDX>

:param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
the minimum (default=False).

:param freq: int, epoch frequency to perform all of the above (default=1).
:param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
the minimum (default=False).

:param freq: int, epoch frequency to perform all of the above (default=1).

:param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
:param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
:param max_images: Maximum images to save. If -1, save all images.
"""

def __init__(
Expand All @@ -1421,13 +1479,24 @@ def __init__(
max: bool = False,
freq: int = 1,
ignore_idx: int = -1,
enable_on_train_loader: bool = False,
enable_on_valid_loader: bool = True,
max_images: int = -1,
):
super(ExtremeBatchSegVisualizationCallback, self).__init__(
metric=metric, metric_component_name=metric_component_name, loss_to_monitor=loss_to_monitor, max=max, freq=freq
metric=metric,
metric_component_name=metric_component_name,
loss_to_monitor=loss_to_monitor,
max=max,
freq=freq,
enable_on_valid_loader=enable_on_valid_loader,
enable_on_train_loader=enable_on_train_loader,
max_images=max_images,
)
self.ignore_idx = ignore_idx

def process_extreme_batch(self) -> np.array:
@torch.no_grad()
def process_extreme_batch(self) -> np.ndarray:
inputs = self.extreme_batch
inputs -= inputs.min()
inputs /= inputs.max()
Expand All @@ -1445,6 +1514,8 @@ def process_extreme_batch(self) -> np.array:
colors = ["green", "red"]
images_to_save = []
for i in range(len(inputs)):
images_to_save.append(draw_segmentation_masks(inputs[i].cpu(), overlay[i], colors=colors, alpha=0.4).detach().numpy())
images_to_save = np.array(images_to_save)
image = draw_segmentation_masks(inputs[i].cpu(), overlay[i].cpu(), colors=colors, alpha=0.4).numpy()
image = np.transpose(image, (1, 2, 0))
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
images_to_save.append(image)
images_to_save = np.stack(images_to_save)
return images_to_save
Loading