Skip to content

Commit

Permalink
Update doc (#1381)
Browse files Browse the repository at this point in the history
* update doc

* add typing

* fix

* add explanation about multi_process_safe
  • Loading branch information
Louis-Dupont committed Aug 22, 2023
1 parent 0b4710a commit 7b182ff
Show file tree
Hide file tree
Showing 2 changed files with 719 additions and 182 deletions.
253 changes: 150 additions & 103 deletions documentation/source/PhaseCallbacks.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Phase Callbacks

Integrating your own code into an already existing training pipeline can draw much effort on the user's end.
To tackle this challenge, a list of callables triggered at specific points of the training code can be passed through `phase_calbacks_list` inside `training_params` when calling `Trainer.train(...)`.
To tackle this challenge, a list of callables triggered at specific points of the training code can
be passed through `training_params.phase_calbacks_list` when calling `Trainer.train(...)`.

SG's `super_gradients.training.utils.callbacks` module implements some common use cases as callbacks:

Expand All @@ -20,12 +21,12 @@ SG's `super_gradients.training.utils.callbacks` module implements some common us
TrainingStageSwitchCallbackBase
YoloXTrainingStageSwitchCallback

For example, the YoloX's COCO detection training recipe uses `YoloXTrainingStageSwitchCallback` to turn off augmentations and incorporate L1 loss starting from epoch 285:
For example, the YoloX's COCO detection training recipe uses `YoloXTrainingStageSwitchCallback` to turn
off augmentations and incorporate L1 loss starting from epoch 285:

`super_gradients/recipes/training_hyperparams/coco2017_yolox_train_params.yaml`:

```yaml

max_epochs: 300
...

Expand All @@ -39,118 +40,155 @@ phase_callbacks:
...
```

Another example is how we use `BinarySegmentationVisualizationCallback` to visualize predictions during training in our [Segmentation Transfer Learning Notebook](https://bit.ly/3qKwMbe):


## Integrating Your Code with Callbacks
Another example is how we use `BinarySegmentationVisualizationCallback` to visualize predictions
during training in the [Segmentation Transfer Learning Notebook](https://bit.ly/3qKwMbe):

Integrating your code requires implementing a callback that `Trainer` would trigger in the proper phases inside SG's training pipeline.

So let's first get familiar with `super_gradients.training.utils.callbacks.base_callbacks.Callback` class.
### How Callbacks work

It implements the following methods:
`Callback` implements the following methods:

```python
on_training_start(self, context: PhaseContext) -> None
on_train_loader_start(self, context: PhaseContext) -> None:
on_train_batch_start(self, context: PhaseContext) -> None:
on_train_batch_loss_end(self, context: PhaseContext) -> None:
on_train_batch_backward_end(self, context: PhaseContext) -> None:
on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
on_train_batch_end(self, context: PhaseContext) -> None:
on_train_loader_end(self, context: PhaseContext) -> None:
on_validation_loader_start(self, context: PhaseContext) -> None:
on_validation_batch_start(self, context: PhaseContext) -> None:
on_validation_batch_end(self, context: PhaseContext) -> None:
on_validation_loader_end(self, context: PhaseContext) -> None:
on_validation_end_best_epoch(self, context: PhaseContext) -> None:
on_test_loader_start(self, context: PhaseContext) -> None:
on_test_batch_start(self, context: PhaseContext) -> None:
on_test_batch_end(self, context: PhaseContext) -> None:
on_test_loader_end(self, context: PhaseContext) -> None:
on_training_end(self, context: PhaseContext) -> None:
# super_gradients.training.utils.callbacks.base_callbacks.Callback
class Callback:
def on_training_start(self, context: PhaseContext) -> None: pass
def on_train_loader_start(self, context: PhaseContext) -> None: pass
def on_train_batch_start(self, context: PhaseContext) -> None: pass
def on_train_batch_loss_end(self, context: PhaseContext) -> None: pass
def on_train_batch_backward_end(self, context: PhaseContext) -> None: pass
def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None: pass
def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None: pass
def on_train_batch_end(self, context: PhaseContext) -> None: pass
def on_train_loader_end(self, context: PhaseContext) -> None: pass
def on_validation_loader_start(self, context: PhaseContext) -> None: pass
def on_validation_batch_start(self, context: PhaseContext) -> None: pass
def on_validation_batch_end(self, context: PhaseContext) -> None: pass
def on_validation_loader_end(self, context: PhaseContext) -> None: pass
def on_validation_end_best_epoch(self, context: PhaseContext) -> None: pass
def on_test_loader_start(self, context: PhaseContext) -> None: pass
def on_test_batch_start(self, context: PhaseContext) -> None: pass
def on_test_batch_end(self, context: PhaseContext) -> None: pass
def on_test_loader_end(self, context: PhaseContext) -> None: pass
def on_training_end(self, context: PhaseContext) -> None: pass
```

The order of the events is as follows:
```python
on_training_start(context) # called once before training starts, good for setting up the warmup LR

for epoch in range(epochs):
on_train_loader_start(context)
for batch in train_loader:
on_train_batch_start(context)
on_train_batch_loss_end(context) # called after loss has been computed
on_train_batch_backward_end(context) # called after .backward() was called
on_train_batch_gradient_step_start(context) # called before the optimizer step about to happen (gradient clipping, logging of gradients)
on_train_batch_gradient_step_end(context) # called after gradient step was done, good place to update LR (for step-based schedulers)
on_train_batch_end(context)
on_train_loader_end(context)

on_validation_loader_start(context)
for batch in validation_loader:
on_validation_batch_start(context)
on_validation_batch_end(context)
on_validation_loader_end(context)
on_validation_end_best_epoch(context)

on_test_start(context)
for batch in test_loader:
on_test_batch_start(context)
on_test_batch_end(context)
on_test_end(context)

on_training_end(context) # called once after training ends.
```

Our callback needs to inherit from the above class and override the appropriate methods according to the points at which we would like to trigger it.
Callbacks are implemented by inheriting this `Callback` class, and then by override any of the above-mentioned
method with the wanted behavior.

To understand which methods we need to override, we need to understand better when are the above methods triggered.
### Phase Context

From the class docs, the order of the events is as follows:
```python
on_training_start(context) # called once before training starts, good for setting up the warmup LR

for epoch in range(epochs):
on_train_loader_start(context)
for batch in train_loader:
on_train_batch_start(context)
on_train_batch_loss_end(context) # called after loss has been computed
on_train_batch_backward_end(context) # called after .backward() was called
on_train_batch_gradient_step_start(context) # called before the optimizer step about to happen (gradient clipping, logging of gradients)
on_train_batch_gradient_step_end(context) # called after gradient step was done, good place to update LR (for step-based schedulers)
on_train_batch_end(context)
on_train_loader_end(context)

on_validation_loader_start(context)
for batch in validation_loader:
on_validation_batch_start(context)
on_validation_batch_end(context)
on_validation_loader_end(context)
on_validation_end_best_epoch(context)

on_test_start(context)
for batch in test_loader:
on_test_batch_start(context)
on_test_batch_end(context)
on_test_end(context)

on_training_end(context) # called once after training ends.
You may have noticed that the `Callback`'s methods expect a single argument - a `PhaseContext` instance.

`PhaseContext` includes attributes representing a wide range of training attributes at a given point of the training.

```
- epoch
- batch_idx
- optimizer
- metrics_dict
- inputs
- preds
- target
- metrics_compute_fn
- loss_avg_meter
- loss_log_items
- criterion
- device
- experiment_name
- ckpt_dir
- net
- lr_warmup_epochs
- sg_logger
- train_loader
- valid_loader
- test_loader
- training_params
- ddp_silent_mode
- checkpoint_params
- architecture
- arch_params
- metric_to_watch
- valid_metrics
- ema_model
- loss_logging_items_names
```

As you noticed, all of `Callback`'s methods expect a single argument - a `PhaseContext` instance.
This argument gives access to some variables at the points mentioned above in the code through its attributes.
We can discover what variables are exposed by looking at the documentation of the `Callback`'s specific methods we need to override.
Each of these attributes is set to `None` by default, up until the point it computed or defined in the training pipeline.
- E.g. `epoch` will be `None` within `on_training_start` because, as explained above, this steps happens before the first epoch begins

For example:
You can find which context attribute is set by looking into each method docstring:
```python
...
class Callback:

...

def on_training_start(self, context: PhaseContext) -> None:
"""
Called once before the start of the first epoch
At this point, the context argument is guaranteed to have the following attributes:
- optimizer
- net
- checkpoints_dir_path
- criterion
- sg_logger
- train_loader
- valid_loader
- training_params
- checkpoint_params
- architecture
- arch_params
- metric_to_watch
- device
- ema_model
...
:return:
Called once before start of the first epoch
At this point, the context argument will have the following attributes:
- optimizer
- criterion
- device
- experiment_name
- ckpt_dir
- net
- sg_logger
- train_loader
- valid_loader
- training_params
- checkpoint_params
- arch_params
- metric_to_watch
- valid_metrics
The corresponding Phase enum value for this event is Phase.PRE_TRAINING.
:param context:
"""
pass
```

Now let's implement our callback.
Suppose we would like to implement a simple callback that saves the first batch of images in each epoch for both training and validation
in a new folder called "batch_images" under our local checkpoints directory.
### Build your own Callback

Suppose we would like to implement a simple callback that saves the first batch of images in each epoch for both
training and validation in a new folder called "batch_images" under the local checkpoints directory.

Our callback needs to be triggered in 3 places:
1. At the start of training, create a new "batch_images" under our local checkpoints directory.
2. Before passing a train image batch through the network.
3. Before passing a validation image batch through the network.
This callback needs to be triggered in 3 places:
1. At the start of training, create a new "batch_images" under the local checkpoints directory.
2. Before passing a train image batch through the network, save it in the new folder.
3. Before passing a validation image batch through the network, save it in the new folder.

Therefore, our callback will override `Callback`'s `on_training_start`, `on_train_batch_start`, and `on_validation_batch_start` methods:
Therefore, the callback will override `Callback`'s `on_training_start`, `on_train_batch_start`, and `on_validation_batch_start` methods:

```python
from super_gradients.training.utils.callbacks import Callback, PhaseContext
Expand Down Expand Up @@ -179,35 +217,41 @@ class SaveFirstBatchCallback(Callback):
if context.batch_idx == 0 and not self.saved_first_validation_batch:
save_image(context.inputs, os.path.join(self.outputs_path, f"first_validation_batch_epoch_{context.epoch}.png"))
self.saved_first_validation_batch = True


```
**IMPORTANT**

Note the `@multi_process_safe` decorator, which allows the callback to be triggered precisely once when running distributed training.
When training on multiple nodes (see [DDP](device.md)), the callback will be called at each step once for every
node you are working with. This behaviour may be useful in some specific cases, but in general you will
want to have each method to be triggered only once per step. You can add the decorator `@multi_process_safe` to ensure
that only the main node will trigger the callback.

For coded training scripts (i.e., not [using configuration files](configuration_files.md)), we can pass an instance of the callback through `phase_callbacks`:
In our example, we want to trigger only once per step, so we need to add the `@multi_process_safe` decorator.

```python
...
### Using Custom Callback within Python Script
The callback can directly be passed through `training_params.phase_callbacks`

...
```python
trainer = Trainer("my_experiment")
train_dataloader = ...
valid_dataloader = ...
model = ...

train_params = {
...
"loss": "cross_entropy",
"criterion_params": {}
...
"criterion_params": {},
"phase_callbacks": [SaveFirstBatchCallback()],
...
}

trainer.train(training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)
```

Otherwise, for training with configuration files, we need to register our new callback by decorating it with the `register_loss` decorator:
### Using Custom Callback in a Recipe
If you are working with [Configuration files](configuration_files.md), you will be required to do an extra step.
This is similar to using any custom objects in a recipe, and is already defined in the [above-mentioned](configuration_files.md).

To summarize, you need to register the new callback by decorating it with the `register_callback` decorator,
so that SuperGradients would know how to instantiate it from the `.yaml` recipe.

```python
from super_gradients.training.utils.callbacks import Callback, PhaseContext
Expand Down Expand Up @@ -254,7 +298,7 @@ phase_callbacks:
- SaveFirstBatchCallback
```

Last, in your ``my_train_from_recipe_script.py`` file, import the newly registered class (even though the class itself is unused, just to trigger the registry):
Last, make sure to import `SaveFirstBatchCallback` in the script you use to launch training from config:

```python

Expand All @@ -278,3 +322,6 @@ Last, in your ``my_train_from_recipe_script.py`` file, import the newly register
if __name__ == "__main__":
run()
```

This is required, as otherwise `SaveFirstBatchCallback` would not be imported at all and therefore SuperGradients
would fail to recognize and instantiate it.
Loading

0 comments on commit 7b182ff

Please sign in to comment.