Skip to content

Commit

Permalink
Fixed bug in ModelWeightAveraging class when metric to watch was NaN (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Nov 1, 2023
1 parent 20ddf3d commit c2821bb
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 17 deletions.
44 changes: 27 additions & 17 deletions src/super_gradients/training/utils/weight_averaging_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
from typing import Mapping, Tuple, Optional, Union, Any

import torch
import numpy as np
from torch import nn

from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict
from super_gradients.training.utils.utils import move_state_dict_to_device, unwrap_model

Expand All @@ -15,18 +19,18 @@ class ModelWeightAveraging:

def __init__(
self,
ckpt_dir,
greater_is_better,
metric_to_watch="acc",
load_checkpoint=False,
number_of_models_to_average=10,
ckpt_dir: str,
greater_is_better: bool,
metric_to_watch: str,
load_checkpoint: bool = False,
number_of_models_to_average: int = 10,
):
"""
Init the ModelWeightAveraging
:param ckpt_dir: the directory where the checkpoints are saved
:param metric_to_watch: monitoring loss or acc, will be identical to that which determines best_model
:param load_checkpoint: whether to load pre-existing snapshot dict.
:param number_of_models_to_average: number of models to average
:param ckpt_dir: The directory where the checkpoints are saved
:param metric_to_watch: Monitoring loss or acc, will be identical to that which determines best_model
:param load_checkpoint: Whether to load pre-existing snapshot dict.
:param number_of_models_to_average: Number of models to average
"""

self.averaging_snapshots_file = os.path.join(ckpt_dir, "averaging_snapshots.pkl")
Expand All @@ -37,7 +41,6 @@ def __init__(
# if continuing training, copy previous snapshot dict if exist
if load_checkpoint and ckpt_dir is not None and os.path.isfile(self.averaging_snapshots_file):
averaging_snapshots_dict = read_ckpt_state_dict(self.averaging_snapshots_file)

else:
averaging_snapshots_dict = {"snapshot" + str(i): None for i in range(self.number_of_models_to_average)}
# if metric to watch is acc, hold a zero array, if loss hold inf array
Expand All @@ -48,7 +51,7 @@ def __init__(

torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)

def update_snapshots_dict(self, model, validation_results_dict):
def update_snapshots_dict(self, model: nn.Module, validation_results_dict: Mapping[str, float]):
"""
Update the snapshot dict and returns the updated average model for saving
:param model: the latest model
Expand All @@ -64,11 +67,11 @@ def update_snapshots_dict(self, model, validation_results_dict):
new_sd = move_state_dict_to_device(new_sd, "cpu")

averaging_snapshots_dict["snapshot" + str(update_ind)] = new_sd
averaging_snapshots_dict["snapshots_metric"][update_ind] = validation_results_dict[self.metric_to_watch]
averaging_snapshots_dict["snapshots_metric"][update_ind] = float(validation_results_dict[self.metric_to_watch])

return averaging_snapshots_dict

def get_average_model(self, model, validation_results_dict=None):
def get_average_model(self, model, validation_results_dict=None) -> Mapping[str, torch.Tensor]:
"""
Returns the averaged model
:param model: will be used to determine arch
Expand Down Expand Up @@ -99,14 +102,21 @@ def cleanup(self):
"""
os.remove(self.averaging_snapshots_file)

def _is_better(self, averaging_snapshots_dict, validation_results_dict):
def _is_better(
self, averaging_snapshots_dict: Mapping[str, Any], validation_results_dict: Mapping[str, Union[float, torch.Tensor]]
) -> Tuple[bool, Optional[int]]:
"""
Determines if the new model is better according to the specified metrics
:param averaging_snapshots_dict: snapshot dict
:param validation_results_dict: latest model performance
:param validation_results_dict: latest model performance
:return: Tuple (bool, index) whether first item is True if the new model is better and False otherwise;
Second item is the index in the averaging_snapshots_dict to which the new model should be saved
"""
snapshot_metric_array = averaging_snapshots_dict["snapshots_metric"]
val = validation_results_dict[self.metric_to_watch]
val = float(validation_results_dict[self.metric_to_watch])

if not np.isfinite(val):
return False, None

if self.greater_is_better:
update_ind = np.argmin(snapshot_metric_array)
Expand All @@ -119,4 +129,4 @@ def _is_better(self, averaging_snapshots_dict, validation_results_dict):
return False, None

def _get_averaging_snapshots_dict(self):
return torch.load(self.averaging_snapshots_file)
return torch.load(self.averaging_snapshots_file, map_location="cpu")
2 changes: 2 additions & 0 deletions tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
from tests.unit_tests.test_deprecations import DeprecationsUnitTest
from tests.unit_tests.test_min_samples_single_node import TestMinSamplesSingleNode
from tests.unit_tests.test_model_weight_averaging import TestModelWeightAveraging
from tests.unit_tests.test_train_with_torch_scheduler import TrainWithTorchSchedulerTest
from tests.unit_tests.test_version_check import TestVersionCheck
from tests.unit_tests.test_yolo_nas_pose import YoloNASPoseTests
Expand Down Expand Up @@ -172,6 +173,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DynamicModelTests))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestConvertRecipeToCode))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestVersionCheck))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelWeightAveraging))

def _add_modules_to_end_to_end_tests_suite(self):
"""
Expand Down
70 changes: 70 additions & 0 deletions tests/unit_tests/test_model_weight_averaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import collections
import tempfile
import unittest

import numpy as np
import torch
from torch import nn

from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging


class TestModelWeightAveraging(unittest.TestCase):
def test_model_weight_averaging_single_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
avg = ModelWeightAveraging(
ckpt_dir=tmp_dir,
greater_is_better=True,
metric_to_watch="acc",
load_checkpoint=False,
number_of_models_to_average=10,
)

model = self._create_dummy_model()
model_sd = model.state_dict()
avg_model_sd = avg.get_average_model(model, {"acc": 0.99})
self.assertStateDictAlmostEqual(avg_model_sd, model_sd)

def test_model_weight_averaging_with_nan_metric(self):
corrupted_metric_values = np.nan, +np.inf, -np.inf, torch.nan, torch.inf, -torch.inf

for corrupted_metric_value in corrupted_metric_values:
with self.subTest(corrupted_metric_value=corrupted_metric_value):
with tempfile.TemporaryDirectory() as tmp_dir:
avg = ModelWeightAveraging(
ckpt_dir=tmp_dir,
greater_is_better=True,
metric_to_watch="acc",
load_checkpoint=False,
number_of_models_to_average=10,
)

model = self._create_dummy_model()
model_sd = model.state_dict()
avg.get_average_model(model, {"acc": 0.99})

corrupted_model = self._create_dummy_model()
corrupted_model.fc1.weight.data = torch.randn(10, 10) * torch.nan
avg_model_sd = avg.get_average_model(corrupted_model, {"acc": corrupted_metric_value})

self.assertStateDictAlmostEqual(avg_model_sd, model_sd)

def assertStateDictAlmostEqual(self, sd1, sd2, eps=1e-5):
self.assertEqual(set(sd1.keys()), set(sd2.keys()))
for key in sd1.keys():
v1 = sd1[key]
v2 = sd2[key]
if torch.is_floating_point(v1) and torch.is_floating_point(v2):
difference = torch.nn.functional.l1_loss(v1, v2)
self.assertLessEqual(difference, eps, msg=f"{key}: {v1} vs {v2}")
else:
self.assertEqual(v1, v2)

def _create_dummy_model(self) -> nn.Module:
net = nn.Sequential(collections.OrderedDict([("fc1", nn.Linear(10, 10)), ("bn", nn.BatchNorm1d(10))]))
net.fc1.weight.data = torch.randn(10, 10)
return net


if __name__ == "__main__":
unittest.main()

0 comments on commit c2821bb

Please sign in to comment.