Skip to content

Commit

Permalink
Add utilities to plot datasets to Weights & Biases + Add callback to …
Browse files Browse the repository at this point in the history
…log validation predictions to Weights & Biases (#1167)

* add: plot_detection_dataset_on_wandb

* add: doctring for plot_detection_dataset_on_wandb

* update: bbox type

* update: visualize_image_detection_prediction_on_wandb

* update: fix wandb module and linting

* add: WandBDetectionValidationPredictionLoggerCallback

* update: docstring for WandBDetectionValidationPredictionLoggerCallback

* update: WandBDetectionValidationPredictionLoggerCallback

* update: plot_detection_dataset_on_wandb

* update: imports

* fix: linting

* fix: linting

* fix: linting

* fix: imports in validation_logger.py

* Update log_predictions.py

* update: add max_predictions_plotted parameter to WandBDetectionValidationPredictionLoggerCallback

* update: docstring for WandBDetectionValidationPredictionLoggerCallback

* update: validation logger

* update: WandBDetectionValidationPredictionLoggerCallback + visualize_image_detection_prediction_on_wandb

* update: wandb import + docstring

* update: doctring

* update: docstrings + channel reversal

* update: make ci happy

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
soumik12345 and BloodAxe authored Aug 8, 2023
1 parent bc1b24d commit 2d3004a
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 7 deletions.
14 changes: 12 additions & 2 deletions src/super_gradients/common/plugins/wandb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from super_gradients.common.plugins.wandb.log_predictions import log_detection_results_to_wandb
from super_gradients.common.plugins.wandb.log_predictions import (
visualize_image_detection_prediction_on_wandb,
log_detection_results_to_wandb,
plot_detection_dataset_on_wandb,
)
from super_gradients.common.plugins.wandb.validation_logger import WandBDetectionValidationPredictionLoggerCallback


__all__ = ["log_detection_results_to_wandb"]
__all__ = [
"visualize_image_detection_prediction_on_wandb",
"log_detection_results_to_wandb",
"plot_detection_dataset_on_wandb",
"WandBDetectionValidationPredictionLoggerCallback",
]
66 changes: 61 additions & 5 deletions src/super_gradients/common/plugins/wandb/log_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,27 @@
except (ModuleNotFoundError, ImportError, NameError):
pass # no action or logging - this is normal in most cases

import numpy as np
from tqdm import tqdm

from super_gradients.training.transforms.transforms import DetectionTargetsFormatTransform
from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
from super_gradients.training.datasets.detection_datasets import DetectionDataset

from super_gradients.training.utils.predict import ImageDetectionPrediction, ImagesDetectionPrediction


def _visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPrediction, show_confidence: bool):
def visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPrediction, show_confidence: bool, reverse_channels: bool = False):
"""Visualize detection results on a single image.
:param prediction: Prediction results of a single image
(a `super_gradients.training.models.prediction_results.ImageDetectionPrediction` object)
:param show_confidence: Whether to log confidence scores to Weights & Biases or not.
:param reverse_channels: Reverse the order of channels on the images while plotting.
"""
boxes = []
image = prediction.image.copy()
image = image[:, :, ::-1] if reverse_channels else image
height, width, _ = image.shape
class_id_to_labels = {int(_id): str(_class_name) for _id, _class_name in enumerate(prediction.class_names)}

