Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable TPU support #868

Merged
merged 90 commits into from
Feb 17, 2020
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
a915582
added tpu docs
williamFalcon Feb 16, 2020
04461ba
added tpu flags
williamFalcon Feb 16, 2020
8e6ba2a
add tpu docs + init training call
williamFalcon Feb 16, 2020
2409492
amp
williamFalcon Feb 16, 2020
80d23dd
amp
williamFalcon Feb 16, 2020
b0d6b8f
amp
williamFalcon Feb 16, 2020
607aa0e
amp
williamFalcon Feb 16, 2020
0faa217
optimizer step
williamFalcon Feb 16, 2020
64d6cdd
added auto data transfer to TPU
williamFalcon Feb 16, 2020
b771aa8
added auto data transfer to TPU
williamFalcon Feb 16, 2020
c3dfce3
added auto data transfer to TPU
williamFalcon Feb 16, 2020
dceeefc
added auto data transfer to TPU
williamFalcon Feb 16, 2020
7ebb64a
added auto data transfer to TPU
williamFalcon Feb 16, 2020
20f655e
added auto data transfer to TPU
williamFalcon Feb 16, 2020
a3553da
added auto data transfer to TPU
williamFalcon Feb 16, 2020
a2fcb00
added auto data transfer to TPU
williamFalcon Feb 16, 2020
ebcd6cc
added auto data transfer to TPU
williamFalcon Feb 16, 2020
935852f
added auto data transfer to TPU
williamFalcon Feb 16, 2020
5901d07
added auto data transfer to TPU
williamFalcon Feb 16, 2020
a51d468
added auto data transfer to TPU
williamFalcon Feb 16, 2020
120e146
added auto data transfer to TPU
williamFalcon Feb 17, 2020
c549e61
added auto data transfer to TPU
williamFalcon Feb 17, 2020
baa335a
added auto data transfer to TPU
williamFalcon Feb 17, 2020
3bda0bf
added auto data transfer to TPU
williamFalcon Feb 17, 2020
1e4d2f5
added auto data transfer to TPU
williamFalcon Feb 17, 2020
62ada48
added auto data transfer to TPU
williamFalcon Feb 17, 2020
4f5795b
added auto data transfer to TPU
williamFalcon Feb 17, 2020
cb671d5
added auto data transfer to TPU
williamFalcon Feb 17, 2020
495c0d5
added auto data transfer to TPU
williamFalcon Feb 17, 2020
9bc8f49
added auto data transfer to TPU
williamFalcon Feb 17, 2020
4334820
added auto data transfer to TPU
williamFalcon Feb 17, 2020
468073f
added auto data transfer to TPU
williamFalcon Feb 17, 2020
6029fad
fix test pkg create (#873)
Borda Feb 17, 2020
ad1de98
added auto data transfer to TPU
williamFalcon Feb 17, 2020
6dfc640
added auto data transfer to TPU
williamFalcon Feb 17, 2020
a98988b
added auto data transfer to TPU
williamFalcon Feb 17, 2020
55e2465
added test return and print
williamFalcon Feb 17, 2020
3cf04a8
added test return and print
williamFalcon Feb 17, 2020
299fd3e
added test return and print
williamFalcon Feb 17, 2020
91a1452
added test return and print
williamFalcon Feb 17, 2020
a741374
added test return and print
williamFalcon Feb 17, 2020
31833bd
Update pytorch_lightning/trainer/trainer.py
williamFalcon Feb 17, 2020
43ac63f
Fix segmentation example (#876)
akshaykulkarni07 Feb 17, 2020
93e8ad1
Fix/typo (#880)
ethanwharris Feb 17, 2020
a33beb6
Changelog (#869)
ethanwharris Feb 17, 2020
f44dfb3
Fixing Function Signatures (#871)
xssChauhan Feb 17, 2020
8c6547b
added tpu docs
williamFalcon Feb 16, 2020
87b8346
added tpu flags
williamFalcon Feb 16, 2020
c901459
add tpu docs + init training call
williamFalcon Feb 16, 2020
2d73b6b
amp
williamFalcon Feb 16, 2020
35e15f1
amp
williamFalcon Feb 16, 2020
4ba19a7
amp
williamFalcon Feb 16, 2020
097e24e
amp
williamFalcon Feb 16, 2020
3de4fb3
optimizer step
williamFalcon Feb 16, 2020
7839193
added auto data transfer to TPU
williamFalcon Feb 16, 2020
abadbfa
added auto data transfer to TPU
williamFalcon Feb 16, 2020
3b52a47
added auto data transfer to TPU
williamFalcon Feb 16, 2020
3a7717b
added auto data transfer to TPU
williamFalcon Feb 16, 2020
650506f
added auto data transfer to TPU
williamFalcon Feb 16, 2020
1d78400
added auto data transfer to TPU
williamFalcon Feb 16, 2020
2569fcc
added auto data transfer to TPU
williamFalcon Feb 16, 2020
802f3b2
added auto data transfer to TPU
williamFalcon Feb 16, 2020
9feb981
added auto data transfer to TPU
williamFalcon Feb 16, 2020
6e13655
added auto data transfer to TPU
williamFalcon Feb 16, 2020
1b0fcdc
added auto data transfer to TPU
williamFalcon Feb 16, 2020
561efea
added auto data transfer to TPU
williamFalcon Feb 16, 2020
8da918e
added auto data transfer to TPU
williamFalcon Feb 17, 2020
92235b1
added auto data transfer to TPU
williamFalcon Feb 17, 2020
2636a8d
added auto data transfer to TPU
williamFalcon Feb 17, 2020
b40d7b3
added auto data transfer to TPU
williamFalcon Feb 17, 2020
79f5110
added auto data transfer to TPU
williamFalcon Feb 17, 2020
f62f4d1
added auto data transfer to TPU
williamFalcon Feb 17, 2020
56269f6
added auto data transfer to TPU
williamFalcon Feb 17, 2020
f869817
added auto data transfer to TPU
williamFalcon Feb 17, 2020
e9c6be0
added auto data transfer to TPU
williamFalcon Feb 17, 2020
e3d8053
added auto data transfer to TPU
williamFalcon Feb 17, 2020
a160d8f
added auto data transfer to TPU
williamFalcon Feb 17, 2020
06993b0
added auto data transfer to TPU
williamFalcon Feb 17, 2020
09aa6b2
added auto data transfer to TPU
williamFalcon Feb 17, 2020
10cea4c
added auto data transfer to TPU
williamFalcon Feb 17, 2020
acb8472
added auto data transfer to TPU
williamFalcon Feb 17, 2020
5bc3583
added test return and print
williamFalcon Feb 17, 2020
8095c38
added test return and print
williamFalcon Feb 17, 2020
b56e29f
added test return and print
williamFalcon Feb 17, 2020
747b2b9
added test return and print
williamFalcon Feb 17, 2020
2d6afa2
added test return and print
williamFalcon Feb 17, 2020
69e12f6
added test return and print
williamFalcon Feb 17, 2020
b0cd1f4
added test return and print
williamFalcon Feb 17, 2020
73e3751
added test return and print
williamFalcon Feb 17, 2020
10da940
added test return and print
williamFalcon Feb 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions docs/source/apex.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
16-bit training
=================
Lightning offers 16-bit training for CPUs, GPUs and TPUs.

GPU 16-bit
-----------
Lightning uses NVIDIA apex to handle 16-bit precision training.

To use 16-bit precision, do two things:

1. Install Apex
2. Set the amp trainer flag.
2. Set the "precision" trainer flag.

Install apex
----------------------------------------------
^^^^^^^^^^^^
.. code-block:: bash

$ git clone https://github.com/NVIDIA/apex
Expand All @@ -31,12 +36,25 @@ Install apex


Enable 16-bit
--------------
^^^^^^^^^^^^^

.. code-block:: python

# DEFAULT
trainer = Trainer(amp_level='O1', use_amp=False)
# turn on 16-bit
trainer = Trainer(amp_level='O1', precision=16)

If you need to configure the apex init for your particular use case or want to use a different way of doing
16-bit training, override :meth:`pytorch_lightning.core.LightningModule.configure_apex`.
16-bit training, override :meth:`pytorch_lightning.core.LightningModule.configure_apex`.

TPU 16-bit
----------
16-bit on TPus is much simpler. To use 16-bit with TPUs set precision to 16 when using the tpu flag

.. code-block:: python

# DEFAULT
trainer = Trainer(num_tpu_cores=8, precision=32)

# turn on 16-bit
trainer = Trainer(num_tpu_cores=8, precision=16)

1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ PyTorch-Lightning Documentation
single_gpu
sequences
training_tricks
tpu
test_set
optimizers
profiler
Expand Down
2 changes: 1 addition & 1 deletion docs/source/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Then you could do rapid research by switching between these two and using the sa
else:
model = CoolerNotBERT()

trainer = Trainer(gpus=4, use_amp=True)
trainer = Trainer(gpus=4, precision=16)
trainer.fit(model)


Expand Down
175 changes: 175 additions & 0 deletions docs/source/tpu.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
TPU support
===========

Lightning supports running on TPUs. At this moment, TPUs are only available
on Google Cloud (GCP). For more information on TPUs
`watch this video <https://www.youtube.com/watch?v=kPMpmcl_Pyw>`_.

Live demo
----------
Check out this `Google Colab <https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3>`_ to see how to train MNIST on TPUs.

TPU Terminology
---------------
A TPU is a Tensor processing unit. Each TPU has 8 cores where each
core is optimized for 128x128 matrix multiplies. In general, a single
TPU is about as fast as 5 V100 GPUs!

A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores!
You can request a full pod from Google cloud or a "slice" which gives you
some subset of those 2048 cores.

How to access TPUs
-------------------
To access TPUs there are two main ways.

1. Using google colab.
2. Using Google Cloud (GCP).

Colab TPUs
-----------
Colab is like a jupyter notebook with a free GPU or TPU
hosted on GCP.

To get a TPU on colab, follow these steps:

1. Go to https://colab.research.google.com/.

2. Click "new notebook" (bottom right of pop-up).

3. Click runtime > change runtime settings. Select Python 3,
and hardware accelerator "TPU". This will give you a TPU with 8 cores.

4. Next, insert this code into the first cell and execute. This
will install the xla library that interfaces between PyTorch and
the TPU.

.. code-block:: python

import collections
from datetime import datetime, timedelta
import os
import requests
import threading

_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "xrt==1.15.0" #@param ["xrt==1.15.0", "torch_xla==nightly"]
CONFIG = {
'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
(datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
}[VERSION]
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)

# Update TPU XRT version
def update_server_xrt():
print('Updating server-side XRT to {} ...'.format(CONFIG.server))
url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
XRT_VERSION=CONFIG.server,
)
print('Done updating server-side XRT: {}'.format(requests.post(url)))

update = threading.Thread(target=update_server_xrt)
update.start()

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5
update.join()
5. Once the above is done, install PyTorch Lightning (v 0.6.1+).

.. code-block::

! pip install pytorch-lightning

6. Then set up your LightningModule as normal.

7. TPUs require a DistributedSampler. That means you should change your
train_dataloader (and val, train) code as follows.

.. code-block:: python

import torch_xla.core.xla_model as xm

@pl.data_loader
def train_dataloader(self):
dataset = MNIST(
os.getcwd(),
train=True,
download=True,
transform=transforms.ToTensor()
)

# required for TPU support
sampler = None
if use_tpu:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True
)

loader = DataLoader(
dataset,
sampler=sampler,
batch_size=32
)

return loader

8. Configure the number of TPU cores in the trainer. You can only choose
1 or 8. To use a full TPU pod skip to the TPU pod section.

.. code-block:: python

import pytorch_lightning as pl

my_model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8)
trainer.fit(my_model)

That's it! Your model will train on all 8 TPU cores.

TPU Pod
--------
To train on more than 8 cores, your code actually doesn't change!
All you need to do is submit the following command:

.. code-block:: bash
$ python -m torch_xla.distributed.xla_dist
--tpu=$TPU_POD_NAME
--conda-env=torch-xla-nightly
-- python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --fake_data

16 bit precision
-----------------
Lightning also supports training in 16-bit precision with TPUs.
By default, TPU training will use 32-bit precision. To enable 16-bit, also
set the 16-bit flag.

.. code-block:: python

import pytorch_lightning as pl

my_model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8, precision=16)
trainer.fit(my_model)

Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.


About XLA
----------
XLA is the library that interfaces PyTorch with the TPUs.
For more information check out `XLA <https://github.com/pytorch/xla>`_.
13 changes: 8 additions & 5 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def on_after_backward(self):

"""

def backward(self, use_amp, loss, optimizer, optimizer_idx):
def backward(self, trainer, loss, optimizer, optimizer_idx):
"""Override backward with your own implementation if you need to

:param use_amp: Whether amp was requested or not
:param trainer: Pointer to the trainer
:param loss: Loss is already scaled by accumulated grads
:param optimizer: Current optimizer being used
:param optimizer_idx: Index of the current optimizer being used
Expand All @@ -137,8 +137,11 @@ def backward(self, use_amp, loss, optimizer):
loss.backward()

"""
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if trainer.precision == 16:

# .backward is not special on 16-bit with TPUs
if not trainer.on_tpu:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/auto_mix_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

class TrainerAMPMixin(ABC):

def __init__(self):
self.use_amp = None

def init_amp(self, use_amp):
self.use_amp = use_amp and APEX_AVAILABLE
if self.use_amp:
Expand Down
25 changes: 19 additions & 6 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(self):
self.use_ddp2 = None
self.shown_warnings = None
self.val_check_interval = None
self.use_tpu = None
self.tpu_local_core_rank = None

def _percent_range_check(self, name):
value = getattr(self, name)
Expand Down Expand Up @@ -80,9 +82,10 @@ def init_train_dataloader(self, model):
self.val_check_batch = max(1, self.val_check_batch)

on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and not isinstance(self.get_train_dataloader().sampler, DistributedSampler):
needs_sampler = on_ddp or self.use_tpu
if needs_sampler and not isinstance(self.get_train_dataloader().sampler, DistributedSampler):
msg = """
You're using multiple gpus and multiple nodes without using a DistributedSampler
You're using multiple gpus and multiple nodes, or TPUs without using a
to assign a subset of your data to each process. To silence this warning, pass a
DistributedSampler to your DataLoader.

Expand Down Expand Up @@ -119,13 +122,14 @@ def init_val_dataloader(self, model):
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)

