Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ruff: C4-PT-RET-SIM #117

Merged
merged 3 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ repos:
hooks:
- id: yesqa
additional_dependencies:
- "flake8-docstrings"
- flake8-docstrings
- pep8-naming
- flake8-comprehensions
- flake8-pytest-style
- flake8-return
- flake8-simplify

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.253
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ ignore_missing_imports = true
"D", # see: https://pypi.org/project/pydocstyle
"N", # see: https://pypi.org/project/pep8-naming
]
extend-select = [
"C4", # see: https://pypi.org/project/flake8-comprehensions
"PT", # see: https://pypi.org/project/flake8-pytest-style
"RET", # see: https://pypi.org/project/flake8-return
"SIM", # see: https://pypi.org/project/flake8-simplify
]
ignore = [
"E731",
]
Expand Down
6 changes: 4 additions & 2 deletions src/lightning_utilities/cli/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None:


def _replace_min(fname: str) -> None:
req = open(fname).read().replace(">=", "==")
open(fname, "w").write(req)
with open(fname) as fo:
req = fo.read().replace(">=", "==")
with open(fname, "w") as fw:
fw.write(req)


def replace_oldest_ver(req_files: Sequence[str] = REQUIREMENT_FILES_ALL) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_utilities/core/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def apply_to_collections(
"""
if data1 is None:
if data2 is None:
return
return None
# in case they were passed reversed
data1, data2 = data2, None

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_utilities/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") ->
if requested string does not match any option based on selected source.
"""
if source in ("key", "any"):
for enum_key in cls.__members__.keys():
for enum_key in cls.__members__:
if enum_key.lower() == value.lower():
return cls[enum_key]
if source in ("value", "any"):
Expand Down
12 changes: 7 additions & 5 deletions tests/unittests/core/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __eq__(self, o: object) -> bool:
"""Perform equal operation."""
if not isinstance(o, WithClassVar):
return NotImplemented
elif isinstance(self.dummy, torch.Tensor):
if isinstance(self.dummy, torch.Tensor):
return torch.equal(self.dummy, o.dummy)

return self.dummy == o.dummy
Expand All @@ -71,7 +71,7 @@ def __post_init__(self, override: Optional[Any]): # noqa: D105
def __eq__(self, o: object) -> bool: # noqa: D105
if not isinstance(o, WithInitVar):
return NotImplemented
elif isinstance(self.dummy, torch.Tensor):
if isinstance(self.dummy, torch.Tensor):
return torch.equal(self.dummy, o.dummy)

return self.dummy == o.dummy
Expand All @@ -90,7 +90,7 @@ def __post_init__(self, override: Optional[Any]): # noqa: D105
def __eq__(self, o: object) -> bool: # noqa: D105
if not isinstance(o, WithClassAndInitVar):
return NotImplemented
elif isinstance(self.dummy, torch.Tensor):
if isinstance(self.dummy, torch.Tensor):
return torch.equal(self.dummy, o.dummy)

return self.dummy == o.dummy
Expand Down Expand Up @@ -173,9 +173,10 @@ def test_recursive_application_to_collection():
assert reduced["g"] == expected_result["g"], "Reduction of a number did not yield the desired result"

def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""):
assert dataclasses.is_dataclass(actual) and not isinstance(
actual, type
assert dataclasses.is_dataclass(
actual
), f"Reduction of a {dataclass_type} dataclass should result in a dataclass"
assert not isinstance(actual, type)
for field in dataclasses.fields(actual):
if dataclasses.is_dataclass(field.type):
_assert_dataclass_reduction(getattr(actual, field.name), getattr(expected, field.name), "nested")
Expand Down Expand Up @@ -224,6 +225,7 @@ def test_apply_to_collection_include_none():
def fn(x):
if isinstance(x, float):
return x
return None

reduced = apply_to_collection(to_reduce, (int, float), fn)
assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})]
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/core/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MyEnum(StrEnum):
assert MyEnum.NUM in (32, "32")

# key-based
assert MyEnum.NUM == MyEnum.from_str("num")
assert MyEnum.from_str("num") == MyEnum.NUM

# collections
assert MyEnum.BAZ not in ("FOO", "BAR")
Expand Down
10 changes: 6 additions & 4 deletions tests/unittests/core/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,14 @@ def test_lazy_import():
def callback_fcn():
raise ValueError

with pytest.raises(ValueError):
math = lazy_import("math", callback=callback_fcn)
math = lazy_import("math", callback=callback_fcn)
with pytest.raises(ValueError, match=""): # noqa: PT011
math.floor(5.1)
with pytest.raises(ModuleNotFoundError):
module = lazy_import("asdf")

module = lazy_import("asdf")
with pytest.raises(ModuleNotFoundError, match="No module named 'asdf'"):
print(module)

os = lazy_import("os")
assert os.getcwd()

Expand Down
12 changes: 6 additions & 6 deletions tests/unittests/test/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ def test_no_warning_call():
with no_warning_call():
...

with pytest.raises(AssertionError, match=escape("`Warning` was raised: UserWarning('foo')")):
with no_warning_call():
warnings.warn("foo")
with pytest.raises(AssertionError, match=escape("`Warning` was raised: UserWarning('foo')")), no_warning_call():
warnings.warn("foo")

with no_warning_call(DeprecationWarning):
warnings.warn("foo")

class MyDeprecationWarning(DeprecationWarning):
...

with pytest.raises(AssertionError, match=escape("`DeprecationWarning` was raised: MyDeprecationWarning('bar')")):
with pytest.warns(DeprecationWarning), no_warning_call(DeprecationWarning):
warnings.warn("bar", category=MyDeprecationWarning)
with pytest.raises(
AssertionError, match=escape("`DeprecationWarning` was raised: MyDeprecationWarning('bar')")
), pytest.warns(DeprecationWarning), no_warning_call(DeprecationWarning):
warnings.warn("bar", category=MyDeprecationWarning)