diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 75d3eeec..067aef78 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,11 +28,6 @@ repos: args: [--py37-plus] name: Upgrade code - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - - repo: https://github.com/psf/black rev: 23.3.0 hooks: @@ -54,15 +49,14 @@ repos: hooks: - id: yesqa additional_dependencies: - - flake8-docstrings - pep8-naming - - flake8-comprehensions - flake8-pytest-style - - flake8-return - - flake8-simplify + - flake8-bandit + - flake8-builtins + - flake8-bugbear - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.270 + rev: v0.0.272 hooks: - id: ruff args: ["--fix"] diff --git a/pyproject.toml b/pyproject.toml index 4b8dadb6..117cdbbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,14 +74,28 @@ ignore_missing_imports = true select = [ "E", "W", # see: https://pypi.org/project/pycodestyle "F", # see: https://pypi.org/project/pyflakes + "I", #see: https://pypi.org/project/isort/ "D", # see: https://pypi.org/project/pydocstyle "N", # see: https://pypi.org/project/pep8-naming + "S", # see: https://pypi.org/project/flake8-bandit ] extend-select = [ + "A", # see: https://pypi.org/project/flake8-builtins + "B", # see: https://pypi.org/project/flake8-bugbear "C4", # see: https://pypi.org/project/flake8-comprehensions "PT", # see: https://pypi.org/project/flake8-pytest-style "RET", # see: https://pypi.org/project/flake8-return "SIM", # see: https://pypi.org/project/flake8-simplify + "YTT", # see: https://pypi.org/project/flake8-2020 + "ANN", # see: https://pypi.org/project/flake8-annotations + "TID", # see: https://pypi.org/project/flake8-tidy-imports/ + "T10", # see: https://pypi.org/project/flake8-debugger + "Q", # see: https://pypi.org/project/flake8-quotes + "RUF", # Ruff-specific rules + "EXE", # see: https://pypi.org/project/flake8-executable + "ISC", # see: https://pypi.org/project/flake8-implicit-str-concat + "PIE", # see: https://pypi.org/project/flake8-pie + "PLE", # see: https://pypi.org/project/pylint/ ] ignore = [ "E731", @@ -100,20 +114,35 @@ extend-select = [ ignore-init-module-imports = true [tool.ruff.per-file-ignores] -"setup.py" = ["D100", "SIM115"] +"setup.py" = ["ANN202", "D100", "SIM115"] "__about__.py" = ["D100"] "__init__.py" = ["D100"] "src/**" = [ + "ANN101", # Missing type annotation for `self` in method + "ANN102", # Missing type annotation for `cls` in classmethod + "ANN401", # Dynamically typed expressions (typing.Any) + "B905", # `zip()` without an explicit `strict=` parameter "D100", # Missing docstring in public module "D107", # Missing docstring in `__init__` ] "tests/**" = [ + "ANN001", # Missing type annotation for function argument + "ANN101", # Missing type annotation for `self` in method + "ANN201", # Missing return type annotation for public function + "ANN202", # Missing return type annotation for private function + "ANN204", # Missing return type annotation for special method + "ANN401", # Dynamically typed expressions (typing.Any) + "B905", # `zip()` without an explicit `strict=` parameter "D100", # Missing docstring in public module "D101", # Missing docstring in public class "D102", # Missing docstring in public method "D103", # Missing docstring in public function "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method "D107", # Missing docstring in `__init__` + "S101", # Use of `assert` detected + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "B028", # No explicit `stacklevel` keyword argument found ] [tool.ruff.pydocstyle] diff --git a/setup.py b/setup.py index a2390418..48548c87 100755 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ _PATH_REQUIRE = os.path.join(_PATH_ROOT, "requirements") -def _load_py_module(fname, pkg="lightning_utilities"): +def _load_py_module(fname: str, pkg: str = "lightning_utilities"): spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_SOURCE, pkg, fname)) py = module_from_spec(spec) spec.loader.exec_module(py) diff --git a/src/lightning_utilities/__init__.py b/src/lightning_utilities/__init__.py index 83f1ea55..cc46773d 100644 --- a/src/lightning_utilities/__init__.py +++ b/src/lightning_utilities/__init__.py @@ -2,7 +2,7 @@ import os -from lightning_utilities.__about__ import * # noqa: F401, F403 +from lightning_utilities.__about__ import * # noqa: F403 from lightning_utilities.core.apply_func import apply_to_collection from lightning_utilities.core.enums import StrEnum from lightning_utilities.core.imports import compare_version, module_available diff --git a/src/lightning_utilities/core/apply_func.py b/src/lightning_utilities/core/apply_func.py index 5302ed2e..73a29256 100644 --- a/src/lightning_utilities/core/apply_func.py +++ b/src/lightning_utilities/core/apply_func.py @@ -3,7 +3,7 @@ # http://www.apache.org/licenses/LICENSE-2.0 # import dataclasses -from collections import defaultdict, OrderedDict +from collections import OrderedDict, defaultdict from copy import deepcopy from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union @@ -185,7 +185,8 @@ def apply_to_collections( is_namedtuple_ = is_namedtuple(data1) is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str) if (is_namedtuple_ or is_sequence) and data2 is not None: - assert len(data1) == len(data2), "Sequence collections have different sizes." + if len(data1) != len(data2): + raise ValueError("Sequence collections have different sizes.") out = [ apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for v1, v2 in zip(data1, data2) diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index 7d3445ed..e1099891 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -59,7 +59,7 @@ def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key" try: return cls.from_str(value, source) except ValueError: - warnings.warn( + warnings.warn( # noqa: B028 UserWarning(f"Invalid string: expected one of {cls._allowed_matches(source)}, but got {value}.") ) return None diff --git a/src/lightning_utilities/install/__init__.py b/src/lightning_utilities/install/__init__.py index c2773ebb..63008e4b 100644 --- a/src/lightning_utilities/install/__init__.py +++ b/src/lightning_utilities/install/__init__.py @@ -1,5 +1,5 @@ """Generic Installation tools.""" -from lightning_utilities.install.requirements import load_requirements, Requirement +from lightning_utilities.install.requirements import Requirement, load_requirements __all__ = ["load_requirements", "Requirement"] diff --git a/src/lightning_utilities/install/requirements.py b/src/lightning_utilities/install/requirements.py index eb71084e..9c97c026 100644 --- a/src/lightning_utilities/install/requirements.py +++ b/src/lightning_utilities/install/requirements.py @@ -15,7 +15,8 @@ class _RequirementWithComment(Requirement): def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.comment = comment - assert pip_argument is None or pip_argument # sanity check that it's not an empty str + if not (pip_argument is None or pip_argument): # sanity check that it's not an empty str + raise RuntimeError(f"wrong pip argument: {pip_argument}") self.pip_argument = pip_argument self.strict = self.strict_string in comment.lower() @@ -109,8 +110,10 @@ def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str >>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ['sphinx<6.0,>=4.0', ...] """ - assert unfreeze in {"none", "major", "all"} + if unfreeze not in {"none", "major", "all"}: + raise ValueError(f'unsupported option of "{unfreeze}"') path = Path(path_dir) / file_name - assert path.exists(), (path_dir, file_name, path) + if not path.exists(): + raise FileNotFoundError(f"missing file for {(path_dir, file_name, path)}") text = path.read_text() return [req.adjust(unfreeze) for req in _parse_requirements(text)] diff --git a/tests/unittests/core/test_apply_func.py b/tests/unittests/core/test_apply_func.py index da74e199..ca77bcc6 100644 --- a/tests/unittests/core/test_apply_func.py +++ b/tests/unittests/core/test_apply_func.py @@ -1,13 +1,15 @@ import dataclasses import numbers -from collections import defaultdict, namedtuple, OrderedDict +from collections import OrderedDict, defaultdict, namedtuple from dataclasses import InitVar from typing import Any, ClassVar, List, Optional import pytest +from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections from unittests.mocks import torch -from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections +_TENSOR_0 = torch.tensor(0) +_TENSOR_1 = torch.tensor(1) @dataclasses.dataclass @@ -29,10 +31,10 @@ class ModelExample: label: torch.Tensor some_constant: int = dataclasses.field(init=False) - def __post_init__(self): # noqa: D105 + def __post_init__(self): self.some_constant = 7 - def __eq__(self, o: object) -> bool: # noqa: D105 + def __eq__(self, o: object) -> bool: if not isinstance(o, ModelExample): return NotImplemented @@ -64,11 +66,11 @@ class WithInitVar: dummy: Any override: InitVar[Optional[Any]] = None - def __post_init__(self, override: Optional[Any]): # noqa: D105 + def __post_init__(self, override: Optional[Any]): if override is not None: self.dummy = override - def __eq__(self, o: object) -> bool: # noqa: D105 + def __eq__(self, o: object) -> bool: if not isinstance(o, WithInitVar): return NotImplemented if isinstance(self.dummy, torch.Tensor): @@ -79,15 +81,16 @@ def __eq__(self, o: object) -> bool: # noqa: D105 @dataclasses.dataclass class WithClassAndInitVar: - class_var: ClassVar[torch.Tensor] = torch.tensor(0) + class_var: ClassVar[torch.Tensor] = _TENSOR_0 dummy: Any - override: InitVar[Optional[Any]] = torch.tensor(1) + override: InitVar[Optional[Any]] = _TENSOR_1 - def __post_init__(self, override: Optional[Any]): # noqa: D105 + def __post_init__(self, override: Optional[Any]): if override is not None: self.dummy = override - def __eq__(self, o: object) -> bool: # noqa: D105 + def __eq__(self, o: object) -> bool: + """Equal.""" if not isinstance(o, WithClassAndInitVar): return NotImplemented if isinstance(self.dummy, torch.Tensor): @@ -206,7 +209,7 @@ def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""): # custom mappings class _CustomCollection(dict): - def __init__(self, initial_dict): + def __init__(self, initial_dict) -> None: super().__init__(initial_dict) to_reduce = _CustomCollection({"a": 1, "b": 2, "c": 3}) @@ -262,7 +265,7 @@ def fn(a, b): assert reduced == [1, 2, 3, 4] # different sizes - with pytest.raises(AssertionError, match="Sequence collections have different sizes"): + with pytest.raises(ValueError, match="Sequence collections have different sizes"): apply_to_collections([[1, 2], [3]], [4], int, fn) def fn(a, b): @@ -323,7 +326,7 @@ def fn(a, b): def test_apply_to_collection_frozen_dataclass(): @dataclasses.dataclass(frozen=True) class Foo: - input: int + var: int foo = Foo(0) with pytest.raises(ValueError, match="frozen dataclass was passed"): @@ -333,7 +336,7 @@ class Foo: def test_apply_to_collection_allow_frozen_dataclass(): @dataclasses.dataclass(frozen=True) class Foo: - input: int + var: int foo = Foo(0) result = apply_to_collection(foo, int, lambda x: x + 1, allow_frozen=True) diff --git a/tests/unittests/core/test_imports.py b/tests/unittests/core/test_imports.py index b5ff6cec..845a077c 100644 --- a/tests/unittests/core/test_imports.py +++ b/tests/unittests/core/test_imports.py @@ -2,14 +2,13 @@ import re import pytest - from lightning_utilities.core.imports import ( + ModuleAvailableCache, + RequirementCache, compare_version, get_dependency_min_version_spec, lazy_import, module_available, - ModuleAvailableCache, - RequirementCache, requires, ) diff --git a/tests/unittests/core/test_overrides.py b/tests/unittests/core/test_overrides.py index b89b074d..9c580791 100644 --- a/tests/unittests/core/test_overrides.py +++ b/tests/unittests/core/test_overrides.py @@ -1,8 +1,8 @@ from functools import partial, wraps +from typing import Any, Callable from unittest.mock import Mock import pytest - from lightning_utilities.core.overrides import is_overridden @@ -36,14 +36,14 @@ def bar(self): assert is_overridden("training_step", LightningModule(), parent=BoringModel) class WrappedModel(TestModel): - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any): obj = super().__new__(cls) obj.foo = cls.wrap(obj.foo) obj.bar = cls.wrap(obj.bar) return obj @staticmethod - def wrap(fn): + def wrap(fn) -> Callable: @wraps(fn) def wrapper(): fn() diff --git a/tests/unittests/core/test_rank_zero.py b/tests/unittests/core/test_rank_zero.py index 854365a7..523dbd43 100644 --- a/tests/unittests/core/test_rank_zero.py +++ b/tests/unittests/core/test_rank_zero.py @@ -1,5 +1,4 @@ import pytest - from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only diff --git a/tests/unittests/mocks.py b/tests/unittests/mocks.py index e3d7d472..f047085e 100644 --- a/tests/unittests/mocks.py +++ b/tests/unittests/mocks.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import Any, Iterable from lightning_utilities.core.imports import package_available @@ -7,7 +7,7 @@ else: # minimal torch implementation to avoid installing torch in testing CI class TensorMock: - def __init__(self, data): + def __init__(self, data) -> None: self.data = data def __add__(self, other): @@ -32,7 +32,7 @@ def __iter__(self): """Iterate.""" return iter(self.data) - def __repr__(self): + def __repr__(self) -> str: """Return object representation.""" return repr(self.data) @@ -44,15 +44,15 @@ class TorchMock: Tensor = TensorMock @staticmethod - def tensor(data): + def tensor(data: Any) -> TensorMock: return TensorMock(data) @staticmethod - def equal(a, b): + def equal(a: Any, b: Any) -> bool: return a == b @staticmethod - def arange(*args): + def arange(*args: Any) -> TensorMock: return TensorMock(list(range(*args))) torch = TorchMock() diff --git a/tests/unittests/test/test_warnings.py b/tests/unittests/test/test_warnings.py index c280262b..340bd775 100644 --- a/tests/unittests/test/test_warnings.py +++ b/tests/unittests/test/test_warnings.py @@ -2,7 +2,6 @@ from re import escape import pytest - from lightning_utilities.test.warning import no_warning_call