Skip to content

Commit

Permalink
Support for segmentation extreme batch cases (#1282)
Browse files Browse the repository at this point in the history
* tested version

* changed base to abc and abstractmethod

* comments wip

* refactoring, docs

* removed testing metric

* removed device arg from maybe all reduce

* fixed metrice typo in docs

* loss_name changed to loss_to_monitor

* unit tests added

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
shaydeci and BloodAxe authored Aug 1, 2023
1 parent 5792080 commit 699b972
Show file tree
Hide file tree
Showing 6 changed files with 370 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl
# COMPUTE THE LOSS FOR BACK PROP + EXTRA METRICS COMPUTED DURING THE LOSS FORWARD PASS
loss, loss_log_items = self._get_losses(outputs, targets)

context.update_context(preds=outputs, loss_log_items=loss_log_items)
context.update_context(preds=outputs, loss_log_items=loss_log_items, loss_logging_items_names=self.loss_logging_items_names)
self.phase_callback_handler.on_train_batch_loss_end(context)

if not self.ddp_silent_mode and batch_idx == 0:
Expand Down Expand Up @@ -1316,6 +1316,7 @@ def forward(self, inputs, targets):
metric_to_watch=self.metric_to_watch,
device=device_config.device,
ema_model=self.ema_model,
valid_metrics=self.valid_metrics,
)
self.phase_callback_handler.on_training_start(context)

