Skip to content

Commit

Permalink
[bugfix] TPU test hangs to barrier on 1 process (#6272)
Browse files Browse the repository at this point in the history
* update

* resolve flake8

* update

* update

* update changelog

* update

* resolve flake8

Co-authored-by: Your Name <you@example.com>
  • Loading branch information
tchaton and Your Name authored Mar 2, 2021
1 parent 4157b35 commit 1aac481
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))


- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))


- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))


Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
from torch.optim.lr_scheduler import _LRScheduler, Optimizer

from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -116,7 +117,8 @@ def start_predicting(self, trainer):
hvd.join()

def barrier(self, *args, **kwargs):
hvd.join()
if torch_distrib.is_initialized():
hvd.join()

def broadcast(self, obj: object, src: int = 0) -> object:
obj = hvd.broadcast_object(obj, src)
Expand Down
20 changes: 17 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.multiprocessing as mp

from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -125,7 +126,8 @@ def model_to_device(self) -> None:
self._model.to(xm.xla_device())

def barrier(self, name: Optional[str] = None) -> None:
rendezvous(f"pl.Trainer.{name}")
if torch_distrib.is_initialized():
rendezvous(f"pl.Trainer.{name}")

def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
Expand All @@ -139,14 +141,26 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing trainer through model -> trainer?
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
xm.save(self.lightning_module.state_dict(), last_path)
self.save(self.lightning_module.state_dict(), last_path)

if self.global_rank == 0:
# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)

def save(self, state_dict: Dict, path: str) -> None:
"""
Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``.
The rendez-vous doesn't affect directly saving.
We can ignore the ``RuntimeError`` to reduce friction with TPUs.
"""
try:
xm.save(state_dict, path)
except RuntimeError as e:
if "Failed to meet rendezvous" not in str(e):
raise e

def broadcast(self, obj: object, src: int = 0) -> object:
buffer = io.BytesIO()
torch.save(obj, buffer)
Expand Down Expand Up @@ -294,4 +308,4 @@ def save_checkpoint(self, filepath, weights_only: bool = False):
# dump states as a checkpoint dictionary object
_checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
# Todo: TypeError: 'mappingproxy' object does not support item assignment
xm.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
# define the max CPU available
self.num_processes = os.cpu_count()
# special case with TPUs
elif self.distributed_backend == 'tpu':
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
self._device_type = DeviceType.TPU
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -912,8 +912,8 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
f'specify a path for a checkpoint .test(ckpt_path=PATH)'
)
return {}
if not self._device_type == DeviceType.TPU:
self.accelerator.barrier()

self.training_type_plugin.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
Expand Down
5 changes: 2 additions & 3 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ def test_model_16bit_tpu_cores_8(tmpdir):
def test_model_tpu_early_stop(tmpdir):
"""Test if single TPU core training works"""

# todo: Test on 8 cores - hanging.

class CustomBoringModel(BoringModel):

def validation_step(self, *args, **kwargs):
Expand All @@ -196,9 +194,10 @@ def validation_step(self, *args, **kwargs):
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
tpu_cores=[1],
tpu_cores=8,
)
trainer.fit(model)
trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))


@RunIf(tpu=True)
Expand Down

0 comments on commit 1aac481

Please sign in to comment.