Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support log_every_n_steps with validate|test #18895

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def setup(
)

def _get_and_log_device_stats(self, trainer: "pl.Trainer", key: str) -> None:
if not trainer._logger_connector.should_update_logs:
if not trainer._logger_connector.should_update_logs():
return

device = trainer.strategy.root_device
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _check_no_key(key: str) -> bool:
self.last_weight_decay_values = {name + "-weight_decay": None for name in names_flatten}

def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
if not trainer._logger_connector.should_update_logs:
if not trainer._logger_connector.should_update_logs():
return

if self.logging_interval != "epoch":
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
)

def _compute(self, trainer: "Trainer", iter_num: Optional[int] = None) -> None:
if not trainer._logger_connector.should_update_logs:
if not trainer._logger_connector.should_update_logs(iter_num):
return
stage = trainer.state.stage
assert stage is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,31 @@ def on_trainer_init(
self.trainer.log_every_n_steps = log_every_n_steps

@property
def should_update_logs(self) -> bool:
trainer = self.trainer
if trainer.log_every_n_steps == 0:
return False
if (loop := trainer._active_loop) is None:
return True
def current_step(self) -> int:
if (loop := self.trainer._active_loop) is None:
raise RuntimeError
if isinstance(loop, pl.loops._FitLoop):
# `+ 1` because it can be checked before a step is executed, for example, in `on_train_batch_start`
step = loop.epoch_loop._batches_that_stepped + 1
elif isinstance(loop, (pl.loops._EvaluationLoop, pl.loops._PredictionLoop)):
step = loop.batch_progress.current.ready
else:
raise NotImplementedError(loop)
should_log = step % trainer.log_every_n_steps == 0
return should_log or trainer.should_stop
return step

def should_update_logs(self, step: Optional[int] = None) -> bool:
trainer = self.trainer
if trainer.fast_dev_run:
return True
if trainer.log_every_n_steps == 0:
return False
if trainer.should_stop:
return True
if trainer.sanity_checking:
return False
if step is None:
step = self.current_step
return step % trainer.log_every_n_steps == 0

def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> None:
if not logger:
Expand All @@ -86,14 +96,13 @@ def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> Non
else:
self.trainer.loggers = [logger]

def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
"""Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses
metrics["step"] as a step.
def log_metrics(self, metrics: _OUT_DICT, step: int, add_epoch: bool = False) -> None:
"""Logs the metric dict passed in.

Args:
metrics: Metric values
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
the total validation / test log step count during validation and testing.
step: Step for which metrics should be logged.
add_epoch: Whether to add the current ``epoch``.

"""
if not self.trainer.loggers or not metrics:
Expand All @@ -104,13 +113,11 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
# turn all tensors to scalars
scalar_metrics = convert_tensors_to_scalars(metrics)

if step is None:
step = scalar_metrics.pop("step", None)

if step is None:
# added metrics for convenience
step = scalar_metrics.pop("step", step)
# this is for backwards compatibility
if add_epoch: # only enabled for training metrics
step -= 1
scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
Comment on lines +117 to 120
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this bit of logic is not well designed, but keeping it to limit the PRs scope

step = self.trainer.fit_loop.epoch_loop._batches_that_stepped

# log actual metrics
for logger in self.trainer.loggers:
Expand All @@ -128,8 +135,8 @@ def _evaluation_epoch_end(self) -> None:

def update_eval_step_metrics(self, step: int) -> None:
assert isinstance(self._first_loop_iter, bool)
# logs user requested information to logger
self.log_metrics(self.metrics["log"], step=step)
if self.should_update_logs(step):
self.log_metrics(self.metrics["log"], step=step)

def update_eval_epoch_metrics(self) -> _OUT_DICT:
assert self._first_loop_iter is None
Expand All @@ -147,7 +154,7 @@ def log_eval_end_metrics(self, metrics: _OUT_DICT) -> None:
return

# log all the metrics as a single dict
self.log_metrics(metrics)
self.log_metrics(metrics, self.current_step)

"""
Train metric updates
Expand All @@ -159,13 +166,14 @@ def update_train_step_metrics(self) -> None:

# when metrics should be logged
assert isinstance(self._first_loop_iter, bool)
if self.should_update_logs or self.trainer.fast_dev_run:
self.log_metrics(self.metrics["log"])
step = self.current_step
if self.should_update_logs(step):
self.log_metrics(self.metrics["log"], step, add_epoch=True)

def update_train_epoch_metrics(self) -> None:
# add the metrics to the loggers
assert self._first_loop_iter is None
self.log_metrics(self.metrics["log"])

self.log_metrics(self.metrics["log"], self.current_step, add_epoch=True)

# reset result collection for next epoch
self.reset_results()
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __init__(
Set it to `-1` to run all batches in all validation dataloaders.
Default: ``2``.

log_every_n_steps: How often to log within steps.
log_every_n_steps: How often to log within training/testing/validation/prediction steps.
Default: ``50``.

enable_checkpointing: If ``True``, enable checkpointing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,11 @@ def test_step(self, batch, batch_idx):
call(metrics={"train_loss": ANY, "epoch": 0}, step=0),
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=0),
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=1),
call(metrics={"valid_loss_0_epoch": ANY, "valid_loss_1": ANY, "epoch": 0}, step=0),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The step value here corresponded to the training step. Now it's for the validation step

call(metrics={"valid_loss_0_epoch": ANY, "valid_loss_1": ANY}, step=2),
call(metrics={"train_loss": ANY, "epoch": 1}, step=1),
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=2),
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=3),
call(metrics={"valid_loss_0_epoch": ANY, "valid_loss_1": ANY, "epoch": 1}, step=1),
call(metrics={"valid_loss_0_epoch": ANY, "valid_loss_1": ANY}, step=2),
]

def get_metrics_at_idx(idx):
Expand Down Expand Up @@ -737,7 +737,6 @@ def test_dataloader(self):

mock_call = mock_log_metrics.mock_calls[0]
logged_metrics = mock_call.kwargs["metrics"] if _PYTHON_GREATER_EQUAL_3_8_0 else mock_call[2]["metrics"]
cb_metrics.add("epoch")
assert set(logged_metrics) == cb_metrics


Expand Down Expand Up @@ -954,6 +953,7 @@ def test_dataloader(self):
default_root_dir=tmpdir,
max_epochs=max_epochs,
limit_train_batches=1,
log_every_n_steps=1,
limit_val_batches=limit_batches,
limit_test_batches=limit_batches,
logger=TensorBoardLogger(tmpdir),
Expand Down
Loading