diff --git a/jupyter_client/blocking/client.py b/jupyter_client/blocking/client.py index 27604fcb3..625b3f5a2 100644 --- a/jupyter_client/blocking/client.py +++ b/jupyter_client/blocking/client.py @@ -7,6 +7,7 @@ from traitlets import Type from ..utils import run_sync +from ..utils import uses_run_sync from jupyter_client.channels import HBChannel from jupyter_client.channels import ZMQSocketChannel from jupyter_client.client import KernelClient @@ -25,6 +26,7 @@ def _(self, *args, **kwargs): return _ +@uses_run_sync class BlockingKernelClient(KernelClient): """A KernelClient with blocking APIs diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index 19afd0db7..e03b83553 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -33,6 +33,7 @@ from .provisioning import KernelProvisionerFactory as KPF from .utils import ensure_async from .utils import run_sync +from .utils import uses_run_sync from jupyter_client import KernelClient from jupyter_client import kernelspec @@ -85,6 +86,7 @@ async def wrapper(self, *args, **kwargs): return t.cast(F, wrapper) +@uses_run_sync class KernelManager(ConnectionFileMixin): """Manages a single kernel in a subprocess on this host. @@ -636,6 +638,7 @@ async def _async_wait(self, pollinterval: float = 0.1) -> None: await asyncio.sleep(pollinterval) +@uses_run_sync(enable=False) class AsyncKernelManager(KernelManager): # the class to create with our `client` method client_class: DottedObjectName = DottedObjectName( diff --git a/jupyter_client/multikernelmanager.py b/jupyter_client/multikernelmanager.py index 7dceb9448..0c0a23544 100644 --- a/jupyter_client/multikernelmanager.py +++ b/jupyter_client/multikernelmanager.py @@ -24,6 +24,7 @@ from .manager import KernelManager from .utils import ensure_async from .utils import run_sync +from .utils import uses_run_sync class DuplicateKernelError(Exception): @@ -50,6 +51,7 @@ def wrapped( return wrapped +@uses_run_sync class MultiKernelManager(LoggingConfigurable): """A class for managing multiple kernels.""" @@ -529,6 +531,7 @@ def new_kernel_id(self, **kwargs: t.Any) -> str: return str(uuid.uuid4()) +@uses_run_sync(enable=False) class AsyncMultiKernelManager(MultiKernelManager): kernel_manager_class = DottedObjectName( diff --git a/jupyter_client/utils.py b/jupyter_client/utils.py index 585bf1b17..9fa2ce1f1 100644 --- a/jupyter_client/utils.py +++ b/jupyter_client/utils.py @@ -4,10 +4,48 @@ - vendor functions from ipython_genutils that should be retired at some point. """ import asyncio +import functools import inspect import os +def uses_run_sync(__cls=None, *, enable=True): + """decorator for classes that uses `run_sync` to wrap methods + + This will ensure that nest_asyncio is applied in a timely manner. + + If an inheriting class becomes async again, it can call with + `enable=False` to prevent the nest_asyncio patching. + """ + + def perform_wrap(cls): + orig_init = cls.__init__ + + @functools.wraps(orig_init) + def __init__(self, *args, **kwargs): + if cls._uses_run_sync: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop: + import nest_asyncio # type: ignore + + nest_asyncio.apply(loop) + return orig_init(self, *args, **kwargs) + + cls._uses_run_sync = uses_sync + cls.__init__ = __init__ + return cls + + if __cls: + uses_sync = True + return perform_wrap(__cls) + else: + uses_sync = enable + return perform_wrap + + def run_sync(coro): def wrapped(*args, **kwargs): try: @@ -19,7 +57,7 @@ def wrapped(*args, **kwargs): except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - import nest_asyncio # type: ignore + import nest_asyncio nest_asyncio.apply(loop) future = asyncio.ensure_future(coro(*args, **kwargs), loop=loop)