Skip to content

Commit

Permalink
Added support of multiple test loaders in train_from_config (#1641)
Browse files Browse the repository at this point in the history
* Add test loaders support (WIP)

* Added test

* Added docs

* Pass test loaders

* Update test

* Move maybe_instantiate_test_loaders method to cfg_utils
  • Loading branch information
BloodAxe committed Nov 15, 2023
1 parent 4931566 commit 7ff90e4
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 7 deletions.
86 changes: 86 additions & 0 deletions documentation/source/Data.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,89 @@ Last, in your ``my_train_from_recipe_script.py`` file, import the newly register
if __name__ == "__main__":
run()
```

### Adding test datasets

In addition to the train and validation datasets, you can also add a test dataset or multiple test datasets to your configuration file.
At the end of training, metrics from each test dataset will be computed and returned in final results.

#### Single test dataset

To add a single test dataset to recipe, add following properties to your configuration file:

```yaml
test_dataloaders: <dataloader_name>

dataset_params:
test_dataset_params:
...

test_dataloader_params:
...
```


#### Multiple test datasets

In addition to the train and validation datasets, you can also add a test dataset or multiple test datasets to your configuration file.
This is how you can achieve this using YAML file:

#### Explicitly specifying all parameters

```yaml
test_dataloaders:
test_dataset_name_1: <dataloader_name>
test_dataset_name_2: <dataloader_name>

dataset_params:
test_dataset_params:
test_dataset_name_1:
...
test_dataset_name_2:
...

test_dataloader_params:
test_dataset_name_1:
...
test_dataset_name_2:
...
```

#### Without dataloader names

A `test_dataloaders` property of the configuration file is optional and can be skipped.
You may want to use this option when you don't have a dataloader factory method registered.
In this case you have to specify a dataset class in corresponding dataloaders params.

```yaml
dataset_params:
test_dataset_params:
test_dataset_name_1:
...
test_dataset_name_2:
...

test_dataloader_params:
test_dataset_name_1:
dataset: <dataset_class_name>
...
test_dataset_name_2:
dataset: <dataset_class_name>
...
```

#### Without dataloader params

A `dataset_params.test_dataloader_params` property is optional and can be skipped.
In this case `dataset_params.val_dataloader_params` will be used for instantiating test dataloaders.
Please note that if you don't use `test_dataloaders` and `test_dataloader_params` properties, a `dataset_params.val_dataloader_params`
must contain a `dataset` property specifying class name of the dataset to use.

```yaml
dataset_params:
test_dataset_params:
test_dataset_name_1:
...
test_dataset_name_2:
...
```
37 changes: 36 additions & 1 deletion src/super_gradients/common/environment/cfg_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
from pathlib import Path
from typing import List, Optional, Union, Dict, Any
from typing import List, Optional, Union, Dict, Any, Mapping

import hydra
import pkg_resources

from hydra import initialize_config_dir, compose
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf, open_dict, DictConfig
from torch.utils.data import DataLoader

from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
from super_gradients.common.environment.path_utils import normalize_path
Expand Down Expand Up @@ -195,3 +196,37 @@ def export_recipe(config_name: str, save_path: str, config_dir: str = pkg_resour
cfg = compose(config_name=config_name)
OmegaConf.save(config=cfg, f=save_path)
logger.info(f"Successfully saved recipe at {save_path}. \n" f"Recipe content:\n {cfg}")


def maybe_instantiate_test_loaders(cfg) -> Optional[Mapping[str, DataLoader]]:
"""
Instantiate test loaders if they are defined in the config.
:param cfg: Recipe config
:return: A mapping from dataset name to test loader or None if no test loaders are defined.
"""
from super_gradients.training.utils.utils import get_param
from super_gradients.training import dataloaders

test_loaders = None
if "test_dataset_params" in cfg.dataset_params:
test_dataloaders = get_param(cfg, "test_dataloaders")
test_dataset_params = cfg.dataset_params.test_dataset_params
test_dataloader_params = get_param(cfg.dataset_params, "test_dataloader_params")

if test_dataloaders is not None:
if not isinstance(test_dataloaders, Mapping):
raise ValueError("`test_dataloaders` should be a mapping from test_loader_name to test_loader_params.")

if test_dataloader_params is not None and test_dataloader_params.keys() != test_dataset_params.keys():
raise ValueError("test_dataloader_params and test_dataset_params should have the same keys.")

test_loaders = {}
for dataset_name, dataset_params in test_dataset_params.items():
loader_name = test_dataloaders[dataset_name] if test_dataloaders is not None else None
dataset_params = test_dataset_params[dataset_name]
dataloader_params = test_dataloader_params[dataset_name] if test_dataloader_params is not None else cfg.dataset_params.val_dataloader_params
loader = dataloaders.get(loader_name, dataset_params=dataset_params, dataloader_params=dataloader_params)
test_loaders[dataset_name] = loader

return test_loaders
10 changes: 4 additions & 6 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from super_gradients.common.factories.losses_factory import LossesFactory
from super_gradients.common.factories.metrics_factory import MetricsFactory
from super_gradients.common.environment.package_utils import get_installed_packages
from super_gradients.common.environment.cfg_utils import maybe_instantiate_test_loaders

from super_gradients.training import utils as core_utils, models, dataloaders
from super_gradients.training.datasets.samplers import RepeatAugSampler
Expand Down Expand Up @@ -284,13 +285,15 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup
dataloader_params=cfg.dataset_params.val_dataloader_params,
)

test_loaders = maybe_instantiate_test_loaders(cfg)

recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
# TRAIN
res = trainer.train(
model=model,
train_loader=train_dataloader,
valid_loader=val_dataloader,
test_loaders=None, # TODO: Add option to set test_loaders in recipe
test_loaders=test_loaders,
training_params=cfg.training_hyperparams,
additional_configs_to_log=recipe_logged_cfg,
)
Expand Down Expand Up @@ -1679,7 +1682,6 @@ def _set_test_metrics(self, test_metrics_list):
self.test_metrics = MetricCollection(test_metrics_list)

def _initialize_mixed_precision(self, mixed_precision_enabled: bool):

if mixed_precision_enabled and not device_config.is_cuda:
warnings.warn("Mixed precision training is not supported on CPU. Disabling mixed precision. (i.e. `mixed_precision=False`)")
mixed_precision_enabled = False
Expand All @@ -1688,7 +1690,6 @@ def _initialize_mixed_precision(self, mixed_precision_enabled: bool):
self.scaler = GradScaler(enabled=mixed_precision_enabled)

if mixed_precision_enabled:

if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL:
# IN DATAPARALLEL MODE WE NEED TO WRAP THE FORWARD FUNCTION OF OUR MODEL SO IT WILL RUN WITH AUTOCAST.
# BUT SINCE THE MODULE IS CLONED TO THE DEVICES ON EACH FORWARD CALL OF A DATAPARALLEL MODEL,
Expand Down Expand Up @@ -1799,7 +1800,6 @@ def _switch_device(self, new_device):

# FIXME - we need to resolve flake8's 'function is too complex' for this function
def _load_checkpoint_to_model(self):

self.checkpoint = {}
strict_load = core_utils.get_param(self.training_params, "resume_strict_load", StrictLoad.ON)
ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", "ckpt_latest.pth")
Expand Down Expand Up @@ -1921,7 +1921,6 @@ def _initialize_sg_logger_objects(self, additional_configs_to_log: Dict = None):
if isinstance(sg_logger, AbstractSGLogger):
self.sg_logger = sg_logger
elif isinstance(sg_logger, str):

sg_logger_cls = SG_LOGGERS.get(sg_logger)
if sg_logger_cls is None:
raise RuntimeError(f"sg_logger={sg_logger} not registered in SuperGradients. Available {list(SG_LOGGERS.keys())}")
Expand Down Expand Up @@ -2178,7 +2177,6 @@ def evaluate(
with tqdm(
data_loader, total=expected_iterations, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode
) as progress_bar_data_loader:

if not silent_mode:
# PRINT TITLES
pbar_start_msg = "Validating" if evaluation_type == EvaluationType.VALIDATION else "Testing"
Expand Down
56 changes: 56 additions & 0 deletions tests/unit_tests/configs/cifar10_multiple_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
defaults:
- cifar10_resnet

test_dataloaders:
cifar10: cifar10_val
cifar10_v2: cifar10_val

dataset_params:
train_dataloader_params:
num_workers: 0

val_dataloader_params:
num_workers: 0

test_dataset_params:
cifar10:
root: ./data/cifar10
train: False
transforms:
- Resize:
size: 32
- ToTensor
- Normalize:
mean:
- 0.4914
- 0.4822
- 0.4465
std:
- 0.2023
- 0.1994
- 0.2010
target_transform: null
download: True

cifar10_v2:
root: ./data/cifar10
train: False
transforms:
- Resize:
size: 32
- ToTensor
- Normalize:
mean:
- 0.5
- 0.5
- 0.5
std:
- 0.2
- 0.2
- 0.2
target_transform: null
download: True

hydra:
searchpath:
- pkg://super_gradients.recipes
15 changes: 15 additions & 0 deletions tests/unit_tests/train_with_intialized_param_args_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import os
import unittest

import hydra
import numpy as np
import torch
from hydra.core.global_hydra import GlobalHydra
from torch.optim import SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
from torchmetrics import F1Score

from super_gradients import Trainer
from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
from super_gradients.common.environment.path_utils import normalize_path
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
Expand Down Expand Up @@ -189,6 +194,16 @@ def test_train_with_external_dataloaders(self):
}
trainer.train(model=model, training_params=train_params, train_loader=train_loader, valid_loader=val_loader)

def test_train_with_multiple_test_loaders(self):
register_hydra_resolvers()
GlobalHydra.instance().clear()
configs_dir = os.path.join(os.path.dirname(__file__), "configs")
with hydra.initialize_config_dir(config_dir=normalize_path(configs_dir), version_base="1.2"):
cfg = hydra.compose(config_name="cifar10_multiple_test")

cfg.training_hyperparams.max_epochs = 1
Trainer.train_from_config(cfg)


if __name__ == "__main__":
unittest.main()

0 comments on commit 7ff90e4

Please sign in to comment.