From e5b1e21e6f0a8118e61edd259d6bddc3c7e0805a Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Mar 2023 15:07:19 +0100 Subject: [PATCH 1/3] ruff: C4-PT-RET-SIM --- .pre-commit-config.yaml | 5 +++++ pyproject.toml | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5bac87fb..5e21d3c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -54,6 +54,11 @@ repos: - id: yesqa additional_dependencies: - "flake8-docstrings" + - "pep8-naming" + #- flake8-comprehensions + #- flake8-pytest-style + #- flake8-return + #- flake8-simplify - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.0.253 diff --git a/pyproject.toml b/pyproject.toml index 48a45a4b..4b8dadb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,12 @@ ignore_missing_imports = true "D", # see: https://pypi.org/project/pydocstyle "N", # see: https://pypi.org/project/pep8-naming ] +extend-select = [ + "C4", # see: https://pypi.org/project/flake8-comprehensions + "PT", # see: https://pypi.org/project/flake8-pytest-style + "RET", # see: https://pypi.org/project/flake8-return + "SIM", # see: https://pypi.org/project/flake8-simplify +] ignore = [ "E731", ] From 2f631526508f18ce4fb5aef8c4e083f94f55bb3c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Mar 2023 14:08:10 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_utilities/core/apply_func.py | 2 +- src/lightning_utilities/core/enums.py | 2 +- tests/unittests/core/test_apply_func.py | 1 + tests/unittests/core/test_enums.py | 2 +- tests/unittests/test/test_warnings.py | 5 ++--- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lightning_utilities/core/apply_func.py b/src/lightning_utilities/core/apply_func.py index 2815f71e..5302ed2e 100644 --- a/src/lightning_utilities/core/apply_func.py +++ b/src/lightning_utilities/core/apply_func.py @@ -163,7 +163,7 @@ def apply_to_collections( """ if data1 is None: if data2 is None: - return + return None # in case they were passed reversed data1, data2 = data2, None diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index 17a96bbf..7d3445ed 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -44,7 +44,7 @@ def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> if requested string does not match any option based on selected source. """ if source in ("key", "any"): - for enum_key in cls.__members__.keys(): + for enum_key in cls.__members__: if enum_key.lower() == value.lower(): return cls[enum_key] if source in ("value", "any"): diff --git a/tests/unittests/core/test_apply_func.py b/tests/unittests/core/test_apply_func.py index 796048e1..ccfc7723 100644 --- a/tests/unittests/core/test_apply_func.py +++ b/tests/unittests/core/test_apply_func.py @@ -224,6 +224,7 @@ def test_apply_to_collection_include_none(): def fn(x): if isinstance(x, float): return x + return None reduced = apply_to_collection(to_reduce, (int, float), fn) assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})] diff --git a/tests/unittests/core/test_enums.py b/tests/unittests/core/test_enums.py index 3bd3638c..36871592 100644 --- a/tests/unittests/core/test_enums.py +++ b/tests/unittests/core/test_enums.py @@ -19,7 +19,7 @@ class MyEnum(StrEnum): assert MyEnum.NUM in (32, "32") # key-based - assert MyEnum.NUM == MyEnum.from_str("num") + assert MyEnum.from_str("num") == MyEnum.NUM # collections assert MyEnum.BAZ not in ("FOO", "BAR") diff --git a/tests/unittests/test/test_warnings.py b/tests/unittests/test/test_warnings.py index 3b5954ed..6db627dd 100644 --- a/tests/unittests/test/test_warnings.py +++ b/tests/unittests/test/test_warnings.py @@ -10,9 +10,8 @@ def test_no_warning_call(): with no_warning_call(): ... - with pytest.raises(AssertionError, match=escape("`Warning` was raised: UserWarning('foo')")): - with no_warning_call(): - warnings.warn("foo") + with pytest.raises(AssertionError, match=escape("`Warning` was raised: UserWarning('foo')")), no_warning_call(): + warnings.warn("foo") with no_warning_call(DeprecationWarning): warnings.warn("foo") From d99aad5f49473cba4b88fceb1f48c8966bf8fdad Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Mar 2023 15:15:33 +0100 Subject: [PATCH 3/3] fixing --- .pre-commit-config.yaml | 12 ++++++------ src/lightning_utilities/cli/dependencies.py | 6 ++++-- tests/unittests/core/test_apply_func.py | 11 ++++++----- tests/unittests/core/test_imports.py | 10 ++++++---- tests/unittests/test/test_warnings.py | 7 ++++--- 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e21d3c1..7cf96ff0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,12 +53,12 @@ repos: hooks: - id: yesqa additional_dependencies: - - "flake8-docstrings" - - "pep8-naming" - #- flake8-comprehensions - #- flake8-pytest-style - #- flake8-return - #- flake8-simplify + - flake8-docstrings + - pep8-naming + - flake8-comprehensions + - flake8-pytest-style + - flake8-return + - flake8-simplify - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.0.253 diff --git a/src/lightning_utilities/cli/dependencies.py b/src/lightning_utilities/cli/dependencies.py index ec45ec72..7bb2c901 100644 --- a/src/lightning_utilities/cli/dependencies.py +++ b/src/lightning_utilities/cli/dependencies.py @@ -38,8 +38,10 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None: def _replace_min(fname: str) -> None: - req = open(fname).read().replace(">=", "==") - open(fname, "w").write(req) + with open(fname) as fo: + req = fo.read().replace(">=", "==") + with open(fname, "w") as fw: + fw.write(req) def replace_oldest_ver(req_files: Sequence[str] = REQUIREMENT_FILES_ALL) -> None: diff --git a/tests/unittests/core/test_apply_func.py b/tests/unittests/core/test_apply_func.py index ccfc7723..da74e199 100644 --- a/tests/unittests/core/test_apply_func.py +++ b/tests/unittests/core/test_apply_func.py @@ -53,7 +53,7 @@ def __eq__(self, o: object) -> bool: """Perform equal operation.""" if not isinstance(o, WithClassVar): return NotImplemented - elif isinstance(self.dummy, torch.Tensor): + if isinstance(self.dummy, torch.Tensor): return torch.equal(self.dummy, o.dummy) return self.dummy == o.dummy @@ -71,7 +71,7 @@ def __post_init__(self, override: Optional[Any]): # noqa: D105 def __eq__(self, o: object) -> bool: # noqa: D105 if not isinstance(o, WithInitVar): return NotImplemented - elif isinstance(self.dummy, torch.Tensor): + if isinstance(self.dummy, torch.Tensor): return torch.equal(self.dummy, o.dummy) return self.dummy == o.dummy @@ -90,7 +90,7 @@ def __post_init__(self, override: Optional[Any]): # noqa: D105 def __eq__(self, o: object) -> bool: # noqa: D105 if not isinstance(o, WithClassAndInitVar): return NotImplemented - elif isinstance(self.dummy, torch.Tensor): + if isinstance(self.dummy, torch.Tensor): return torch.equal(self.dummy, o.dummy) return self.dummy == o.dummy @@ -173,9 +173,10 @@ def test_recursive_application_to_collection(): assert reduced["g"] == expected_result["g"], "Reduction of a number did not yield the desired result" def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""): - assert dataclasses.is_dataclass(actual) and not isinstance( - actual, type + assert dataclasses.is_dataclass( + actual ), f"Reduction of a {dataclass_type} dataclass should result in a dataclass" + assert not isinstance(actual, type) for field in dataclasses.fields(actual): if dataclasses.is_dataclass(field.type): _assert_dataclass_reduction(getattr(actual, field.name), getattr(expected, field.name), "nested") diff --git a/tests/unittests/core/test_imports.py b/tests/unittests/core/test_imports.py index 059fb9aa..b5ff6cec 100644 --- a/tests/unittests/core/test_imports.py +++ b/tests/unittests/core/test_imports.py @@ -85,12 +85,14 @@ def test_lazy_import(): def callback_fcn(): raise ValueError - with pytest.raises(ValueError): - math = lazy_import("math", callback=callback_fcn) + math = lazy_import("math", callback=callback_fcn) + with pytest.raises(ValueError, match=""): # noqa: PT011 math.floor(5.1) - with pytest.raises(ModuleNotFoundError): - module = lazy_import("asdf") + + module = lazy_import("asdf") + with pytest.raises(ModuleNotFoundError, match="No module named 'asdf'"): print(module) + os = lazy_import("os") assert os.getcwd() diff --git a/tests/unittests/test/test_warnings.py b/tests/unittests/test/test_warnings.py index 6db627dd..c280262b 100644 --- a/tests/unittests/test/test_warnings.py +++ b/tests/unittests/test/test_warnings.py @@ -19,6 +19,7 @@ def test_no_warning_call(): class MyDeprecationWarning(DeprecationWarning): ... - with pytest.raises(AssertionError, match=escape("`DeprecationWarning` was raised: MyDeprecationWarning('bar')")): - with pytest.warns(DeprecationWarning), no_warning_call(DeprecationWarning): - warnings.warn("bar", category=MyDeprecationWarning) + with pytest.raises( + AssertionError, match=escape("`DeprecationWarning` was raised: MyDeprecationWarning('bar')") + ), pytest.warns(DeprecationWarning), no_warning_call(DeprecationWarning): + warnings.warn("bar", category=MyDeprecationWarning)