Skip to content

Commit

Permalink
refactor 1/n for v1.0.0 (#2704)
Browse files Browse the repository at this point in the history
* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator

* reactor into gpu accelerator
  • Loading branch information
williamFalcon committed Jul 25, 2020
1 parent 5dc8c1d commit 071e09f
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 23 deletions.
1 change: 1 addition & 0 deletions .pyrightconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"pytorch_lightning/__init__.py",
"pytorch_lightning/callbacks",
"pytorch_lightning/core",
"pytorch_lightning/accelerators",
"pytorch_lightning/loggers",
"pytorch_lightning/logging",
"pytorch_lightning/metrics",
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
exclude_patterns = [
'api/pytorch_lightning.rst',
'api/pl_examples.*',
'api/pytorch_lightning.accelerators.*',
'api/modules.rst',
'PULL_REQUEST_TEMPLATE.md',

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
46 changes: 46 additions & 0 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch


class GPUAccelerator(object):

def __init__(self, trainer):
self.trainer = trainer

def setup(self, model):
# call setup
if not self.trainer.testing:
self.trainer.setup('fit')
model.setup('fit')

model.cuda(self.trainer.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# TODO: remove with dropping NVIDIA AMP support
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
if self.trainer.use_amp and not native_amp_available:
self._setup_nvidia_apex(model)

def _setup_nvidia_apex(self, model):
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
14 changes: 14 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from abc import abstractmethod
from argparse import ArgumentParser, Namespace
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Lightning supports model training on a cluster managed by SLURM in the following cases:
Expand Down
36 changes: 14 additions & 22 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Root module for all distributed operations in Lightning.
Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.
Expand Down Expand Up @@ -165,28 +179,6 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)

def single_gpu_train(self, model):
# call setup
if not self.testing:
self.setup('fit')
model.setup('fit')

model.cuda(self.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

# TODO: remove with dropping NVIDIA AMP support
if self.use_amp and not NATIVE_AMP_AVALAIBLE:
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

results = self.run_pretrain_routine(model)
return results

def tpu_train(self, tpu_core_idx, model):
# call setup after the ddp process has connected
if not self.testing:
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from typing import List, Tuple

Expand Down
20 changes: 19 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
import warnings
Expand Down Expand Up @@ -37,6 +51,7 @@
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down Expand Up @@ -646,6 +661,7 @@ def __init__(
# tracks internal state for debugging
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.accelerator = None

# Callback system
self.on_init_end()
Expand Down Expand Up @@ -1057,7 +1073,9 @@ def fit(
results = self.horovod_train(model)

elif self.single_gpu:
results = self.single_gpu_train(model)
self.accelerator = GPUAccelerator(self)
self.accelerator.setup(model)
results = self.run_pretrain_routine(model)

elif self.use_tpu: # pragma: no-cover
rank_zero_info(f'training on {self.tpu_cores} TPU cores')
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
The lightning training loop handles everything except the actual computations of your model.
To decide what will happen in your training loop, define the `training_step` function.
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import sys
from abc import ABC, abstractmethod
Expand Down

0 comments on commit 071e09f

Please sign in to comment.