Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed bug in ModelWeightAveraging class that led to corrupted model when metric to watch was NaN/Inf #1598

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions src/super_gradients/training/utils/weight_averaging_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
from typing import Mapping, Tuple, Optional, Union, Any

import torch
import numpy as np
from torch import nn

from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict
from super_gradients.training.utils.utils import move_state_dict_to_device, unwrap_model

Expand All @@ -15,18 +19,18 @@ class ModelWeightAveraging:

def __init__(
self,
ckpt_dir,
greater_is_better,
metric_to_watch="acc",
load_checkpoint=False,
number_of_models_to_average=10,
ckpt_dir: str,
greater_is_better: bool,
metric_to_watch: str,
load_checkpoint: bool = False,
number_of_models_to_average: int = 10,
):
"""
Init the ModelWeightAveraging
:param ckpt_dir: the directory where the checkpoints are saved
:param metric_to_watch: monitoring loss or acc, will be identical to that which determines best_model
:param load_checkpoint: whether to load pre-existing snapshot dict.
:param number_of_models_to_average: number of models to average
:param ckpt_dir: The directory where the checkpoints are saved
:param metric_to_watch: Monitoring loss or acc, will be identical to that which determines best_model
:param load_checkpoint: Whether to load pre-existing snapshot dict.
:param number_of_models_to_average: Number of models to average
"""

self.averaging_snapshots_file = os.path.join(ckpt_dir, "averaging_snapshots.pkl")
Expand All @@ -37,7 +41,6 @@ def __init__(
# if continuing training, copy previous snapshot dict if exist
if load_checkpoint and ckpt_dir is not None and os.path.isfile(self.averaging_snapshots_file):
averaging_snapshots_dict = read_ckpt_state_dict(self.averaging_snapshots_file)

else:
averaging_snapshots_dict = {"snapshot" + str(i): None for i in range(self.number_of_models_to_average)}
# if metric to watch is acc, hold a zero array, if loss hold inf array
Expand All @@ -48,7 +51,7 @@ def __init__(

torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)

def update_snapshots_dict(self, model, validation_results_dict):
def update_snapshots_dict(self, model: nn.Module, validation_results_dict: Mapping[str, float]):
"""
Update the snapshot dict and returns the updated average model for saving
:param model: the latest model
Expand All @@ -64,11 +67,11 @@ def update_snapshots_dict(self, model, validation_results_dict):
new_sd = move_state_dict_to_device(new_sd, "cpu")

averaging_snapshots_dict["snapshot" + str(update_ind)] = new_sd
averaging_snapshots_dict["snapshots_metric"][update_ind] = validation_results_dict[self.metric_to_watch]
averaging_snapshots_dict["snapshots_metric"][update_ind] = float(validation_results_dict[self.metric_to_watch])

return averaging_snapshots_dict

def get_average_model(self, model, validation_results_dict=None):
def get_average_model(self, model, validation_results_dict=None) -> Mapping[str, torch.Tensor]:
"""
Returns the averaged model
:param model: will be used to determine arch
Expand Down Expand Up @@ -99,14 +102,21 @@ def cleanup(self):
"""
os.remove(self.averaging_snapshots_file)

def _is_better(self, averaging_snapshots_dict, validation_results_dict):
def _is_better(
self, averaging_snapshots_dict: Mapping[str, Any], validation_results_dict: Mapping[str, Union[float, torch.Tensor]]
) -> Tuple[bool, Optional[int]]:
"""
Determines if the new model is better according to the specified metrics
:param averaging_snapshots_dict: snapshot dict
:param validation_results_dict: latest model performance
:param validation_results_dict: latest model performance
:return: Tuple (bool, index) whether first item is True if the new model is better and False otherwise;
Second item is the index in the averaging_snapshots_dict to which the new model should be saved
"""
snapshot_metric_array = averaging_snapshots_dict["snapshots_metric"]
val = validation_results_dict[self.metric_to_watch]
val = float(validation_results_dict[self.metric_to_watch])

if not np.isfinite(val):
return False, None

if self.greater_is_better:
update_ind = np.argmin(snapshot_metric_array)
Expand All @@ -119,4 +129,4 @@ def _is_better(self, averaging_snapshots_dict, validation_results_dict):
return False, None

def _get_averaging_snapshots_dict(self):
return torch.load(self.averaging_snapshots_file)
return torch.load(self.averaging_snapshots_file, map_location="cpu")
2 changes: 2 additions & 0 deletions tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
from tests.unit_tests.test_deprecations import DeprecationsUnitTest
from tests.unit_tests.test_min_samples_single_node import TestMinSamplesSingleNode
from tests.unit_tests.test_model_weight_averaging import TestModelWeightAveraging
from tests.unit_tests.test_train_with_torch_scheduler import TrainWithTorchSchedulerTest
from tests.unit_tests.test_version_check import TestVersionCheck
from tests.unit_tests.test_yolo_nas_pose import YoloNASPoseTests
Expand Down Expand Up @@ -172,6 +173,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DynamicModelTests))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestConvertRecipeToCode))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestVersionCheck))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelWeightAveraging))

def _add_modules_to_end_to_end_tests_suite(self):
"""
Expand Down
70 changes: 70 additions & 0 deletions tests/unit_tests/test_model_weight_averaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import collections
import tempfile
import unittest

import numpy as np
import torch
from torch import nn

from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging


class TestModelWeightAveraging(unittest.TestCase):
def test_model_weight_averaging_single_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
avg = ModelWeightAveraging(
ckpt_dir=tmp_dir,
greater_is_better=True,
metric_to_watch="acc",
load_checkpoint=False,
number_of_models_to_average=10,
)

model = self._create_dummy_model()
model_sd = model.state_dict()
avg_model_sd = avg.get_average_model(model, {"acc": 0.99})
self.assertStateDictAlmostEqual(avg_model_sd, model_sd)

def test_model_weight_averaging_with_nan_metric(self):
corrupted_metric_values = np.nan, +np.inf, -np.inf, torch.nan, torch.inf, -torch.inf

for corrupted_metric_value in corrupted_metric_values:
with self.subTest(corrupted_metric_value=corrupted_metric_value):
with tempfile.TemporaryDirectory() as tmp_dir:
avg = ModelWeightAveraging(
ckpt_dir=tmp_dir,
greater_is_better=True,
metric_to_watch="acc",
load_checkpoint=False,
number_of_models_to_average=10,
)

model = self._create_dummy_model()
model_sd = model.state_dict()
avg.get_average_model(model, {"acc": 0.99})

corrupted_model = self._create_dummy_model()
corrupted_model.fc1.weight.data = torch.randn(10, 10) * torch.nan
avg_model_sd = avg.get_average_model(corrupted_model, {"acc": corrupted_metric_value})

self.assertStateDictAlmostEqual(avg_model_sd, model_sd)

def assertStateDictAlmostEqual(self, sd1, sd2, eps=1e-5):
self.assertEqual(set(sd1.keys()), set(sd2.keys()))
for key in sd1.keys():
v1 = sd1[key]
v2 = sd2[key]
if torch.is_floating_point(v1) and torch.is_floating_point(v2):
difference = torch.nn.functional.l1_loss(v1, v2)
self.assertLessEqual(difference, eps, msg=f"{key}: {v1} vs {v2}")
else:
self.assertEqual(v1, v2)

def _create_dummy_model(self) -> nn.Module:
net = nn.Sequential(collections.OrderedDict([("fc1", nn.Linear(10, 10)), ("bn", nn.BatchNorm1d(10))]))
net.fc1.weight.data = torch.randn(10, 10)
return net


if __name__ == "__main__":
unittest.main()