Skip to content

Commit

Permalink
simplify model forward transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 26, 2020
1 parent e951aae commit fbf3eff
Showing 1 changed file with 16 additions and 26 deletions.
42 changes: 16 additions & 26 deletions pytorch_lightning/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch.nn as nn

import pytorch_lightning as pl

from pytorch_lightning.utilities import transfer_batch_to_device
from pytorch_lightning.utilities.apply_func import apply_to_collection

UNKNOWN_SIZE = 'unknown'

Expand Down Expand Up @@ -59,12 +60,14 @@ def _register_hook(self):
on the first forward pass. The hook will remove itself from the module, meaning that
recursive models will only record their input- and output shapes once.
"""

def hook(module, inp, out):
if len(inp) == 1:
inp = inp[0]
self._in_size = parse_batch_shape(inp)
self._out_size = parse_batch_shape(out)
self._hook_handle.remove() # hook detaches itself from module

return self._module.register_forward_hook(hook)

@property
Expand Down Expand Up @@ -176,40 +179,27 @@ def summarize(self) -> Dict[str, LayerSummary]:

def _forward_example_input(self) -> None:
""" Run the example input through each layer to get input- and output sizes. """
model = self._model
trainer = self._model.trainer

input_ = self._model.example_input_array
input_ = transfer_batch_to_device(input_, self._model.device)

# TODO: should rethink this to add support for GPU, TPU, AMP, ... and avoid code duplication
# or should it always be done on cpu?
if self._model.on_gpu:
device = next(self._model.parameters()).device
# test if input is a list or a tuple
if isinstance(input_, (list, tuple)):
input_ = [input_i.to(device) if torch.is_tensor(input_i) else input_i
for input_i in input_]
else:
input_ = input_.to(device)

# if model.trainer.use_amp and self.use_native_amp:
# model.forward = torch.cuda.amp.autocast()(model.forward)
if trainer is not None and trainer.use_amp:
if model.use_native_amp:
model.forward = torch.cuda.amp.autocast()(model.forward)

if self._model.trainer is not None and self._model.trainer.use_amp:
# test if it is not a list or a tuple
if isinstance(input_, (list, tuple)):
input_ = [input_i.half() if torch.is_tensor(input_i) else input_i
for input_i in input_]
else:
input_ = input_.half()
input_ = apply_to_collection(input_, torch.Tensor, lambda x: x.type(model.dtype))

mode = self._model.training
self._model.eval()
mode = model.training
model.eval()
with torch.no_grad():
# let the model hooks collect the input- and output shapes
if isinstance(input_, (list, tuple)):
self._model(*input_)
model(*input_)
else:
self._model(input_)
self._model.train(mode) # restore mode of module
model(input_)
model.train(mode) # restore mode of module

def __str__(self):
"""
Expand Down

0 comments on commit fbf3eff

Please sign in to comment.