Skip to content

Commit

Permalink
Fix some pyright member access errors in training module (#2121)
Browse files Browse the repository at this point in the history
* Fix pyright member access errors in training module

* Fix Trainer instantiation error due to inheritence order

* Add GH workflow for pyright

* Fix more pyright errors in trainer module

* Add pyrightconfig and setup python environment in type-check workflow

* Exclude pyrightconfig.json

* suggestions

Co-authored-by: Jirka <jirka@pytorchlightning.ai>
  • Loading branch information
uditarora and Borda authored Jun 12, 2020
1 parent 9045b6c commit 08573d0
Show file tree
Hide file tree
Showing 18 changed files with 192 additions and 37 deletions.
67 changes: 67 additions & 0 deletions .github/workflows/python-type-check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
name: "Python static type checking"
on:
# Trigger the workflow on push or pull request,
# but only for the master branch
push:
branches:
- master
pull_request:
branches:
- master

jobs:
python_type_checking:
name: Python static type checking with Pyright
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-18.04]
python-version: [3.7]

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 15
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}

# Note: This uses an internal pip API and may not always work
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
- name: Get pip cache
id: pip-cache
run: |
python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)"
- name: Cache pip
uses: actions/cache@v1
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-extra.txt') }}
restore-keys: |
${{ runner.os }}-${{ matrix.python-version }}-pip-
- name: Install dependencies
run: |
# python -m pip install --upgrade --user pip
pip install -r requirements.txt -U -f https://download.pytorch.org/whl/torch_stable.html -q
HOROVOD_BUILD_ARCH_FLAGS="-mfma" pip install -r ./tests/requirements-devel.txt -q
# pip install tox coverage
python --version
pip --version
pip list
shell: bash

- name: Set up node
uses: actions/setup-node@v1

- name: Install pyright
run: |
npm install pyright
- name: Run type checking
run: |
$(npm bin)/pyright --project .pyrightconfig.json
2 changes: 1 addition & 1 deletion .mergify.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pull_request_rules:
# no requested chnages from any reviewer
- "#changes-requested-reviews-by=0"
# this serves as ALL check has to pass as we have actually 27 tests in total
- "#status-success>=29"
- "#status-success>=30"
# this is just in case since we rely on GPU tests (note: redundand to the above)
- status-success=continuous-integration/drone/pr
# this is patter-like, unofrunatly serves as `any(...)` (note: redundand to the above)
Expand Down
33 changes: 33 additions & 0 deletions .pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"include": [
"pytorch_lightning"
],

"ignore": [
"pytorch_lightning/__init__.py",
"pytorch_lightning/callbacks",
"pytorch_lightning/core",
"pytorch_lightning/loggers",
"pytorch_lightning/logging",
"pytorch_lightning/metrics",
"pytorch_lightning/overrides",
"pytorch_lightning/profiler",
"pytorch_lightning/pt_overrides",
"pytorch_lightning/root_module",
"pytorch_lightning/utilities",
"pytorch_lightning/trainer/data_loading.py",
"pytorch_lightning/trainer/deprecated_api.py",
"pytorch_lightning/trainer/distrib_parts.py",
"pytorch_lightning/trainer/evaluation_loop.py",
"pytorch_lightning/trainer/logging.py",
"pytorch_lightning/trainer/lr_finder.py",
"pytorch_lightning/trainer/optimizers.py",
"pytorch_lightning/trainer/supporters.py",
"pytorch_lightning/trainer/trainer.py",
"pytorch_lightning/trainer/training_io.py",
"pytorch_lightning/trainer/training_loop.py",
"pytorch_lightning/trainer/training_tricks.py"
],

"reportMissingImports": false
}
3 changes: 3 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ include requirements-extra.txt
exclude *.yml
exclude *.yaml

# Exclude pyright config
exclude .pyrightconfig.json

prune .git
prune .github
prune .circleci
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Union, List
from typing import List, Callable, Optional


from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
Expand All @@ -14,10 +14,10 @@ class TrainerCallbackConfigMixin(ABC):
# the proper values/initialisation should be done in child class
callbacks: List[Callback]
default_root_dir: str
logger: Union[LightningLoggerBase, bool]
weights_save_path: str
logger: LightningLoggerBase
weights_save_path: Optional[str]
ckpt_path: str
checkpoint_callback: ModelCheckpoint
checkpoint_callback: Optional[ModelCheckpoint]

