Skip to content

Commit

Permalink
Type Hints for Lightning Core (Lightning-AI#946)
Browse files Browse the repository at this point in the history
* first pass for LightningModule typehints

* fix return types

* add missing types

* add type annotations to grads.py

* add type annotations to hooks.py

* add type annotation to memory.py

* proper docstring quotation marks

* add type annotations to saving.py

* fix cyclic import problem

* fix cyclic import problem

* add missing whitespace

* finish type hints for load_from_ methods

* docs: prepare_data does not return anything

* fix auto types in docs

* revert typehint for trainer in hook

* remove unnecessary return docs

* some fixes for memory docs

* revert typing for args kwargs

* added all missing None return types

* remove unused import

* add more details to dict/list return types

* fix line too long

* optimize imports

* linted

* Revert "linted"

This reverts commit 8555961.

* remove whitespace

* update

* update

* update

* update

* update

* changelog

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
3 people authored and tullie committed Apr 3, 2020
1 parent 1985c41 commit 963f87a
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 120 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))

### Changed
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/grads.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""
Module to describe gradients
"""
from typing import Dict

from torch import nn


class GradInformation(nn.Module):

def grad_norm(self, norm_type):
def grad_norm(self, norm_type: float) -> Dict[str, int]:
results = {}
total_norm = 0
for name, p in self.named_parameters():
Expand Down
39 changes: 16 additions & 23 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
3. Add the correct place in the :py:mod:`pytorch_lightning.models.trainer` where it should be called.
"""

from typing import Any

import torch

from torch import Tensor
from torch.optim.optimizer import Optimizer

try:
from apex import amp
Expand All @@ -36,48 +37,45 @@ def on_sanity_check_start(self):
:return:
"""

def on_train_start(self):
def on_train_start(self) -> None:
"""Called at the beginning of training before sanity check
:return:
"""
# do something at the start of training

def on_train_end(self):
def on_train_end(self) -> None:
"""
Called at the end of training before logger experiment is closed
:return:
"""
# do something at the end of training

def on_batch_start(self, batch):
def on_batch_start(self, batch: Any) -> None:
"""Called in the training loop before anything happens for that batch.
:param batch:
:return:
"""
# do something when the batch starts

def on_batch_end(self):
def on_batch_end(self) -> None:
"""Called in the training loop after the batch."""
# do something when the batch ends

def on_epoch_start(self):
def on_epoch_start(self) -> None:
"""Called in the training loop at the very beginning of the epoch."""
# do something when the epoch starts

def on_epoch_end(self):
def on_epoch_end(self) -> None:
"""Called in the training loop at the very end of the epoch."""
# do something when the epoch ends

def on_pre_performance_check(self):
def on_pre_performance_check(self) -> None:
"""Called at the very beginning of the validation loop."""
# do something before validation starts

def on_post_performance_check(self):
def on_post_performance_check(self) -> None:
"""Called at the very end of the validation loop."""
# do something before validation end

def on_before_zero_grad(self, optimizer):
def on_before_zero_grad(self, optimizer: Optimizer) -> None:
"""Called after optimizer.step() and before optimizer.zero_grad()
Called in the training loop after taking an optimizer step and before zeroing grads.
Expand All @@ -89,17 +87,13 @@ def on_before_zero_grad(self, optimizer):
model.on_before_zero_grad(optimizer) # < ---- called here
optimizer.zero_grad
:param optimizer:
:return:
:param optimizer: The optimizer for which grads should be zeroed.
"""
# do something with the optimizer or inspect it.

def on_after_backward(self):
"""Called after loss.backward() and before optimizers do anything.
:return:
def on_after_backward(self) -> None:
"""Called in the training loop after loss.backward() and before optimizers do anything.
Called in the training loop after model.backward()
This is the ideal place to inspect or log gradient information
.. code-block:: python
Expand All @@ -116,14 +110,13 @@ def on_after_backward(self):
"""

def backward(self, trainer, loss, optimizer, optimizer_idx):
def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Override backward with your own implementation if you need to
:param trainer: Pointer to the trainer
:param loss: Loss is already scaled by accumulated grads
:param optimizer: Current optimizer being used
:param optimizer_idx: Index of the current optimizer being used
:return:
Called to perform backward step.
Feel free to override as needed.
Expand Down
Loading

0 comments on commit 963f87a

Please sign in to comment.