Expand Down Expand Up @@ -1985,6 +1986,7 @@ def evaluate(

lr_warmup_epochs = self.training_params.lr_warmup_epochs if self.training_params else None
context = PhaseContext(
net=self.net,
epoch=epoch,
metrics_compute_fn=metrics,
loss_avg_meter=loss_avg_meter,
Expand All @@ -1994,6 +1996,7 @@ def evaluate(
sg_logger=self.sg_logger,
train_loader=self.train_loader,
valid_loader=self.valid_loader,
loss_logging_items_names=self.loss_logging_items_names,
)

with tqdm(data_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode) as progress_bar_data_loader:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
metric_to_watch=None,
valid_metrics=None,
ema_model=None,
loss_logging_items_names=None,
):
self.epoch = epoch
self.batch_idx = batch_idx
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
self.metric_to_watch = metric_to_watch
self.valid_metrics = valid_metrics
self.ema_model = ema_model
self.loss_logging_items_names = loss_logging_items_names

def update_context(self, **kwargs):
for attr, attr_val in kwargs.items():
Expand Down
257 changes: 257 additions & 0 deletions src/super_gradients/training/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import signal
import time
from abc import ABC, abstractmethod
from typing import List, Union, Optional, Sequence, Mapping

import csv
Expand All @@ -13,19 +14,25 @@
import torch
from deprecated import deprecated
from torch.utils.data import DataLoader
from torchmetrics import MetricCollection, Metric

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.environment.device_utils import device_config
from super_gradients.common.factories.metrics_factory import MetricsFactory
from super_gradients.common.plugins.deci_client import DeciClient
from super_gradients.common.registry.registry import register_lr_scheduler, register_lr_warmup, register_callback, LR_SCHEDULERS_CLS_DICT, TORCH_LR_SCHEDULERS
from super_gradients.common.object_names import LRSchedulers, LRWarmups, Callbacks
from super_gradients.common.sg_loggers.time_units import GlobalBatchStepNumber, EpochNumber
from super_gradients.training.utils import get_param
from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase, Callback
from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback
from super_gradients.training.utils.distributed_training_utils import maybe_all_reduce_tensor_average, maybe_all_gather_np_images
from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization
from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path
from super_gradients.training.utils.utils import unwrap_model
from torchvision.utils import draw_segmentation_masks

logger = get_logger(__name__)

Expand Down Expand Up @@ -948,3 +955,253 @@ def create_lr_scheduler_callback(
raise ValueError(f"Unknown lr_mode: {lr_mode}")

return sg_lr_callback


class ExtremeBatchCaseVisualizationCallback(Callback, ABC):
"""
ExtremeBatchCaseVisualizationCallback
A base class for visualizing worst/best validation batches in an epoch
according to some metric or loss value, with Full DDP support.
Images are saved with training_hyperparams.sg_logger.
:param metric: Metric, will be the metric which is monitored.
:param metric_component_name: In case metric returns multiple values (as Mapping),
the value at metric.compute()[metric_component_name] will be the one monitored.
:param loss_to_monitor: str, loss_to_monitor corresponfing to the 'criterion' passed through training_params in Trainer.train(...).
Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:
if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
<LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.
If a single item is returned rather then a tuple:
<LOSS_CLASS.__name__>.
When there is no such attributesand criterion.forward(..) returns a tuple:
<LOSS_CLASS.__name__>"/"Loss_"<IDX>
:param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
the minimum (default=False).
:param freq: int, epoch frequency to perform all of the above (default=1).
Inheritors should implement process_extreme_batch which returns an image, as an np.array (uint8) with shape BCHW.
"""

@resolve_param("metric", MetricsFactory())
def __init__(
self,
metric: Optional[Metric] = None,
metric_component_name: Optional[str] = None,
loss_to_monitor: Optional[str] = None,
max: bool = False,
freq: int = 1,
):
super(ExtremeBatchCaseVisualizationCallback, self).__init__()

if (metric and loss_to_monitor) or (metric is None and loss_to_monitor is None):
raise RuntimeError("Must pass exactly one of: loss, metric != None")

self._set_tag_attr(loss_to_monitor, max, metric, metric_component_name)
self.metric = metric
if self.metric:
self.metric = MetricCollection(self.metric)
self.metric.to(device_config.device)

self.metric_component_name = metric_component_name

self.loss_to_monitor = loss_to_monitor
self.max = max
self.freq = freq
self.extreme_score = -1 * np.inf if max else np.inf

self.extreme_batch = None
self.extreme_preds = None
self.extreme_targets = None

self._first_call = True
self._idx_loss_tuple = None

def _set_tag_attr(self, loss_to_monitor, max, metric, metric_component_name):
if metric_component_name:
monitored_val_name = metric_component_name
elif metric:
monitored_val_name = metric.__class__.__name__
else:
monitored_val_name = loss_to_monitor
self._tag = f"max_{monitored_val_name}_batch" if max else f"min_{monitored_val_name}_batch"

@abstractmethod
def process_extreme_batch(self) -> np.ndarray:
"""
This method is called right before adding the images to the in SGLoggger (inside the on_validation_loader_end call).
It should process self.extreme_batch, self.extreme_preds and self.extreme_targets and output the images, as np.ndarrray.
Output should be of shape N,3,H,W and uint8.
:return: images to save, np.ndarray
"""
raise NotImplementedError

def on_validation_batch_end(self, context: PhaseContext) -> None:
if context.epoch % self.freq == 0:
# FOR METRIC OBJECTS, RESET THEM AND COMPUTE SCORE ONLY ON BATCH.
if self.metric is not None:
self.metric.update(**context.__dict__)
score = self.metric.compute()
if self.metric_component_name is not None:
if not isinstance(score, Mapping) or (isinstance(score, Mapping) and self.metric_component_name not in score.keys()):
raise RuntimeError(
f"metric_component_name: {self.metric_component_name} is not a component "
f"of the monitored metric: {self.metric.__class__.__name__}"
)
score = score[self.metric_component_name]
elif len(score) > 1:
raise RuntimeError(f"returned multiple values from {self.metric} but no metric_component_name has been passed to __init__.")
else:
score = score.pop(list(score.keys())[0])
self.metric.reset()

else:

# FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERIVE IT ON THE FIRST PASS
loss_tuple = context.loss_log_items
if self._first_call:
self._init_loss_attributes(context)
score = loss_tuple[self._idx_loss_tuple]

# IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP
device = next(context.net.parameters()).device
score.to(device)
score = maybe_all_reduce_tensor_average(score)

if self._is_more_extreme(score):
self.extreme_score = score
self.extreme_batch = context.inputs
self.extreme_preds = context.preds
self.extreme_targets = context.target

def _init_loss_attributes(self, context: PhaseContext):
if self.loss_to_monitor not in context.loss_logging_items_names:
raise ValueError(f"{self.loss_to_monitor} not a loss or loss component.")
self._idx_loss_tuple = context.loss_logging_items_names.index(self.loss_to_monitor)
self._first_call = False

def on_validation_loader_end(self, context: PhaseContext) -> None:
if context.epoch % self.freq == 0:
images_to_save = self.process_extreme_batch()
images_to_save = maybe_all_gather_np_images(images_to_save)
if not context.ddp_silent_mode:
context.sg_logger.add_images(tag=self._tag, images=images_to_save, global_step=context.epoch)

self._reset()

def _reset(self):
self.extreme_score = -1 * np.inf if self.max else np.inf
self.extreme_batch = None
self.extreme_preds = None
self.extreme_targets = None
if self.metric is not None:
self.metric.reset()

def _is_more_extreme(self, score: float) -> bool:
if self.max:
return self.extreme_score < score
else:
return self.extreme_score > score


@register_callback("ExtremeBatchSegVisualizationCallback")
class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback):
"""
ExtremeBatchSegVisualizationCallback
Visualizes worst/best batch in an epoch, for segmentation.
Assumes context.preds in validation is a score tensor of shape BCHW, or a tuple whose first item is one.
True predictions will be marked with green, false ones with red.
Example usage in training_params definition:
training_hyperparams ={
...
"phase_callbacks":
[ExtremeBatchSegVisualizationCallback(
metric=IoU(20, ignore_idx=19)
max=False
ignore_idx=19),
ExtremeBatchSegVisualizationCallback(
loss_to_monitor="LabelSmoothingCrossEntropyLoss"
max=True
ignore_idx=19)]
...}
Example usage in Yaml config:
training_hyperparams:
phase_callbacks:
- ExtremeBatchSegVisualizationCallback:
loss_to_monitor: DiceCEEdgeLoss/aux_loss0
ignore_idx: 19
:param metric: Metric, will be the metric which is monitored.
:param metric_component_name: In case metric returns multiple values (as Mapping),
the value at metric.compute()[metric_component_name] will be the one monitored.
:param loss_to_monitor: str, loss_to_monitor corresponfing to the 'criterion' passed through training_params in Trainer.train(...).
Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:
if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
<LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.
If a single item is returned rather then a tuple:
<LOSS_CLASS.__name__>.
When there is no such attributesand criterion.forward(..) returns a tuple:
<LOSS_CLASS.__name__>"/"Loss_"<IDX>
:param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
the minimum (default=False).
:param freq: int, epoch frequency to perform all of the above (default=1).
"""

def __init__(
self,
metric: Optional[Metric] = None,
metric_component_name: Optional[str] = None,
loss_to_monitor: Optional[str] = None,
max: bool = False,
freq: int = 1,
ignore_idx: int = -1,
):
super(ExtremeBatchSegVisualizationCallback, self).__init__(
metric=metric, metric_component_name=metric_component_name, loss_to_monitor=loss_to_monitor, max=max, freq=freq
)
self.ignore_idx = ignore_idx

def process_extreme_batch(self) -> np.array:
inputs = self.extreme_batch
inputs -= inputs.min()
inputs /= inputs.max()
inputs *= 255
inputs = inputs.to(torch.uint8)
preds = self.extreme_preds
if isinstance(preds, tuple):
preds = preds[0]
preds = preds.argmax(1)
p_mask = preds == self.extreme_targets
n_mask = preds != self.extreme_targets
p_mask[self.extreme_targets == self.ignore_idx] = False
n_mask[self.extreme_targets == self.ignore_idx] = False
overlay = torch.cat([p_mask.unsqueeze(1), n_mask.unsqueeze(1)], 1)
colors = ["green", "red"]
images_to_save = []
for i in range(len(inputs)):
images_to_save.append(draw_segmentation_masks(inputs[i].cpu(), overlay[i], colors=colors, alpha=0.4).detach().numpy())
images_to_save = np.array(images_to_save)
return images_to_save
34 changes: 34 additions & 0 deletions src/super_gradients/training/utils/distributed_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from typing import List, Tuple
from contextlib import contextmanager

import numpy as np
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.cuda.amp import autocast
from torch.distributed import get_rank, all_gather_object
from torch.distributed.elastic.multiprocessing import Std
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
Expand Down Expand Up @@ -412,3 +414,35 @@ def __init__(self):
">>> setup_device(multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, num_gpus=...)"
)
super().__init__(self.message)


def maybe_all_reduce_tensor_average(tensor: torch.Tensor) -> torch.Tensor:
"""
When in DDP- mean-reduces tensor from all devices.
When not in DDP - returns the input tensor.
:param tensor:tensor to (maybe) reduce
:return:
"""
if is_distributed():
tensor = distributed_all_reduce_tensor_average(tensor=tensor, n=torch.distributed.get_world_size())
return tensor


def maybe_all_gather_np_images(image: np.ndarray) -> np.ndarray:
"""
When in DDP- gathers images (as np.ndarray objects) from all processes.
Returns the concatenated np.array across dim=0.
When not in DDP - returns the input tensor.
:param image: np.ndarray, the local rank's tensor to gather
:return: np.ndarray, the output image as described above
"""
if is_distributed():
rank = get_rank()
output_container = [None for _ in range(get_world_size())]
all_gather_object(output_container, image)
if rank == 0:
image = np.concatenate(output_container, 0)
return image
Loading

0 comments on commit 699b972

Please sign in to comment.