Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating OD example by adding additional details on customizing metri… #1449

Merged
merged 3 commits into from
Sep 5, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 161 additions & 6 deletions documentation/source/ObjectDetection.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,65 @@ val_dataloader = dataloaders.get(name='coco2017_val',
)
```

### Loss functions

Generally speaking, in object detection task the loss function is tightly coupled with the model and cannot be used interchangeably. E.g. you cannot use YoloX loss with YoloNAS model and vice versa.
This is different from classification or segmentation task where model output is "standard" and usually does not change.
In Object Detection task, the model output format may vary greatly and also training objective is often tailored for a specific model architecture.

To indicate compatibility between a model and a loss function, we use the convention of model name and loss function starting from the same prefix name. For example: `SSDLiteMobileNetV2` model & `SSDLoss`, `YoloX_S` and `YoloXFastDetectionLoss`, etc.

Of course, you are free to adjust hyperparameters of the loss function to your liking. Let's check a `PPYoloELoss` loss class as an example:
It has the following constructor:

```python
@register_loss(Losses.PPYOLOE_LOSS)
class PPYoloELoss(nn.Module):
def __init__(
self,
num_classes: int,
use_varifocal_loss: bool = True,
use_static_assigner: bool = True,
reg_max: int = 16,
classification_loss_weight: float = 1.0,
iou_loss_weight: float = 2.5,
dfl_loss_weight: float = 0.5,
):
...
```

In your recipe you can pass the desired values for each parameter.
For example show below, we increase the classification component weight to 10 and set the DFL & IOU components of the loss to 1.0:

```yaml
training_hyperparams:
loss:
ppyoloe_loss:
num_classes: ${arch_params.num_classes}
classification_loss_weight: 10
iou_loss_weight: 1.0
dfl_loss_weight: 1.0
```

This is how you can modify the loss hyperparameters. If you need to modify the loss itself, you can subclass it and override the `forward` method to fit your needs.

```python
@register_loss()
class MyCustomPPYoloELoss(nn.Module):
def forward(self, outputs, target):
...
```

```yaml
training_hyperparams:
loss:
MyCustomPPYoloELoss:
num_classes: ${arch_params.num_classes}
classification_loss_weight: 10
iou_loss_weight: 1.0
dfl_loss_weight: 1.0
```

### Metrics

A typical metric for object detection is mean average precision, mAP for short.
Expand All @@ -48,6 +107,46 @@ Both one value and a range can be used as IoU, where a range refers to an averag
The most popular metric for mAP on COCO is mAP@0.5:0.95, SuperGradients provide its implementation [DetectionMetrics](https://docs.deci.ai/super-gradients/docstring/training/metrics.html#training.metrics.detection_metrics.DetectionMetrics).
It is written to be as close as possible to the official metric implementation from [COCO API](https://pypi.org/project/pycocotools/), while being much faster and DDP-friendly.

We provide a few metrics for object detection with pre-defined IoU levels to fit the most frequent use cases:

* DetectionMetrics_050_095 - computes mAP at IoU range [0.5; 0.95] with a step of 0.05 (Default COCO metric)
* DetectionMetrics_050 - computes mAP at IoU level 0.5
* DetectionMetrics_075 - computes mAP at IoU level 0.75
* DetectionMetrics - computes mAP at user-specified IoU level (Defaults to [0.5; 0.95])

You can also specify a custom IoU range or a single IoU level for the metric.
In addition to computing mAP, `DetectionMetrics` also computes other metrics such as:

* Recall score at a given score threshold
* Precision score at a given score threshold
* F-1 detection score at a given score threshold
* Average precision score for each class

DetectionMetrics can even find the optimal confidence threshold that maximizes mean F1 score.
Here is how to enable computing all these metrics:

```yaml
training_hyperparams:
valid_metrics_list:
- DetectionMetrics:
num_cls: ${num_classes}
normalize_targets: True
score_thres: 0.1 # A lower bound rejection threshold for predictions
top_k_predictions: 300 # At most 300 predictions per image will be considered with confidence above score_thres
iou_thres: [0.6, 0.8] # <--- IoU range [0.6; 0.8] with 0.05 step
include_classwise_ap: True # Enables computing AP for each class (helps to find problematic classes)
calc_best_score_thresholds: True # Enables computing optimal confidence threshold that maximizes mean F1 score
post_prediction_callback:
_target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
score_threshold: 0.01
nms_top_k: 1000
max_predictions: 300
nms_threshold: 0.7