on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and self.get_val_dataloaders() is not None:
needs_sampler = on_ddp or self.use_tpu
if needs_sampler and self.get_val_dataloaders() is not None:
for dataloader in self.get_val_dataloaders():
if not isinstance(dataloader.sampler, DistributedSampler):
msg = """
Your val_dataloader(s) don't use DistributedSampler.

You're using multiple gpus and multiple nodes without using a
You're using multiple gpus and multiple nodes, or TPUs without using a
DistributedSampler to assign a subset of your data to each process.
To silence this warning, pass a DistributedSampler to your DataLoader.

Expand Down Expand Up @@ -162,13 +166,14 @@ def init_test_dataloader(self, model):
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)

on_ddp = self.use_ddp or self.use_ddp2
if on_ddp and self.get_test_dataloaders() is not None:
needs_sampler = on_ddp or self.use_tpu
if needs_sampler and self.get_test_dataloaders() is not None:
for dataloader in self.get_test_dataloaders():
if not isinstance(dataloader.sampler, DistributedSampler):
msg = """
Your `test_dataloader(s)` don't use DistributedSampler.

You're using multiple gpus and multiple nodes without using a
You're using multiple gpus and multiple nodes, or TPUs without using a
DistributedSampler to assign a subset of your data to each process.
To silence this warning, pass a DistributedSampler to your DataLoader.

Expand Down Expand Up @@ -210,6 +215,14 @@ def get_dataloaders(self, model):
self.get_test_dataloaders()
self.get_val_dataloaders()

# on TPUs load each dataloader only on process 0
# this will trigger the data downloads
if self.use_tpu:
if self.tpu_local_core_rank == 0:
self.get_train_dataloader()
self.get_test_dataloaders()
self.get_val_dataloaders()

# support IterableDataset for train data
self.is_iterable_train_dataloader = (
EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset))
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(self):
self.distributed_backend = None
self.use_amp = None
self.amp_level = None
self.use_tpu = None

@abstractmethod
def copy_trainer_model_properties(self, model):
Expand All @@ -160,6 +161,13 @@ def init_optimizers(self, optimizers):
# this is just empty shell for code from other class
pass

def init_tpu(self):
# turn off all the GPU stuff
self.distributed_backend = None

# enable tpu
self.use_tpu = True

def set_distributed_mode(self, distributed_backend, num_gpu_nodes):
# skip for CPU
if self.num_gpus == 0:
Expand Down
Loading