Skip to content

Commit

Permalink
fix flake8 for new plugins (#5951)
Browse files Browse the repository at this point in the history
* flake8

* fix cyclic import

* isort
  • Loading branch information
awaelchli authored Feb 18, 2021
1 parent 6cc1a06 commit fc9bb53
Show file tree
Hide file tree
Showing 17 changed files with 38 additions and 56 deletions.
20 changes: 7 additions & 13 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Iterable, Optional, Union

import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import (
ApexMixedPrecisionPlugin,
MixedPrecisionPlugin,
NativeMixedPrecisionPlugin,
PrecisionPlugin,
)
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
Expand Down Expand Up @@ -64,7 +58,7 @@ def __init__(
self.lr_schedulers = None
self.optimizer_frequencies = None

def setup(self, trainer: "Trainer", model: LightningModule) -> None:
def setup(self, trainer, model: LightningModule) -> None:
"""
Connects the plugins to the training process, creates optimizers
Expand All @@ -76,13 +70,13 @@ def setup(self, trainer: "Trainer", model: LightningModule) -> None:
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)

def start_training(self, trainer: 'Trainer'):
def start_training(self, trainer):
self.training_type_plugin.start_training(trainer)

def start_testing(self, trainer: 'Trainer'):
def start_testing(self, trainer):
self.training_type_plugin.start_testing(trainer)

def start_predicting(self, trainer: 'Trainer'):
def start_predicting(self, trainer):
self.training_type_plugin.start_predicting(trainer)

def pre_dispatch(self) -> None:
Expand Down Expand Up @@ -310,7 +304,7 @@ def on_train_end(self) -> None:
"""Hook to do something at the end of the training"""
pass

def setup_optimizers(self, trainer: "Trainer"):
def setup_optimizers(self, trainer):
"""creates optimizers and schedulers
Args:
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Generator, Optional, overload, Sequence, Tuple
from typing import Generator, Optional, Sequence, Tuple

import torch
from torch.nn import Module


Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

class PrecisionPlugin(Plugin):
""" Plugin handling the precision-specific parts of the training.
The static classattributes EPSILON and precision must be overwritten in child-classes and their default values reflect fp32 training
The static classattributes EPSILON and precision must be overwritten in child-classes and their
default values reflect fp32 training.
"""
EPSILON = 1e-6
precision = 32
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/tpu_bfloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ class TPUHalfPrecisionPlugin(PrecisionPlugin):

def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
os.environ["XLA_USE_BF16"] = str(1)
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
30 changes: 15 additions & 15 deletions pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin
from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments import SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
Expand Down Expand Up @@ -120,7 +119,7 @@ def _call_children_scripts(self):
command = sys.argv
try:
full_path = path_lib(command[0])
except Exception as e:
except Exception:
full_path = os.path.abspath(command[0])

command[0] = full_path
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

def init_optimizers(self, trainer: "Trainer", model: LightningModule) -> Tuple[List, List, List]:
def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]:
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
# User may have specified config options instead in configure_optimizers, but this is handled
# via `_initialize_deepspeed_train`
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ def start_training(self, trainer):
hvd.join()

def start_testing(self, trainer):
with ExitStack() as stack:
with ExitStack():
self._results = trainer.run_test()

# Make sure all workers have finished training before returning to the user
hvd.join()

def start_predicting(self, trainer):
with ExitStack() as stack:
with ExitStack():
# set up training routine
self._results = trainer.run_predict()

Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp


Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.
import os
from contextlib import suppress
from typing import List, Optional, Sequence
from typing import List, Optional

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _RPC_AVAILABLE
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,11 @@ def post_training_step(self):
if self.main_rpc_process:
super().post_training_step()

def start_training(self, trainer: 'Trainer') -> None:
def start_training(self, trainer) -> None:
if self.main_rpc_process:
super().start_training(trainer)

def start_testing(self, trainer: 'Trainer') -> None:
def start_testing(self, trainer) -> None:
if self.main_rpc_process:
super().start_testing(trainer)

Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only

Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Union

import torch
from torch._C import device

from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import io
import os
from typing import Optional, Union

Expand All @@ -11,7 +10,6 @@
from pytorch_lightning.utilities.apply_func import move_data_to_device

if _TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm


Expand Down Expand Up @@ -68,4 +66,4 @@ def on_save(self, checkpoint: dict) -> dict:

@property
def is_distributed(self):
return False
return False
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import os
import re
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -206,7 +206,7 @@ def post_dispatch(self) -> None:
# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
results = self.mp_queue.get()
self._results = self.mp_queue.get()

# transfer back the best path to the trainer
if self.lightning_module.trainer.checkpoint_callback is not None:
Expand Down
5 changes: 0 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ exclude =
*.egg
build
temp
# TODO: temporary until accelerator refactor finished
pytorch_lightning/accelerators/accelerator.py
pytorch_lightning/plugins/training_type
pytorch_lightning/plugins/precision
pytorch_lightning/plugins/base_plugin.py

select = E,W,F
doctests = True
Expand Down
4 changes: 2 additions & 2 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_if_test_works_after_train(tmpdir):
model = BoringModel()
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
assert trainer.test(model) == 1
assert len(trainer.test(model)) == 1


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
Expand Down Expand Up @@ -119,4 +119,4 @@ def on_post_move_to_device(self):
assert result

assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list))
assert trainer.test(model) == 1
assert len(trainer.test(model)) == 1

0 comments on commit fc9bb53

Please sign in to comment.