Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace meta_tags.csv with hparams.yaml #1271

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
af9555a
Add support for hierarchical dict
S-aiueo32 Mar 14, 2020
29ae37c
Support nested Namespace
S-aiueo32 Mar 14, 2020
fc7fe7c
Resolving conflict with the current master
S-aiueo32 Mar 15, 2020
8ced285
Add docstring
S-aiueo32 Mar 15, 2020
c2e00a4
Migrate hparam flattening to each logger
S-aiueo32 Mar 15, 2020
117acd7
Modify URLs in CHANGELOG
S-aiueo32 Mar 15, 2020
e128430
typo
S-aiueo32 Mar 15, 2020
5497d10
Simplify the conditional branch about Namespace
S-aiueo32 Mar 18, 2020
891c600
Update CHANGELOG.md
S-aiueo32 Mar 18, 2020
d54c597
added examples section to docstring
S-aiueo32 Mar 18, 2020
52fe621
Merge branch 'feature/support-hierarchical-dict' of https://github.co…
S-aiueo32 Mar 18, 2020
377b7c4
renamed _dict -> input_dict
S-aiueo32 Mar 18, 2020
942895f
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
S-aiueo32 Mar 28, 2020
5e3d02e
Merge remote-tracking branch 'pl_origin/master'
S-aiueo32 Apr 26, 2020
ce7ec1b
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
S-aiueo32 May 12, 2020
1bbc90f
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
S-aiueo32 May 12, 2020
e339a9c
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
S-aiueo32 May 13, 2020
f83c998
mata_tags.csv -> hparams.yaml
S-aiueo32 Mar 28, 2020
28f9abd
code style fixes
S-aiueo32 Mar 28, 2020
41c5b9a
add pyyaml
S-aiueo32 Mar 28, 2020
8bc6478
remove unused import
S-aiueo32 Mar 29, 2020
f4264a1
create the member NAME_HPARAMS_FILE
S-aiueo32 Mar 29, 2020
c876dee
improve tests
S-aiueo32 Mar 29, 2020
1092b44
Update tensorboard.py
Borda Apr 16, 2020
d8d8655
pass the local test w/o relavents of Horovod
S-aiueo32 Apr 26, 2020
8b022e8
formatting
S-aiueo32 Apr 26, 2020
6073405
update dependencies
S-aiueo32 May 10, 2020
4d22c8d
fix dependencies
S-aiueo32 May 10, 2020
e7c8d0a
Apply suggestions from code review
Borda May 10, 2020
ec76bef
add savings
Borda May 10, 2020
19d7369
warn
Borda May 10, 2020
1bfdc06
docstrings
Borda May 10, 2020
4383b7e
tests
Borda May 10, 2020
290fd0f
Apply suggestions from code review
Borda May 10, 2020
1124490
saving
Borda May 10, 2020
cebd3ed
Apply suggestions from code review
Borda May 10, 2020
ecead94
use default
S-aiueo32 May 10, 2020
34f5b79
remove logging
S-aiueo32 May 10, 2020
f312053
typo fixes
S-aiueo32 May 10, 2020
2a165af
update docs
S-aiueo32 May 10, 2020
b4c40ad
update CHANGELOG
S-aiueo32 May 10, 2020
f1d4502
clean imports
S-aiueo32 May 10, 2020
02d2537
add blank lines
S-aiueo32 May 10, 2020
1fc4ffd
Update pytorch_lightning/core/lightning.py
S-aiueo32 May 11, 2020
63d7abd
Update pytorch_lightning/core/lightning.py
S-aiueo32 May 11, 2020
d4f28e6
back to namespace
S-aiueo32 May 11, 2020
53c0a40
add docs
S-aiueo32 May 11, 2020
95a28ad
test fix
S-aiueo32 May 12, 2020
5d27e34
update dependencies
S-aiueo32 May 12, 2020
df6cd16
add space
S-aiueo32 May 13, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Replace mata_tags.csv with hparams.yaml ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271))

- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))

- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))
Expand All @@ -36,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated `tags_csv` in favor of `hparams_file` ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271))

### Removed

