Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Join Horovod workers at the end of trainer.fit() to prevent race conditions following training #1786

Merged
merged 3 commits into from
May 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytorch_lightning/core/model_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.saving import * # noqa: F403

rank_zero_warn("`model_saving` module has been renamed to `saving` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.saving import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/core/root_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.lightning import * # noqa: F403

rank_zero_warn("`root_module` module has been renamed to `lightning` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.lightning import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers import * # noqa: F403

rank_zero_warn("`logging` package has been renamed to `loggers` since v0.7.0"
" The deprecated package name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers.comet import CometLogger # noqa: F403

rank_zero_warn("`logging.comet` module has been renamed to `loggers.comet` since v0.7.0."
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers.comet import CometLogger # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers.mlflow import MLFlowLogger # noqa: F403

rank_zero_warn("`logging.mlflow` module has been renamed to `loggers.mlflow` since v0.7.0."
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers.mlflow import MLFlowLogger # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers.neptune import NeptuneLogger # noqa: F403

rank_zero_warn("`logging.neptune` module has been renamed to `loggers.neptune` since v0.7.0."
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers.neptune import NeptuneLogger # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers.test_tube import TestTubeLogger # noqa: F403

rank_zero_warn("`logging.test_tube` module has been renamed to `loggers.test_tube` since v0.7.0."
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers.test_tube import TestTubeLogger # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/logging/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F403

rank_zero_warn("`logging.wandb` module has been renamed to `loggers.wandb` since v0.7.0."
" The deprecated module name will be removed in v0.9.0.", DeprecationWarning)

from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F403 E402
5 changes: 3 additions & 2 deletions pytorch_lightning/pt_overrides/override_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.overrides.data_parallel import ( # noqa: F402
get_a_var, parallel_apply, LightningDataParallel, LightningDistributedDataParallel)

rank_zero_warn("`override_data_parallel` module has been renamed to `data_parallel` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.overrides.data_parallel import ( # noqa: F402 E402
get_a_var, parallel_apply, LightningDataParallel, LightningDistributedDataParallel)
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.decorators import * # noqa: F403

rank_zero_warn("`root_module.decorators` module has been renamed to `core.decorators` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.decorators import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.grads import * # noqa: F403

rank_zero_warn("`root_module.grads` module has been renamed to `core.grads` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.grads import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.hooks import * # noqa: F403

rank_zero_warn("`root_module.hooks` module has been renamed to `core.hooks` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.hooks import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.memory import * # noqa: F403

rank_zero_warn("`root_module.memory` module has been renamed to `core.memory` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.memory import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/model_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.saving import * # noqa: F403

rank_zero_warn("`root_module.model_saving` module has been renamed to `core.saving` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.saving import * # noqa: F403 E402
3 changes: 2 additions & 1 deletion pytorch_lightning/root_module/root_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.core.lightning import * # noqa: F403

rank_zero_warn("`root_module.root_module` module has been renamed to `core.lightning` since v0.6.0."
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)

from pytorch_lightning.core.lightning import * # noqa: F403 E402
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,9 @@ def filter_named_parameters(model, optimizer):

self.run_pretrain_routine(model)

# Make sure all workers have finished training before returning to the user
hvd.join()


def normalize_parse_gpu_string_input(s):
if isinstance(s, str):
Expand Down