Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1039 add factory doc #1395

Merged
merged 13 commits into from
Aug 23, 2023
Merged
88 changes: 10 additions & 78 deletions documentation/source/configuration_files.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ train_dataset_params:
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: ${dataset_params.img_mean}
std: ${dataset_params.img_std}
mean: [0.485, 0.456, 0.406] # mean for normalization
std: [0.229, 0.224, 0.225] # std for normalization

val_dataset_params:
root: /data/Imagenet/val
Expand All @@ -65,8 +65,8 @@ val_dataset_params:
size: 224
- ToTensor
- Normalize:
mean: ${dataset_params.img_mean}
std: ${dataset_params.img_std}
mean: [0.485, 0.456, 0.406] # mean for normalization
std: [0.229, 0.224, 0.225] # std for normalization
```

Configuration file can also help you track the exact settings used for each one of your experiments, tweak and tune these settings, and share them with others.
Expand Down Expand Up @@ -107,7 +107,6 @@ python -m super_gradients.evaluate_from_recipe --config-name=cifar10_resnet
that will run only the evaluation part of the recipe (without any training iterations)



## Hydra
Hydra is an open-source Python framework that provides us with many useful functionalities for YAML management. You can learn about Hydra
[here](https://hydra.cc/docs/intro). We use Hydra to load YAML files and convert them into dictionaries, while
Expand All @@ -130,7 +129,6 @@ in the first arg of the command line.

In the experiment directory a `.hydra` subdirectory will be created. The configuration files related to this run will be saved by hydra to that subdirectory.

--------
Two Hydra features worth mentioning are _YAML Composition_ and _Command-Line Overrides_.

#### YAML Composition
Expand Down Expand Up @@ -163,6 +161,7 @@ initial learning-rate. This feature is extremely usefully when experimenting wit
Note that the arguments are referenced without the `--` prefix and that each parameter is referenced with its full path in the
configuration tree, concatenated with a `.`.


## Resolvers
Resolvers are converting the strings from the YAML file into Python objects or values. The most basic resolvers are the Hydra native resolvers.
Here are a few simple examples:
Expand All @@ -178,79 +177,12 @@ third_of_list: "${getitem: ${my_list}, 2}"
first_of_list: "${first: ${my_list}}"
last_of_list: "${last: ${my_list}}"
```
You can register any additional resolver you want by simply following the official [documentation](https://omegaconf.readthedocs.io/en/latest/usage.html#resolvers).

The more advanced resolvers will instantiate objects. In the following example we define a few transforms that
will be used to augment a dataset.
```yaml
train_dataset_params:
transforms:
# for more options see common.factories.transforms_factory.py
- SegColorJitter:
brightness: 0.1
contrast: 0.1
saturation: 0.1

- SegRandomFlip:
prob: 0.5

- SegRandomRescale:
scales: [ 0.4, 1.6 ]
```
Each one of the keys (`SegColorJitter`, `SegRandomFlip`, `SegRandomRescale`) is mapped to a type, and the configuration parameters under that key will be passed
to the type constructor by name (as key word arguments).

If you want to see where this magic is happening, you can look for the `@resolve_param` decorator in the code

```python
class ImageNetDataset(torch_datasets.ImageFolder):

@resolve_param("transforms", factory=TransformsFactory())
def __init__(self, root: str, transforms: Union[list, dict] = [], *args, **kwargs):
...
...
```

The `@resolve_param` wraps functions and resolves a string or a dictionary argument (in the example above "transforms") to an object.
To do so, it uses a factory object that maps a string or a dictionary to a type. when `__init__(..)` will be called, the function will receive
an object, and not a dictionary. The parameters under "transforms" in the YAML will be passed as
arguments for instantiation the objects. We will learn how to add a new type of object into these mappings in the next sections.

## Registering a new object
To use a new object from your configuration file, you need to define the mapping of the string to a type.
This is done using one of the many registration function supported by SG.
```python
register_model
register_detection_module
register_metric
register_loss
register_dataloader
register_callback
register_transform
register_dataset
```

These decorator functions can be imported and used as follows:

```python
from super_gradients.common.registry import register_model

@register_model(name="MyNet")
class MyExampleNet(nn.Module):
def __init__(self, num_classes: int):
....
```

This simple decorator, maps the name "MyNet" to the type `MyExampleNet`. Note that if your constructor
include required arguments, you will be expected to provide them when using this string

```yaml
...
architecture:
MyNet:
num_classes: 8
...

```
## Factories
Factories are similar to resolvers but were built specifically to instantiate SuperGradients objects within a recipe.
This is a key feature of SuperGradient which is being used in all of our recipes, and we recommend you to
go over this [introduction to Factories](factories.md).

## Required Hyper-Parameters
Most parameters can be defined by default when including `default_train_params` in you `defaults`.
Expand Down
207 changes: 207 additions & 0 deletions documentation/source/factories.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Working with Factories

Factories in SuperGradients provide a powerful and concise way to instantiate objects in your configuration files.

Prerequisites:
- [Training with Configuration Files](configuration_files.md)

In this tutorial, we'll cover how to use existing factories, register new ones, and briefly explore the implementation details.

## Using Existing Factories

If you had a look at the [recipes](https://github.com/Deci-AI/super-gradients/tree/master/src/super_gradients/recipes), you may have noticed that many objects are defined directly in the recipes.

In the [Supervisely dataset recipe](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/dataset_params/supervisely_persons_dataset_params.yaml) you can see the following

```yaml
train_dataset_params:
transforms:
- SegColorJitter:
brightness: 0.1
contrast: 0.1
saturation: 0.1
- SegRandomFlip:
prob: 0.5
- SegRandomRescale:
scales: [0.4, 1.6]
```
If you load the `.yaml` recipe as is into a python dictionary, you would get the following
```python
{
"train_dataset_params": {
"transforms": [
{
"SegColorJitter": {
"brightness": 0.1,
"contrast": 0.1,
"saturation": 0.1
}
},
{
"SegRandomFlip": {
"prob": 0.5
}
},
{
"SegRandomRescale": {
"scales": [0.4, 1.6]
}
}
]
}
}
```

This configuration alone is not very useful, as we need instances of the classes, not just their configurations.
So we would like to somehow instantiate these classes `SegColorJitter`, `SegRandomFlip` and `SegRandomRescale`.

Factories in SuperGradients come into play here! All these objects were registered beforehand in SuperGradients,
so that when you write these names in the recipe, SuperGradients will detect and instantiate them for you.

## Registering a Class

As explained above, only registered objects can be instantiated.
This registration consists of mapping the object name to the corresponding class type.

In the example above, the string `"SegColorJitter"` was mapped to the class `SegColorJitter`, and this is how SuperGradients knows how to convert the string defined in the recipe, into an object.

You can register the class using a name different from the actual class name.
However, it's generally recommended to use the same name for consistency and clarity.

### Example

```python
from super_gradients.common.registry import register_transform

@register_transform(name="MyTransformName")
class MyTransform:
def __init__(self, prob: float):
...
```
In this simple example, we register a new transform.
Note that here we registered (for the sake of the example) the class `MyTransform` to the name `MyTransformName` which is different.
We strongly recommend to not do it, and to instead register a class with its own name.

Once you registered a class, you can use it in your recipe. Here, we will add this transform to the original recipe
```yaml
train_dataset_params:
transforms:
- SegColorJitter:
brightness: 0.1
contrast: 0.1
saturation: 0.1
- SegRandomFlip:
prob: 0.5
- SegRandomRescale:
scales: [0.4, 1.6]
- MyTransformName: # We use the name used to register, which may be different from the name of the class
prob: 0.7
```

Final Step: Ensure that you import the module containing `MyTransformName` into your script.
Doing so will trigger the registration function, allowing SuperGradients to recognize it.

Here is an example (adapted from the [train_from_recipe script](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/train_from_recipe.py)).

```python
from .my_module import MyTransform # Importing the module is enough as it will trigger the register_model function
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved

# The code below is the same as the basic `train_from_recipe.py` script
# See: https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/train_from_recipe.py
from omegaconf import DictConfig
import hydra

from super_gradients import Trainer, init_trainer


@hydra.main(config_path="recipes", version_base="1.2")
def _main(cfg: DictConfig) -> None:
Trainer.train_from_config(cfg)


def main() -> None:
init_trainer() # `init_trainer` needs to be called before `@hydra.main`
_main()


if __name__ == "__main__":
main()

```

## Under the Hood

Until now, we saw how to use existing Factories, and how to register new ones.
In some cases, you may want to create objects that would benefit from using the factories.

### Basic
The basic way to use factories as below.
```
from super_gradients.common.factories import TransformsFactory
factory = TransformsFactory()
my_transform = factory.get({'MyTransformName': {'prob': 0.7}})
```
You may recognize that the input passed to `factory.get` is actually the dictionary that we get after loading the recipe
(See [Using Existing Factories](#using-existing-factories))

### Recommended
Factories become even more powerful when used with the `@resolve_param` decorator.
This feature allows functions to accept both instantiated objects and their dictionary representations.
It means you can pass either the actual python object or a dictionary that describes it straight from the recipe.

```python
class ImageNetDataset(torch_datasets.ImageFolder):

@resolve_param("transforms", factory=TransformsFactory())
def __init__(self, root: str, transform: Transform):
...
```

Now, `ImageNetDataset` can be passed both an instance of `MyTransform`

```python
my_transform = MyTransform(prob=0.7)
ImageNetDataset(root=..., transform=my_transform)
```

And a dictionary representing the same object
```python
my_transform = {'MyTransformName': {'prob': 0.7}}
ImageNetDataset(root=..., transform=my_transform)
```

This second way of instantiating the dataset combines perfectly with the concept `.yaml` recipes.

## Supported Factory Types
Until here, we focused on a single type of factory, `TransformsFactory`,
associated with the registration decorator `register_transform`.

SuperGradients supports a wide range of factories, used throughout the training process,
each with its own registering decorator.

SuperGradients offers various types of factories, and each is associated with a specific registration decorator.

``` python
from super_gradients.common.factories import (
register_model
register_kd_model
register_detection_module
register_metric
register_loss
register_dataloader
register_callback
register_transform
register_dataset
register_pre_launch_callback
register_unet_backbone_stage
register_unet_up_block
register_target_generator
register_lr_scheduler
register_lr_warmup
register_sg_logger
register_collate_function
register_sampler
register_optimizer
register_processing
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
)
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ nav:
- Phase Callbacks: ./documentation/source/PhaseCallbacks.md
- YAMLs and Recipes:
- Configurations: ./documentation/source/configuration_files.md
- Factories: ./documentation/source/Factories.md
- Recipes: ./src/super_gradients/recipes/Training_Recipes.md
- Checkpoints: ./documentation/source/Checkpoints.md
- Docker: ./documentation/source/SGDocker.md
Expand Down
47 changes: 44 additions & 3 deletions src/super_gradients/common/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,45 @@
from super_gradients.common.registry.registry import register_model, register_metric, register_loss, register_detection_module, register_lr_scheduler
from super_gradients.common.registry.registry import (
register_model,
register_kd_model,
register_detection_module,
register_metric,
register_loss,
register_dataloader,
register_callback,
register_transform,
register_dataset,
register_pre_launch_callback,
register_unet_backbone_stage,
register_unet_up_block,
register_target_generator,
register_lr_scheduler,
register_lr_warmup,
register_sg_logger,
register_collate_function,
register_sampler,
register_optimizer,
register_processing,
)


__all__ = ["register_model", "register_detection_module", "register_metric", "register_loss", "register_lr_scheduler"]
__all__ = [
"register_model",
"register_kd_model",
"register_detection_module",
"register_metric",
"register_loss",
"register_dataloader",
"register_callback",
"register_transform",
"register_dataset",
"register_pre_launch_callback",
"register_unet_backbone_stage",
"register_unet_up_block",
"register_target_generator",
"register_lr_scheduler",
"register_lr_warmup",
"register_sg_logger",
"register_collate_function",
"register_sampler",
"register_optimizer",
"register_processing",
]