Skip to content

Commit

Permalink
ruff: C4-PT-RET-SIM (#117)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Mar 6, 2023
1 parent 3a51bbc commit c01bf05
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 21 deletions.
7 changes: 6 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ repos:
hooks:
- id: yesqa
additional_dependencies:
- "flake8-docstrings"
- flake8-docstrings
- pep8-naming
- flake8-comprehensions
- flake8-pytest-style
- flake8-return
- flake8-simplify

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.253
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ ignore_missing_imports = true
"D", # see: https://pypi.org/project/pydocstyle
"N", # see: https://pypi.org/project/pep8-naming
]
extend-select = [
"C4", # see: https://pypi.org/project/flake8-comprehensions
"PT", # see: https://pypi.org/project/flake8-pytest-style
"RET", # see: https://pypi.org/project/flake8-return
"SIM", # see: https://pypi.org/project/flake8-simplify
]
ignore = [
"E731",
]
Expand Down
6 changes: 4 additions & 2 deletions src/lightning_utilities/cli/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None:


def _replace_min(fname: str) -> None:
req = open(fname).read().replace(">=", "==")
open(fname, "w").write(req)
with open(fname) as fo:
req = fo.read().replace(">=", "==")
with open(fname, "w") as fw:
fw.write(req)


def replace_oldest_ver(req_files: Sequence[str] = REQUIREMENT_FILES_ALL) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_utilities/core/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def apply_to_collections(
"""
if data1 is None:
if data2 is None:
return
return None
# in case they were passed reversed
data1, data2 = data2, None

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_utilities/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") ->
if requested string does not match any option based on selected source.
"""
if source in ("key", "any"):
for enum_key in cls.__members__.keys():
for enum_key in cls.__members__:
if enum_key.lower() == value.lower():
return cls[enum_key]
if source in ("value", "any"):
Expand Down
12 changes: 7 additions & 5 deletions tests/unittests/core/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __eq__(self, o: object) -> bool:
"""Perform equal operation."""
if not isinstance(o, WithClassVar):
return NotImplemented
elif isinstance(self.dummy, torch.Tensor):
if isinstance(self.dummy, torch.Tensor):
return torch.equal(self.dummy, o.dummy)

return self.dummy == o.dummy
Expand All @@ -71,7 +71,7 @@ def __post_init__(self, override: Optional[Any]): # noqa: D105
def __eq__(self, o: object) -> bool: # noqa: D105
if not isinstance(o, WithInitVar):
return NotImplemented
elif isinstance(self.dummy, torch.Tensor):
if isinstance(self.dummy, torch.Tensor):
return torch.equal(self.dummy, o.dummy)

return self.dummy == o.dummy
Expand All @@ -90,7 +90,7 @@ def __post_init__(self, override: Optional[Any]): # noqa: D105
def __eq__(self, o: object) -> bool: # noqa: D105
if not isinstance(o, WithClassAndInitVar):
return NotImplemented
elif isinstance(self.dummy, torch.Tensor):
if isinstance(self.dummy, torch.Tensor):
return torch.equal(self.dummy, o.dummy)

return self.dummy == o.dummy
Expand Down Expand Up @@ -173,9 +173,10 @@ def test_recursive_application_to_collection():
assert reduced["g"] == expected_result["g"], "Reduction of a number did not yield the desired result"

def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""):
assert dataclasses.is_dataclass(actual) and not isinstance(
actual, type
assert dataclasses.is_dataclass(
actual
), f"Reduction of a {dataclass_type} dataclass should result in a dataclass"
assert not isinstance(actual, type)
for field in dataclasses.fields(actual):
if dataclasses.is_dataclass(field.type):
_assert_dataclass_reduction(getattr(actual, field.name), getattr(expected, field.name), "nested")
Expand Down Expand Up @@ -224,6 +225,7 @@ def test_apply_to_collection_include_none():
def fn(x):
if isinstance(x, float):
return x
return None

reduced = apply_to_collection(to_reduce, (int, float), fn)
assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})]
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/core/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MyEnum(StrEnum):
assert MyEnum.NUM in (32, "32")

# key-based
assert MyEnum.NUM == MyEnum.from_str("num")
assert MyEnum.from_str("num") == MyEnum.NUM

# collections
assert MyEnum.BAZ not in ("FOO", "BAR")
Expand Down
10 changes: 6 additions & 4 deletions tests/unittests/core/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,14 @@ def test_lazy_import():
def callback_fcn():
raise ValueError

with pytest.raises(ValueError):
math = lazy_import("math", callback=callback_fcn)
math = lazy_import("math", callback=callback_fcn)
with pytest.raises(ValueError, match=""): # noqa: PT011
math.floor(5.1)
with pytest.raises(ModuleNotFoundError):
module = lazy_import("asdf")

module = lazy_import("asdf")
with pytest.raises(ModuleNotFoundError, match="No module named 'asdf'"):
print(module)

os = lazy_import("os")
assert os.getcwd()

Expand Down
12 changes: 6 additions & 6 deletions tests/unittests/test/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ def test_no_warning_call():
with no_warning_call():
...

with pytest.raises(AssertionError, match=escape("`Warning` was raised: UserWarning('foo')")):
with no_warning_call():
warnings.warn("foo")
with pytest.raises(AssertionError, match=escape("`Warning` was raised: UserWarning('foo')")), no_warning_call():
warnings.warn("foo")

with no_warning_call(DeprecationWarning):
warnings.warn("foo")

class MyDeprecationWarning(DeprecationWarning):
...

with pytest.raises(AssertionError, match=escape("`DeprecationWarning` was raised: MyDeprecationWarning('bar')")):
with pytest.warns(DeprecationWarning), no_warning_call(DeprecationWarning):
warnings.warn("bar", category=MyDeprecationWarning)
with pytest.raises(
AssertionError, match=escape("`DeprecationWarning` was raised: MyDeprecationWarning('bar')")
), pytest.warns(DeprecationWarning), no_warning_call(DeprecationWarning):
warnings.warn("bar", category=MyDeprecationWarning)

0 comments on commit c01bf05

Please sign in to comment.