Skip to content

Commit

Permalink
Merge pull request #1632 from quinor/patch-1
Browse files Browse the repository at this point in the history
Fix ModelCheckpoint not being picklable.
  • Loading branch information
williamFalcon committed Apr 27, 2020
2 parents 89877fe + 735520b commit 26933a9
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,18 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve

torch_inf = torch.tensor(np.Inf)
mode_dict = {
'min': (torch.lt, torch_inf, 'min'),
'max': (torch.gt, -torch_inf, 'max'),
'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure')
else (torch.lt, torch_inf, 'min'),
'min': (torch_inf, 'min'),
'max': (-torch_inf, 'max'),
'auto': (-torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure')
else (torch_inf, 'min'),
}

if mode not in mode_dict:
rank_zero_warn(f'ModelCheckpoint mode {mode} is unknown, '
f'fallback to auto mode.', RuntimeWarning)
mode = 'auto'

self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
self.kth_value, self.mode = mode_dict[mode]

def _del_model(self, filepath):
if os.path.isfile(filepath):
Expand All @@ -151,7 +151,12 @@ def check_monitor_top_k(self, current):
if not isinstance(current, torch.Tensor):
current = torch.tensor(current)

return self.monitor_op(current, self.best_k_models[self.kth_best_model])
monitor_op = {
"min": torch.lt,
"max": torch.gt,
}[self.mode]

return monitor_op(current, self.best_k_models[self.kth_best_model])

def format_checkpoint_name(self, epoch, metrics, ver=None):
"""Generate a filename according to the defined template.
Expand Down

0 comments on commit 26933a9

Please sign in to comment.