Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support of multiple test loaders in train_from_config #1641

Merged
merged 6 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions documentation/source/Data.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,89 @@ Last, in your ``my_train_from_recipe_script.py`` file, import the newly register
if __name__ == "__main__":
run()
```

### Adding test datasets

In addition to the train and validation datasets, you can also add a test dataset or multiple test datasets to your configuration file.
At the end of training, metrics from each test dataset will be computed and returned in final results.

#### Single test dataset

To add a single test dataset to recipe, add following properties to your configuration file:

```yaml
test_dataloaders: <dataloader_name>

dataset_params:
test_dataset_params:
...

test_dataloader_params:
...
```


#### Multiple test datasets

In addition to the train and validation datasets, you can also add a test dataset or multiple test datasets to your configuration file.
This is how you can achieve this using YAML file:

#### Explicitly specifying all parameters

```yaml
test_dataloaders:
test_dataset_name_1: <dataloader_name>
test_dataset_name_2: <dataloader_name>

dataset_params:
test_dataset_params:
test_dataset_name_1:
...
test_dataset_name_2:
...

test_dataloader_params:
test_dataset_name_1:
...
test_dataset_name_2:
...
```

#### Without dataloader names

A `test_dataloaders` property of the configuration file is optional and can be skipped.
You may want to use this option when you don't have a dataloader factory method registered.
In this case you have to specify a dataset class in corresponding dataloaders params.

```yaml
dataset_params:
test_dataset_params:
test_dataset_name_1:
...
test_dataset_name_2:
...

