Skip to content

Commit

Permalink
Merge branch 'fix_limit_batches_infdl' of https://github.com/PyTorchL…
Browse files Browse the repository at this point in the history
…ightning/pytorch-lightning into fix_limit_batches_infdl
  • Loading branch information
Borda committed Aug 7, 2020
2 parents 0246e3c + cb7d3a7 commit c76b23e
Show file tree
Hide file tree
Showing 59 changed files with 897 additions and 163 deletions.
10 changes: 5 additions & 5 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ workflows:
filters:
branches:
# https://discuss.circleci.com/t/create-separate-steps-jobs-for-pr-forks-versus-branches/13419/4
only:
# only from forks
- /^pull\/.*$/
# only from canonical repository
- /^(?!pull\/).*$/
#only:
# # only from forks
# - /^pull\/.\d+$/
ignore:
- master
cleanup:
triggers:
- schedule:
Expand Down
5 changes: 2 additions & 3 deletions .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ formatting errors. In certain cases, a missing blank line or a wrong indent can
Run these commands

```bash
pip install -r requirements/docs.txt
pip install ".[docs]"
cd docs
make html
```
Expand All @@ -159,8 +159,7 @@ Testing your work locally will help you speed up the process since it allows you
To setup a local development environment, install both local and test dependencies:

```bash
python -m pip install -r requirements/devel.txt
python -m pip install -r requirements/examples.txt
python -m pip install ".[dev, examples]"
python -m pip install pre-commit
```

Expand Down
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added SyncBN for DDP ([#2801](https://github.com/PyTorchLightning/pytorch-lightning/pull/2801))

- Added basic `CSVLogger` ([#2721](https://github.com/PyTorchLightning/pytorch-lightning/pull/2721))

- Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671))

- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))
Expand All @@ -31,7 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562))

- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2787](https://github.com/PyTorchLightning/pytorch-lightning/pull/2787))
- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2840](https://github.com/PyTorchLightning/pytorch-lightning/pull/2840))

### Changed

Expand Down
6 changes: 6 additions & 0 deletions docs/source/loggers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -339,4 +339,10 @@ Test-tube
^^^^^^^^^

.. autoclass:: pytorch_lightning.loggers.test_tube.TestTubeLogger
:noindex:

CSVLogger
^^^^^^^^^

.. autoclass:: pytorch_lightning.loggers.csv_logs.CSVLogger
:noindex:
4 changes: 2 additions & 2 deletions docs/source/results.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ using the equivalent syntax via the `TrainResult` object:
--------------------

Validation loop example
-----------------------
Validation/Test loop example
-----------------------------
We can replace the following validation/test loop:

.. code-block:: python
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Root package info."""

__version__ = '0.9.0rc6'
__version__ = '0.9.0rc9'
__author__ = 'William Falcon et al.'
__author_email__ = 'waf2107@columbia.edu'
__license__ = 'Apache-2.0'
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/accelerator_backends/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# MODEL
# copy model to each gpu
if self.trainer.on_gpu:
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/accelerator_backends/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# MODEL
# copy model to each gpu
if self.trainer.on_gpu:
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def on_sanity_check_end(self, trainer, pl_module):
"""Called when the validation sanity check ends."""
pass

def on_train_batch_start(self, trainer, pl_module):
"""Called when the validation batch begins."""
pass

def on_train_batch_end(self, trainer, pl_module):
"""Called when the validation batch ends."""
pass

def on_train_epoch_start(self, trainer, pl_module):
"""Called when the train epoch begins."""
pass
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/lr_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def on_train_start(self, trainer, pl_module):
# Initialize for storing values
self.lrs = {name: [] for name in names}

def on_batch_start(self, trainer, pl_module):
def on_train_batch_start(self, trainer, pl_module):
latest_stat = self._extract_lr(trainer, 'step')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(self):
def disable(self):
self.enable = False
def on_batch_end(self, trainer, pl_module):
super().on_batch_end(trainer, pl_module) # don't forget this :)
def on_train_batch_end(self, trainer, pl_module):
super().on_train_batch_end(trainer, pl_module) # don't forget this :)
percent = (self.train_batch_idx / self.total_train_batches) * 100
sys.stdout.flush()
sys.stdout.write(f'{percent:.01f} percent complete \r')
Expand Down Expand Up @@ -138,7 +138,7 @@ def on_train_start(self, trainer, pl_module):
def on_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0

def on_batch_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module):
self._train_batch_idx += 1

def on_validation_start(self, trainer, pl_module):
Expand Down Expand Up @@ -318,8 +318,8 @@ def on_epoch_start(self, trainer, pl_module):
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}')

def on_batch_end(self, trainer, pl_module):
super().on_batch_end(trainer, pl_module)
def on_train_batch_end(self, trainer, pl_module):
super().on_train_batch_end(trainer, pl_module)
if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0:
self.main_progress_bar.update(self.refresh_rate)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
Expand Down
21 changes: 21 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,23 @@ def on_train_end(self) -> None:
"""
# do something at the end of training

def on_train_batch_start(self, batch: Any) -> None:
"""
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
Args:
batch: The batched data as it is returned by the training DataLoader.
"""
# do something when the batch starts

def on_train_batch_end(self) -> None:
"""
Called in the training loop after the batch.
"""
# do something when the batch end

def on_batch_start(self, batch: Any) -> None:
"""
Called in the training loop before anything happens for that batch.
Expand All @@ -85,12 +102,16 @@ def on_batch_start(self, batch: Any) -> None:
Args:
batch: The batched data as it is returned by the training DataLoader.
.. warning:: Deprecated in 0.9.0 will remove 1.0.0 (use `on_train_batch_start` instead)
"""
# do something when the batch starts

def on_batch_end(self) -> None:
"""
Called in the training loop after the batch.
.. warning:: Deprecated in 0.9.0 will remove 1.0.0 (use `on_train_batch_end` instead)
"""
# do something when the batch ends

Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,23 @@ def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managi
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule':
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Override to synchronize batchnorm between specific process groups instead
of the whole world or use a different sync_bn like `apex`'s version.
Args:
model: pointer to current :class:`LightningModule`.
Return:
LightningModule with batchnorm layers synchronized between process groups
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model

def configure_apex(
self, amp: object, model: 'LightningModule', optimizers: List[Optimizer], amp_level: str
) -> Tuple['LightningModule', List[Optimizer]]:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
return {}

with open(config_yaml) as fp:
tags = yaml.load(fp, Loader=yaml.SafeLoader)
tags = yaml.load(fp)

return tags

Expand Down
33 changes: 29 additions & 4 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numbers
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any
from torch import Tensor
import torch
from copy import copy
from pytorch_lightning.metrics.converters import _sync_ddp_if_available


class Result(Dict):
Expand Down Expand Up @@ -89,11 +91,18 @@ def log(
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
value = value.detach()

# sync across ddp
if sync_ddp and isinstance(value, (torch.Tensor, numbers.Number)):
value = _sync_ddp_if_available(value, group=sync_ddp_group, reduce_op=sync_ddp_op)

if 'meta' not in self:
self.__setitem__('meta', {})

Expand Down Expand Up @@ -338,6 +347,9 @@ def log(
on_epoch: bool = False,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a key, value
Expand Down Expand Up @@ -369,7 +381,8 @@ def log(
reduce_fx: Torch.mean by default
enable_graph: if True, will not auto detach the graph
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)

def log_dict(
self,
Expand All @@ -380,6 +393,9 @@ def log_dict(
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a dictonary of values at once
Expand All @@ -399,7 +415,8 @@ def log_dict(
enable_graph:
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)


class EvalResult(Result):
Expand Down Expand Up @@ -446,6 +463,9 @@ def log(
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a key, value
Expand Down Expand Up @@ -476,7 +496,8 @@ def log(
reduce_fx: Torch.mean by default
enable_graph: if True, will not auto detach the graph :
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)

def log_dict(
self,
Expand All @@ -487,6 +508,9 @@ def log_dict(
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_ddp: bool = False,
sync_ddp_op: Union[Any, str] = 'mean',
sync_ddp_group: Optional[Any] = None
):
"""
Log a dictonary of values at once
Expand All @@ -506,7 +530,8 @@ def log_dict(
enable_graph:
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph,
sync_ddp=sync_ddp, sync_ddp_group=sync_ddp_group, sync_ddp_op=sync_ddp_op)

def get_callback_metrics(self) -> dict:
result = {
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

from pytorch_lightning.loggers.base import LightningLoggerBase, LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loggers.csv_logs import CSVLogger


__all__ = [
'LightningLoggerBase',
'LoggerCollection',
'TensorBoardLogger',
'CSVLogger',
]

try:
Expand Down
Loading

0 comments on commit c76b23e

Please sign in to comment.