Expand All @@ -28,9 +43,7 @@ def _visualize_image_detection_prediction_on_wandb(prediction: ImageDetectionPre
box["scores"] = {"confidence": float(round(prediction.prediction.confidence[pred_i], 2))}
boxes.append(box)

wandb_image = wandb.Image(image, boxes={"predictions": {"box_data": boxes, "class_labels": class_id_to_labels}})

wandb.log({"Predictions": wandb_image})
return wandb.Image(image, boxes={"predictions": {"box_data": boxes, "class_labels": class_id_to_labels}})


def log_detection_results_to_wandb(prediction: ImagesDetectionPrediction, show_confidence: bool = True):
Expand All @@ -42,4 +55,47 @@ def log_detection_results_to_wandb(prediction: ImagesDetectionPrediction, show_c
if wandb.run is None:
raise wandb.Error("Images and bounding boxes cannot be visualized on Weights & Biases without initializing a run using `wandb.init()`")
for prediction in prediction._images_prediction_lst:
_visualize_image_detection_prediction_on_wandb(prediction=prediction, show_confidence=show_confidence)
wandb_image = visualize_image_detection_prediction_on_wandb(prediction=prediction, show_confidence=show_confidence)
wandb.log({"Predictions": wandb_image})


def plot_detection_dataset_on_wandb(detection_dataset: DetectionDataset, max_examples: int = None, dataset_name: str = None, reverse_channels: bool = True):
"""Log a detection dataset to Weights & Biases Table.
:param detection_dataset: The Detection Dataset (a `super_gradients.training.datasets.detection_datasets.DetectionDataset` object)
:param max_examples: Maximum number of examples from the detection dataset to plot (an `int`).
:param dataset_name: Name of the dataset (a `str`).
:param reverse_channels: Reverse the order of channels on the images while plotting.
"""
max_examples = len(detection_dataset) if max_examples is None else max_examples
wandb_table = wandb.Table(columns=["Images", "Class-Frequencies"])
input_format = detection_dataset.output_target_format
target_format_transform = DetectionTargetsFormatTransform(input_format=input_format, output_format=XYXY_LABEL)
class_id_to_labels = {int(_id): str(_class_name) for _id, _class_name in enumerate(detection_dataset.classes)}
for data_idx in tqdm(range(max_examples), desc="Plotting Examples on Weights & Biases"):
image, targets, *_ = detection_dataset[data_idx]
image = image.transpose(1, 2, 0).astype(np.int32)
sample = target_format_transform({"image": image, "target": targets})
boxes = sample["target"][:, 0:4]
boxes = boxes[(boxes != 0).any(axis=1)]
classes = targets[:, 0].tolist()
wandb_boxes = []
class_frequencies = {str(_class_name): 0 for _id, _class_name in enumerate(detection_dataset.classes)}
for idx in range(boxes.shape[0]):
wandb_boxes.append(
{
"position": {
"minX": float(boxes[idx][0] / image.shape[1]),
"maxX": float(boxes[idx][2] / image.shape[1]),
"minY": float(boxes[idx][1] / image.shape[0]),
"maxY": float(boxes[idx][3] / image.shape[0]),
},
"class_id": int(classes[idx]),
"box_caption": str(class_id_to_labels[int(classes[idx])]),
}
)
class_frequencies[str(class_id_to_labels[int(classes[idx])])] += 1
image = image[:, :, ::-1] if reverse_channels else image
wandb_table.add_data(wandb.Image(image, boxes={"ground_truth": {"box_data": wandb_boxes, "class_labels": class_id_to_labels}}), class_frequencies)
dataset_name = "Dataset" if dataset_name is None else dataset_name
wandb.log({dataset_name: wandb_table}, commit=False)
99 changes: 99 additions & 0 deletions src/super_gradients/common/plugins/wandb/validation_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Optional

import torch
import numpy as np

from super_gradients.training.utils.callbacks import Callback, PhaseContext
from super_gradients.common.plugins.wandb.log_predictions import visualize_image_detection_prediction_on_wandb
from super_gradients.training.models.predictions import DetectionPrediction
from super_gradients.training.utils.predict import ImageDetectionPrediction
from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
from super_gradients.module_interfaces import HasPredict
from super_gradients.training.utils.utils import unwrap_model

try:
import wandb
except (ModuleNotFoundError, ImportError, NameError):
pass # no action or logging - this is normal in most cases


class WandBDetectionValidationPredictionLoggerCallback(Callback):
def __init__(
self,
class_names,
max_predictions_plotted: Optional[int] = None,
post_prediction_callback: Optional[DetectionPostPredictionCallback] = None,
reverse_channels: bool = True,
) -> None:
"""A callback for logging object detection predictions to Weights & Biases during training. This callback is logging images on each batch in validation
and accumulating generated images in a `wandb.Table` in the RAM. This could potentially cause OOM errors for very large datasets like COCO. In order to
avoid this, it is recommended to explicitly set the parameter `max_predictions_plotted` to a small value, thus limiting the number of images logged in
the table.
:param class_names: A list of class names.
:param max_predictions_plotted: Maximum number of predictions to be plotted per epoch. This is set to `None` by default which means that the
predictions corresponding to all images from `context.inputs` is logged, otherwise only `max_predictions_plotted`
number of images is logged. Since `WandBDetectionValidationPredictionLoggerCallback` accumulates the generated
images in the RAM, it is advisable that the value of this parameter be explicitly specified for larger datasets in
order to avoid out-of-memory errors.
:param post_prediction_callback: `DetectionPostPredictionCallback` for post-processing outputs of the model.
:param reverse_channels: Reverse the order of channels on the images while plotting.
"""
super().__init__()
self.class_names = class_names
self.max_predictions_plotted = max_predictions_plotted
self.post_prediction_callback = post_prediction_callback
self.reverse_channels = reverse_channels
self.wandb_images = []
self.epoch_count = 0
self.mean_prediction_dicts = []
self.wandb_table = wandb.Table(columns=["Epoch", "Prediction", "Mean-Confidence"])

def on_validation_batch_end(self, context: PhaseContext) -> None:
self.wandb_images = []
mean_prediction_dict = {class_name: 0.0 for class_name in self.class_names}
if isinstance(context.net, HasPredict):
post_nms_predictions = context.net(context.inputs)
else:
self.post_prediction_callback = (
unwrap_model(context.net).get_post_prediction_callback() if self.post_prediction_callback is None else self.post_prediction_callback
)
self.post_prediction_callback.fuse_layers = False
post_nms_predictions = self.post_prediction_callback(context.preds, device=context.device)
if self.max_predictions_plotted is not None:
post_nms_predictions = post_nms_predictions[: self.max_predictions_plotted]
input_images = context.inputs[: self.max_predictions_plotted]
else:
input_images = context.inputs
for prediction, image in zip(post_nms_predictions, input_images):
prediction = prediction if prediction is not None else torch.zeros((0, 6), dtype=torch.float32)
prediction = prediction.detach().cpu().numpy()
postprocessed_image = image.detach().cpu().numpy().transpose(1, 2, 0).astype(np.int32)
image_prediction = ImageDetectionPrediction(
image=postprocessed_image,
class_names=self.class_names,
prediction=DetectionPrediction(
bboxes=prediction[:, :4],
confidence=prediction[:, 4],
labels=prediction[:, 5],
bbox_format="xyxy",
image_shape=image.shape,
),
)
for predicted_label, prediction_confidence in zip(prediction[:, 5], prediction[:, 4]):
mean_prediction_dict[self.class_names[int(predicted_label)]] += prediction_confidence
mean_prediction_dict = {k: v / len(prediction[:, 4]) for k, v in mean_prediction_dict.items()}
self.mean_prediction_dicts.append(mean_prediction_dict)
wandb_image = visualize_image_detection_prediction_on_wandb(
prediction=image_prediction, show_confidence=True, reverse_channels=self.reverse_channels
)
self.wandb_images.append(wandb_image)

def on_validation_loader_end(self, context: PhaseContext) -> None:
for wandb_image, mean_prediction_dict in zip(self.wandb_images, self.mean_prediction_dicts):
self.wandb_table.add_data(self.epoch_count, wandb_image, mean_prediction_dict)
self.wandb_images, self.mean_prediction_dicts = [], []
self.epoch_count += 1

def on_training_end(self, context: PhaseContext) -> None:
wandb.log({"Validation-Prediction": self.wandb_table})

0 comments on commit 2d3004a

Please sign in to comment.