Skip to content

Commit

Permalink
Enable TPU support (#868)
Browse files Browse the repository at this point in the history
* added tpu docs

* added tpu flags

* add tpu docs + init training call

* amp

* amp

* amp

* amp

* optimizer step

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* fix test pkg create (#873)

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added test return and print

* added test return and print

* added test return and print

* added test return and print

* added test return and print

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Luis Capelo <luiscape@gmail.com>

* Fix segmentation example (#876)

* removed torchvision model and added custom model

* minor fix

* Fixed relative imports issue

* Fix/typo (#880)

* Update greetings.yml

* Update greetings.yml

* Changelog (#869)

* Create CHANGELOG.md

* Update CHANGELOG.md

* Update CHANGELOG.md

* Update PULL_REQUEST_TEMPLATE.md

* Update PULL_REQUEST_TEMPLATE.md

* Add PR links to Version 0.6.0 in CHANGELOG.md

* Add PR links for Unreleased in CHANGELOG.md

* Update PULL_REQUEST_TEMPLATE.md

* Fixing Function Signatures (#871)

* added tpu docs

* added tpu flags

* add tpu docs + init training call

* amp

* amp

* amp

* amp

* optimizer step

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added auto data transfer to TPU

* added test return and print

* added test return and print

* added test return and print

* added test return and print

* added test return and print

* added test return and print

* added test return and print

* added test return and print

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Luis Capelo <luiscape@gmail.com>
Co-authored-by: Akshay Kulkarni <akshayk.vnit@gmail.com>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: Shikhar Chauhan <xssChauhan@users.noreply.github.com>
  • Loading branch information
6 people committed Feb 17, 2020
1 parent e38b18e commit d4a31f0
Show file tree
Hide file tree
Showing 14 changed files with 489 additions and 48 deletions.
30 changes: 24 additions & 6 deletions docs/source/apex.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
16-bit training
=================
Lightning offers 16-bit training for CPUs, GPUs and TPUs.

GPU 16-bit
-----------
Lightning uses NVIDIA apex to handle 16-bit precision training.

To use 16-bit precision, do two things:

1. Install Apex
2. Set the amp trainer flag.
2. Set the "precision" trainer flag.

Install apex
----------------------------------------------
^^^^^^^^^^^^
.. code-block:: bash
$ git clone https://github.com/NVIDIA/apex
Expand All @@ -31,12 +36,25 @@ Install apex
Enable 16-bit
--------------
^^^^^^^^^^^^^

.. code-block:: python
# DEFAULT
trainer = Trainer(amp_level='O1', use_amp=False)
# turn on 16-bit
trainer = Trainer(amp_level='O1', precision=16)
If you need to configure the apex init for your particular use case or want to use a different way of doing
16-bit training, override :meth:`pytorch_lightning.core.LightningModule.configure_apex`.
16-bit training, override :meth:`pytorch_lightning.core.LightningModule.configure_apex`.

TPU 16-bit
----------
16-bit on TPus is much simpler. To use 16-bit with TPUs set precision to 16 when using the tpu flag

.. code-block:: python
# DEFAULT
trainer = Trainer(num_tpu_cores=8, precision=32)
# turn on 16-bit
trainer = Trainer(num_tpu_cores=8, precision=16)
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ PyTorch-Lightning Documentation
single_gpu
sequences
training_tricks
tpu
test_set
optimizers
profiler
Expand Down
2 changes: 1 addition & 1 deletion docs/source/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Then you could do rapid research by switching between these two and using the sa
else:
model = CoolerNotBERT()
trainer = Trainer(gpus=4, use_amp=True)
trainer = Trainer(gpus=4, precision=16)
trainer.fit(model)
Expand Down
175 changes: 175 additions & 0 deletions docs/source/tpu.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
TPU support
===========

Lightning supports running on TPUs. At this moment, TPUs are only available
on Google Cloud (GCP). For more information on TPUs
`watch this video <https://www.youtube.com/watch?v=kPMpmcl_Pyw>`_.

Live demo
----------
Check out this `Google Colab <https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3>`_ to see how to train MNIST on TPUs.

TPU Terminology
---------------
A TPU is a Tensor processing unit. Each TPU has 8 cores where each
core is optimized for 128x128 matrix multiplies. In general, a single
TPU is about as fast as 5 V100 GPUs!

A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores!
You can request a full pod from Google cloud or a "slice" which gives you
some subset of those 2048 cores.

How to access TPUs
-------------------
To access TPUs there are two main ways.

1. Using google colab.
2. Using Google Cloud (GCP).

Colab TPUs
-----------
Colab is like a jupyter notebook with a free GPU or TPU
hosted on GCP.

To get a TPU on colab, follow these steps:

1. Go to https://colab.research.google.com/.

2. Click "new notebook" (bottom right of pop-up).

3. Click runtime > change runtime settings. Select Python 3,
and hardware accelerator "TPU". This will give you a TPU with 8 cores.

4. Next, insert this code into the first cell and execute. This
will install the xla library that interfaces between PyTorch and
the TPU.

.. code-block:: python
import collections
from datetime import datetime, timedelta
import os
import requests
import threading
_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "xrt==1.15.0" #@param ["xrt==1.15.0", "torch_xla==nightly"]
CONFIG = {
'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
(datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
}[VERSION]
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
# Update TPU XRT version
def update_server_xrt():
print('Updating server-side XRT to {} ...'.format(CONFIG.server))
url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
XRT_VERSION=CONFIG.server,
)
print('Done updating server-side XRT: {}'.format(requests.post(url)))
update = threading.Thread(target=update_server_xrt)
update.start()
# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5
update.join()
5. Once the above is done, install PyTorch Lightning (v 0.6.1+).

.. code-block::
! pip install pytorch-lightning
6. Then set up your LightningModule as normal.

7. TPUs require a DistributedSampler. That means you should change your
train_dataloader (and val, train) code as follows.

.. code-block:: python
import torch_xla.core.xla_model as xm
@pl.data_loader
def train_dataloader(self):
dataset = MNIST(
os.getcwd(),
train=True,
download=True,
transform=transforms.ToTensor()
)
# required for TPU support
sampler = None
if use_tpu:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True
)
loader = DataLoader(
dataset,
sampler=sampler,
batch_size=32
)
return loader
8. Configure the number of TPU cores in the trainer. You can only choose
1 or 8. To use a full TPU pod skip to the TPU pod section.

.. code-block:: python
import pytorch_lightning as pl
my_model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8)
trainer.fit(my_model)
That's it! Your model will train on all 8 TPU cores.

TPU Pod
--------
To train on more than 8 cores, your code actually doesn't change!
All you need to do is submit the following command:

.. code-block:: bash
$ python -m torch_xla.distributed.xla_dist
--tpu=$TPU_POD_NAME
--conda-env=torch-xla-nightly
-- python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --fake_data
16 bit precision
-----------------
Lightning also supports training in 16-bit precision with TPUs.
By default, TPU training will use 32-bit precision. To enable 16-bit, also
set the 16-bit flag.

.. code-block:: python
import pytorch_lightning as pl
my_model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8, precision=16)
trainer.fit(my_model)
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.


About XLA
----------
XLA is the library that interfaces PyTorch with the TPUs.
For more information check out `XLA <https://github.com/pytorch/xla>`_.
13 changes: 8 additions & 5 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def on_after_backward(self):
"""

def backward(self, use_amp, loss, optimizer, optimizer_idx):
def backward(self, trainer, loss, optimizer, optimizer_idx):
"""Override backward with your own implementation if you need to
:param use_amp: Whether amp was requested or not
:param trainer: Pointer to the trainer
:param loss: Loss is already scaled by accumulated grads
:param optimizer: Current optimizer being used
:param optimizer_idx: Index of the current optimizer being used
Expand All @@ -137,8 +137,11 @@ def backward(self, use_amp, loss, optimizer):
loss.backward()
"""
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if trainer.precision == 16:

# .backward is not special on 16-bit with TPUs
if not trainer.on_tpu:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/auto_mix_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

class TrainerAMPMixin(ABC):

def __init__(self):
self.use_amp = None

def init_amp(self, use_amp):
self.use_amp = use_amp and APEX_AVAILABLE
if self.use_amp:
Expand Down
25 changes: 19 additions & 6 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(self):
self.use_ddp2 = None
self.shown_warnings = None
self.val_check_interval = None
self.use_tpu = None
self.tpu_local_core_rank = None

def _percent_range_check(self, name):
value = getattr(self, name)
Expand Down Expand Up @@ -80,9 +82,10 @@ def init_train_dataloader(self, model):
self.val_check_batch = max(1, self.val_check_batch)

on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and not isinstance(self.get_train_dataloader().sampler, DistributedSampler):
needs_sampler = on_ddp or self.use_tpu
if needs_sampler and not isinstance(self.get_train_dataloader().sampler, DistributedSampler):
msg = """
You're using multiple gpus and multiple nodes without using a DistributedSampler
You're using multiple gpus and multiple nodes, or TPUs without using a
to assign a subset of your data to each process. To silence this warning, pass a
DistributedSampler to your DataLoader.
Expand Down Expand Up @@ -119,13 +122,14 @@ def init_val_dataloader(self, model):
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)

on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and self.get_val_dataloaders() is not None:
needs_sampler = on_ddp or self.use_tpu
if needs_sampler and self.get_val_dataloaders() is not None:
for dataloader in self.get_val_dataloaders():
if not isinstance(dataloader.sampler, DistributedSampler):
msg = """
Your val_dataloader(s) don't use DistributedSampler.
You're using multiple gpus and multiple nodes without using a
You're using multiple gpus and multiple nodes, or TPUs without using a
DistributedSampler to assign a subset of your data to each process.
To silence this warning, pass a DistributedSampler to your DataLoader.
Expand Down Expand Up @@ -162,13 +166,14 @@ def init_test_dataloader(self, model):
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)

on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and self.get_test_dataloaders() is not None:
needs_sampler = on_ddp or self.use_tpu
if needs_sampler and self.get_test_dataloaders() is not None:
for dataloader in self.get_test_dataloaders():
if not isinstance(dataloader.sampler, DistributedSampler):
msg = """
Your `test_dataloader(s)` don't use DistributedSampler.
You're using multiple gpus and multiple nodes without using a
You're using multiple gpus and multiple nodes, or TPUs without using a
DistributedSampler to assign a subset of your data to each process.
To silence this warning, pass a DistributedSampler to your DataLoader.
Expand Down Expand Up @@ -210,6 +215,14 @@ def get_dataloaders(self, model):
self.get_test_dataloaders()
self.get_val_dataloaders()

# on TPUs load each dataloader only on process 0
# this will trigger the data downloads
if self.use_tpu:
if self.tpu_local_core_rank == 0:
self.get_train_dataloader()
self.get_test_dataloaders()
self.get_val_dataloaders()

# support IterableDataset for train data
self.is_iterable_train_dataloader = (
EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset))
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(self):
self.distributed_backend = None
self.use_amp = None
self.amp_level = None
self.use_tpu = None

@abstractmethod
def copy_trainer_model_properties(self, model):
Expand All @@ -160,6 +161,13 @@ def init_optimizers(self, optimizers):
# this is just empty shell for code from other class
pass

def init_tpu(self):
# turn off all the GPU stuff
self.distributed_backend = None

# enable tpu
self.use_tpu = True

def set_distributed_mode(self, distributed_backend, num_gpu_nodes):
# skip for CPU
if self.num_gpus == 0:
Expand Down
Loading

2 comments on commit d4a31f0

@Borda
Copy link
Member

@Borda Borda commented on d4a31f0 Feb 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, Great work!

@Borda
Copy link
Member

@Borda Borda commented on d4a31f0 Feb 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamFalcon pls update CHANGELOG.md

Please sign in to comment.