Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new error code: TRIO912 unnecessary checkpoints + autofix #183

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 108 additions & 2 deletions flake8_trio/visitors/visitor91x.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class LoopState:
nodes_needing_checkpoints: list[cst.Return | cst.Yield] = field(
default_factory=list
)
possibly_redundant_lowlevel_checkpoints: list[cst.BaseExpression] = field(
default_factory=list
)

def copy(self):
return LoopState(
Expand All @@ -66,6 +69,7 @@ def copy(self):
uncheckpointed_before_break=self.uncheckpointed_before_break.copy(),
artificial_errors=self.artificial_errors.copy(),
nodes_needing_checkpoints=self.nodes_needing_checkpoints.copy(),
possibly_redundant_lowlevel_checkpoints=self.possibly_redundant_lowlevel_checkpoints.copy(),
)


Expand Down Expand Up @@ -214,6 +218,22 @@ def leave_Yield(
leave_Return = leave_Yield # type: ignore


# class RemoveLowlevelCheckpoints(cst.CSTTransformer):
# def __init__(self, stmts_to_remove: set[cst.Await]):
# self.stmts_to_remove = stmts_to_remove
#
# def leave_Await(self, original_node: cst.Await, updated_node: cst.Await) -> cst.Await:
# # return original node to preserve identity
# return original_node
#
# # for some reason you can't just return RemovalSentinel from Await, so we have to
# # visit the possible wrappers and modify their bodies instead
#
# def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
# new_body = [stmt for stmt in updated_node.body.body if stmt not in self.stmts_to_remove]
# return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body))


@error_class_cst
@disabled_by_default
class Visitor91X(Flake8TrioVisitor_cst, CommonVisitors):
Expand All @@ -226,16 +246,27 @@ class Visitor91X(Flake8TrioVisitor_cst, CommonVisitors):
"{0} from async iterable with no guaranteed checkpoint since {1.name} "
"on line {1.lineno}."
),
"TRIO912": "Redundant checkpoint with no effect on program execution.",
}

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.has_yield = False
self.safe_decorator = False
self.async_function = False
self.uncheckpointed_statements: set[Statement] = set()
self.comp_unknown = False

self.uncheckpointed_statements: set[Statement] = set()
self.checkpointed_by_lowlevel = False

# value == False, not redundant (or not determined to be redundant yet)
# value == True, there were no uncheckpointed statements when we encountered it
# value = expr/stmt, made redundant by the given expr/stmt
self.lowlevel_checkpoints: dict[
cst.Await, cst.BaseStatement | cst.BaseExpression | bool
] = {}
self.lowlevel_checkpoint_updated_nodes: dict[cst.Await, cst.Await] = {}

self.loop_state = LoopState()
self.try_state = TryState()

Expand All @@ -258,6 +289,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
"safe_decorator",
"async_function",
"uncheckpointed_statements",
"lowlevel_checkpoints",
"loop_state",
"try_state",
copy=True,
Expand Down Expand Up @@ -299,8 +331,31 @@ def leave_FunctionDef(
indentedblock = updated_node.body.with_changes(body=new_body)
updated_node = updated_node.with_changes(body=indentedblock)

res: cst.FunctionDef = updated_node
to_remove: set[cst.Await] = set()
for expr, value in self.lowlevel_checkpoints.items():
if value != False:
self.error(expr, error_code="TRIO912")
if self.should_autofix():
to_remove.add(self.lowlevel_checkpoint_updated_nodes.pop(expr))

if to_remove:
new_body = []
for stmt in updated_node.body.body:
if not m.matches(
stmt,
m.SimpleStatementLine(
[m.Expr(m.MatchIfTrue(lambda x: x in to_remove))]
),
):
new_body.append(stmt) # type: ignore
assert new_body != updated_node.body.body
res = updated_node.with_changes(
body=updated_node.body.with_changes(body=new_body)
)

self.restore_state(original_node)
return updated_node # noqa: R504
return res

# error if function exit/return/yields with uncheckpointed statements
# returns a bool indicating if any real (i.e. not artificial) errors were raised
Expand Down Expand Up @@ -372,12 +427,48 @@ def error_91x(
error_code="TRIO911" if self.has_yield else "TRIO910",
)

def is_lowlevel_checkpoint(self, node: cst.BaseExpression) -> bool:
# TODO: match against both libraries if both are imported
return m.matches(
node,
m.Call(
m.Attribute(
m.Attribute(m.Name(self.library[0]), m.Name("lowlevel")),
m.Name("checkpoint"),
)
),
)

def visit_Await(self, node: cst.Await) -> None:
# do a match against the awaited expr
# if that is trio.lowlevel.checkpoint, and uncheckpointed statements
# are empty, raise TRIO912.
if self.is_lowlevel_checkpoint(node.expression):
if not self.uncheckpointed_statements:
self.lowlevel_checkpoints[node] = True
elif self.uncheckpointed_statements == {ARTIFICIAL_STATEMENT}:
self.loop_state.possibly_redundant_lowlevel_checkpoints.append(node)
else:
self.lowlevel_checkpoints[node] = False
# if trio.lowlevel.checkpoint and *not* empty, take note of it in a special list.
elif not self.uncheckpointed_statements:
for expr, value in self.lowlevel_checkpoints.items():
if value == False:
self.lowlevel_checkpoints[expr] = node

# if this is not a trio.lowlevel.checkpoint, and there are no uncheckpointed statements, check if there is a lowlevel checkpoint in the special list. If so, raise a TRIO912 for it and remove it.

def leave_Await(
self, original_node: cst.Await, updated_node: cst.Await
) -> cst.Await:
# the expression being awaited is not checkpointed
# so only set checkpoint after the await node

# TODO: dirty hack to get identity right, the logic in visit should maybe be
# moved/split into the leave
if original_node in self.lowlevel_checkpoints:
self.lowlevel_checkpoint_updated_nodes[original_node] = updated_node

# all nodes are now checkpointed
self.uncheckpointed_statements = set()
return updated_node
Expand Down Expand Up @@ -494,6 +585,10 @@ def leave_Try(self, original_node: cst.Try, updated_node: cst.Try) -> cst.Try:
self.restore_state(original_node)
return updated_node

# if a previous lowlevel checkpoint is marked as redundant after all bodies, then
# it's redundant.
# If any body marks it as necessary, then it's necessary.
# Otherwise, it keeps it's state from before.
def leave_If_test(self, node: cst.If | cst.IfExp) -> None:
if not self.async_function:
return
Expand Down Expand Up @@ -604,6 +699,11 @@ def leave_While_body(self, node: cst.For | cst.While):
if not any_error:
self.loop_state.nodes_needing_checkpoints = []

# but lowlevel checkpoints are redundant
for expr in self.loop_state.possibly_redundant_lowlevel_checkpoints:
self.error(expr, error_code="TRIO912")
# self.possibly_redundant_lowlevel_checkpoints.clear()

# replace artificial statements in else with prebody uncheckpointed statements
# non-artificial stmts before continue/break/at body end will already be in them
for stmts in (
Expand Down Expand Up @@ -654,6 +754,12 @@ def leave_While_orelse(self, node: cst.For | cst.While):
# reset break & continue in case of nested loops
self.outer[node]["uncheckpointed_statements"] = self.uncheckpointed_statements

# TODO: if this loop always checkpoints
# e.g. from being an async for, or being guaranteed to run once, or other stuff.
# then we can warn about redundant checkpoints before the loop.
# ... except if the reason we always checkpoint is due to redundant checkpoints
# we're about to remove.... :thinking:

leave_For_orelse = leave_While_orelse

def leave_While(
Expand Down
Loading