Skip to content

Commit

Permalink
EvalResult support for val loop (PR 3/5) (#2651)
Browse files Browse the repository at this point in the history
* add EvalResult to support to val/test loops
  • Loading branch information
williamFalcon authored Jul 22, 2020
1 parent a3934ad commit 62ce00f
Show file tree
Hide file tree
Showing 21 changed files with 991 additions and 178 deletions.
24 changes: 15 additions & 9 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,26 +134,32 @@ def load_state_dict(self, state_dict):
self.best_score = state_dict['best_score']
self.patience = state_dict['patience']

def on_sanity_check_end(self, trainer, pl_module):
logs = trainer.callback_metrics
self._validate_condition_metric(logs)

def on_validation_end(self, trainer, pl_module):
self._run_early_stopping_check(trainer, pl_module)

def on_validation_epoch_end(self, trainer, pl_module):
val_es_key = 'val_early_stop_on'
if trainer.callback_metrics.get(val_es_key) is not None:
self.monitor = val_es_key

# disable strict checking when using structured results
if val_es_key in trainer.callback_metrics:
self.strict = False

self._validate_condition_metric(trainer.callback_metrics)

def on_train_epoch_end(self, trainer, pl_module):
# disable early stopping in train loop when there's a val loop
if self.monitor == 'val_early_stop_on':
return

# early stopping can also work in the train loop when there is no val loop and when using structured results
should_check_early_stop = False
train_es_key = 'early_stop_on'
if trainer.callback_metrics.get(train_es_key, None) is not None:
self.monitor = train_es_key
should_check_early_stop = True

val_es_key = 'val_early_stop_on'
if trainer.callback_metrics.get(val_es_key, None) is not None:
self.monitor = val_es_key
should_check_early_stop = True

if should_check_early_stop:
self._run_early_stopping_check(trainer, pl_module)

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def on_validation_end(self, trainer, pl_module):
if metrics.get('checkpoint_on') is not None:
self.monitor = 'checkpoint_on'

# conditioned val metrics override conditioned train loop metrics
if metrics.get('val_checkpoint_on') is not None:
self.monitor = 'val_checkpoint_on'

if self.save_top_k == 0:
# no models are saved
return
Expand Down
43 changes: 29 additions & 14 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,23 @@ def log(
if 'meta' not in self:
self.__setitem__('meta', {})

self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)

# set the value
self.__setitem__(name, value)
# if user requests both step and epoch, then we split the metric in two automatically
# one will be logged per step. the other per epoch
if on_step and on_epoch:
# set step version
step_name = f'step_{name}'
self.__set_meta(step_name, value, prog_bar, logger, on_step=True, on_epoch=False, reduce_fx=reduce_fx)
self.__setitem__(step_name, value)

# set epoch version
epoch_name = f'epoch_{name}'
self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True, reduce_fx=reduce_fx)
self.__setitem__(epoch_name, value)
else:
self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)

# set the value
self.__setitem__(name, value)

def __set_meta(
self,
Expand All @@ -111,7 +124,7 @@ def __set_meta(
on_step: bool,
on_epoch: bool,
reduce_fx: Callable,
):
):
# set the meta for the item
meta_value = value
meta = dict(
Expand All @@ -122,6 +135,7 @@ def __set_meta(
reduce_fx=reduce_fx,
value=meta_value
)

self['meta'][name] = meta

# track whether any input requires reduction on epoch end
Expand Down Expand Up @@ -219,11 +233,13 @@ def __copy__(self):

@classmethod
def gather(cls, outputs):
meta = outputs[0]['meta']
meta = outputs[0].get('meta')
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)
result['meta'] = meta

if meta:
result['meta'] = meta
return result

@classmethod
Expand Down Expand Up @@ -326,11 +342,10 @@ def log(
):
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)

def get_callback_metrics(self) -> dict:
result = {
'val_early_stop_on': self.early_stop_on,
'val_checkpoint_on': self.checkpoint_on
}

# if __name__ == '__main__':
# import torch
# result = TrainResult()
# result.hiddens = torch.tensor(1)
# result.log('some', 123)
# print(result)
# result.minimize = torch.tensor(1)
return result
Loading

0 comments on commit 62ce00f

Please sign in to comment.