Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to select individual TPU core to train on #1729

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8995fbc
added tpu_id
Apr 19, 2020
bd9e88c
train on individual tpu
Apr 26, 2020
1daadfa
parallel loader if tpu_id is None
May 3, 2020
e4d49d0
removed progress_bar_refresh_rate
May 4, 2020
0ed38cd
chlog
Borda May 5, 2020
725ef5d
replaced num_tpu_cores with tpu_cores
May 6, 2020
c0a4f9d
set tpu_id to None if int
May 6, 2020
f25d516
changed num_tpu_cores to tpu_cores in docs
May 6, 2020
a93c6bc
Merge branch 'master' into feature/1539_tpu_train_parallel
lezwon May 7, 2020
b22f485
updated docs
May 9, 2020
cdda262
Merge branch 'master' into feature/1539_tpu_train_parallel
lezwon May 9, 2020
0669ad2
updated __init__.py
May 9, 2020
2253b9f
Update pytorch_lightning/trainer/__init__.py
Borda May 10, 2020
67c5688
check if tpu_cores is a list
lezwon May 13, 2020
ec278d1
xla device conditional
May 10, 2020
100071b
num_tpu_cores deprecation
May 13, 2020
8adb0a9
removed duplicate warning
May 13, 2020
34f2209
Merge remote-tracking branch 'official/master' into feature/1539_tpu_…
May 13, 2020
f779d01
fixed pep8 error
May 13, 2020
dafe174
Revert "removed duplicate warning"
May 14, 2020
4c6958e
deprecated api update
May 14, 2020
5c0db30
fixed recursion error
May 14, 2020
c7a9b4e
fixed tests
May 14, 2020
83e5d99
fixed flake errors
May 14, 2020
230831e
Merge remote-tracking branch 'official/master' into feature/1539_tpu_…
May 14, 2020
59e0b49
removed current_tpu_index
May 14, 2020
f22d90d
Merge branch 'master' into feature/1539_tpu_train_parallel
williamFalcon May 17, 2020
940f70b
Update CHANGELOG.md
Borda May 17, 2020
ec300ee
Update trainer.py
Borda May 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ class TrainerDPMixin(ABC):
data_parallel_device_ids: ...
logger: Union[LightningLoggerBase, bool]
progress_bar_callback: ...
tpu_id: int

@property
@abstractmethod
Expand Down Expand Up @@ -443,7 +444,7 @@ def __transfer_data_to_device(self, batch, device, gpu_id=None):
if device == 'tpu' and XLA_AVAILABLE:
# base case: object can be directly moved using `to`
if callable(getattr(batch, 'to', None)):
return batch.to(xm.xla_device())
return batch.to(xm.xla_device(self.tpu_id))

if device == 'gpu':
# base case: object can be directly moved using `cuda` or `to`
Expand Down Expand Up @@ -498,7 +499,7 @@ def single_gpu_train(self, model):

def tpu_train(self, tpu_core_idx, model):
# put model on tpu
model.to(xm.xla_device())
model.to(xm.xla_device(self.tpu_id))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this now makes it ONLY possible to train on 1 core no? not multiple cores

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so... @lezwon ^^

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have noticed that if self.tpu_id is None and I use xmp.spawn, the model trains at the same speed it trains when all cores are being used. So I assumed that all cores are being used. I could add some logging to confirm. Or just add a conditional for xm.xla_device() maybe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONLY when the user requests a specific TPU index should we use
model.to(xm.xla_device(self.tpu_id)) otherwise, leave it as it was.

@Borda we need TPU tests to make sure this PR doesn't break functionality


# get the appropriate tpu ranks
self.tpu_local_core_rank = xm.get_local_ordinal()
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class TrainerEvaluationLoopMixin(ABC):
val_dataloaders: DataLoader
use_tpu: bool
reload_dataloaders_every_epoch: ...
tpu_id: int

# Callback system
on_validation_batch_start: Callable
Expand Down Expand Up @@ -249,8 +250,8 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
dl_outputs = []

# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
device = xm.xla_device()
if self.use_tpu and self.tpu_id is None:
device = xm.xla_device(self.tpu_id)
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)

Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
gpus: Optional[Union[List[int], str, int]] = None,
auto_select_gpus: bool = False,
num_tpu_cores: Optional[int] = None,
tpu_id: Optional[int] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I can use only one TPU? not several with indexes like GPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not as per my knowledge. xla_distributed only supports 1 or 8 cores. We can't selectively choose the cores.
Ref: https://pytorch.org/xla/release/1.5/index.html#torch_xla.distributed.xla_multiprocessing.spawn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tpu_id is not needed...

Copy link
Contributor Author

@lezwon lezwon May 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamFalcon I have replaced it with tpu_cores as you suggested. Valid values are 1/8/[<1-(max_cores)>]

log_gpu_memory: Optional[str] = None,
progress_bar_refresh_rate: int = 1,
overfit_pct: float = 0.0,
Expand Down Expand Up @@ -321,6 +322,8 @@ def __init__(
self.num_tpu_cores = num_tpu_cores
assert num_tpu_cores in [1, 8, None], 'num_tpu_cores can only be 1 or 8'

self.tpu_id = tpu_id

if num_processes != 1 and distributed_backend != "ddp_cpu":
rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.")
self.num_processes = num_processes
Expand Down Expand Up @@ -775,7 +778,10 @@ def fit(
self.model = model

# train
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)
if self.tpu_id is not None:
self.tpu_train(self.tpu_id, model)
else:
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)

# load weights if not interrupted
self.load_spawn_weights(model)
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ class TrainerTrainLoopMixin(ABC):
total_batch_idx: int
checkpoint_callback: ...
terminate_on_nan: bool
tpu_id: int

# Callback system
callbacks: List[Callback]
Expand Down Expand Up @@ -393,8 +394,8 @@ def run_training_epoch(self):
train_dataloader = self.train_dataloader

# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
device = xm.xla_device()
if self.use_tpu and self.tpu_id is None:
device = xm.xla_device(self.tpu_id)
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)

Expand Down