Skip to content

Commit

Permalink
Replace meta_tags.csv with hparams.yaml (#1271)
Browse files Browse the repository at this point in the history
* Add support for hierarchical dict

* Support nested Namespace

* Add docstring

* Migrate hparam flattening to each logger

* Modify URLs in CHANGELOG

* typo

* Simplify the conditional branch about Namespace

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update CHANGELOG.md

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* added examples section to docstring

* renamed _dict -> input_dict

* mata_tags.csv -> hparams.yaml

* code style fixes

* add pyyaml

* remove unused import

* create the member NAME_HPARAMS_FILE

* improve tests

* Update tensorboard.py

* pass the local test w/o relavents of Horovod

* formatting

* update dependencies

* fix dependencies

* Apply suggestions from code review

* add savings

* warn

* docstrings

* tests

* Apply suggestions from code review

* saving

* Apply suggestions from code review

* use default

* remove logging

* typo fixes

* update docs

* update CHANGELOG

* clean imports

* add blank lines

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* back to namespace

* add docs

* test fix

* update dependencies

* add space

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
3 people committed May 13, 2020
1 parent 35fe2ef commit 22d7d03
Show file tree
Hide file tree
Showing 12 changed files with 226 additions and 79 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Replace mata_tags.csv with hparams.yaml ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271))

- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))

- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))
Expand All @@ -36,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated `tags_csv` in favor of `hparams_file` ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271))

### Removed

### Fixed
Expand Down
7 changes: 3 additions & 4 deletions docs/source/test_set.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ To run the test set on a pre-trained model, use this method.

