Skip to content

Commit

Permalink
replace Hparams by init args (#1896)
Browse files Browse the repository at this point in the history
* remove the need for hparams

* remove the need for hparams

* remove the need for hparams

* remove the need for hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* replace self.hparams

* fixed

* fixed

* fixed

* fixed

* fixed

* fixed

* fixed

* fixed

* fixed

* fixed

* fixed

* fixed

* fixed

* fixed

* finished moco

* basic

* testing

* todo

* recurse

* hparams

* persist

* hparams

* chlog

* tests

* tests

* tests

* tests

* tests

* tests

* review

* saving

* tests

* tests

* tests

* docs

* finished moco

* hparams

* review

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* hparams

* overwrite

* transform

* transform

* transform

* transform

* cleaning

* cleaning

* tests

* examples

* examples

* examples

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* chp key

* tests

* Apply suggestions from code review

* class

* updated docs

* updated docs

* updated docs

* updated docs

* save

* wip

* fix

* flake8

Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
4 people committed May 24, 2020
1 parent a20db4e commit caa9c67
Show file tree
Hide file tree
Showing 38 changed files with 687 additions and 542 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,5 @@ mnist/
# pl tests
ml-runs/
*.zip
pytorch\ lightning
pytorch\ lightning
test-reports/
3 changes: 3 additions & 0 deletions .run_local_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ rm -rf ./tests/tests/*
rm -rf ./lightning_logs
python -m coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8
python -m coverage report -m

# specific file
# python -m coverage run --source pytorch_lightning -m py.test -k test_trainer.py --flake8
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed non-finite values from loss in `LRFinder` ([#1862](https://github.com/PyTorchLightning/pytorch-lightning/pull/1862))

- Allow passing model hyperparameters as complete kwarg list ([#1896](https://github.com/PyTorchLightning/pytorch-lightning/pull/1896))

### Deprecated

- Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917))
Expand Down
122 changes: 86 additions & 36 deletions docs/source/hyperparameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Now in your main trainer file, add the Trainer args, the program args, and add t
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
parser = Trainer.add_argparse_args(parser)

hparams = parser.parse_args()
args = parser.parse_args()

Now you can call run your program like so

Expand All @@ -87,39 +87,50 @@ Finally, make sure to start the training like so:

.. code-block:: python
# YES
model = LitModel(hparams)
trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...)
# init the trainer like this
trainer = Trainer.from_argparse_args(args, early_stopping_callback=...)
# NOT like this
trainer = Trainer(gpus=hparams.gpus, ...)
# init the model with Namespace directly
model = LitModel(args)
# or init the model with all the key-value pairs
dict_args = vars(args)
model = LitModel(**dict_args)
# NO
# model = LitModel(learning_rate=hparams.learning_rate, ...)
# trainer = Trainer(gpus=hparams.gpus, ...)
LightningModule hyperparameters
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

LightningModule hparams
^^^^^^^^^^^^^^^^^^^^^^^
.. warning:: The use of `hparams` is no longer recommended (but still supported)

Normally, we don't hard-code the values to a model. We usually use the command line to
modify the network and read those values in the LightningModule
LightningModule is just an nn.Module, you can use it as you normally would. However, there are
some best practices to improve readability and reproducibility.

1. It's more readable to specify all the arguments that go into a module (with default values).
This helps users of your module know everything that is required to run this.

.. testcode::

class LitMNIST(LightningModule):

def __init__(self, hparams):
def __init__(self, layer_1_dim=128, layer_2_dim=256, learning_rate=1e-4, batch_size=32, **kwargs):
super().__init__()
self.layer_1_dim = layer_1_dim
self.layer_2_dim = layer_2_dim
self.learning_rate = learning_rate
self.batch_size = batch_size
# do this to save all arguments in any logger (tensorboard)
self.hparams = hparams

self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
self.layer_2 = torch.nn.Linear(hparams.layer_1_dim, hparams.layer_2_dim)
self.layer_3 = torch.nn.Linear(hparams.layer_2_dim, 10)
self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_dim)
self.layer_2 = torch.nn.Linear(self.layer_1_dim, self.layer_2_dim)
self.layer_3 = torch.nn.Linear(self.layer_2_dim, 10)

def train_dataloader(self):
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)
return DataLoader(mnist_train, batch_size=self.batch_size)

def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.learning_rate)
return Adam(self.parameters(), lr=self.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
Expand All @@ -130,20 +141,59 @@ modify the network and read those values in the LightningModule
parser.add_argument('--learning_rate', type=float, default=0.002)
return parser

Now pass in the params when you init your model
2. You can also pass in a dict or Namespace, but this obscures the parameters your module is looking
for. The user would have to search the file to find what is parametrized.

.. code-block:: python
# using a argparse.Namespace
class LitMNIST(LightningModule):
def __init__(self, hparams, *args, **kwargs):
super().__init__()
self.hparams = hparams
self.layer_1 = torch.nn.Linear(28 * 28, self.hparams.layer_1_dim)
self.layer_2 = torch.nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
self.layer_3 = torch.nn.Linear(self.hparams.layer_2_dim, 10)
def train_dataloader(self):
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)
One way to get around this is to convert a Namespace or dict into key-value pairs using `**`

.. code-block:: python
parser = ArgumentParser()
parser = LitMNIST.add_model_specific_args(parser)
hparams = parser.parse_args()
model = LitMNIST(hparams)
args = parser.parse_args()
dict_args = vars(args)
model = LitMNIST(**dict_args)
Within any LightningModule all the arguments you pass into your `__init__` will be stored in
the checkpoint so that you know all the values that went into creating this model.

We will also add all of those values to the TensorBoard hparams tab (unless it's an object which
we won't). We also will store those values into checkpoints for you which you can use to init your
models.

.. code-block:: python
The line `self.hparams = hparams` is very special. This line assigns your hparams to the LightningModule.
This does two things:
class LitMNIST(LightningModule):
def __init__(self, layer_1_dim, some_other_param):
super().__init__()
self.layer_1_dim = layer_1_dim
self.some_other_param = some_other_param
self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_dim)
self.layer_2 = torch.nn.Linear(self.layer_1_dim, self.some_other_param)
self.layer_3 = torch.nn.Linear(self.some_other_param, 10)
model = LitMNIST(10, 20)
1. It adds them automatically to TensorBoard logs under the hparams tab.
2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly.
Trainer args
^^^^^^^^^^^^
Expand Down Expand Up @@ -171,27 +221,27 @@ polluting the main.py file, the LightningModule lets you define arguments for ea

class LitMNIST(LightningModule):

def __init__(self, hparams):
def __init__(self, layer_1_dim, **kwargs):
super().__init__()
self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
self.layer_1 = torch.nn.Linear(28 * 28, layer_1_dim)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser])
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--layer_1_dim', type=int, default=128)
return parser

.. testcode::

class GoodGAN(LightningModule):

def __init__(self, hparams):
def __init__(self, encoder_layers, **kwargs):
super().__init__()
self.encoder = Encoder(layers=hparams.encoder_layers)
self.encoder = Encoder(layers=encoder_layers)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser])
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--encoder_layers', type=int, default=12)
return parser

Expand All @@ -201,14 +251,14 @@ Now we can allow each model to inject the arguments it needs in the ``main.py``
.. code-block:: python
def main(args):
dict_args = vars(args)
# pick model
if args.model_name == 'gan':
model = GoodGAN(hparams=args)
model = GoodGAN(**dict_args)
elif args.model_name == 'mnist':
model = LitMNIST(hparams=args)
model = LitMNIST(**dict_args)
model = LitMNIST(hparams=args)
trainer = Trainer.from_argparse_args(args)
trainer.fit(model)
Expand Down
13 changes: 6 additions & 7 deletions docs/source/lr_finder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,17 @@ hyperparameters of the model.
# default: no automatic learning rate finder
trainer = Trainer(auto_lr_find=False)

When the ``lr`` or ``learning_rate`` key in hparams exists, this flag sets your learning_rate.
In both cases, if the respective fields are not found, an error will be thrown.

This flag sets your learning rate which can be accessed via ``self.lr`` or ``self.learning_rate``.

.. testcode::

class LitModel(LightningModule):

def __init__(self, hparams):
self.hparams = hparams
def __init__(self, learning_rate):
self.learning_rate = learning_rate

def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.lr|self.hparams.learning_rate)
return Adam(self.parameters(), lr=(self.lr or self.learning_rate))

# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
Expand Down Expand Up @@ -97,7 +96,7 @@ of this would look like
# update hparams of the model
model.hparams.lr = new_lr
# Fit model
trainer.fit(model)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ a binary search.
.. code-block:: python
def train_dataloader(self):
return DataLoader(train_dataset, batch_size=self.hparams.batch_size)
return DataLoader(train_dataset, batch_size=self.batch_size)
.. warning::

Expand Down
55 changes: 28 additions & 27 deletions docs/source/weights_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,20 @@ Or disable it by passing
trainer = Trainer(checkpoint_callback=False)


The Lightning checkpoint also saves the hparams (hyperparams) passed into the LightningModule init.
The Lightning checkpoint also saves the arguments passed into the LightningModule init
under the `module_arguments` key in the checkpoint.

.. note:: hparams is a `Namespace <https://docs.python.org/2/library/argparse.html#argparse.Namespace>`_.

.. testcode::

from argparse import Namespace
.. code-block:: python
# usually these come from command line args
args = Namespace(learning_rate=0.001)
class MyLightningModule(LightningModule):
# define you module to have hparams as the first arg
# this means your checkpoint will have everything that went into making
# this model (in this case, learning rate)
class MyLightningModule(LightningModule):
def __init__(self, learning_rate, *args, **kwargs):
super().__init__()
def __init__(self, hparams, *args, **kwargs):
self.hparams = hparams
# all init args were saved to the checkpoint
checkpoint = torch.load(CKPT_PATH)
print(checkpoint['module_arguments'])
# {'learning_rate': the_value}
Manual saving
^^^^^^^^^^^^^
Expand All @@ -92,37 +88,42 @@ You can manually save checkpoints and restore your model from the checkpointed s
Checkpoint Loading
------------------

To load a model along with its weights, biases and hyperparameters use following method.
To load a model along with its weights, biases and `module_arguments` use following method.

.. code-block:: python
model = MyLightingModule.load_from_checkpoint(PATH)
model.eval()
y_hat = model(x)
The above only works if you used `hparams` in your model definition
.. testcode::

class LitModel(LightningModule):
print(model.learning_rate)
# prints the learning_rate you used in this checkpoint
def __init__(self, hparams):
self.hparams = hparams
self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim)
model.eval()
y_hat = model(x)
But if you don't and instead pass individual parameters
But if you don't want to use the values saved in the checkpoint, pass in your own here

.. testcode::

class LitModel(LightningModule):

def __init__(self, in_dim, out_dim):
self.l1 = nn.Linear(in_dim, out_dim)
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.l1 = nn.Linear(self.in_dim, self.out_dim)

you can restore the model like this

.. code-block:: python
# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)
# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
Expand Down
Loading

0 comments on commit caa9c67

Please sign in to comment.