Skip to content

Commit

Permalink
Join Horovod workers at the end of trainer.fit() to prevent race cond…
Browse files Browse the repository at this point in the history
…itions following training (#1786)

* Join Horovod workers at the end of trainer.fit() to prevent race conditions following training

* flake8

* flake8

Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
tgaddair and Borda committed May 12, 2020
1 parent 7b60d49 commit acab068
Show file tree
Hide file tree
Showing 16 changed files with 34 additions and 16 deletions.
3 changes: 2 additions & 1 deletion pytorch_lightning/core/model_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.saving import * # noqa: F403

rank_zero_warn("`model_saving` module has been renamed to `saving` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.saving import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/core/root_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.lightning import * # noqa: F403

rank_zero_warn("`root_module` module has been renamed to `lightning` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.lightning import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers import * # noqa: F403

rank_zero_warn("`logging` package has been renamed to `loggers` since v0.7.0"
" The deprecated package name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers.comet import CometLogger # noqa: F403

rank_zero_warn("`logging.comet` module has been renamed to `loggers.comet` since v0.7.0."
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers.comet import CometLogger # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers.mlflow import MLFlowLogger # noqa: F403

rank_zero_warn("`logging.mlflow` module has been renamed to `loggers.mlflow` since v0.7.0."
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers.mlflow import MLFlowLogger # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers.neptune import NeptuneLogger # noqa: F403

rank_zero_warn("`logging.neptune` module has been renamed to `loggers.neptune` since v0.7.0."
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers.neptune import NeptuneLogger # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers.test_tube import TestTubeLogger # noqa: F403

rank_zero_warn("`logging.test_tube` module has been renamed to `loggers.test_tube` since v0.7.0."
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers.test_tube import TestTubeLogger # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F403

rank_zero_warn("`logging.wandb` module has been renamed to `loggers.wandb` since v0.7.0."
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F403 E402
5 changes: 3 additions & 2 deletions pytorch_lightning/pt_overrides/override_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.overrides.data_parallel import ( # noqa: F402
get_a_var, parallel_apply, LightningDataParallel, LightningDistributedDataParallel)

rank_zero_warn("`override_data_parallel` module has been renamed to `data_parallel` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.overrides.data_parallel import ( # noqa: F402 E402
get_a_var, parallel_apply, LightningDataParallel, LightningDistributedDataParallel)
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.decorators import * # noqa: F403

rank_zero_warn("`root_module.decorators` module has been renamed to `core.decorators` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.decorators import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.grads import * # noqa: F403

rank_zero_warn("`root_module.grads` module has been renamed to `core.grads` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.grads import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.hooks import * # noqa: F403

rank_zero_warn("`root_module.hooks` module has been renamed to `core.hooks` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.hooks import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.memory import * # noqa: F403

rank_zero_warn("`root_module.memory` module has been renamed to `core.memory` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.memory import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/model_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.saving import * # noqa: F403

rank_zero_warn("`root_module.model_saving` module has been renamed to `core.saving` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.saving import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/root_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.lightning import * # noqa: F403

rank_zero_warn("`root_module.root_module` module has been renamed to `core.lightning` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.lightning import * # noqa: F403 E402
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,9 @@ def filter_named_parameters(model, optimizer):

self.run_pretrain_routine(model)

# Make sure all workers have finished training before returning to the user
hvd.join()


def normalize_parse_gpu_string_input(s):
if isinstance(s, str):
Expand Down

0 comments on commit acab068

Please sign in to comment.