Skip to content

Commit

Permalink
[bugfix] TPU test hangs to barrier on 1 process (#6272)
Browse files Browse the repository at this point in the history
* update

* resolve flake8

* update

* update

* update changelog

* update

* resolve flake8

Co-authored-by: Your Name <you@example.com>
  • Loading branch information
tchaton and Your Name committed Mar 9, 2021
1 parent efcd761 commit c388431
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 13 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `AttributeError` when `logger=None` on TPU ([#6221](https://github.com/PyTorchLightning/pytorch-lightning/pull/6221))



- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def setup(self, trainer, model):
return super().setup(trainer, model)

def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
from torch.optim.lr_scheduler import _LRScheduler, Optimizer

from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -116,7 +117,8 @@ def start_predicting(self, trainer):
hvd.join()

def barrier(self, *args, **kwargs):
hvd.join()
if torch_distrib.is_initialized():
hvd.join()

def broadcast(self, obj: object, src: int = 0) -> object:
obj = hvd.broadcast_object(obj, src)
Expand Down
20 changes: 17 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.multiprocessing as mp

from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -112,7 +113,8 @@ def model_to_device(self) -> None:
self._model.to(xm.xla_device())

def barrier(self, name: Optional[str] = None) -> None:
rendezvous(f"pl.Trainer.{name}")
if torch_distrib.is_initialized():
rendezvous(f"pl.Trainer.{name}")

def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
Expand All @@ -126,14 +128,26 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing trainer through model -> trainer?
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
xm.save(self.lightning_module.state_dict(), last_path)
self.save(self.lightning_module.state_dict(), last_path)

if self.global_rank == 0:
# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)

def save(self, state_dict: Dict, path: str) -> None:
"""
Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``.
The rendez-vous doesn't affect directly saving.
We can ignore the ``RuntimeError`` to reduce friction with TPUs.
"""
try:
xm.save(state_dict, path)
except RuntimeError as e:
if "Failed to meet rendezvous" not in str(e):
raise e

def broadcast(self, obj: object, src: int = 0) -> object:
buffer = io.BytesIO()
torch.save(obj, buffer)
Expand Down Expand Up @@ -281,4 +295,4 @@ def save_checkpoint(self, filepath, weights_only: bool = False):
# dump states as a checkpoint dictionary object
_checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
# Todo: TypeError: 'mappingproxy' object does not support item assignment
xm.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
# define the max CPU available
self.num_processes = os.cpu_count()
# special case with TPUs
elif self.distributed_backend == 'tpu':
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
self._device_type = DeviceType.TPU
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.enums import LightningEnum
Expand Down Expand Up @@ -949,8 +949,8 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
f'specify a path for a checkpoint .test(ckpt_path=PATH)'
)
return {}
if not self._device_type == DeviceType.TPU:
self.accelerator.barrier()

self.training_type_plugin.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
Expand Down
5 changes: 2 additions & 3 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ def test_model_16bit_tpu_cores_8(tmpdir):
def test_model_tpu_early_stop(tmpdir):
"""Test if single TPU core training works"""

# todo: Test on 8 cores - hanging.

class CustomBoringModel(BoringModel):

def validation_step(self, *args, **kwargs):
Expand All @@ -195,9 +193,10 @@ def validation_step(self, *args, **kwargs):
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
tpu_cores=[1],
tpu_cores=8,
)
trainer.fit(model)
trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
Expand Down

0 comments on commit c388431

Please sign in to comment.