Skip to content

Commit

Permalink
add more detail to tbptt example (#755)
Browse files Browse the repository at this point in the history
* add more detail to tbptt example

* warn user about new arg in training_step
  • Loading branch information
jeremyjordan authored Feb 1, 2020
1 parent 76a1c67 commit 589815f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
8 changes: 8 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ def training_step(self, batch, batch_idx, optimizer_idx):
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# hiddens are the hiddens from the previous truncated backprop step
...
out, hiddens = self.lstm(data, hiddens)
...
return {
"loss": ...,
"hiddens": hiddens # remember to detach() this
}
You can also return a -1 instead of a dict to stop the current loop. This is useful
if you want to break out of the current training epoch early.
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ def __init__(
# backprop every 5 steps in a batch
trainer = Trainer(truncated_bptt_steps=5)
Using this feature requires updating your LightningModule's `training_step()` to include
a `hiddens` arg.
resume_from_checkpoint (str): To resume training from a specific checkpoint pass in the path here.k
Example::
Expand Down

0 comments on commit 589815f

Please sign in to comment.