Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix on_train_batch_start hook to end epoch early #3700

Merged
merged 6 commits into from
Oct 2, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `on_train_batch_start` hook to end epoch early ([#3700](https://github.com/PyTorchLightning/pytorch-lightning/pull/3700))

- Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917))

- Fixed RMSLE metric ([#3188](https://github.com/PyTorchLightning/pytorch-lightning/pull/3188))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/early_stopping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Early stopping

Stopping an epoch early
-----------------------
You can stop an epoch early by overriding :meth:`~pytorch_lightning.core.lightning.LightningModule.on_batch_start` to return ``-1`` when some condition is met.
You can stop an epoch early by overriding :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start` to return ``-1`` when some condition is met.

If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire run.

Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,10 @@ def run_training_epoch(self):
# ------------------------------------
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)

# when returning -1 from train_step, we end epoch early
if batch_output.signal == -1:
break

# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
epoch_end_outputs = self.process_train_step_outputs(
Expand All @@ -529,9 +533,6 @@ def run_training_epoch(self):
# TODO: add outputs to batches
self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)

# when returning -1 from train_step, we end epoch early
self.trainer.should_stop = batch_output.signal == -1

# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
Expand Down
25 changes: 25 additions & 0 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,28 @@ def transfer_batch_to_device(self, data, device):
expected = torch.device('cuda', 0)
assert model.hook_called
assert batch_gpu.samples.device == batch_gpu.targets.device == expected


@pytest.mark.parametrize(
['max_epochs', 'batch_idx_'],
[
pytest.param(2, 5),
pytest.param(3, 8),
pytest.param(4, 12)
]
)
def test_on_train_batch_start_hook(max_epochs, batch_idx_):
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
class CurrentModel(EvalModelTemplate):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
if batch_idx == batch_idx_:
return -1

model = CurrentModel()
trainer = Trainer(max_epochs=max_epochs)
trainer.fit(model)
if batch_idx_ > len(model.val_dataloader()) - 1:
assert trainer.batch_idx == len(model.val_dataloader()) - 1
assert trainer.global_step == len(model.val_dataloader()) * max_epochs
else:
assert trainer.batch_idx == batch_idx_
assert trainer.global_step == (batch_idx_ + 1) * max_epochs