.. code-block:: python
model = MyLightningModule.load_from_metrics(
weights_path='/path/to/pytorch_checkpoint.ckpt',
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
on_gpu=True,
model = MyLightningModule.load_from_checkpoint(
checkpoint_path='/path/to/pytorch_checkpoint.ckpt',
hparams_file='/path/to/test_tube/experiment/version/hparams.yaml',
map_location=None
)
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- pytorch>=1.1
- tensorboard>=1.14
- future>=0.17.1
- pyyaml>=3.13

# For dev and testing
- tox
Expand Down
94 changes: 75 additions & 19 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import inspect
import os
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
Expand All @@ -16,7 +17,7 @@
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, update_hparams
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml, update_hparams
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -1438,49 +1439,88 @@ def load_from_checkpoint(
cls,
checkpoint_path: str,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
tags_csv: Optional[str] = None,
hparams_file: Optional[str] = None,
tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0
hparam_overrides: Optional[Dict] = None,
*args, **kwargs
) -> 'LightningModule':
r"""
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule`
with an argument called ``hparams`` which is a :class:`~argparse.Namespace`
(output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments).
with an argument called ``hparams`` which is an object of :class:`~dict` or
:class:`~argparse.Namespace` (output of :meth:`~argparse.ArgumentParser.parse_args`
when parsing command line arguments).
If you want `hparams` to have a hierarchical structure, you have to define it as :class:`~dict`.
Any other arguments specified through \*args and \*\*kwargs will be passed to the model.
Example:
.. code-block:: python
# define hparams as Namespace
from argparse import Namespace
hparams = Namespace(**{'learning_rate': 0.1})
model = MyModel(hparams)
class MyModel(LightningModule):
def __init__(self, hparams):
def __init__(self, hparams: Namespace):
self.learning_rate = hparams.learning_rate
# ----------
# define hparams as dict
hparams = {
drop_prob: 0.2,
dataloader: {
batch_size: 32
}
}
model = MyModel(hparams)
class MyModel(LightningModule):
def __init__(self, hparams: dict):
self.learning_rate = hparams['learning_rate']
Args:
checkpoint_path: Path to checkpoint.
model_args: Any keyword args needed to init the model.
map_location:
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup.
The behaviour is the same as in :func:`torch.load`.
tags_csv: Optional path to a .csv file with two columns (key, value)
hparams_file: Optional path to a .yaml file with hierarchical structure
as in this example::
key,value
drop_prob,0.2
batch_size,32
drop_prob: 0.2
dataloader:
batch_size: 32
You most likely won't need this since Lightning will always save the hyperparameters
to the checkpoint.
However, if your checkpoint weights don't have the hyperparameters saved,
use this method to pass in a .csv file with the hparams you'd like to use.
These will be converted into a :class:`~argparse.Namespace` and passed into your
use this method to pass in a .yaml file with the hparams you'd like to use.
These will be converted into a :class:`~dict` and passed into your
:class:`LightningModule` for use.
If your model's `hparams` argument is :class:`~argparse.Namespace`
and .yaml file has hierarchical structure, you need to refactor your model to treat
`hparams` as :class:`~dict`.
.csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.
tags_csv:
.. warning:: .. deprecated:: 0.7.6
`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0.
Optional path to a .csv file with two columns (key, value)
as in this example::
key,value
drop_prob,0.2
batch_size,32
Use this method to pass in a .csv file with the hparams you'd like to use.
hparam_overrides: A dictionary with keys to override in the hparams
Return:
Expand All @@ -1502,7 +1542,7 @@ def __init__(self, hparams):
# or load weights and hyperparameters from separate files.
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
tags_csv='/path/to/hparams_file.csv'
hparams_file='/path/to/hparams_file.yaml'
)
# override some of the params with new values
Expand Down Expand Up @@ -1531,9 +1571,22 @@ def __init__(self, hparams):

# add the hparams from csv file to checkpoint
if tags_csv is not None:
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
checkpoint['hparams'] = vars(hparams)
hparams_file = tags_csv
rank_zero_warn('`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0', DeprecationWarning)

if hparams_file is not None:
extension = hparams_file.split('.')[-1]
if extension.lower() in ('csv'):
hparams = load_hparams_from_tags_csv(hparams_file)
elif extension.lower() in ('yml', 'yaml'):
hparams = load_hparams_from_yaml(hparams_file)
else:
raise ValueError('.csv, .yml or .yaml is required for `hparams_file`')

hparams['on_gpu'] = False

# overwrite hparams by the given file
checkpoint['hparams'] = hparams

# override the hparam keys that were passed in
if hparam_overrides is not None:
Expand All @@ -1549,15 +1602,18 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh

if cls_takes_hparams:
if ckpt_hparams is not None:
is_namespace = checkpoint.get('hparams_type', 'namespace') == 'namespace'
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
hparams_type = checkpoint.get('hparams_type', 'Namespace')
if hparams_type.lower() == 'dict':
hparams = ckpt_hparams
elif hparams_type.lower() == 'namespace':
hparams = Namespace(**ckpt_hparams)
else:
rank_zero_warn(
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__"
" contains argument 'hparams'. Will pass in an empty Namespace instead."
" Did you forget to store your model hyperparameters in self.hparams?"
)
hparams = Namespace()
hparams = {}
else: # The user's LightningModule does not define a hparams argument
if ckpt_hparams is None:
hparams = None
Expand All @@ -1568,7 +1624,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh
)

# load the state_dict on the model automatically
if hparams:
if cls_takes_hparams:
kwargs.update(hparams=hparams)
model = cls(*args, **kwargs)
model.load_state_dict(checkpoint['state_dict'])
Expand Down
93 changes: 72 additions & 21 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import ast
import csv
import os
import yaml
from argparse import Namespace
from typing import Union, Dict, Any

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn


class ModelIO(object):
Expand Down Expand Up @@ -79,30 +82,78 @@ def update_hparams(hparams: dict, updates: dict) -> None:
hparams.update({k: v})


def load_hparams_from_tags_csv(tags_csv: str) -> Namespace:
def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
"""Load hparams from a file.
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
>>> path_csv = './testing-hparams.csv'
>>> save_hparams_to_tags_csv(path_csv, hparams)
>>> hparams_new = load_hparams_from_tags_csv(path_csv)
>>> vars(hparams) == hparams_new
True
>>> os.remove(path_csv)
"""
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')
return Namespace()
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
return {}

with open(tags_csv) as f:
csv_reader = csv.reader(f, delimiter=',')
with open(tags_csv) as fp:
csv_reader = csv.reader(fp, delimiter=',')
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
ns = Namespace(**tags)
return ns

return tags


def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')

if isinstance(hparams, Namespace):
hparams = vars(hparams)

with open(tags_csv, 'w') as fp:
fieldnames = ['key', 'value']
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({'key': 'key', 'value': 'value'})
for k, v in hparams.items():
writer.writerow({'key': k, 'value': v})


def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
"""Load hparams from a file.
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
>>> path_yaml = './testing-hparams.yaml'
>>> save_hparams_to_yaml(path_yaml, hparams)
>>> hparams_new = load_hparams_from_yaml(path_yaml)
>>> vars(hparams) == hparams_new
True
>>> os.remove(path_yaml)
"""
if not os.path.isfile(config_yaml):
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
return {}

with open(config_yaml) as fp:
tags = yaml.load(fp, Loader=yaml.SafeLoader)

return tags


def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')

if isinstance(hparams, Namespace):
hparams = vars(hparams)

with open(config_yaml, 'w', newline='') as fp:
yaml.dump(hparams, fp)


def convert(val: str) -> Union[int, float, bool, str]:
constructors = [int, float, str]

if isinstance(val, str):
if val.lower() == 'true':
return True
if val.lower() == 'false':
return False

for c in constructors:
try:
return c(val)
except ValueError:
pass
return val
try:
return ast.literal_eval(val)
except (ValueError, SyntaxError) as e:
log.debug(e)
return val
Loading

0 comments on commit 22d7d03

Please sign in to comment.