Skip to content

Commit

Permalink
Typing for tests 1/n (#6313)
Browse files Browse the repository at this point in the history
* typing

* yapf

* typing
  • Loading branch information
Borda authored Mar 9, 2021
1 parent fc6d402 commit 55dd3a4
Show file tree
Hide file tree
Showing 16 changed files with 113 additions and 119 deletions.
12 changes: 7 additions & 5 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License

import os
from typing import Optional
from unittest import mock

import pytest
Expand All @@ -30,6 +31,7 @@
DDPSpawnPlugin,
DDPSpawnShardedPlugin,
DeepSpeedPlugin,
ParallelPlugin,
PrecisionPlugin,
SingleDevicePlugin,
)
Expand Down Expand Up @@ -408,10 +410,8 @@ def test_ipython_incompatible_backend_error(*_):
["accelerator", "plugin"],
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],
)
def test_plugin_accelerator_choice(accelerator, plugin):
"""
Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent.
"""
def test_plugin_accelerator_choice(accelerator: Optional[str], plugin: str):
"""Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent."""
trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2)
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)

Expand All @@ -428,7 +428,9 @@ def test_plugin_accelerator_choice(accelerator, plugin):
])
@mock.patch('torch.cuda.is_available', return_value=True)
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_multi_node_gpu(mock_is_available, mock_device_count, accelerator, plugin, tmpdir):
def test_accelerator_choice_multi_node_gpu(
mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin
):
trainer = Trainer(
accelerator=accelerator,
default_root_dir=tmpdir,
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_callback_hook_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


@pytest.mark.parametrize("single_cb", [False, True])
def test_train_step_no_return(tmpdir, single_cb):
def test_train_step_no_return(tmpdir, single_cb: bool):
"""
Tests that only training_step can be used
"""
Expand Down
21 changes: 12 additions & 9 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
import os
import pickle
from typing import List, Optional
from unittest import mock

import cloudpickle
Expand Down Expand Up @@ -119,7 +120,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir):
([6, 5, 6, 5, 5, 5], 3, 4),
],
)
def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch):
def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expected_stop_epoch: int):
"""Test to ensure that early stopping is not triggered before patience is exhausted."""

class ModelOverrideValidationReturn(BoringModel):
Expand All @@ -142,7 +143,7 @@ def validation_epoch_end(self, outputs):
assert trainer.current_epoch == expected_stop_epoch


@pytest.mark.parametrize('validation_step', ['base', None])
@pytest.mark.parametrize('validation_step_none', [True, False])
@pytest.mark.parametrize(
"loss_values, patience, expected_stop_epoch",
[
Expand All @@ -151,7 +152,9 @@ def validation_epoch_end(self, outputs):
([6, 5, 6, 5, 5, 5], 3, 4),
],
)
def test_early_stopping_patience_train(tmpdir, validation_step, loss_values, patience, expected_stop_epoch):
def test_early_stopping_patience_train(
tmpdir, validation_step_none: bool, loss_values: list, patience: int, expected_stop_epoch: int
):
"""Test to ensure that early stopping is not triggered before patience is exhausted."""

class ModelOverrideTrainReturn(BoringModel):
Expand All @@ -163,7 +166,7 @@ def training_epoch_end(self, outputs):

model = ModelOverrideTrainReturn()

if validation_step is None:
if validation_step_none:
model.validation_step = None

early_stop_callback = EarlyStopping(monitor="train_loss", patience=patience, verbose=True)
Expand Down Expand Up @@ -254,7 +257,7 @@ def validation_epoch_end(self, outputs):


@pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)])
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, min_steps, min_epochs):
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int):
"""Excepted Behaviour:
IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being triggered,
THEN the trainer should continue until reaching `trainer.global_step` == `min_steps`, and stop.
Expand Down Expand Up @@ -386,10 +389,10 @@ def on_train_end(self) -> None:
marks=RunIf(skip_windows=True)),
],
)
def test_multiple_early_stopping_callbacks(callbacks, expected_stop_epoch, accelerator, num_processes, tmpdir):
"""
Ensure when using multiple early stopping callbacks we stop if any signals we should stop.
"""
def test_multiple_early_stopping_callbacks(
tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, accelerator: Optional[str], num_processes: int
):
"""Ensure when using multiple early stopping callbacks we stop if any signals we should stop."""

model = EarlyStoppingModel(expected_stop_epoch)

Expand Down
8 changes: 3 additions & 5 deletions tests/callbacks/test_lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ def test_lr_monitor_single_lr(tmpdir):


@pytest.mark.parametrize('opt', ['SGD', 'Adam'])
def test_lr_monitor_single_lr_with_momentum(tmpdir, opt):
"""
Test that learning rates and momentum are extracted and logged for single lr scheduler.
"""
def test_lr_monitor_single_lr_with_momentum(tmpdir, opt: str):
"""Test that learning rates and momentum are extracted and logged for single lr scheduler."""

class LogMomentumModel(BoringModel):

Expand Down Expand Up @@ -170,7 +168,7 @@ def test_lr_monitor_no_logger(tmpdir):


@pytest.mark.parametrize("logging_interval", ['step', 'epoch'])
def test_lr_monitor_multi_lrs(tmpdir, logging_interval):
def test_lr_monitor_multi_lrs(tmpdir, logging_interval: str):
""" Test that learning rates are extracted and logged for multi lr schedulers. """
tutils.reset_seed()

Expand Down
29 changes: 20 additions & 9 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
import sys
from typing import Optional, Union
from unittest import mock
from unittest.mock import ANY, call, Mock

Expand All @@ -36,7 +37,7 @@
([ProgressBar(refresh_rate=2)], 1),
]
)
def test_progress_bar_on(tmpdir, callbacks, refresh_rate):
def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
"""Test different ways the progress bar can be turned on."""

trainer = Trainer(
Expand All @@ -60,7 +61,7 @@ def test_progress_bar_on(tmpdir, callbacks, refresh_rate):
([ModelCheckpoint(dirpath='../trainer')], 0),
]
)
def test_progress_bar_off(tmpdir, callbacks, refresh_rate):
def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int]):
"""Test different ways the progress bar can be turned off."""

trainer = Trainer(
Expand Down Expand Up @@ -165,7 +166,7 @@ def test_progress_bar_fast_dev_run(tmpdir):


@pytest.mark.parametrize('refresh_rate', [0, 1, 50])
def test_progress_bar_progress_refresh(tmpdir, refresh_rate):
def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
"""Test that the three progress bars get correctly updated when using different refresh rates."""

model = BoringModel()
Expand Down Expand Up @@ -219,7 +220,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal


@pytest.mark.parametrize('limit_val_batches', (0, 5))
def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches):
def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int):
"""
Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument.
"""
Expand Down Expand Up @@ -309,7 +310,9 @@ def init_test_tqdm(self):
[5, 2, 6, [6, 1], [2]],
]
)
def test_main_progress_bar_update_amount(tmpdir, train_batches, val_batches, refresh_rate, train_deltas, val_deltas):
def test_main_progress_bar_update_amount(
tmpdir, train_batches: int, val_batches: int, refresh_rate: int, train_deltas: list, val_deltas: list
):
"""
Test that the main progress updates with the correct amount together with the val progress. At the end of
the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate.
Expand All @@ -336,7 +339,7 @@ def test_main_progress_bar_update_amount(tmpdir, train_batches, val_batches, ref
[3, 1, [1, 1, 1]],
[5, 3, [3, 2]],
])
def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate, test_deltas):
def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, test_deltas: list):
"""
Test that test progress updates with the correct amount.
"""
Expand Down Expand Up @@ -379,10 +382,18 @@ def training_step(self, batch, batch_idx):


@pytest.mark.parametrize(
"input_num, expected", [[1, '1'], [1.0, '1.000'], [0.1, '0.100'], [1e-3, '0.001'], [1e-5, '1e-5'], ['1.0', '1.000'],
['10000', '10000'], ['abc', 'abc']]
"input_num, expected", [
[1, '1'],
[1.0, '1.000'],
[0.1, '0.100'],
[1e-3, '0.001'],
[1e-5, '1e-5'],
['1.0', '1.000'],
['10000', '10000'],
['abc', 'abc'],
]
)
def test_tqdm_format_num(input_num, expected):
def test_tqdm_format_num(input_num: Union[str, int, float], expected: str):
""" Check that the specialized tqdm.format_num appends 0 to floats and strings """
assert tqdm.format_num(input_num) == expected

Expand Down
11 changes: 7 additions & 4 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from collections import OrderedDict
from logging import INFO
from typing import Union

