From 34fda6c43e9607e8fd6c39e3dca60fc1e2c512cd Mon Sep 17 00:00:00 2001 From: Louis-Dupont <35190946+Louis-Dupont@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:26:03 +0300 Subject: [PATCH] Feature/sg 1198 mixed precision automatically changed with warning (#1567) * fix * work with tmpdir * minor change of comment * improve device_config --- .../common/environment/device_utils.py | 16 +++++- .../training/sg_trainer/sg_trainer.py | 11 +++- tests/deci_core_unit_test_suite_runner.py | 2 + tests/unit_tests/__init__.py | 2 + tests/unit_tests/test_mixed_precision_cpu.py | 51 +++++++++++++++++++ 5 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 tests/unit_tests/test_mixed_precision_cpu.py diff --git a/src/super_gradients/common/environment/device_utils.py b/src/super_gradients/common/environment/device_utils.py index 519310c4cb..7f9fa07500 100644 --- a/src/super_gradients/common/environment/device_utils.py +++ b/src/super_gradients/common/environment/device_utils.py @@ -19,10 +19,24 @@ def _get_assigned_rank() -> int: @dataclasses.dataclass class DeviceConfig: - device: str = "cuda" if torch.cuda.is_available() else "cpu" + _device: str = "cuda" if torch.cuda.is_available() else "cpu" multi_gpu: str = None assigned_rank: int = dataclasses.field(default=_get_assigned_rank(), init=False) + @property + def device(self) -> str: + return self._device + + @device.setter + def device(self, value: str): + if "cuda" in value and not torch.cuda.is_available(): + raise ValueError("CUDA is not available, cannot set device to cuda") + self._device = value + + @property + def is_cuda(self): + return "cuda" in self._device + # Singleton holding the device information device_config = DeviceConfig() diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index c1ca383f93..9c46aaf3a7 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -2,6 +2,7 @@ import inspect import os import typing +import warnings from copy import deepcopy from typing import Union, Tuple, Mapping, Dict, Any, List, Optional @@ -1331,7 +1332,7 @@ def forward(self, inputs, targets): self.pre_prediction_callback = CallbacksFactory().get(self.training_params.pre_prediction_callback) - self._initialize_mixed_precision(self.training_params.mixed_precision) + self.training_params.mixed_precision = self._initialize_mixed_precision(self.training_params.mixed_precision) self.ckpt_best_name = self.training_params.ckpt_best_name @@ -1601,11 +1602,16 @@ def _set_test_metrics(self, test_metrics_list): self.test_metrics = MetricCollection(test_metrics_list) def _initialize_mixed_precision(self, mixed_precision_enabled: bool): + + if mixed_precision_enabled and not device_config.is_cuda: + warnings.warn("Mixed precision training is not supported on CPU. Disabling mixed precision. (i.e. `mixed_precision=False`)") + mixed_precision_enabled = False + # SCALER IS ALWAYS INITIALIZED BUT IS DISABLED IF MIXED PRECISION WAS NOT SET self.scaler = GradScaler(enabled=mixed_precision_enabled) if mixed_precision_enabled: - assert device_config.device.startswith("cuda"), "mixed precision is not available for CPU" + if device_config.multi_gpu == MultiGPUMode.DATA_PARALLEL: # IN DATAPARALLEL MODE WE NEED TO WRAP THE FORWARD FUNCTION OF OUR MODEL SO IT WILL RUN WITH AUTOCAST. # BUT SINCE THE MODULE IS CLONED TO THE DEVICES ON EACH FORWARD CALL OF A DATAPARALLEL MODEL, @@ -1621,6 +1627,7 @@ def hook(module, _): logger.warning("Mixed Precision - scaler state_dict not found in loaded model. This may case issues " "with loss scaling") else: self.scaler.load_state_dict(scaler_state_dict) + return mixed_precision_enabled def _validate_final_average_model(self, context: PhaseContext, checkpoint_dir_path: str, cleanup_snapshots_pkl_file=False): """ diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index b9abce904e..2f504dcfaa 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -24,6 +24,7 @@ TestPostPredictionCallback, TestModelPredict, TestDeprecationDecorator, + TestMixedPrecisionDisabled, ) from tests.end_to_end_tests import TestTrainer from tests.unit_tests.detection_utils_test import TestDetectionUtils @@ -162,6 +163,7 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPoseEstimationModelExport)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(YoloNASPoseTests)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PoseEstimationSampleTest)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestMixedPrecisionDisabled)) def _add_modules_to_end_to_end_tests_suite(self): """ diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py index d83d7a4b14..684d23fe88 100644 --- a/tests/unit_tests/__init__.py +++ b/tests/unit_tests/__init__.py @@ -26,6 +26,7 @@ from tests.unit_tests.post_prediction_callback_test import TestPostPredictionCallback from tests.unit_tests.test_predict import TestModelPredict from tests.unit_tests.test_deprecate import TestDeprecationDecorator +from tests.unit_tests.test_mixed_precision_cpu import TestMixedPrecisionDisabled __all__ = [ "CrashTipTest", @@ -55,4 +56,5 @@ "TestPostPredictionCallback", "TestModelPredict", "TestDeprecationDecorator", + "TestMixedPrecisionDisabled", ] diff --git a/tests/unit_tests/test_mixed_precision_cpu.py b/tests/unit_tests/test_mixed_precision_cpu.py new file mode 100644 index 0000000000..8ca3cb3c5c --- /dev/null +++ b/tests/unit_tests/test_mixed_precision_cpu.py @@ -0,0 +1,51 @@ +import unittest +import tempfile + +from super_gradients import Trainer +from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader +from super_gradients.training.metrics import Accuracy, Top5 +from super_gradients.training.models import ResNet18 +from super_gradients.training.utils.distributed_training_utils import setup_device + + +class TestMixedPrecisionDisabled(unittest.TestCase): + def test_mixed_precision_automatically_changed_with_warning(self): + setup_device(device="cpu") + + with tempfile.TemporaryDirectory() as temp_dir: + trainer = Trainer("test_mixed_precision_automatically_changed_with_warning", ckpt_root_dir=temp_dir) + net = ResNet18(num_classes=5, arch_params={}) + train_params = { + "max_epochs": 2, + "lr_updates": [1], + "lr_decay_factor": 0.1, + "lr_mode": "StepLRScheduler", + "lr_warmup_epochs": 0, + "initial_lr": 0.1, + "loss": "CrossEntropyLoss", + "criterion_params": {"ignore_index": 0}, + "train_metrics_list": [Accuracy(), Top5()], + "valid_metrics_list": [Accuracy(), Top5()], + "metric_to_watch": "Accuracy", + "greater_metric_to_watch_is_better": True, + "mixed_precision": True, # This is not supported for CPU, so we expect a warning to be raised AND the code to run + } + import warnings + + with warnings.catch_warnings(record=True) as w: + # Trigger a filter to always make warnings visible + warnings.simplefilter("always") + + trainer.train( + model=net, + training_params=train_params, + train_loader=classification_test_dataloader(batch_size=10), + valid_loader=classification_test_dataloader(batch_size=10), + ) + + # Check if the desired warning is in the list of warnings + self.assertTrue(any("Mixed precision training is not supported on CPU" in str(warn.message) for warn in w)) + + +if __name__ == "__main__": + unittest.main()