diff --git a/flake8_async/visitors/visitors.py b/flake8_async/visitors/visitors.py index 07ceaab..060c020 100644 --- a/flake8_async/visitors/visitors.py +++ b/flake8_async/visitors/visitors.py @@ -104,17 +104,29 @@ def visit_With(self, node: ast.With | ast.AsyncWith): continue var_name = item.optional_vars.id - # check for trio.open_nursery + # check for trio.open_nursery and anyio.create_task_group nursery = get_matching_call( - item.context_expr, "open_nursery", base=("trio",) - ) + item.context_expr, "open_nursery", base="trio" + ) or get_matching_call(item.context_expr, "create_task_group", base="anyio") + start_methods: tuple[str, ...] = ("start", "start_soon") + if nursery is None: + # check for asyncio.TaskGroup + nursery = get_matching_call( + item.context_expr, "TaskGroup", base="asyncio" + ) + if nursery is None: + continue + start_methods = ("create_task",) + + body_call = node.body[0].value + if isinstance(body_call, ast.Await): + body_call = body_call.value # `isinstance(..., ast.Call)` is done in get_matching_call - body_call = cast("ast.Call", node.body[0].value) + body_call = cast("ast.Call", body_call) if ( - nursery is not None - and get_matching_call(body_call, "start", "start_soon", base=var_name) + get_matching_call(body_call, *start_methods, base=var_name) # check for presence of as parameter and not any( (isinstance(n, ast.Name) and n.id == var_name) diff --git a/tests/eval_files/async112.py b/tests/eval_files/async112.py index ce4ad93..2837f70 100644 --- a/tests/eval_files/async112.py +++ b/tests/eval_files/async112.py @@ -1,7 +1,7 @@ # type: ignore # ASYNC112: Nursery body with only a call to nursery.start[_soon] and not passing itself as a parameter can be replaced with a regular function call. -# ASYNCIO_NO_ERROR - # TODO: expand check to work with asyncio.TaskGroup -# ANYIO_NO_ERROR - # TODO: expand check to work with anyio.TaskGroup +# ASYNCIO_NO_ERROR +# ANYIO_NO_ERROR import functools from functools import partial @@ -81,9 +81,9 @@ async def foo(): n.start_soon(lambda n: n + 1) -# body isn't a call to n.start +# body is a call to await n.start async def foo_1(): - with trio.open_nursery(...) as n: + with trio.open_nursery(...) as n: # error: 9, "n" await n.start(...) diff --git a/tests/eval_files/async112_anyio.py b/tests/eval_files/async112_anyio.py new file mode 100644 index 0000000..ec5eb7c --- /dev/null +++ b/tests/eval_files/async112_anyio.py @@ -0,0 +1,28 @@ +# main tests in async112.py +# this only tests anyio.create_task_group in particular +# BASE_LIBRARY anyio +# ASYNCIO_NO_ERROR +# TRIO_NO_ERROR + +import anyio + + +async def bar(*args): ... + + +async def foo(): + async with anyio.create_task_group() as tg: # error: 15, "tg" + await tg.start_soon(bar()) + + async with anyio.create_task_group() as tg: + await tg.start(bar(tg)) + + async with anyio.create_task_group() as tg: # error: 15, "tg" + tg.start_soon(bar()) + + async with anyio.create_task_group() as tg: + tg.start_soon(bar(tg)) + + # will not trigger on create_task + async with anyio.create_task_group() as tg: + tg.create_task(bar()) # type: ignore[attr-defined] diff --git a/tests/eval_files/async112_asyncio.py b/tests/eval_files/async112_asyncio.py new file mode 100644 index 0000000..5d620a9 --- /dev/null +++ b/tests/eval_files/async112_asyncio.py @@ -0,0 +1,24 @@ +# main tests in async112.py +# this only tests asyncio.TaskGroup in particular +# BASE_LIBRARY asyncio +# ANYIO_NO_ERROR +# TRIO_NO_ERROR +# TaskGroup introduced in 3.11, we run typechecks with 3.9 +# mypy: disable-error-code=attr-defined + +import asyncio + + +async def bar(*args): ... + + +async def foo(): + async with asyncio.TaskGroup() as tg: # error: 15, "tg" + tg.create_task(bar()) + + async with asyncio.TaskGroup() as tg: + tg.create_task(bar(tg)) + + # will not trigger on start / start_soon + async with asyncio.TaskGroup() as tg: + tg.start(bar())