test_dataloader_params:
test_dataset_name_1:
dataset: <dataset_class_name>
...
test_dataset_name_2:
dataset: <dataset_class_name>
...
```

#### Without dataloader params

A `dataset_params.test_dataloader_params` property is optional and can be skipped.
In this case `dataset_params.val_dataloader_params` will be used for instantiating test dataloaders.
Please note that if you don't use `test_dataloaders` and `test_dataloader_params` properties, a `dataset_params.val_dataloader_params`
must contain a `dataset` property specifying class name of the dataset to use.

```yaml
dataset_params:
test_dataset_params:
test_dataset_name_1:
...
test_dataset_name_2:
...
```
37 changes: 36 additions & 1 deletion src/super_gradients/common/environment/cfg_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
from pathlib import Path
from typing import List, Optional, Union, Dict, Any
from typing import List, Optional, Union, Dict, Any, Mapping

import hydra
import pkg_resources

from hydra import initialize_config_dir, compose
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf, open_dict, DictConfig
from torch.utils.data import DataLoader

from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
from super_gradients.common.environment.path_utils import normalize_path
Expand Down Expand Up @@ -195,3 +196,37 @@ def export_recipe(config_name: str, save_path: str, config_dir: str = pkg_resour
cfg = compose(config_name=config_name)
OmegaConf.save(config=cfg, f=save_path)
logger.info(f"Successfully saved recipe at {save_path}. \n" f"Recipe content:\n {cfg}")


def maybe_instantiate_test_loaders(cfg) -> Optional[Mapping[str, DataLoader]]:
"""
Instantiate test loaders if they are defined in the config.

:param cfg: Recipe config
:return: A mapping from dataset name to test loader or None if no test loaders are defined.
"""
from super_gradients.training.utils.utils import get_param
from super_gradients.training import dataloaders

test_loaders = None
if "test_dataset_params" in cfg.dataset_params:
test_dataloaders = get_param(cfg, "test_dataloaders")
test_dataset_params = cfg.dataset_params.test_dataset_params
test_dataloader_params = get_param(cfg.dataset_params, "test_dataloader_params")

if test_dataloaders is not None:
if not isinstance(test_dataloaders, Mapping):
raise ValueError("`test_dataloaders` should be a mapping from test_loader_name to test_loader_params.")

if test_dataloader_params is not None and test_dataloader_params.keys() != test_dataset_params.keys():
raise ValueError("test_dataloader_params and test_dataset_params should have the same keys.")

test_loaders = {}
for dataset_name, dataset_params in test_dataset_params.items():
loader_name = test_dataloaders[dataset_name] if test_dataloaders is not None else None
dataset_params = test_dataset_params[dataset_name]
dataloader_params = test_dataloader_params[dataset_name] if test_dataloader_params is not None else cfg.dataset_params.val_dataloader_params
loader = dataloaders.get(loader_name, dataset_params=dataset_params, dataloader_params=dataloader_params)
test_loaders[dataset_name] = loader

return test_loaders
10 changes: 4 additions & 6 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from super_gradients.common.factories.losses_factory import LossesFactory
from super_gradients.common.factories.metrics_factory import MetricsFactory
from super_gradients.common.environment.package_utils import get_installed_packages
from super_gradients.common.environment.cfg_utils import maybe_instantiate_test_loaders

from super_gradients.training import utils as core_utils, models, dataloaders
from super_gradients.training.datasets.samplers import RepeatAugSampler
Expand Down Expand Up @@ -284,13 +285,15 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup
dataloader_params=cfg.dataset_params.val_dataloader_params,
)

test_loaders = maybe_instantiate_test_loaders(cfg)

recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
# TRAIN
res = trainer.train(
model=model,
train_loader=train_dataloader,
valid_loader=val_dataloader,
test_loaders=None, # TODO: Add option to set test_loaders in recipe
test_loaders=test_loaders,
training_params=cfg.training_hyperparams,
additional_configs_to_log=recipe_logged_cfg,
)
Expand Down Expand Up @@ -1679,7 +1682,6 @@ def _set_test_metrics(self, test_metrics_list):
self.test_metrics = MetricCollection(test_metrics_list)

def _initialize_mixed_precision(self, mixed_precision_enabled: bool):

if mixed_precision_enabled and not device_config.is_cuda:
warnings.warn("Mixed precision training is not supported on CPU. Disabling mixed precision. (i.e. `mixed_precision=False`)")
mixed_precision_enabled = False
Expand All @@ -1688,7 +1690,6 @@ def _initialize_mixed_precision(self, mixed_precision_enabled: bool):
self.scaler = GradScaler(enabled=mixed_precision_enabled)

if mixed_precision_enabled:

if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL:
# IN DATAPARALLEL MODE WE NEED TO WRAP THE FORWARD FUNCTION OF OUR MODEL SO IT WILL RUN WITH AUTOCAST.
# BUT SINCE THE MODULE IS CLONED TO THE DEVICES ON EACH FORWARD CALL OF A DATAPARALLEL MODEL,
Expand Down Expand Up @@ -1799,7 +1800,6 @@ def _switch_device(self, new_device):

# FIXME - we need to resolve flake8's 'function is too complex' for this function
def _load_checkpoint_to_model(self):

self.checkpoint = {}
strict_load = core_utils.get_param(self.training_params, "resume_strict_load", StrictLoad.ON)
ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", "ckpt_latest.pth")
Expand Down Expand Up @@ -1921,7 +1921,6 @@ def _initialize_sg_logger_objects(self, additional_configs_to_log: Dict = None):
if isinstance(sg_logger, AbstractSGLogger):
self.sg_logger = sg_logger
elif isinstance(sg_logger, str):

sg_logger_cls = SG_LOGGERS.get(sg_logger)
if sg_logger_cls is None:
raise RuntimeError(f"sg_logger={sg_logger} not registered in SuperGradients. Available {list(SG_LOGGERS.keys())}")
Expand Down Expand Up @@ -2178,7 +2177,6 @@ def evaluate(
with tqdm(
data_loader, total=expected_iterations, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode
) as progress_bar_data_loader:

if not silent_mode:
# PRINT TITLES
pbar_start_msg = "Validating" if evaluation_type == EvaluationType.VALIDATION else "Testing"
Expand Down
56 changes: 56 additions & 0 deletions tests/unit_tests/configs/cifar10_multiple_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
defaults:
- cifar10_resnet

test_dataloaders:
cifar10: cifar10_val
cifar10_v2: cifar10_val

dataset_params:
train_dataloader_params:
num_workers: 0

val_dataloader_params:
num_workers: 0

test_dataset_params:
cifar10:
root: ./data/cifar10
train: False
transforms:
- Resize:
size: 32
- ToTensor
- Normalize:
mean:
- 0.4914
- 0.4822
- 0.4465
std:
- 0.2023
- 0.1994
- 0.2010
target_transform: null
download: True

cifar10_v2:
root: ./data/cifar10
train: False
transforms:
- Resize:
size: 32
- ToTensor
- Normalize:
mean:
- 0.5
- 0.5
- 0.5
std:
- 0.2
- 0.2
- 0.2
target_transform: null
download: True

hydra:
searchpath:
- pkg://super_gradients.recipes
15 changes: 15 additions & 0 deletions tests/unit_tests/train_with_intialized_param_args_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import os
import unittest

import hydra
import numpy as np
import torch
from hydra.core.global_hydra import GlobalHydra
from torch.optim import SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
from torchmetrics import F1Score

from super_gradients import Trainer
from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
from super_gradients.common.environment.path_utils import normalize_path
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
Expand Down Expand Up @@ -189,6 +194,16 @@ def test_train_with_external_dataloaders(self):
}
trainer.train(model=model, training_params=train_params, train_loader=train_loader, valid_loader=val_loader)

def test_train_with_multiple_test_loaders(self):
register_hydra_resolvers()
GlobalHydra.instance().clear()
configs_dir = os.path.join(os.path.dirname(__file__), "configs")
with hydra.initialize_config_dir(config_dir=normalize_path(configs_dir), version_base="1.2"):
cfg = hydra.compose(config_name="cifar10_multiple_test")

cfg.training_hyperparams.max_epochs = 1
Trainer.train_from_config(cfg)


if __name__ == "__main__":
unittest.main()