From 602630ac1855e38ef06361c68f6e216375a06180 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 15 Feb 2022 15:42:04 -0800 Subject: [PATCH] bpo-46752: Add TaskGroup; add Task..cancelled(),.uncancel() (GH-31270) asyncio/taskgroups.py is an adaptation of taskgroup.py from EdgeDb, with the following key changes: - Allow creating new tasks as long as the last task hasn't finished - Raise [Base]ExceptionGroup (directly) rather than TaskGroupError deriving from MultiError - Instead of monkey-patching the parent task's cancel() method, add a new public API to Task The Task class has a new internal flag, `_cancel_requested`, which is set when `.cancel()` is called successfully. The `.cancelling()` method returns the value of this flag. Further `.cancel()` calls while this flag is set return False. To reset this flag, call `.uncancel()`. Thus, a Task that catches and ignores `CancelledError` should call `.uncancel()` if it wants to be cancellable again; until it does so, it is deemed to be busy with uninterruptible cleanup. This new Task API helps solve the problem where TaskGroup needs to distinguish between whether the parent task being cancelled "from the outside" vs. "from inside". Co-authored-by: Yury Selivanov Co-authored-by: Andrew Svetlov --- Lib/asyncio/__init__.py | 1 + Lib/asyncio/base_tasks.py | 2 +- Lib/asyncio/taskgroups.py | 235 ++++++ Lib/asyncio/tasks.py | 16 +- Lib/test/test_asyncio/test_taskgroups.py | 694 ++++++++++++++++++ Lib/test/test_asyncio/test_tasks.py | 45 ++ .../2022-02-14-21-21-49.bpo-46752.m6ldTm.rst | 2 + Modules/_asynciomodule.c | 59 ++ Modules/clinic/_asynciomodule.c.h | 49 +- 9 files changed, 1100 insertions(+), 3 deletions(-) create mode 100644 Lib/asyncio/taskgroups.py create mode 100644 Lib/test/test_asyncio/test_taskgroups.py create mode 100644 Misc/NEWS.d/next/Library/2022-02-14-21-21-49.bpo-46752.m6ldTm.rst diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index 200b14c2a3f21e..db1124cc9bd1ee 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -17,6 +17,7 @@ from .streams import * from .subprocess import * from .tasks import * +from .taskgroups import * from .threads import * from .transports import * diff --git a/Lib/asyncio/base_tasks.py b/Lib/asyncio/base_tasks.py index 09bb171a2ce750..1d623899f69a9d 100644 --- a/Lib/asyncio/base_tasks.py +++ b/Lib/asyncio/base_tasks.py @@ -8,7 +8,7 @@ def _task_repr_info(task): info = base_futures._future_repr_info(task) - if task._must_cancel: + if task.cancelling() and not task.done(): # replace status info[0] = 'cancelling' diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py new file mode 100644 index 00000000000000..718277892c51c9 --- /dev/null +++ b/Lib/asyncio/taskgroups.py @@ -0,0 +1,235 @@ +# Adapted with permission from the EdgeDB project. + + +__all__ = ["TaskGroup"] + +import itertools +import textwrap +import traceback +import types +import weakref + +from . import events +from . import exceptions +from . import tasks + +class TaskGroup: + + def __init__(self, *, name=None): + if name is None: + self._name = f'tg-{_name_counter()}' + else: + self._name = str(name) + + self._entered = False + self._exiting = False + self._aborting = False + self._loop = None + self._parent_task = None + self._parent_cancel_requested = False + self._tasks = weakref.WeakSet() + self._unfinished_tasks = 0 + self._errors = [] + self._base_error = None + self._on_completed_fut = None + + def get_name(self): + return self._name + + def __repr__(self): + msg = f' bool: + assert isinstance(exc, BaseException) + return isinstance(exc, (SystemExit, KeyboardInterrupt)) + + def _abort(self): + self._aborting = True + + for t in self._tasks: + if not t.done(): + t.cancel() + + def _on_task_done(self, task): + self._unfinished_tasks -= 1 + assert self._unfinished_tasks >= 0 + + if self._on_completed_fut is not None and not self._unfinished_tasks: + if not self._on_completed_fut.done(): + self._on_completed_fut.set_result(True) + + if task.cancelled(): + return + + exc = task.exception() + if exc is None: + return + + self._errors.append(exc) + if self._is_base_error(exc) and self._base_error is None: + self._base_error = exc + + if self._parent_task.done(): + # Not sure if this case is possible, but we want to handle + # it anyways. + self._loop.call_exception_handler({ + 'message': f'Task {task!r} has errored out but its parent ' + f'task {self._parent_task} is already completed', + 'exception': exc, + 'task': task, + }) + return + + self._abort() + if not self._parent_task.cancelling(): + # If parent task *is not* being cancelled, it means that we want + # to manually cancel it to abort whatever is being run right now + # in the TaskGroup. But we want to mark parent task as + # "not cancelled" later in __aexit__. Example situation that + # we need to handle: + # + # async def foo(): + # try: + # async with TaskGroup() as g: + # g.create_task(crash_soon()) + # await something # <- this needs to be canceled + # # by the TaskGroup, e.g. + # # foo() needs to be cancelled + # except Exception: + # # Ignore any exceptions raised in the TaskGroup + # pass + # await something_else # this line has to be called + # # after TaskGroup is finished. + self._parent_cancel_requested = True + self._parent_task.cancel() + + +_name_counter = itertools.count(1).__next__ diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 2bee5c050ded7d..c11d0daaefea7e 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -105,6 +105,7 @@ def __init__(self, coro, *, loop=None, name=None): else: self._name = str(name) + self._cancel_requested = False self._must_cancel = False self._fut_waiter = None self._coro = coro @@ -201,6 +202,9 @@ def cancel(self, msg=None): self._log_traceback = False if self.done(): return False + if self._cancel_requested: + return False + self._cancel_requested = True if self._fut_waiter is not None: if self._fut_waiter.cancel(msg=msg): # Leave self._fut_waiter; it may be a Task that @@ -212,6 +216,16 @@ def cancel(self, msg=None): self._cancel_message = msg return True + def cancelling(self): + return self._cancel_requested + + def uncancel(self): + if self._cancel_requested: + self._cancel_requested = False + return True + else: + return False + def __step(self, exc=None): if self.done(): raise exceptions.InvalidStateError( @@ -634,7 +648,7 @@ def _ensure_future(coro_or_future, *, loop=None): loop = events._get_event_loop(stacklevel=4) try: return loop.create_task(coro_or_future) - except RuntimeError: + except RuntimeError: if not called_wrap_awaitable: coro_or_future.close() raise diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py new file mode 100644 index 00000000000000..ea6ee2ed43d2f8 --- /dev/null +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -0,0 +1,694 @@ +# Adapted with permission from the EdgeDB project. + + +import asyncio + +from asyncio import taskgroups +import unittest + + +# To prevent a warning "test altered the execution environment" +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class MyExc(Exception): + pass + + +class MyBaseExc(BaseException): + pass + + +def get_error_types(eg): + return {type(exc) for exc in eg.exceptions} + + +class TestTaskGroup(unittest.IsolatedAsyncioTestCase): + + async def test_taskgroup_01(self): + + async def foo1(): + await asyncio.sleep(0.1) + return 42 + + async def foo2(): + await asyncio.sleep(0.2) + return 11 + + async with taskgroups.TaskGroup() as g: + t1 = g.create_task(foo1()) + t2 = g.create_task(foo2()) + + self.assertEqual(t1.result(), 42) + self.assertEqual(t2.result(), 11) + + async def test_taskgroup_02(self): + + async def foo1(): + await asyncio.sleep(0.1) + return 42 + + async def foo2(): + await asyncio.sleep(0.2) + return 11 + + async with taskgroups.TaskGroup() as g: + t1 = g.create_task(foo1()) + await asyncio.sleep(0.15) + t2 = g.create_task(foo2()) + + self.assertEqual(t1.result(), 42) + self.assertEqual(t2.result(), 11) + + async def test_taskgroup_03(self): + + async def foo1(): + await asyncio.sleep(1) + return 42 + + async def foo2(): + await asyncio.sleep(0.2) + return 11 + + async with taskgroups.TaskGroup() as g: + t1 = g.create_task(foo1()) + await asyncio.sleep(0.15) + # cancel t1 explicitly, i.e. everything should continue + # working as expected. + t1.cancel() + + t2 = g.create_task(foo2()) + + self.assertTrue(t1.cancelled()) + self.assertEqual(t2.result(), 11) + + async def test_taskgroup_04(self): + + NUM = 0 + t2_cancel = False + t2 = None + + async def foo1(): + await asyncio.sleep(0.1) + 1 / 0 + + async def foo2(): + nonlocal NUM, t2_cancel + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + t2_cancel = True + raise + NUM += 1 + + async def runner(): + nonlocal NUM, t2 + + async with taskgroups.TaskGroup() as g: + g.create_task(foo1()) + t2 = g.create_task(foo2()) + + NUM += 10 + + with self.assertRaises(ExceptionGroup) as cm: + await asyncio.create_task(runner()) + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + self.assertEqual(NUM, 0) + self.assertTrue(t2_cancel) + self.assertTrue(t2.cancelled()) + + async def test_taskgroup_05(self): + + NUM = 0 + t2_cancel = False + runner_cancel = False + + async def foo1(): + await asyncio.sleep(0.1) + 1 / 0 + + async def foo2(): + nonlocal NUM, t2_cancel + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + t2_cancel = True + raise + NUM += 1 + + async def runner(): + nonlocal NUM, runner_cancel + + async with taskgroups.TaskGroup() as g: + g.create_task(foo1()) + g.create_task(foo1()) + g.create_task(foo1()) + g.create_task(foo2()) + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + runner_cancel = True + raise + + NUM += 10 + + # The 3 foo1 sub tasks can be racy when the host is busy - if the + # cancellation happens in the middle, we'll see partial sub errors here + with self.assertRaises(ExceptionGroup) as cm: + await asyncio.create_task(runner()) + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + self.assertEqual(NUM, 0) + self.assertTrue(t2_cancel) + self.assertTrue(runner_cancel) + + async def test_taskgroup_06(self): + + NUM = 0 + + async def foo(): + nonlocal NUM + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + NUM += 1 + raise + + async def runner(): + async with taskgroups.TaskGroup() as g: + for _ in range(5): + g.create_task(foo()) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + self.assertEqual(NUM, 5) + + async def test_taskgroup_07(self): + + NUM = 0 + + async def foo(): + nonlocal NUM + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + NUM += 1 + raise + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup() as g: + for _ in range(5): + g.create_task(foo()) + + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + self.assertEqual(NUM, 15) + + async def test_taskgroup_08(self): + + async def foo(): + await asyncio.sleep(0.1) + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup() as g: + for _ in range(5): + g.create_task(foo()) + + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + async def test_taskgroup_09(self): + + t1 = t2 = None + + async def foo1(): + await asyncio.sleep(1) + return 42 + + async def foo2(): + await asyncio.sleep(2) + return 11 + + async def runner(): + nonlocal t1, t2 + async with taskgroups.TaskGroup() as g: + t1 = g.create_task(foo1()) + t2 = g.create_task(foo2()) + await asyncio.sleep(0.1) + 1 / 0 + + try: + await runner() + except ExceptionGroup as t: + self.assertEqual(get_error_types(t), {ZeroDivisionError}) + else: + self.fail('ExceptionGroup was not raised') + + self.assertTrue(t1.cancelled()) + self.assertTrue(t2.cancelled()) + + async def test_taskgroup_10(self): + + t1 = t2 = None + + async def foo1(): + await asyncio.sleep(1) + return 42 + + async def foo2(): + await asyncio.sleep(2) + return 11 + + async def runner(): + nonlocal t1, t2 + async with taskgroups.TaskGroup() as g: + t1 = g.create_task(foo1()) + t2 = g.create_task(foo2()) + 1 / 0 + + try: + await runner() + except ExceptionGroup as t: + self.assertEqual(get_error_types(t), {ZeroDivisionError}) + else: + self.fail('ExceptionGroup was not raised') + + self.assertTrue(t1.cancelled()) + self.assertTrue(t2.cancelled()) + + async def test_taskgroup_11(self): + + async def foo(): + await asyncio.sleep(0.1) + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(): + async with taskgroups.TaskGroup() as g2: + for _ in range(5): + g2.create_task(foo()) + + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + async def test_taskgroup_12(self): + + async def foo(): + await asyncio.sleep(0.1) + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup() as g1: + g1.create_task(asyncio.sleep(10)) + + async with taskgroups.TaskGroup() as g2: + for _ in range(5): + g2.create_task(foo()) + + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + async def test_taskgroup_13(self): + + async def crash_after(t): + await asyncio.sleep(t) + raise ValueError(t) + + async def runner(): + async with taskgroups.TaskGroup(name='g1') as g1: + g1.create_task(crash_after(0.1)) + + async with taskgroups.TaskGroup(name='g2') as g2: + g2.create_task(crash_after(0.2)) + + r = asyncio.create_task(runner()) + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ValueError}) + + async def test_taskgroup_14(self): + + async def crash_after(t): + await asyncio.sleep(t) + raise ValueError(t) + + async def runner(): + async with taskgroups.TaskGroup(name='g1') as g1: + g1.create_task(crash_after(10)) + + async with taskgroups.TaskGroup(name='g2') as g2: + g2.create_task(crash_after(0.1)) + + r = asyncio.create_task(runner()) + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) + self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError}) + + async def test_taskgroup_15(self): + + async def crash_soon(): + await asyncio.sleep(0.3) + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(name='g1') as g1: + g1.create_task(crash_soon()) + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + await asyncio.sleep(0.5) + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + async def test_taskgroup_16(self): + + async def crash_soon(): + await asyncio.sleep(0.3) + 1 / 0 + + async def nested_runner(): + async with taskgroups.TaskGroup(name='g1') as g1: + g1.create_task(crash_soon()) + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + await asyncio.sleep(0.5) + raise + + async def runner(): + t = asyncio.create_task(nested_runner()) + await t + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + async def test_taskgroup_17(self): + NUM = 0 + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup(): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + self.assertEqual(NUM, 10) + + async def test_taskgroup_18(self): + NUM = 0 + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup(): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + # This isn't a good idea, but we have to support + # this weird case. + raise MyExc + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + + try: + await r + except ExceptionGroup as t: + self.assertEqual(get_error_types(t),{MyExc}) + else: + self.fail('ExceptionGroup was not raised') + + self.assertEqual(NUM, 10) + + async def test_taskgroup_19(self): + async def crash_soon(): + await asyncio.sleep(0.1) + 1 / 0 + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise MyExc + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(crash_soon()) + await nested() + + r = asyncio.create_task(runner()) + try: + await r + except ExceptionGroup as t: + self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError}) + else: + self.fail('TasgGroupError was not raised') + + async def test_taskgroup_20(self): + async def crash_soon(): + await asyncio.sleep(0.1) + 1 / 0 + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise KeyboardInterrupt + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(crash_soon()) + await nested() + + with self.assertRaises(KeyboardInterrupt): + await runner() + + async def test_taskgroup_20a(self): + async def crash_soon(): + await asyncio.sleep(0.1) + 1 / 0 + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise MyBaseExc + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(crash_soon()) + await nested() + + with self.assertRaises(BaseExceptionGroup) as cm: + await runner() + + self.assertEqual( + get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError} + ) + + async def _test_taskgroup_21(self): + # This test doesn't work as asyncio, currently, doesn't + # correctly propagate KeyboardInterrupt (or SystemExit) -- + # those cause the event loop itself to crash. + # (Compare to the previous (passing) test -- that one raises + # a plain exception but raises KeyboardInterrupt in nested(); + # this test does it the other way around.) + + async def crash_soon(): + await asyncio.sleep(0.1) + raise KeyboardInterrupt + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise TypeError + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(crash_soon()) + await nested() + + with self.assertRaises(KeyboardInterrupt): + await runner() + + async def test_taskgroup_21a(self): + + async def crash_soon(): + await asyncio.sleep(0.1) + raise MyBaseExc + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise TypeError + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(crash_soon()) + await nested() + + with self.assertRaises(BaseExceptionGroup) as cm: + await runner() + + self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError}) + + async def test_taskgroup_22(self): + + async def foo1(): + await asyncio.sleep(1) + return 42 + + async def foo2(): + await asyncio.sleep(2) + return 11 + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(foo1()) + g.create_task(foo2()) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.05) + r.cancel() + + with self.assertRaises(asyncio.CancelledError): + await r + + async def test_taskgroup_23(self): + + async def do_job(delay): + await asyncio.sleep(delay) + + async with taskgroups.TaskGroup() as g: + for count in range(10): + await asyncio.sleep(0.1) + g.create_task(do_job(0.3)) + if count == 5: + self.assertLess(len(g._tasks), 5) + await asyncio.sleep(1.35) + self.assertEqual(len(g._tasks), 0) + + async def test_taskgroup_24(self): + + async def root(g): + await asyncio.sleep(0.1) + g.create_task(coro1(0.1)) + g.create_task(coro1(0.2)) + + async def coro1(delay): + await asyncio.sleep(delay) + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(root(g)) + + await runner() + + async def test_taskgroup_25(self): + nhydras = 0 + + async def hydra(g): + nonlocal nhydras + nhydras += 1 + await asyncio.sleep(0.01) + g.create_task(hydra(g)) + g.create_task(hydra(g)) + + async def hercules(): + while nhydras < 10: + await asyncio.sleep(0.015) + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup() as g: + g.create_task(hydra(g)) + g.create_task(hercules()) + + with self.assertRaises(ExceptionGroup) as cm: + await runner() + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + self.assertGreaterEqual(nhydras, 10) diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index 8c4dceacdeec96..fe6bfb363f1c67 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -496,6 +496,51 @@ async def run(): # This also distinguishes from the initial has_cycle=None. self.assertEqual(has_cycle, False) + + def test_cancelling(self): + loop = asyncio.new_event_loop() + + async def task(): + await asyncio.sleep(10) + + try: + t = self.new_task(loop, task()) + self.assertFalse(t.cancelling()) + self.assertNotIn(" cancelling ", repr(t)) + self.assertTrue(t.cancel()) + self.assertTrue(t.cancelling()) + self.assertIn(" cancelling ", repr(t)) + self.assertFalse(t.cancel()) + + with self.assertRaises(asyncio.CancelledError): + loop.run_until_complete(t) + finally: + loop.close() + + def test_uncancel(self): + loop = asyncio.new_event_loop() + + async def task(): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + asyncio.current_task().uncancel() + await asyncio.sleep(10) + + try: + t = self.new_task(loop, task()) + loop.run_until_complete(asyncio.sleep(0.01)) + self.assertTrue(t.cancel()) # Cancel first sleep + self.assertIn(" cancelling ", repr(t)) + loop.run_until_complete(asyncio.sleep(0.01)) + self.assertNotIn(" cancelling ", repr(t)) # after .uncancel() + self.assertTrue(t.cancel()) # Cancel second sleep + + with self.assertRaises(asyncio.CancelledError): + loop.run_until_complete(t) + finally: + loop.close() + def test_cancel(self): def gen(): diff --git a/Misc/NEWS.d/next/Library/2022-02-14-21-21-49.bpo-46752.m6ldTm.rst b/Misc/NEWS.d/next/Library/2022-02-14-21-21-49.bpo-46752.m6ldTm.rst new file mode 100644 index 00000000000000..f460600c8c9dde --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-02-14-21-21-49.bpo-46752.m6ldTm.rst @@ -0,0 +1,2 @@ +Add task groups to asyncio (structured concurrency, inspired by Trio's nurseries). +This also introduces a change to task cancellation, where a cancelled task can't be cancelled again until it calls .uncancel(). diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index 72dbdb8902f121..6725e2eba79bc2 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -91,6 +91,7 @@ typedef struct { PyObject *task_context; int task_must_cancel; int task_log_destroy_pending; + int task_cancel_requested; } TaskObj; typedef struct { @@ -2039,6 +2040,7 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop, Py_CLEAR(self->task_fut_waiter); self->task_must_cancel = 0; self->task_log_destroy_pending = 1; + self->task_cancel_requested = 0; Py_INCREF(coro); Py_XSETREF(self->task_coro, coro); @@ -2205,6 +2207,11 @@ _asyncio_Task_cancel_impl(TaskObj *self, PyObject *msg) Py_RETURN_FALSE; } + if (self->task_cancel_requested) { + Py_RETURN_FALSE; + } + self->task_cancel_requested = 1; + if (self->task_fut_waiter) { PyObject *res; int is_true; @@ -2232,6 +2239,56 @@ _asyncio_Task_cancel_impl(TaskObj *self, PyObject *msg) Py_RETURN_TRUE; } +/*[clinic input] +_asyncio.Task.cancelling + +Return True if the task is in the process of being cancelled. + +This is set once .cancel() is called +and remains set until .uncancel() is called. + +As long as this flag is set, further .cancel() calls will be ignored, +until .uncancel() is called to reset it. +[clinic start generated code]*/ + +static PyObject * +_asyncio_Task_cancelling_impl(TaskObj *self) +/*[clinic end generated code: output=803b3af96f917d7e input=c50e50f9c3ca4676]*/ +/*[clinic end generated code]*/ +{ + if (self->task_cancel_requested) { + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } +} + +/*[clinic input] +_asyncio.Task.uncancel + +Reset the flag returned by cancelling(). + +This should be used by tasks that catch CancelledError +and wish to continue indefinitely until they are cancelled again. + +Returns the previous value of the flag. +[clinic start generated code]*/ + +static PyObject * +_asyncio_Task_uncancel_impl(TaskObj *self) +/*[clinic end generated code: output=58184d236a817d3c input=5db95e28fcb6f7cd]*/ +/*[clinic end generated code]*/ +{ + if (self->task_cancel_requested) { + self->task_cancel_requested = 0; + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } +} + /*[clinic input] _asyncio.Task.get_stack @@ -2455,6 +2512,8 @@ static PyMethodDef TaskType_methods[] = { _ASYNCIO_TASK_SET_RESULT_METHODDEF _ASYNCIO_TASK_SET_EXCEPTION_METHODDEF _ASYNCIO_TASK_CANCEL_METHODDEF + _ASYNCIO_TASK_CANCELLING_METHODDEF + _ASYNCIO_TASK_UNCANCEL_METHODDEF _ASYNCIO_TASK_GET_STACK_METHODDEF _ASYNCIO_TASK_PRINT_STACK_METHODDEF _ASYNCIO_TASK__MAKE_CANCELLED_ERROR_METHODDEF diff --git a/Modules/clinic/_asynciomodule.c.h b/Modules/clinic/_asynciomodule.c.h index c472e652fb7c56..5648e14f337f7f 100644 --- a/Modules/clinic/_asynciomodule.c.h +++ b/Modules/clinic/_asynciomodule.c.h @@ -447,6 +447,53 @@ _asyncio_Task_cancel(TaskObj *self, PyObject *const *args, Py_ssize_t nargs, PyO return return_value; } +PyDoc_STRVAR(_asyncio_Task_cancelling__doc__, +"cancelling($self, /)\n" +"--\n" +"\n" +"Return True if the task is in the process of being cancelled.\n" +"\n" +"This is set once .cancel() is called\n" +"and remains set until .uncancel() is called.\n" +"\n" +"As long as this flag is set, further .cancel() calls will be ignored,\n" +"until .uncancel() is called to reset it."); + +#define _ASYNCIO_TASK_CANCELLING_METHODDEF \ + {"cancelling", (PyCFunction)_asyncio_Task_cancelling, METH_NOARGS, _asyncio_Task_cancelling__doc__}, + +static PyObject * +_asyncio_Task_cancelling_impl(TaskObj *self); + +static PyObject * +_asyncio_Task_cancelling(TaskObj *self, PyObject *Py_UNUSED(ignored)) +{ + return _asyncio_Task_cancelling_impl(self); +} + +PyDoc_STRVAR(_asyncio_Task_uncancel__doc__, +"uncancel($self, /)\n" +"--\n" +"\n" +"Reset the flag returned by cancelling().\n" +"\n" +"This should be used by tasks that catch CancelledError\n" +"and wish to continue indefinitely until they are cancelled again.\n" +"\n" +"Returns the previous value of the flag."); + +#define _ASYNCIO_TASK_UNCANCEL_METHODDEF \ + {"uncancel", (PyCFunction)_asyncio_Task_uncancel, METH_NOARGS, _asyncio_Task_uncancel__doc__}, + +static PyObject * +_asyncio_Task_uncancel_impl(TaskObj *self); + +static PyObject * +_asyncio_Task_uncancel(TaskObj *self, PyObject *Py_UNUSED(ignored)) +{ + return _asyncio_Task_uncancel_impl(self); +} + PyDoc_STRVAR(_asyncio_Task_get_stack__doc__, "get_stack($self, /, *, limit=None)\n" "--\n" @@ -871,4 +918,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs, exit: return return_value; } -/*[clinic end generated code: output=0d127162ac92e0c0 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=c02708a9d6a774cc input=a9049054013a1b77]*/