Skip to content

Commit

Permalink
feat: change Checkpoint callback's save_best_only to save_top_k
Browse files Browse the repository at this point in the history
  • Loading branch information
Ir1d committed Aug 17, 2019
1 parent 6b54c1e commit 49e35b1
Showing 1 changed file with 57 additions and 19 deletions.
76 changes: 57 additions & 19 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,19 @@ class ModelCheckpoint(Callback):
"""

def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=False, save_weights_only=False,
save_top_k=0, save_weights_only=False,
mode='auto', period=1, prefix=''):
super(ModelCheckpoint, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_save = 0
self.epochs_since_last_check = 0
self.prefix = prefix
self.bestk = {} # {epoch: monitor}
self.best = 0

if mode not in ['auto', 'min', 'max']:
print('ModelCheckpoint mode %s is unknown, '
Expand All @@ -193,17 +195,35 @@ def __init__(self, filepath, monitor='val_loss', verbose=0,

if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
self.kth = np.Inf
self.mode = 'min'
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
self.kth = -np.Inf
self.mode = 'max'
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.best = -np.Inf
self.kth = -np.Inf
self.mode = 'max'
else:
self.monitor_op = np.less
self.best = np.Inf
self.kth = np.Inf
self.mode = 'min'

def del_model(self, filepath):
dirpath = '/'.join(filepath.split('/')[:-1])

# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)

for filename in os.listdir(dirpath):
if self.prefix in filename:
path_to_delete = os.path.join(dirpath, filename)
try:
shutil.rmtree(path_to_delete)
except OSError:
os.remove(path_to_delete)

def save_model(self, filepath, overwrite):
dirpath = '/'.join(filepath.split('/')[:-1])
Expand All @@ -225,29 +245,44 @@ def save_model(self, filepath, overwrite):

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
self.epochs_since_last_check += 1
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
filepath = '{}/{}_ckpt_epoch_{}.ckpt'.format(self.filepath, self.prefix, epoch + 1)
if self.save_best_only:
if self.save_top_k:
current = logs.get(self.monitor)
if current is None:
print('Can save best model only with %s available,'
' skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if (len(self.bestk.keys()) < self.save_top_k) or (self.monitor_op(current, self.bestk[self.kth])):
if len(self.bestk.keys()) == self.save_top_k:
# need to pop the kth
delpath = '{}/{}_ckpt_epoch_{}.ckpt'.format(self.filepath, self.prefix, self.kth + 1)
self.bestk.pop(self.kth)
self.del_model(delpath)
self.bestk[epoch] = current
if len(self.bestk.keys()) == self.save_top_k:
# monitor dict has reached k elements
if self.mode == 'min':
self.kth = max(self.bestk, key=self.bestk.get)
else:
self.kth = min(self.bestk, key=self.bestk.get)
if self.mode == 'min':
self.best = min(self.bestk.values())
else:
self.best = max(self.bestk.values())
if self.verbose > 0:
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
print('\nEpoch %05d: %s reached %s (best %s),'
' saving model to %s as top %d'
% (epoch + 1, self.monitor, current, self.best,
filepath, self.save_top_k))
self.save_model(filepath, overwrite=True)

else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve' %
(epoch + 1, self.monitor))
print('\nEpoch %05d: %s was not in top %d' %
(epoch + 1, self.monitor, self.save_top_k))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
Expand All @@ -262,3 +297,6 @@ def on_epoch_end(self, epoch, logs=None):
print(loss)
if should_stop:
break
w = ModelCheckpoint('res', save_top_k=2, verbose=1)
for i, loss in enumerate(losses):
w.on_epoch_end(i, logs={'val_loss': loss})

0 comments on commit 49e35b1

Please sign in to comment.