Skip to content

Commit

Permalink
add asyncio/anyio taskgroup support to async112
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed May 1, 2024
1 parent 60ac4ca commit 8e2aeed
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 10 deletions.
24 changes: 18 additions & 6 deletions flake8_async/visitors/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,29 @@ def visit_With(self, node: ast.With | ast.AsyncWith):
continue
var_name = item.optional_vars.id

# check for trio.open_nursery
# check for trio.open_nursery and anyio.create_task_group
nursery = get_matching_call(
item.context_expr, "open_nursery", base=("trio",)
)
item.context_expr, "open_nursery", base="trio"
) or get_matching_call(item.context_expr, "create_task_group", base="anyio")
start_methods: tuple[str, ...] = ("start", "start_soon")
if nursery is None:
# check for asyncio.TaskGroup
nursery = get_matching_call(
item.context_expr, "TaskGroup", base="asyncio"
)
if nursery is None:
continue
start_methods = ("create_task",)

body_call = node.body[0].value
if isinstance(body_call, ast.Await):
body_call = body_call.value

# `isinstance(..., ast.Call)` is done in get_matching_call
body_call = cast("ast.Call", node.body[0].value)
body_call = cast("ast.Call", body_call)

if (
nursery is not None
and get_matching_call(body_call, "start", "start_soon", base=var_name)
get_matching_call(body_call, *start_methods, base=var_name)
# check for presence of <X> as parameter
and not any(
(isinstance(n, ast.Name) and n.id == var_name)
Expand Down
8 changes: 4 additions & 4 deletions tests/eval_files/async112.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# type: ignore
# ASYNC112: Nursery body with only a call to nursery.start[_soon] and not passing itself as a parameter can be replaced with a regular function call.
# ASYNCIO_NO_ERROR - # TODO: expand check to work with asyncio.TaskGroup
# ANYIO_NO_ERROR - # TODO: expand check to work with anyio.TaskGroup
# ASYNCIO_NO_ERROR
# ANYIO_NO_ERROR
import functools
from functools import partial

Expand Down Expand Up @@ -81,9 +81,9 @@ async def foo():
n.start_soon(lambda n: n + 1)


# body isn't a call to n.start
# body is a call to await n.start
async def foo_1():
with trio.open_nursery(...) as n:
with trio.open_nursery(...) as n: # error: 9, "n"
await n.start(...)


Expand Down
28 changes: 28 additions & 0 deletions tests/eval_files/async112_anyio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# main tests in async112.py
# this only tests anyio.create_task_group in particular
# BASE_LIBRARY anyio
# ASYNCIO_NO_ERROR
# TRIO_NO_ERROR

import anyio


async def bar(*args): ...


async def foo():
async with anyio.create_task_group() as tg: # error: 15, "tg"
await tg.start_soon(bar())

async with anyio.create_task_group() as tg:
await tg.start(bar(tg))

async with anyio.create_task_group() as tg: # error: 15, "tg"
tg.start_soon(bar())

async with anyio.create_task_group() as tg:
tg.start_soon(bar(tg))

# will not trigger on create_task
async with anyio.create_task_group() as tg:
tg.create_task(bar()) # type: ignore[attr-defined]
24 changes: 24 additions & 0 deletions tests/eval_files/async112_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# main tests in async112.py
# this only tests asyncio.TaskGroup in particular
# BASE_LIBRARY asyncio
# ANYIO_NO_ERROR
# TRIO_NO_ERROR
# TaskGroup introduced in 3.11, we run typechecks with 3.9
# mypy: disable-error-code=attr-defined

import asyncio


async def bar(*args): ...


async def foo():
async with asyncio.TaskGroup() as tg: # error: 15, "tg"
tg.create_task(bar())

async with asyncio.TaskGroup() as tg:
tg.create_task(bar(tg))

# will not trigger on start / start_soon
async with asyncio.TaskGroup() as tg:
tg.start(bar())

0 comments on commit 8e2aeed

Please sign in to comment.