Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support of multiple test loaders in train_from_config #1641

Merged
merged 6 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,26 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup
dataloader_params=cfg.dataset_params.val_dataloader_params,
)

test_loaders = {}
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
if "test_dataset_params" in cfg.dataset_params:
test_dataloaders = get_param(cfg, "test_dataloaders")
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
test_dataset_params = cfg.dataset_params.test_dataset_params
test_dataloader_params = get_param(cfg.dataset_params, "test_dataloader_params")

if test_dataloaders is not None:
if not isinstance(test_dataloaders, Mapping):
raise ValueError("`test_dataloaders` should be a mapping from test_loader_name to test_loader_params.")

if test_dataloader_params is not None and test_dataloader_params.keys() != test_dataset_params.keys():
raise ValueError("test_dataloader_params and test_dataset_params should have the same keys.")

for dataset_name, dataset_params in test_dataset_params.items():
loader_name = test_dataloaders[dataset_name] if test_dataloaders is not None else None
dataset_params = test_dataset_params[dataset_name]
dataloader_params = test_dataloader_params[dataset_name] if test_dataloader_params is not None else cfg.dataset_params.val_dataloader_params
loader = dataloaders.get(loader_name, dataset_params=dataset_params, dataloader_params=dataloader_params)
test_loaders[dataset_name] = loader

recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
# TRAIN
res = trainer.train(
Expand Down Expand Up @@ -1679,7 +1699,6 @@ def _set_test_metrics(self, test_metrics_list):
self.test_metrics = MetricCollection(test_metrics_list)

def _initialize_mixed_precision(self, mixed_precision_enabled: bool):

if mixed_precision_enabled and not device_config.is_cuda:
warnings.warn("Mixed precision training is not supported on CPU. Disabling mixed precision. (i.e. `mixed_precision=False`)")
mixed_precision_enabled = False
Expand All @@ -1688,7 +1707,6 @@ def _initialize_mixed_precision(self, mixed_precision_enabled: bool):
self.scaler = GradScaler(enabled=mixed_precision_enabled)

if mixed_precision_enabled:

if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL:
# IN DATAPARALLEL MODE WE NEED TO WRAP THE FORWARD FUNCTION OF OUR MODEL SO IT WILL RUN WITH AUTOCAST.
# BUT SINCE THE MODULE IS CLONED TO THE DEVICES ON EACH FORWARD CALL OF A DATAPARALLEL MODEL,
Expand Down Expand Up @@ -1799,7 +1817,6 @@ def _switch_device(self, new_device):

# FIXME - we need to resolve flake8's 'function is too complex' for this function
def _load_checkpoint_to_model(self):

self.checkpoint = {}
strict_load = core_utils.get_param(self.training_params, "resume_strict_load", StrictLoad.ON)
ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", "ckpt_latest.pth")
Expand Down Expand Up @@ -1921,7 +1938,6 @@ def _initialize_sg_logger_objects(self, additional_configs_to_log: Dict = None):
if isinstance(sg_logger, AbstractSGLogger):
self.sg_logger = sg_logger
elif isinstance(sg_logger, str):

sg_logger_cls = SG_LOGGERS.get(sg_logger)
if sg_logger_cls is None:
raise RuntimeError(f"sg_logger={sg_logger} not registered in SuperGradients. Available {list(SG_LOGGERS.keys())}")
Expand Down Expand Up @@ -2178,7 +2194,6 @@ def evaluate(
with tqdm(
data_loader, total=expected_iterations, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode
) as progress_bar_data_loader:

if not silent_mode:
# PRINT TITLES
pbar_start_msg = "Validating" if evaluation_type == EvaluationType.VALIDATION else "Testing"
Expand Down
50 changes: 50 additions & 0 deletions tests/unit_tests/configs/cifar10_multiple_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
defaults:
- cifar10_resnet

test_dataloaders:
cifar10: cifar10_val
cifar100: cifar100_val

dataset_params:
test_dataset_params:
cifar10:
root: ./data/cifar10
train: False
transforms:
- Resize:
size: 32
- ToTensor
- Normalize:
mean:
- 0.4914
- 0.4822
- 0.4465
std:
- 0.2023
- 0.1994
- 0.2010
target_transform: null
download: True

cifar100:
root: ./data/cifar100
train: False
transforms:
- Resize:
size: 32
- ToTensor
- Normalize:
mean:
- 0.4914
- 0.4822
- 0.4465
std:
- 0.2023
- 0.1994
- 0.2010
target_transform: null
download: True

hydra:
searchpath:
- pkg://super_gradients.recipes
16 changes: 16 additions & 0 deletions tests/unit_tests/train_with_intialized_param_args_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import os
import unittest

import hydra
import numpy as np
import torch
from hydra.core.global_hydra import GlobalHydra
from torch.optim import SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
from torchmetrics import F1Score

from super_gradients import Trainer
from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
from super_gradients.common.environment.path_utils import normalize_path
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
Expand Down Expand Up @@ -189,6 +194,17 @@ def test_train_with_external_dataloaders(self):
}
trainer.train(model=model, training_params=train_params, train_loader=train_loader, valid_loader=val_loader)

def test_train_with_multiple_test_loaders(self):
register_hydra_resolvers()
GlobalHydra.instance().clear()
configs_dir = os.path.join(os.path.dirname(__file__), "configs")
with hydra.initialize_config_dir(config_dir=normalize_path(configs_dir), version_base="1.2"):
cfg = hydra.compose(config_name="cifar10_multiple_test")

cfg.training_hyperparams.max_epochs = 5
trained_model, metrics = Trainer.train_from_config(cfg)
print(metrics)


if __name__ == "__main__":
unittest.main()