From 9604d7bf8994615431af9d86c8de154677237b75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Jab=C5=82o=C5=84ski?= Date: Mon, 27 Apr 2020 11:02:33 +0200 Subject: [PATCH] Fix ModelCheckpoint not being fixable (#1632) --- pytorch_lightning/callbacks/model_checkpoint.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1d5d6d1474a09..44d240b503b50 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -116,10 +116,10 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve torch_inf = torch.tensor(np.Inf) mode_dict = { - 'min': (torch.lt, torch_inf, 'min'), - 'max': (torch.gt, -torch_inf, 'max'), - 'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') - else (torch.lt, torch_inf, 'min'), + 'min': (torch_inf, 'min'), + 'max': (-torch_inf, 'max'), + 'auto': (-torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') + else (torch_inf, 'min'), } if mode not in mode_dict: @@ -127,7 +127,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve f'fallback to auto mode.', RuntimeWarning) mode = 'auto' - self.monitor_op, self.kth_value, self.mode = mode_dict[mode] + self.kth_value, self.mode = mode_dict[mode] def _del_model(self, filepath): if os.path.isfile(filepath): @@ -151,7 +151,12 @@ def check_monitor_top_k(self, current): if not isinstance(current, torch.Tensor): current = torch.tensor(current) - return self.monitor_op(current, self.best_k_models[self.kth_best_model]) + monitor_op = { + "min": torch.lt, + "max": torch.gt, + }[self.mode] + + return monitor_op(current, self.best_k_models[self.kth_best_model]) def format_checkpoint_name(self, epoch, metrics, ver=None): """Generate a filename according to the defined template.