@property
@abstractmethod
Expand All @@ -28,6 +28,10 @@ def slurm_job_id(self) -> int:
def save_checkpoint(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def is_overridden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def configure_checkpoint_callback(self):
"""
Weight path set in this priority:
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import platform
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Callable
from typing import Union, List, Tuple, Callable, Optional

import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
Expand Down Expand Up @@ -76,6 +76,9 @@ class TrainerDataLoadingMixin(ABC):
val_percent_check: float
test_percent_check: float
replace_sampler_ddp: bool
num_nodes: int
num_processes: int
distributed_backend: Optional[str]

@abstractmethod
def is_overridden(self, *args):
Expand Down Expand Up @@ -143,6 +146,7 @@ def _get_distributed_sampler(self, dataloader):
'ddp2': self.num_nodes,
'ddp_cpu': self.num_processes * self.num_nodes
}
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.proc_rank)
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler
Expand Down Expand Up @@ -197,7 +201,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[int, List[DataLoader]]:
def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[Union[int, float], List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
Args:
Expand Down
17 changes: 14 additions & 3 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def train_fx(trial_hparams, cluster_manager, _):
import os
import re
from abc import ABC, abstractmethod
from typing import Union
from typing import Union, List, Optional, Callable, Tuple
import subprocess
import sys
from time import sleep
Expand Down Expand Up @@ -151,16 +151,19 @@ class TrainerDDPMixin(ABC):
# the proper values/initialisation should be done in child class
on_gpu: bool
num_gpu_nodes: int
gpus: List[int]
logger: Union[LightningLoggerBase, bool]
checkpoint_callback: Union[ModelCheckpoint, bool]
data_parallel_device_ids: ...
distributed_backend: str
distributed_backend: Optional[str]
amp_level: str
use_tpu: bool
default_root_dir: str
use_native_amp: bool
progress_bar_callback: ...
num_processes: int
num_nodes: int
node_rank: int

@property
@abstractmethod
Expand All @@ -181,7 +184,15 @@ def run_pretrain_routine(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def init_optimizers(self, *args):
def init_optimizers(self, *args) -> Tuple[List, List, List]:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def reinit_scheduler_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def save_checkpoint(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def init_tpu(self):
Expand Down
27 changes: 18 additions & 9 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import time
import random
import torch
from typing import Union, Callable, Any, List, Optional
from typing import Union, Callable, Any, List, Optional, Tuple

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
Expand Down Expand Up @@ -63,9 +63,10 @@ class TrainerDPMixin(ABC):
use_tpu: bool
use_native_amp: bool
data_parallel_device_ids: ...
logger: Union[LightningLoggerBase, bool]
progress_bar_callback: ...
tpu_id: int
tpu_id: Optional[int]
on_colab_kaggle: str
save_spawn_weights: Callable

@property
@abstractmethod
Expand All @@ -77,7 +78,15 @@ def run_pretrain_routine(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def init_optimizers(self, *args):
def init_optimizers(self, *args) -> Tuple[List, List, List]:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def get_model(self) -> LightningModule:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def reinit_scheduler_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def copy_trainer_model_properties(self, model):
Expand Down Expand Up @@ -294,7 +303,7 @@ def filter_named_parameters(model, optimizer):
hvd.join()


def normalize_parse_gpu_string_input(s):
def normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
if isinstance(s, str):
if s == '-1':
return -1
Expand Down Expand Up @@ -369,7 +378,7 @@ def sanitize_gpu_ids(gpus: List[int]) -> List[int]:
return gpus


def parse_gpu_ids(gpus: Union[int, str, List]) -> Optional[List[int]]:
def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[int]]:
"""
Parses the GPU ids given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer`.
Expand Down Expand Up @@ -404,10 +413,10 @@ def parse_gpu_ids(gpus: Union[int, str, List]) -> Optional[List[int]]:

gpus = normalize_parse_gpu_string_input(gpus)
gpus = normalize_parse_gpu_input_to_list(gpus)
gpus = sanitize_gpu_ids(gpus)

if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")
gpus = sanitize_gpu_ids(gpus)

return gpus


Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,13 @@

from abc import ABC, abstractmethod
from pprint import pprint
from typing import Callable
from typing import Callable, Optional

import torch
from torch.utils.data import DataLoader

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.profiler.profilers import BaseProfiler
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -159,6 +160,8 @@ class TrainerEvaluationLoopMixin(ABC):
use_dp: bool
use_ddp2: bool
use_horovod: bool
use_amp: bool
use_native_amp: bool
single_gpu: bool
data_parallel_device_ids: ...
model: LightningModule
Expand All @@ -174,7 +177,8 @@ class TrainerEvaluationLoopMixin(ABC):
val_dataloaders: DataLoader
use_tpu: bool
reload_dataloaders_every_epoch: ...
tpu_id: int
tpu_id: Optional[int]
profiler: BaseProfiler

# Callback system
on_validation_batch_start: Callable
Expand All @@ -191,7 +195,7 @@ def copy_trainer_model_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def get_model(self):
def get_model(self) -> LightningModule:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Union, Iterable
from typing import Iterable, Optional

import torch

Expand All @@ -15,7 +15,7 @@ class TrainerLoggingMixin(ABC):
current_epoch: int
on_gpu: bool
log_gpu_memory: ...
logger: Union[LightningLoggerBase, bool]
logger: Optional[LightningLoggerBase]
progress_bar_metrics: ...
global_step: int
proc_rank: int
Expand Down
Loading

0 comments on commit 08573d0

Please sign in to comment.