From 4404a960bb034b149f7a691e7fec6e8ca9f31710 Mon Sep 17 00:00:00 2001 From: John Belmonte Date: Fri, 25 Sep 2020 17:47:44 +0900 Subject: [PATCH] add @trio_async_generator --- CHANGELOG.md | 6 +++ README.md | 4 +- setup.py | 9 +++- src/trio_util/__init__.py | 1 + src/trio_util/_trio_async_generator.py | 73 ++++++++++++++++++++++++++ test-requirements.txt | 4 +- tests/test_trio_async_generator.py | 64 ++++++++++++++++++++++ 7 files changed, 155 insertions(+), 6 deletions(-) create mode 100644 src/trio_util/_trio_async_generator.py create mode 100644 tests/test_trio_async_generator.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d1b6177..171f832 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Release history +## trio-util (pending) +### Added +- `@trio_async_generator` is a decorator which adapts a generator containing + Trio constructs for safe use. (Normally, it's not allowed to yield from a + nursery or cancel scope when implementing async generators.) + ## trio-util 0.2.0 (2020-09-09) ### Added - `AsyncValue.wait_value() / wait_transition()` additionally accept a plain diff --git a/README.md b/README.md index 1de9405..ec47021 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,10 @@ including: * `AsyncBool`, `AsyncValue` - value wrappers with the ability to wait for a specific value or transition * `AsyncDictionary` - dictionary with waitable get and pop - * `periodic` - a periodic loop which accounts for its own execution + * `periodic` - periodic loop which accounts for its own execution time + * `@trio_async_generator` - decorator which adapts a generator containing + Trio constructs for safe use * `azip`, `azip_longest` - async zip with parallel iteration * `UnqueuedRepeatedEvent`, `MailboxRepeatedEvent` - if you really, really want to reuse an event diff --git a/setup.py b/setup.py index 981c748..a7c889c 100644 --- a/setup.py +++ b/setup.py @@ -19,8 +19,10 @@ * AsyncBool, AsyncValue - value wrappers with the ability to wait for a specific value or transition * AsyncDictionary - dictionary with waitable get and pop - * periodic - a periodic loop which accounts for its own execution + * periodic - periodic loop which accounts for its own execution time + * @trio_async_generator - decorator which adapts a generator containing + Trio constructs for safe use * azip, azip_longest - async zip with parallel iteration * UnqueuedRepeatedEvent, MailboxRepeatedEvent - if you really, really want to reuse an event @@ -34,7 +36,10 @@ license='MIT', packages=[pkg_name], package_dir={'': 'src'}, - install_requires=['trio >= 0.11.0'], + install_requires=[ + 'async_generator', + 'trio >= 0.11.0' + ], python_requires='>=3.7', classifiers=[ 'Development Status :: 3 - Alpha', diff --git a/src/trio_util/__init__.py b/src/trio_util/__init__.py index a663d20..1466ace 100644 --- a/src/trio_util/__init__.py +++ b/src/trio_util/__init__.py @@ -8,6 +8,7 @@ from ._periodic import periodic from ._repeated_event import UnqueuedRepeatedEvent, MailboxRepeatedEvent from ._task_stats import TaskStats +from ._trio_async_generator import trio_async_generator def _metadata_fix(): # don't do this for Sphinx case because it breaks "bysource" member ordering diff --git a/src/trio_util/_trio_async_generator.py b/src/trio_util/_trio_async_generator.py new file mode 100644 index 0000000..12c4637 --- /dev/null +++ b/src/trio_util/_trio_async_generator.py @@ -0,0 +1,73 @@ +import functools +import sys +from contextlib import asynccontextmanager + +import trio +from async_generator import aclosing + + +def trio_async_generator(wrapped): + """async generator pattern which supports Trio nurseries and cancel scopes + + Decorator which adapts an async generator using Trio constructs for safe use. + (Normally, it's not allowed to yield from a nursery or cancel scope when + implementing async generators.) + + Though the wrapped function is written as a normal async generator, usage + of the wrapper is different: the wrapper is an async context manager + providing the async generator to be iterated. + + Synopsis:: + + >>> @trio_async_generator + >>> async def my_generator(): + >>> # yield values, possibly from a nursery or cancel scope + >>> # ... + >>> + >>> + >>> async with my_generator() as agen: + >>> async for value in agen: + >>> print(value) + + Implementation: "The idea is that instead of pushing and popping the + generator from the stack of the task that's consuming it, you instead run + the generator code as a second task that feeds the consumer task values." + See https://github.com/python-trio/trio/issues/638#issuecomment-431954073 + + ISSUE: pylint is confused by this implementation, and every use will + trigger not-async-context-manager + """ + @asynccontextmanager + @functools.wraps(wrapped) + async def wrapper(*args, **kwargs): + send_channel, receive_channel = trio.open_memory_channel(0) + async with trio.open_nursery() as nursery: + async def adapter(): + async with send_channel, aclosing(wrapped(*args, **kwargs)) as agen: + while True: + try: + # Advance underlying async generator to next yield + value = await agen.__anext__() + except StopAsyncIteration: + break + while True: + try: + # Forward the yielded value into the send channel + try: + await send_channel.send(value) + except trio.BrokenResourceError: + return + break + except BaseException: # pylint: disable=broad-except + # If send_channel.send() raised (e.g. Cancelled), + # throw the raised exception back into the generator, + # and get the next yielded value to forward. + try: + value = await agen.athrow(*sys.exc_info()) + except StopAsyncIteration: + return + + nursery.start_soon(adapter, name=wrapped) + async with receive_channel: + yield receive_channel + return wrapper diff --git a/test-requirements.txt b/test-requirements.txt index 852b9be..402050c 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -5,11 +5,10 @@ # pip-compile --output-file=test-requirements.txt setup.py test-requirements.in # astroid==2.4.1 # via pylint -async-generator==1.10 # via pytest-trio, trio +async-generator==1.10 # via pytest-trio, trio, trio_util (setup.py) attrs==19.3.0 # via outcome, pytest, trio coverage==5.1 # via pytest-cov idna==2.9 # via trio -importlib-metadata==1.6.0 # via pluggy, pytest isort==4.3.21 # via pylint lazy-object-proxy==1.4.3 # via astroid mccabe==0.6.1 # via pylint @@ -34,4 +33,3 @@ typed-ast==1.4.1 # via astroid, mypy typing-extensions==3.7.4.2 # via mypy wcwidth==0.1.9 # via pytest wrapt==1.12.1 # via astroid -zipp==3.1.0 # via importlib-metadata diff --git a/tests/test_trio_async_generator.py b/tests/test_trio_async_generator.py new file mode 100644 index 0000000..22fab56 --- /dev/null +++ b/tests/test_trio_async_generator.py @@ -0,0 +1,64 @@ +from math import inf + +import trio + +from trio_util._trio_async_generator import trio_async_generator + +# pylint: disable=not-async-context-manager + + +@trio_async_generator +async def squares_in_range(start, stop, timeout=inf, max_timeout_count=1): + timeout_count = 0 + for i in range(start, stop): + with trio.move_on_after(timeout) as cancel_scope: + yield i ** 2 + await trio.sleep(0) + if cancel_scope.cancelled_caught: + timeout_count += 1 + if timeout_count == max_timeout_count: + break + + +async def test_trio_agen_full_iteration(): + last = None + async with squares_in_range(0, 50) as squares: + async for square in squares: + last = square + assert last == 49 ** 2 + + +async def test_trio_agen_caller_exits(): + async with squares_in_range(0, 50) as squares: + async for square in squares: + if square >= 400: + return + assert False + + +async def test_trio_agen_caller_cancelled(autojump_clock): + with trio.move_on_after(1): + async with squares_in_range(0, 50) as squares: + async for square in squares: + assert square == 0 + # the sleep will be cancelled by move_on_after above + await trio.sleep(10) + + +async def test_trio_agen_aborts_yield(autojump_clock): + async with squares_in_range(0, 50, timeout=.5, max_timeout_count=1) as squares: + async for square in squares: + assert square == 0 + # timeout in the generator will be triggered and it will abort iteration + await trio.sleep(1) + + +async def test_trio_agen_aborts_yield_and_continues(autojump_clock): + async with squares_in_range(0, 50, timeout=.5, max_timeout_count=99) as squares: + _sum = 0 + async for square in squares: + _sum += square + if square == 5 ** 2: + # this will cause the next iteration (6 ** 2) to time out + await trio.sleep(.6) + assert _sum == sum(i ** 2 for i in range(0, 50)) - 6 ** 2