### Fixed
Expand Down
7 changes: 3 additions & 4 deletions docs/source/test_set.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ To run the test set on a pre-trained model, use this method.

.. code-block:: python

model = MyLightningModule.load_from_metrics(
weights_path='/path/to/pytorch_checkpoint.ckpt',
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
on_gpu=True,
model = MyLightningModule.load_from_checkpoint(
checkpoint_path='/path/to/pytorch_checkpoint.ckpt',
hparams_file='/path/to/test_tube/experiment/version/hparams.yaml',
map_location=None
)

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- pytorch>=1.1
- tensorboard>=1.14
- future>=0.17.1
- pyyaml>=3.13

# For dev and testing
- tox
Expand Down
94 changes: 75 additions & 19 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import inspect
import os
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
Expand All @@ -16,7 +17,7 @@
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, update_hparams
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml, update_hparams
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -1438,49 +1439,88 @@ def load_from_checkpoint(
cls,
checkpoint_path: str,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
tags_csv: Optional[str] = None,
hparams_file: Optional[str] = None,
tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0
hparam_overrides: Optional[Dict] = None,
*args, **kwargs
) -> 'LightningModule':
r"""
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule`
with an argument called ``hparams`` which is a :class:`~argparse.Namespace`
(output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments).
with an argument called ``hparams`` which is an object of :class:`~dict` or
:class:`~argparse.Namespace` (output of :meth:`~argparse.ArgumentParser.parse_args`
when parsing command line arguments).
If you want `hparams` to have a hierarchical structure, you have to define it as :class:`~dict`.
Any other arguments specified through \*args and \*\*kwargs will be passed to the model.

Example:
.. code-block:: python

# define hparams as Namespace
from argparse import Namespace
hparams = Namespace(**{'learning_rate': 0.1})

model = MyModel(hparams)

class MyModel(LightningModule):
def __init__(self, hparams):
def __init__(self, hparams: Namespace):
self.learning_rate = hparams.learning_rate

# ----------

# define hparams as dict
hparams = {
drop_prob: 0.2,
dataloader: {
batch_size: 32
}
}

model = MyModel(hparams)

class MyModel(LightningModule):
def __init__(self, hparams: dict):
self.learning_rate = hparams['learning_rate']

Args:
checkpoint_path: Path to checkpoint.
model_args: Any keyword args needed to init the model.
map_location:
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup.
The behaviour is the same as in :func:`torch.load`.
tags_csv: Optional path to a .csv file with two columns (key, value)
hparams_file: Optional path to a .yaml file with hierarchical structure
as in this example::

key,value
drop_prob,0.2
batch_size,32
drop_prob: 0.2
dataloader:
batch_size: 32

You most likely won't need this since Lightning will always save the hyperparameters
to the checkpoint.
However, if your checkpoint weights don't have the hyperparameters saved,
use this method to pass in a .csv file with the hparams you'd like to use.
These will be converted into a :class:`~argparse.Namespace` and passed into your
use this method to pass in a .yaml file with the hparams you'd like to use.
These will be converted into a :class:`~dict` and passed into your
:class:`LightningModule` for use.
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

If your model's `hparams` argument is :class:`~argparse.Namespace`
and .yaml file has hierarchical structure, you need to refactor your model to treat
`hparams` as :class:`~dict`.

.csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.
tags_csv:
.. warning:: .. deprecated:: 0.7.6

`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0.

Optional path to a .csv file with two columns (key, value)
as in this example::

key,value
drop_prob,0.2
batch_size,32

Use this method to pass in a .csv file with the hparams you'd like to use.
hparam_overrides: A dictionary with keys to override in the hparams

Return:
Expand All @@ -1502,7 +1542,7 @@ def __init__(self, hparams):
# or load weights and hyperparameters from separate files.
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
tags_csv='/path/to/hparams_file.csv'
hparams_file='/path/to/hparams_file.yaml'
)

# override some of the params with new values
Expand Down Expand Up @@ -1531,9 +1571,22 @@ def __init__(self, hparams):

# add the hparams from csv file to checkpoint
if tags_csv is not None:
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
checkpoint['hparams'] = vars(hparams)
hparams_file = tags_csv
rank_zero_warn('`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0', DeprecationWarning)

if hparams_file is not None:
extension = hparams_file.split('.')[-1]
if extension.lower() in ('csv'):
hparams = load_hparams_from_tags_csv(hparams_file)
elif extension.lower() in ('yml', 'yaml'):
hparams = load_hparams_from_yaml(hparams_file)
else:
raise ValueError('.csv, .yml or .yaml is required for `hparams_file`')

hparams['on_gpu'] = False

# overwrite hparams by the given file
checkpoint['hparams'] = hparams

# override the hparam keys that were passed in
if hparam_overrides is not None:
Expand All @@ -1549,15 +1602,18 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh

if cls_takes_hparams:
if ckpt_hparams is not None:
is_namespace = checkpoint.get('hparams_type', 'namespace') == 'namespace'
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
hparams_type = checkpoint.get('hparams_type', 'Namespace')
if hparams_type.lower() == 'dict':
hparams = ckpt_hparams
elif hparams_type.lower() == 'namespace':
hparams = Namespace(**ckpt_hparams)
else:
rank_zero_warn(
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__"
" contains argument 'hparams'. Will pass in an empty Namespace instead."
" Did you forget to store your model hyperparameters in self.hparams?"
)
hparams = Namespace()
hparams = {}
else: # The user's LightningModule does not define a hparams argument
if ckpt_hparams is None:
hparams = None
Expand All @@ -1568,7 +1624,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh
)

# load the state_dict on the model automatically
if hparams:
if cls_takes_hparams:
kwargs.update(hparams=hparams)
model = cls(*args, **kwargs)
model.load_state_dict(checkpoint['state_dict'])
Expand Down
93 changes: 72 additions & 21 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import ast
import csv
import os
import yaml
from argparse import Namespace
from typing import Union, Dict, Any

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn


class ModelIO(object):
Expand Down Expand Up @@ -79,30 +82,78 @@ def update_hparams(hparams: dict, updates: dict) -> None:
hparams.update({k: v})


def load_hparams_from_tags_csv(tags_csv: str) -> Namespace:
def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
"""Load hparams from a file.

>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
>>> path_csv = './testing-hparams.csv'
>>> save_hparams_to_tags_csv(path_csv, hparams)
>>> hparams_new = load_hparams_from_tags_csv(path_csv)
>>> vars(hparams) == hparams_new
True
>>> os.remove(path_csv)
"""
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')
return Namespace()
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
return {}

with open(tags_csv) as f:
csv_reader = csv.reader(f, delimiter=',')
with open(tags_csv) as fp:
csv_reader = csv.reader(fp, delimiter=',')
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
ns = Namespace(**tags)
return ns

return tags


def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')

if isinstance(hparams, Namespace):
hparams = vars(hparams)

with open(tags_csv, 'w') as fp:
fieldnames = ['key', 'value']
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({'key': 'key', 'value': 'value'})
for k, v in hparams.items():
writer.writerow({'key': k, 'value': v})


def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
"""Load hparams from a file.

>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
>>> path_yaml = './testing-hparams.yaml'
>>> save_hparams_to_yaml(path_yaml, hparams)
>>> hparams_new = load_hparams_from_yaml(path_yaml)
>>> vars(hparams) == hparams_new
True
>>> os.remove(path_yaml)
"""
if not os.path.isfile(config_yaml):
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
return {}

with open(config_yaml) as fp:
tags = yaml.load(fp, Loader=yaml.SafeLoader)

return tags


def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')

if isinstance(hparams, Namespace):
hparams = vars(hparams)

with open(config_yaml, 'w', newline='') as fp:
yaml.dump(hparams, fp)


def convert(val: str) -> Union[int, float, bool, str]:
constructors = [int, float, str]

if isinstance(val, str):
if val.lower() == 'true':
return True
if val.lower() == 'false':
return False

for c in constructors:
try:
return c(val)
except ValueError:
pass
return val
try:
return ast.literal_eval(val)
except (ValueError, SyntaxError) as e:
log.debug(e)
return val
Loading