diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 0170b55e36afc..9c1da34b1b66f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -168,6 +168,14 @@ def training_step(self, batch, batch_idx, optimizer_idx): # Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hiddens from the previous truncated backprop step + ... + out, hiddens = self.lstm(data, hiddens) + ... + + return { + "loss": ..., + "hiddens": hiddens # remember to detach() this + } You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to break out of the current training epoch early. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5794ba617c99f..914ef27246ba1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -448,6 +448,10 @@ def __init__( # backprop every 5 steps in a batch trainer = Trainer(truncated_bptt_steps=5) + Using this feature requires updating your LightningModule's `training_step()` to include + a `hiddens` arg. + + resume_from_checkpoint (str): To resume training from a specific checkpoint pass in the path here.k Example::