metric_to_watch: 'mAP@0.60:0.80'
```


In order to use `DetectionMetrics` you have to pass a so-called `post_prediction_callback` to the metric, which is responsible for the postprocessing of the model's raw output into final predictions and is explained below.

### Postprocessing
Expand Down Expand Up @@ -76,12 +175,58 @@ Box coordinates are in absolute (pixel) units.
Visualization of the model predictions is a very important part of the training process for any computer vision task.
By visualizing the predicted boxes, developers and researchers can identify errors or inaccuracies in the model's output and adjust the model's architecture or training data accordingly.

SuperGradients provide an implementation of [DetectionVisualizationCallback](https://docs.deci.ai/super-gradients/docstring/training/utils.html#training.utils.callbacks.callbacks.DetectionVisualizationCallback).
You can use this callback in your training pipeline to visualize predictions during training. For this, just add it to `training_hyperparams.phase_callbacks` in your yaml.
During training, the callback will generate a visualization of the model predictions and save it to the TensorBoard or Weights & Biases depending on which logger you
are using (Default is Tensorboard).
#### Extreme Batch Visualization during training

SuperGradients provide an implementation of [ExtremeBatchDetectionVisualizationCallback](https://docs.deci.ai/super-gradients/docstring/training/utils.html#src.super_gradients.training.utils.callbacks.callbacks.ExtremeBatchDetectionVisualizationCallback).
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
You can use this callback in your training pipeline to visualize best or worst batch during training.
This callback observes a specific metric during training epoch and logs the most extreme batch to configured logger (Default is Tensorboard).
The logging includes visualization of ground truth boxes and model's predictions.

To use this callback you would need to add it to `training_hyperparams.phase_callbacks` in your yaml:

```yaml
training_hyperparams:
phase_callbacks:
- ExtremeBatchDetectionVisualizationCallback:
metric: # Defines which metric to observe
DetectionMetrics_050:
score_thres: 0.1
top_k_predictions: 300
num_cls: ${num_classes}
normalize_targets: True
post_prediction_callback:
_target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
score_threshold: 0.01
nms_top_k: 1000
max_predictions: 300
nms_threshold: 0.7
max: False # Indicates that we want to log batch with the lowest metric value
metric_component_name: 'mAP@0.50'
post_prediction_callback:
_target_: super_gradients.training.models.detection_models.pp_yolo_e.PPYoloEPostPredictionCallback
score_threshold: 0.25
nms_top_k: 1000
max_predictions: 300
nms_threshold: 0.7
normalize_targets: True

```
Note in the example below the `ExtremeBatchDetectionVisualizationCallback` callback observes a user-provided Metric class that computes the score for **each batch**.
You may also observe the entire loss or individual components of the loss as follows.
In this case instead of passing `metric` argument to constructor of `ExtremeBatchDetectionVisualizationCallback` you would need to pass `loss_to_monitor` argument.
The fully qualified name of the loss includes the loss class name and component name separated by `/`:

```yaml
training_hyperparams:
phase_callbacks:
- ExtremeBatchDetectionVisualizationCallback:
loss_to_monitor: "YoloNASPoseLoss/loss"
max: True

```

#### Visualization of predictions after training

If you would like to do the visualization outside of training you can use `DetectionVisualization` class as follows:
```python
import torch
import numpy as np
Expand All @@ -100,7 +245,7 @@ def my_undo_image_preprocessing(im_tensor: torch.Tensor) -> np.ndarray:

model = models.get("yolox_s", pretrained_weights="coco", num_classes=80)
imgs, targets = next(iter(train_dataloader))
preds = YoloXPostPredictionCallback(conf=0.1, iou=0.6)(model(imgs))
preds = model.get_post_prediction_callback(conf=0.1, iou=0.6)(model(imgs))
DetectionVisualization.visualize_batch(imgs, preds, targets, batch_name='train', class_names=COCO_DETECTION_CLASSES_LIST,
checkpoint_dir='/path/for/saved_images/', gt_alpha=0.5,
undo_preprocessing_func=my_undo_image_preprocessing)
Expand All @@ -112,6 +257,16 @@ The saved train image for a dataset with a mosaic transform should look somethin

![train_24](images/train_24.jpg)

#### Visualization of predictions after training using predict()

If you would like to do the visualization outside of training you can use `predict()` method that is implemented for most of our detection models.
```python
model = models.get("yolox_s", pretrained_weights="coco", num_classes=80)
model.predict("https://deci-pretrained-models.s3.amazonaws.com/sample_images/beatles-abbeyroad.jpg").show()
```

See for more details on using [Predict API](https://docs.deci.ai/super-gradients/V3_1/documentation/source/ModelPredictions.html).
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved

### Let's train!

As stated above, training can be launched with just one command. For the curious ones, let's see how all the components we've just discussed fall into place in one yaml.
Expand Down