Skip to content

Commit

Permalink
Merge pull request #177 from jrzaurin/fix_restore_best_weights
Browse files Browse the repository at this point in the history
Fix #175 early stopping and model checkpoint restoring weights.
  • Loading branch information
jrzaurin committed Jul 14, 2023
2 parents 26e1985 + 8406813 commit 42cfe5c
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 39 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/bio_imbalanced_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
n_epochs=1,
batch_size=32,
custom_dataloader=DataLoaderImbalanced,
oversample_mul=5,
**{"oversample_mul": 5},
)
print(
"Training time[s]: {}".format(
Expand Down
40 changes: 19 additions & 21 deletions pytorch_widedeep/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS
"""
import os
import copy
import datetime
import warnings
import copy

import numpy as np
import torch
Expand Down Expand Up @@ -349,6 +349,10 @@ class just report best metric and best_epoch.
monitor: str, default="loss"
quantity to monitor. Typically _'val_loss'_ or metric name
(e.g. _'val_acc'_)
min_delta: float, default=0.
minimum change in the monitored quantity to qualify as an
improvement, i.e. an absolute change of less than min_delta, will
count as no improvement.
verbose:int, default=0
verbosity mode
save_best_only: bool, default=False,
Expand Down Expand Up @@ -397,6 +401,7 @@ def __init__(
self,
filepath: Optional[str] = None,
monitor: str = "val_loss",
min_delta: float = 0.0,
verbose: int = 0,
save_best_only: bool = False,
mode: str = "auto",
Expand All @@ -407,6 +412,7 @@ def __init__(

self.filepath = filepath
self.monitor = monitor
self.min_delta = min_delta
self.verbose = verbose
self.save_best_only = save_best_only
self.mode = mode
Expand Down Expand Up @@ -450,6 +456,11 @@ def __init__(
self.monitor_op = np.less
self.best = np.Inf

if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1

def on_epoch_end( # noqa: C901
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
Expand All @@ -468,33 +479,20 @@ def on_epoch_end( # noqa: C901
RuntimeWarning,
)
else:
if self.monitor_op(current, self.best):
if self.monitor_op(current - self.min_delta, self.best):
if self.verbose > 0:
if self.filepath:
print(
"\nEpoch %05d: %s improved from %0.5f to %0.5f,"
" saving model to %s"
% (
epoch + 1,
self.monitor,
self.best,
current,
filepath,
)
f"\nEpoch {epoch + 1}: {self.monitor} improved from {self.best:.5f} to {current:.5f} "
f"Saving model to {filepath}"
)
else:
print(
"\nEpoch %05d: %s improved from %0.5f to %0.5f"
% (
epoch + 1,
self.monitor,
self.best,
current,
)
f"\nEpoch {epoch + 1}: {self.monitor} improved from {self.best:.5f} to {current:.5f} "
)
self.best = current
self.best_epoch = epoch
self.best_state_dict = self.model.state_dict()
self.best_state_dict = copy.deepcopy(self.model.state_dict())
if self.filepath:
torch.save(self.best_state_dict, filepath)
if self.max_save > 0:
Expand All @@ -508,8 +506,8 @@ def on_epoch_end( # noqa: C901
else:
if self.verbose > 0:
print(
"\nEpoch %05d: %s did not improve from %0.5f"
% (epoch + 1, self.monitor, self.best)
f"\nEpoch {epoch + 1}: {self.monitor} did not improve from {self.best:.5f} "
f" considering a 'min_delta' improvement of {self.min_delta:.5f}"
)
if not self.save_best_only and self.filepath:
if self.verbose > 0:
Expand Down
1 change: 1 addition & 0 deletions pytorch_widedeep/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
self.with_lds = dataset.with_lds
if "oversample_mul" in kwargs:
oversample_mul = kwargs["oversample_mul"]
del kwargs["oversample_mul"]
else:
oversample_mul = 1
weights, minor_cls_cnt, num_clss = get_class_weights(dataset)
Expand Down
38 changes: 28 additions & 10 deletions pytorch_widedeep/training/_base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import warnings
from abc import ABC, abstractmethod

import numpy as np
Expand Down Expand Up @@ -130,17 +131,34 @@ def save(
):
raise NotImplementedError("Trainer.save method not implemented")

def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
def _restore_best_weights(self): # noqa: C901
early_stopping_min_delta = None
model_checkpoint_min_delta = None
already_restored = False

for callback in self.callback_container.callbacks:
if (
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
):
early_stopping_min_delta = callback.min_delta
already_restored = True

if callback.__class__.__name__ == "ModelCheckpoint":
model_checkpoint_min_delta = callback.min_delta

if (
early_stopping_min_delta is not None
and model_checkpoint_min_delta is not None
) and (early_stopping_min_delta != model_checkpoint_min_delta):
warnings.warn(
"'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. "
"This implies a different definition of 'improvement' for these two callbacks",
UserWarning,
)

if already_restored:
# already restored via EarlyStopping
pass
else:
for callback in self.callback_container.callbacks:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_widedeep/training/_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def _steps_up_down(self, steps: int, n_epochs: int = 1) -> Tuple[int, int]:
up, down: Tuple, int
number of steps increasing/decreasing the learning rate during the cycle
"""
up = round((steps * n_epochs) * 0.1)
# up = round((steps * n_epochs) * 0.1)
up = max([round((steps * n_epochs) * 0.1), 1])
down = (steps * n_epochs) - up
return up, down
143 changes: 137 additions & 6 deletions tests/test_model_functioning/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,18 +481,21 @@ def test_early_stopping_get_state():
assert no_trainer and no_model


def test_early_stopping_restore_state():
# min_delta is large, so the early stopping condition will never be met except for the first epoch.
# ##############################################################################
# Test the restore weights functionalities after bug fixed
# ##############################################################################
def test_early_stopping_restore_weights_with_metric():
# min_delta is large, so the early stopping condition will be met in the first epoch.
early_stopping = EarlyStopping(
restore_best_weights=True, min_delta=1000, patience=1000
)
trainer_tt = Trainer(
trainer = Trainer(
model,
objective="regression",
callbacks=[early_stopping],
verbose=0,
)
trainer_tt.fit(
trainer.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
Expand All @@ -501,8 +504,136 @@ def test_early_stopping_restore_state():
)
assert early_stopping.wait > 0
# so early stopping is not triggered, but is over-fitting.
pred_val = trainer_tt.predict(X_test={"X_wide": X_wide_val, "X_tab": X_tab_val})
restored_metric = trainer_tt.loss_fn(
pred_val = trainer.predict(X_test={"X_wide": X_wide_val, "X_tab": X_tab_val})
restored_metric = trainer.loss_fn(
torch.tensor(pred_val), torch.tensor(target_val)
).item()
assert np.allclose(restored_metric, early_stopping.best)


def test_early_stopping_restore_weights_with_state():
# Long, perhaps too long, test to check early_stopping restore weights
# functionality

# this is repetitive, but for now I want this unit test "self-contained"

# We first define a model and train it, with early stopping that should
# set the weights back to those after the 1st epoch. We also use
# ModelCheckpoint and save all iterations
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
model = WideDeep(wide=wide, deeptabular=deeptabular)

fpath = "tests/test_model_functioning/modelcheckpoint/weights_out"
model_checkpoint = ModelCheckpoint(
filepath=fpath,
save_best_only=False,
max_save=10,
min_delta=1000, # irrelevant here
)
early_stopping = EarlyStopping(
patience=3, min_delta=1000, restore_best_weights=True
)

trainer = Trainer(
model,
objective="binary",
callbacks=[early_stopping, model_checkpoint],
verbose=0,
)
trainer.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
n_epochs=5,
batch_size=16,
)

# We now define a brand new model
new_wide = Wide(np.unique(X_wide).shape[0], 1)
new_deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
new_model = WideDeep(wide=new_wide, deeptabular=new_deeptabular)

# In general, the best epoch is equal to the (stopped_epoch - patience) + 1
full_best_epoch_path = "_".join(
[
model_checkpoint.filepath,
str((early_stopping.stopped_epoch - early_stopping.patience) + 1) + ".p",
]
)

# we load the weights for the best epoch and these should match those of
# the original model if early_stopping worked
new_model.load_state_dict(torch.load(full_best_epoch_path))
new_model.to(next(model.parameters()).device)

shutil.rmtree("tests/test_model_functioning/modelcheckpoint/")

assert torch.allclose(
new_model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
)


def test_model_checkpoint_restore_weights():
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
model = WideDeep(wide=wide, deeptabular=deeptabular)

fpath = "tests/test_model_functioning/modelcheckpoint/weights_out"
model_checkpoint = ModelCheckpoint(
filepath=fpath,
save_best_only=True,
min_delta=1000, # irrelevant here
)
trainer = Trainer(
model,
objective="binary",
callbacks=[model_checkpoint],
verbose=0,
)
trainer.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
n_epochs=5,
batch_size=16,
)

new_wide = Wide(np.unique(X_wide).shape[0], 1)
new_deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
new_model = WideDeep(wide=new_wide, deeptabular=new_deeptabular)

full_best_epoch_path = "_".join(
[model_checkpoint.filepath, str(model_checkpoint.best_epoch + 1) + ".p"]
)

new_model.load_state_dict(torch.load(full_best_epoch_path))
new_model.to(next(model.parameters()).device)

shutil.rmtree("tests/test_model_functioning/modelcheckpoint/")

assert torch.allclose(
new_model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
)

0 comments on commit 42cfe5c

Please sign in to comment.