Skip to content

Commit

Permalink
Do not override the logged epoch in logged_metrics (#7982)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Jun 16, 2021
1 parent 2134216 commit bc2c2db
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 23 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Changed `WandbLogger(log_model={True/'all'})` to log models as artifacts ([#6231](https://github.com/PyTorchLightning/pytorch-lightning/pull/6231))


- MLFlowLogger now accepts `run_name` as an constructor argument ([#7622](https://github.com/PyTorchLightning/pytorch-lightning/issues/7622))


Expand Down Expand Up @@ -255,6 +257,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))


- Do not override the existing `epoch` value in `logged_metrics` when already logged by the user ([#7982](https://github.com/PyTorchLightning/pytorch-lightning/issues/7982))


- Support manual optimization with DeepSpeed ([#7970](https://github.com/PyTorchLightning/pytorch-lightning/pull/7970))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
the total validation / test log step count during validation and testing.
"""
if self.trainer.logger is None or not metrics:
return

# add gpu memory
if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
mem_map = memory.get_memory_profile(self.log_gpu_memory)
Expand All @@ -99,21 +102,19 @@ def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -
# turn all tensors to scalars
scalar_metrics = metrics_to_scalars(metrics)

if "step" in scalar_metrics and step is None:
step = scalar_metrics.pop("step")

elif step is None:
# added metrics by Lightning for convenience
scalar_metrics['epoch'] = self.trainer.current_epoch
if step is None:
step = scalar_metrics.pop("step", None)
if step is None:
# added metrics for convenience
scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
step = self.trainer.global_step

# log actual metrics
if self.trainer.logger is not None:
if self.trainer.is_global_zero:
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.trainer.logger.save()
if self.trainer.is_global_zero:
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.trainer.logger.save()

self._logged_metrics.update(scalar_metrics)
self._logged_metrics.update(scalar_metrics)

"""
Evaluation metric updates
Expand Down Expand Up @@ -149,9 +150,7 @@ def update_eval_step_metrics(self) -> None:

# logs user requested information to logger
assert not self._epoch_end_reached
metrics = self.metrics[MetricSource.LOG]
if metrics:
self.log_metrics(metrics, step=self._eval_log_step)
self.log_metrics(self.metrics[MetricSource.LOG], step=self._eval_log_step)

# increment the step even if nothing was logged
self._increment_eval_log_step()
Expand Down Expand Up @@ -179,9 +178,7 @@ def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT:

if not self.trainer.sanity_checking:
# log all the metrics as a single dict
log_metrics = metrics[MetricSource.LOG]
if log_metrics:
self.log_metrics(log_metrics)
self.log_metrics(metrics[MetricSource.LOG])

self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK])

Expand Down Expand Up @@ -219,16 +216,13 @@ def update_train_step_metrics(self) -> None:

# when metrics should be logged
assert not self._epoch_end_reached
metrics = self.metrics[MetricSource.LOG]
if self.should_update_logs or self.trainer.fast_dev_run is True and metrics:
self.log_metrics(metrics)
if self.should_update_logs or self.trainer.fast_dev_run:
self.log_metrics(self.metrics[MetricSource.LOG])

def update_train_epoch_metrics(self) -> None:
# add the metrics to the loggers
assert self._epoch_end_reached
metrics = self.metrics[MetricSource.LOG]
if metrics:
self.log_metrics(metrics)
self.log_metrics(self.metrics[MetricSource.LOG])

# reset result collection for next epoch
self.trainer._results.reset(metrics=True)
Expand Down
16 changes: 16 additions & 0 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,19 @@ def test_result_collection_on_tensor_with_mean_reduction():
'loss_on_step_on_epoch_prog_bar_logger': mean,
'loss_on_step_on_epoch_prog_bar_logger_epoch': mean
}


def test_logged_metrics_has_logged_epoch_value(tmpdir):

class TestModel(BoringModel):

def training_step(self, batch, batch_idx):
self.log('epoch', -batch_idx, logger=True)
return super().training_step(batch, batch_idx)

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
trainer.fit(model)

# should not get overridden if logged manually
assert trainer.logged_metrics == {'epoch': -1}

0 comments on commit bc2c2db

Please sign in to comment.