import pytest
import torch
Expand Down Expand Up @@ -144,7 +145,8 @@ def test_pruning_misconfiguration():
)
@pytest.mark.parametrize("use_lottery_ticket_hypothesis", [False, True])
def test_pruning_callback(
tmpdir, use_global_unstructured, parameters_to_prune, pruning_fn, use_lottery_ticket_hypothesis
tmpdir, use_global_unstructured: bool, parameters_to_prune: bool,
pruning_fn: Union[str, pytorch_prune.BasePruningMethod], use_lottery_ticket_hypothesis: bool
):
train_with_pruning_callback(
tmpdir,
Expand All @@ -158,7 +160,7 @@ def test_pruning_callback(
@RunIf(special=True)
@pytest.mark.parametrize("parameters_to_prune", [False, True])
@pytest.mark.parametrize("use_global_unstructured", [False, True])
def test_pruning_callback_ddp(tmpdir, use_global_unstructured, parameters_to_prune):
def test_pruning_callback_ddp(tmpdir, use_global_unstructured: bool, parameters_to_prune: bool):
train_with_pruning_callback(
tmpdir,
parameters_to_prune=parameters_to_prune,
Expand All @@ -179,7 +181,7 @@ def test_pruning_callback_ddp_cpu(tmpdir):


@pytest.mark.parametrize("resample_parameters", (False, True))
def test_pruning_lth_callable(tmpdir, resample_parameters):
def test_pruning_lth_callable(tmpdir, resample_parameters: bool):
model = TestModel()

class ModelPruningTestCallback(ModelPruning):
Expand Down Expand Up @@ -218,7 +220,7 @@ def apply_lottery_ticket_hypothesis(self):


@pytest.mark.parametrize("make_pruning_permanent", (False, True))
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool):
seed_everything(0)
model = TestModel()
pruning_kwargs = {
Expand All @@ -228,6 +230,7 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
}
p1 = ModelPruning("l1_unstructured", amount=0.5, apply_pruning=lambda e: not e % 2, **pruning_kwargs)
p2 = ModelPruning("random_unstructured", amount=0.25, apply_pruning=lambda e: e % 2, **pruning_kwargs)

trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
Expand Down
5 changes: 3 additions & 2 deletions tests/callbacks/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Callable, Union

import pytest
import torch
Expand All @@ -28,7 +29,7 @@
@pytest.mark.parametrize("observe", ['average', pytest.param('histogram', marks=RunIf(min_torch="1.5"))])
@pytest.mark.parametrize("fuse", [True, False])
@RunIf(quantization=True)
def test_quantization(tmpdir, observe, fuse):
def test_quantization(tmpdir, observe: str, fuse: bool):
"""Parity test for quant model"""
seed_everything(42)
dm = RegressDataModule()
Expand Down Expand Up @@ -122,7 +123,7 @@ def custom_trigger_last(trainer):
]
)
@RunIf(quantization=True)
def test_quantization_triggers(tmpdir, trigger_fn, expected_count):
def test_quantization_triggers(tmpdir, trigger_fn: Union[None, int, Callable], expected_count: int):
"""Test how many times the quant is called"""
dm = RegressDataModule()
qmodel = RegressionModel()
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_swa_callback_1_gpu(tmpdir):

@RunIf(min_torch="1.6.0")
@pytest.mark.parametrize("batchnorm", (True, False))
def test_swa_callback(tmpdir, batchnorm):
def test_swa_callback(tmpdir, batchnorm: bool):
train_with_swa(tmpdir, batchnorm=batchnorm)


Expand All @@ -155,7 +155,7 @@ def test_swa_raises():
@pytest.mark.parametrize('stochastic_weight_avg', [False, True])
@pytest.mark.parametrize('use_callbacks', [False, True])
@RunIf(min_torch="1.6.0")
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks, stochastic_weight_avg):
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks: bool, stochastic_weight_avg: bool):
"""Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer"""

class TestModel(BoringModel):
Expand Down
12 changes: 8 additions & 4 deletions tests/checkpointing/test_checkpoint_callback_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_mc_called(tmpdir):
['epochs', 'val_check_interval', 'expected'],
[(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)],
)
def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval, expected):
def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_interval: float, expected: int):

model = BoringModel()
trainer = Trainer(
Expand All @@ -68,9 +68,13 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval,


@mock.patch('torch.save')
@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 1.0, 2),
(2, 1, 0.25, 4), (2, 2, 0.3, 7)])
def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected):
@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [
(1, 1, 1.0, 1),
(2, 2, 1.0, 2),
(2, 1, 0.25, 4),
(2, 2, 0.3, 7),
])
def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int):

class TestModel(BoringModel):

Expand Down
2 changes: 1 addition & 1 deletion tests/checkpointing/test_legacy_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
"1.2.2",
]
)
def test_resume_legacy_checkpoints(tmpdir, pl_version):
def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
path_dir = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)

# todo: make this as mock, so it is cleaner...
Expand Down
Loading

0 comments on commit 55dd3a4

Please sign in to comment.