Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace Hparams by init args #1896

Merged
merged 101 commits into from
May 24, 2020
Merged
Show file tree
Hide file tree
Changes from 87 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
1fced53
remove the need for hparams
williamFalcon May 19, 2020
7fe5f13
remove the need for hparams
williamFalcon May 19, 2020
0283055
remove the need for hparams
williamFalcon May 19, 2020
599c9ad
remove the need for hparams
williamFalcon May 19, 2020
32c7435
replace self.hparams
williamFalcon May 19, 2020
29d3e0a
replace self.hparams
williamFalcon May 19, 2020
f508424
replace self.hparams
williamFalcon May 19, 2020
28b85bd
replace self.hparams
williamFalcon May 19, 2020
355eb7a
replace self.hparams
williamFalcon May 19, 2020
5cc272a
replace self.hparams
williamFalcon May 19, 2020
a5bcd1c
replace self.hparams
williamFalcon May 19, 2020
a4a7407
replace self.hparams
williamFalcon May 19, 2020
8f7e8a2
replace self.hparams
williamFalcon May 19, 2020
b1cd0b5
replace self.hparams
williamFalcon May 19, 2020
e97237e
replace self.hparams
williamFalcon May 19, 2020
a2f6cb5
replace self.hparams
williamFalcon May 19, 2020
9216d28
replace self.hparams
williamFalcon May 19, 2020
7cbc1b2
replace self.hparams
williamFalcon May 19, 2020
137ae13
replace self.hparams
williamFalcon May 19, 2020
b6a9336
replace self.hparams
williamFalcon May 19, 2020
6ea138c
replace self.hparams
williamFalcon May 19, 2020
485ce20
replace self.hparams
williamFalcon May 19, 2020
14dab1b
replace self.hparams
williamFalcon May 19, 2020
268277a
replace self.hparams
williamFalcon May 19, 2020
2111e4b
replace self.hparams
williamFalcon May 19, 2020
07a1c00
replace self.hparams
williamFalcon May 19, 2020
90a1226
replace self.hparams
williamFalcon May 19, 2020
4429d22
replace self.hparams
williamFalcon May 19, 2020
6060a02
replace self.hparams
williamFalcon May 19, 2020
6f856df
replace self.hparams
williamFalcon May 19, 2020
da385fe
replace self.hparams
williamFalcon May 19, 2020
065226d
replace self.hparams
williamFalcon May 19, 2020
f634a8e
replace self.hparams
williamFalcon May 19, 2020
34055b5
replace self.hparams
williamFalcon May 19, 2020
e05c11b
replace self.hparams
williamFalcon May 19, 2020
0937108
replace self.hparams
williamFalcon May 19, 2020
f6587ce
fixed
williamFalcon May 19, 2020
0303695
fixed
williamFalcon May 19, 2020
2b0ceb8
fixed
williamFalcon May 19, 2020
e226c88
fixed
williamFalcon May 19, 2020
ec00520
fixed
williamFalcon May 19, 2020
72793c3
fixed
williamFalcon May 19, 2020
840265d
fixed
williamFalcon May 19, 2020
4bb28fa
fixed
williamFalcon May 19, 2020
91569a8
fixed
williamFalcon May 19, 2020
509036e
fixed
williamFalcon May 19, 2020
a99ffb7
fixed
williamFalcon May 19, 2020
5c3ea20
fixed
williamFalcon May 19, 2020
9d08be3
fixed
williamFalcon May 19, 2020
0452418
fixed
williamFalcon May 19, 2020
0b5557f
finished moco
williamFalcon May 20, 2020
6cd5ea9
basic
williamFalcon May 20, 2020
ed1090c
testing
Borda May 20, 2020
295654e
todo
Borda May 20, 2020
1465a03
recurse
Borda May 20, 2020
91ab93e
hparams
Borda May 20, 2020
0519723
persist
Borda May 20, 2020
a19df1d
hparams
Borda May 20, 2020
1f87263
chlog
Borda May 20, 2020
f35eab0
tests
Borda May 20, 2020
3555e83
tests
Borda May 20, 2020
2a1b2dc
tests
Borda May 20, 2020
3c79ae3
tests
Borda May 21, 2020
5767188
tests
Borda May 21, 2020
cbb00b5
tests
Borda May 21, 2020
acc020f
review
Borda May 21, 2020
b3b6236
saving
Borda May 21, 2020
b97e0b1
tests
Borda May 22, 2020
5a4740a
tests
Borda May 22, 2020
2a6be20
tests
Borda May 22, 2020
e80b006
docs
Borda May 22, 2020
e50b78f
finished moco
williamFalcon May 22, 2020
fd7be0d
hparams
Borda May 22, 2020
c319528
review
Borda May 22, 2020
b313477
Apply suggestions from code review
Borda May 22, 2020
488d18a
hparams
Borda May 22, 2020
d24b78e
overwrite
Borda May 22, 2020
0d7ee37
transform
Borda May 22, 2020
fb7898a
transform
Borda May 22, 2020
3d8a3db
transform
Borda May 22, 2020
db6f943
transform
Borda May 23, 2020
66717da
cleaning
Borda May 23, 2020
088c3bd
cleaning
Borda May 23, 2020
dfd3a26
tests
Borda May 23, 2020
72f4cd0
examples
Borda May 23, 2020
2a8872b
examples
Borda May 23, 2020
5fe6f02
examples
Borda May 23, 2020
bad8d11
Apply suggestions from code review
Borda May 24, 2020
55b58f7
chp key
Borda May 24, 2020
5383014
tests
Borda May 24, 2020
1fd8cce
Apply suggestions from code review
Borda May 24, 2020
ab3be59
class
Borda May 24, 2020
10ca1a8
Merge branch 'no_hparams' of https://github.com/PyTorchLightning/pyto…
Borda May 24, 2020
8f57274
updated docs
williamFalcon May 24, 2020
20ad2ca
updated docs
williamFalcon May 24, 2020
f683160
updated docs
williamFalcon May 24, 2020
db5d1bf
updated docs
williamFalcon May 24, 2020
6af55e7
save
Borda May 24, 2020
04e8e67
wip
Borda May 24, 2020
a432f6e
fix
Borda May 24, 2020
2892e5a
flake8
Borda May 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,5 @@ mnist/
# pl tests
ml-runs/
*.zip
pytorch\ lightning
pytorch\ lightning
test-reports/
3 changes: 3 additions & 0 deletions .run_local_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ rm -rf ./tests/tests/*
rm -rf ./lightning_logs
python -m coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8
python -m coverage report -m

# specific file
# python -m coverage run --source pytorch_lightning -m py.test -k test_trainer.py --flake8
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed non-finite values from loss in `LRFinder` ([#1862](https://github.com/PyTorchLightning/pytorch-lightning/pull/1862))

- Allow passing model hyperparameters as complete kwarg list ([#1896](https://github.com/PyTorchLightning/pytorch-lightning/pull/1896))

### Deprecated

- Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917))
Expand Down
52 changes: 37 additions & 15 deletions docs/source/hyperparameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,22 @@ modify the network and read those values in the LightningModule

class LitMNIST(LightningModule):

def __init__(self, hparams):
def __init__(self, layer_1_dim, layer_2_dim, learning_rate, batch_size):
super().__init__()
self.layer_1_dim = layer_1_dim
self.layer_2_dim = layer_2_dim
self.learning_rate = learning_rate
self.batch_size = batch_size

# do this to save all arguments in any logger (tensorboard)
self.hparams = hparams

self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
self.layer_2 = torch.nn.Linear(hparams.layer_1_dim, hparams.layer_2_dim)
self.layer_3 = torch.nn.Linear(hparams.layer_2_dim, 10)
self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_dim)
self.layer_2 = torch.nn.Linear(self.layer_1_dim, self.layer_2_dim)
self.layer_3 = torch.nn.Linear(self.layer_2_dim, 10)

def train_dataloader(self):
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)
return DataLoader(mnist_train, batch_size=self.batch_size)

def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.learning_rate)
return Adam(self.parameters(), lr=self.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
Expand All @@ -136,14 +137,35 @@ Now pass in the params when you init your model

parser = ArgumentParser()
parser = LitMNIST.add_model_specific_args(parser)
hparams = parser.parse_args()
model = LitMNIST(hparams)
args = parser.parse_args()
model = LitMNIST(args)
Borda marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved

Within any LightningModule all the arguments you pass into your `__init__` will be available
simply with `self._module_arguments`. However, we won't overwrite any other arguments you have already defined.
Borda marked this conversation as resolved.
Show resolved Hide resolved
We will also add all of those values to the TensorBoard hparams tab (unless it's an object which
we won't). We also will store those values into checkpoints for you which you can use to init your
models.

.. code-block:: python

class LitMNIST(LightningModule):

def __init__(self, layer_1_dim, some_other_param):
super().__init__()
self.layer_1_dim = layer_1_dim
self.some_other_param = some_other_param

self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_dim)

# self.some_other_param is automatically available
Borda marked this conversation as resolved.
Show resolved Hide resolved
self.layer_2 = torch.nn.Linear(self.layer_1_dim, self.some_other_param)
self.layer_3 = torch.nn.Linear(self.some_other_param, 10)

self.some_other_param = 12
# but you can override it as normal
Borda marked this conversation as resolved.
Show resolved Hide resolved

The line `self.hparams = hparams` is very special. This line assigns your hparams to the LightningModule.
This does two things:
model = LitMNIST(10, 20)

1. It adds them automatically to TensorBoard logs under the hparams tab.
2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly.

Trainer args
^^^^^^^^^^^^
Expand Down
13 changes: 6 additions & 7 deletions docs/source/lr_finder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,17 @@ hyperparameters of the model.
# default: no automatic learning rate finder
trainer = Trainer(auto_lr_find=False)

When the ``lr`` or ``learning_rate`` key in hparams exists, this flag sets your learning_rate.
In both cases, if the respective fields are not found, an error will be thrown.

This flag sets your learning rate which can be accessed via ``self.lr`` or ``self.learning_rate``.

.. testcode::

class LitModel(LightningModule):

def __init__(self, hparams):
self.hparams = hparams
def __init__(self, learning_rate):
self.learning_rate = learning_rate

def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.lr|self.hparams.learning_rate)
return Adam(self.parameters(), lr=(self.hparams.lr or self.hparams.learning_rate))
Borda marked this conversation as resolved.
Show resolved Hide resolved

# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SkafteNicki once we merge this we need to update setting LR.
I think it should just set it in the model at self.lr or self.learning_rate

Expand Down Expand Up @@ -97,7 +96,7 @@ of this would look like

# update hparams of the model
model.hparams.lr = new_lr

# Fit model
trainer.fit(model)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ a binary search.
.. code-block:: python

def train_dataloader(self):
return DataLoader(train_dataset, batch_size=self.hparams.batch_size)
return DataLoader(train_dataset, batch_size=self.batch_size)

.. warning::

Expand Down
55 changes: 28 additions & 27 deletions docs/source/weights_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,20 @@ Or disable it by passing
trainer = Trainer(checkpoint_callback=False)


The Lightning checkpoint also saves the hparams (hyperparams) passed into the LightningModule init.
The Lightning checkpoint also saves the arguments passed into the LightningModule init
under the `module_arguments` key in the checkpoint.

.. note:: hparams is a `Namespace <https://docs.python.org/2/library/argparse.html#argparse.Namespace>`_.

.. testcode::

from argparse import Namespace
.. code-block:: python

# usually these come from command line args
args = Namespace(learning_rate=0.001)
class MyLightningModule(LightningModule):

# define you module to have hparams as the first arg
# this means your checkpoint will have everything that went into making
# this model (in this case, learning rate)
class MyLightningModule(LightningModule):
def __init__(self, learning_rate, *args, **kwargs):
super().__init__()

def __init__(self, hparams, *args, **kwargs):
self.hparams = hparams
# all init args were saved to the checkpoint
checkpoint = torch.load(CKPT_PATH)
print(checkpoint['module_arguments'])
# {'learning_rate': the_value}

Manual saving
^^^^^^^^^^^^^
Expand All @@ -92,37 +88,42 @@ You can manually save checkpoints and restore your model from the checkpointed s
Checkpoint Loading
------------------

To load a model along with its weights, biases and hyperparameters use following method.
To load a model along with its weights, biases and `module_arguments` use following method.
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

model = MyLightingModule.load_from_checkpoint(PATH)
model.eval()
y_hat = model(x)

The above only works if you used `hparams` in your model definition

.. testcode::

class LitModel(LightningModule):
print(model.learning_rate)
# prints the learning_rate you used in this checkpoint

def __init__(self, hparams):
self.hparams = hparams
self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim)
model.eval()
y_hat = model(x)

But if you don't and instead pass individual parameters
But if you don't want to use the values saved in the checkpoint, pass in your own here

.. testcode::

class LitModel(LightningModule):

def __init__(self, in_dim, out_dim):
self.l1 = nn.Linear(in_dim, out_dim)
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.l1 = nn.Linear(self.in_dim, self.out_dim)

you can restore the model like this

.. code-block:: python

# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)

# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)

# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)


Expand Down
68 changes: 41 additions & 27 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,37 @@ class TransferLearningModel(pl.LightningModule):
dl_path: Path where the data will be downloaded
"""
def __init__(self,
hparams: argparse.Namespace,
dl_path: Union[str, Path]) -> None:
dl_path: Union[str, Path],
backbone: str = 'resnet50',
train_bn: bool = True,
milestones: tuple = (5, 10),
batch_size: int = 8,
lr: float = 1e-2,
lr_scheduler_gamma: float = 1e-1,
num_workers: int = 6) -> None:
super().__init__()
self.hparams = hparams
self.dl_path = dl_path
self.backbone = backbone
self.train_bn = train_bn
self.milestones = milestones
self.batch_size = batch_size
self.lr = lr
self.lr_scheduler_gamma = lr_scheduler_gamma
self.num_workers = num_workers

self.dl_path = dl_path
self.__build_model()

def __build_model(self):
"""Define model layers & loss."""

# 1. Load pre-trained network:
model_func = getattr(models, self.hparams.backbone)
model_func = getattr(models, self.backbone)
backbone = model_func(pretrained=True)

_layers = list(backbone.children())[:-1]
self.feature_extractor = torch.nn.Sequential(*_layers)
freeze(module=self.feature_extractor, train_bn=self.hparams.train_bn)
freeze(module=self.feature_extractor, train_bn=self.train_bn)

# 2. Classifier:
_fc_layers = [torch.nn.Linear(2048, 256),
Expand Down Expand Up @@ -194,29 +208,29 @@ def train(self, mode=True):
super().train(mode=mode)

epoch = self.current_epoch
if epoch < self.hparams.milestones[0] and mode:
if epoch < self.milestones[0] and mode:
# feature extractor is frozen (except for BatchNorm layers)
freeze(module=self.feature_extractor,
train_bn=self.hparams.train_bn)
train_bn=self.train_bn)

elif self.hparams.milestones[0] <= epoch < self.hparams.milestones[1] and mode:
elif self.milestones[0] <= epoch < self.milestones[1] and mode:
# Unfreeze last two layers of the feature extractor
freeze(module=self.feature_extractor,
n=-2,
train_bn=self.hparams.train_bn)
train_bn=self.train_bn)

def on_epoch_start(self):
"""Use `on_epoch_start` to unfreeze layers progressively."""
optimizer = self.trainer.optimizers[0]
if self.current_epoch == self.hparams.milestones[0]:
if self.current_epoch == self.milestones[0]:
_unfreeze_and_add_param_group(module=self.feature_extractor[-2:],
optimizer=optimizer,
train_bn=self.hparams.train_bn)
train_bn=self.train_bn)

elif self.current_epoch == self.hparams.milestones[1]:
elif self.current_epoch == self.milestones[1]:
_unfreeze_and_add_param_group(module=self.feature_extractor[:-2],
optimizer=optimizer,
train_bn=self.hparams.train_bn)
train_bn=self.train_bn)

def training_step(self, batch, batch_idx):

Expand Down Expand Up @@ -246,7 +260,7 @@ def training_epoch_end(self, outputs):
for output in outputs]).mean()
train_acc_mean = torch.stack([output['num_correct']
for output in outputs]).sum().float()
train_acc_mean /= (len(outputs) * self.hparams.batch_size)
train_acc_mean /= (len(outputs) * self.batch_size)
return {'log': {'train_loss': train_loss_mean,
'train_acc': train_acc_mean,
'step': self.current_epoch}}
Expand All @@ -273,19 +287,19 @@ def validation_epoch_end(self, outputs):
for output in outputs]).mean()
val_acc_mean = torch.stack([output['num_correct']
for output in outputs]).sum().float()
val_acc_mean /= (len(outputs) * self.hparams.batch_size)
val_acc_mean /= (len(outputs) * self.batch_size)
return {'log': {'val_loss': val_loss_mean,
'val_acc': val_acc_mean,
'step': self.current_epoch}}

def configure_optimizers(self):
optimizer = optim.Adam(filter(lambda p: p.requires_grad,
self.parameters()),
lr=self.hparams.lr)
lr=self.lr)

scheduler = MultiStepLR(optimizer,
milestones=self.hparams.milestones,
gamma=self.hparams.lr_scheduler_gamma)
milestones=self.milestones,
gamma=self.lr_scheduler_gamma)

return [optimizer], [scheduler]

Expand Down Expand Up @@ -326,8 +340,8 @@ def __dataloader(self, train):

_dataset = self.train_dataset if train else self.valid_dataset
loader = DataLoader(dataset=_dataset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True if train else False)

return loader
Expand Down Expand Up @@ -397,28 +411,28 @@ def add_model_specific_args(parent_parser):
return parser


def main(hparams: argparse.Namespace) -> None:
def main(args: argparse.Namespace) -> None:
"""Train the model.

Args:
hparams: Model hyper-parameters
args: Model hyper-parameters

Note:
For the sake of the example, the images dataset will be downloaded
to a temporary directory.
"""

with TemporaryDirectory(dir=hparams.root_data_path) as tmp_dir:
with TemporaryDirectory(dir=args.root_data_path) as tmp_dir:

model = TransferLearningModel(hparams, dl_path=tmp_dir)
model = TransferLearningModel(dl_path=tmp_dir, **vars(args))

trainer = pl.Trainer(
weights_summary=None,
show_progress_bar=True,
num_sanity_val_steps=0,
gpus=hparams.gpus,
min_epochs=hparams.nb_epochs,
max_epochs=hparams.nb_epochs)
gpus=args.gpus,
min_epochs=args.nb_epochs,
max_epochs=args.nb_epochs)

trainer.fit(model)

Expand Down
Loading