From a3028b080542540b2a835721fdbe99b2f237deed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 7 Jun 2023 15:15:57 +0000 Subject: [PATCH] schema.subscribe again, returns a Union[ExecutionResult, AsyncGen] --- docs/operations/testing.md | 5 +- strawberry/schema/__init__.py | 4 +- strawberry/schema/base.py | 13 +--- strawberry/schema/execute.py | 62 +++++++++++----- strawberry/schema/schema.py | 16 +---- .../graphql_transport_ws/handlers.py | 64 ++++++++--------- .../protocols/graphql_ws/handlers.py | 70 ++++++++++--------- tests/schema/test_permission.py | 6 +- tests/schema/test_subscription.py | 6 +- 9 files changed, 123 insertions(+), 123 deletions(-) diff --git a/docs/operations/testing.md b/docs/operations/testing.md index 175ee98a5b..0e53ff7f5c 100644 --- a/docs/operations/testing.md +++ b/docs/operations/testing.md @@ -134,9 +134,10 @@ async def test_subscription(): } """ + sub = await schema.subscribe(query) + index = 0 - async for ok, result in schema.subscribe(query): - assert ok + async for result in sub: assert not result.errors assert result.data == {"count": index} diff --git a/strawberry/schema/__init__.py b/strawberry/schema/__init__.py index cbc3bf134a..5cf633ac21 100644 --- a/strawberry/schema/__init__.py +++ b/strawberry/schema/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseSchema, SubscribeSingleResult +from .base import BaseSchema from .schema import Schema -__all__ = ["BaseSchema", "Schema", "SubscribeSingleResult"] +__all__ = ["BaseSchema", "Schema"] diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index 5c1170ebee..ed1340a627 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -32,15 +32,6 @@ from .config import StrawberryConfig -class SubscribeSingleResult(RuntimeError): - """Raised when Schema.subscribe() returns a single execution result, instead of a - subscription generator, typically as a result of validation errors. - """ - - def __init__(self, value: ExecutionResult) -> None: - self.value = value - - class BaseSchema(Protocol): config: StrawberryConfig schema_converter: GraphQLCoreConverter @@ -74,14 +65,14 @@ def execute_sync( raise NotImplementedError @abstractmethod - def subscribe( + async def subscribe( self, query: str, variable_values: Optional[Dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> AsyncGenerator[ExecutionResult, None]: + ) -> Union[ExecutionResult, AsyncGenerator[ExecutionResult, None]]: raise NotImplementedError @abstractmethod diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index 8861baf638..acae82887a 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -28,7 +28,6 @@ from strawberry.extensions.runner import SchemaExtensionsRunner from strawberry.types import ExecutionResult -from .base import SubscribeSingleResult from .exceptions import InvalidOperationTypeError if TYPE_CHECKING: @@ -309,19 +308,43 @@ async def subscribe( extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]], execution_context: ExecutionContext, process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], -) -> AsyncGenerator[ExecutionResult, None]: - """ - The graphql-core subscribe function returns either an ExecutionResult or an - AsyncGenerator[ExecutionResult, None]. The former is returned in case of an error - during parsing or validation. - Because we need to maintain execution context, we cannot return an - async generator, we must _be_ an async generator. So we yield a - (bool, ExecutionResult) tuple, where the bool indicates whether the result is an - potentially multiple execution result or a single result. - A False value indicates an single result, most likely an intial - failure (and no more values will be yielded) whereas a True value indicates a - successful subscription, and more values may be yielded. - """ +) -> Union[ExecutionResult, AsyncGenerator[ExecutionResult, None]]: + # The graphql-core subscribe function returns either an ExecutionResult or an + # AsyncGenerator[ExecutionResult, None]. The former is returned in case of an error + # during parsing or validation. + # We repeat that pattern here, but to maintain the context of the extensions + # context manager, we must delegate to an inner async generator. The inner + # generator yields an initial result, either a None, or an ExecutionResult, + # to indicate the two different cases. + + asyncgen = _subscribe( + schema, + extensions=extensions, + execution_context=execution_context, + process_errors=process_errors, + ) + # start the generator + first = await asyncgen.__anext__() + if first is not None: + # Single result. Close the generator to exit any context managers + await asyncgen.aclose() + return first + else: + # return the started generator. Cast away the Optional[] type + return cast(AsyncGenerator[ExecutionResult, None], asyncgen) + + +async def _subscribe( + schema: GraphQLSchema, + *, + extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]], + execution_context: ExecutionContext, + process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], +) -> AsyncGenerator[Optional[ExecutionResult], None]: + # This Async generator first yields either a single ExecutionResult or None. + # If None is yielded, then the subscription has failed and the generator should + # be closed. + # Otherwise, if None is yielded, the subscription can continue. extensions_runner = SchemaExtensionsRunner( execution_context=execution_context, @@ -338,7 +361,8 @@ async def subscribe( execution_context, process_errors, extensions_runner ) if error_result is not None: - raise SubscribeSingleResult(error_result) + yield error_result + return # pragma: no cover async with extensions_runner.executing(): # currently original_subscribe is an async function. A future release @@ -373,12 +397,12 @@ async def subscribe( ) if isinstance(result, GraphQLExecutionResult): - raise SubscribeSingleResult( - await process_subscribe_result( - execution_context, process_errors, extensions_runner, result - ) + yield await process_subscribe_result( + execution_context, process_errors, extensions_runner, result ) + return # pragma: no cover + yield None # signal that we are returning an async generator aiterator = result.__aiter__() try: async for result in aiterator: diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index bc8208855a..cd6222b62f 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -60,15 +60,6 @@ } -class SubscribeSingleResult(RuntimeError): - """Raised when Schema.subscribe() returns a single execution result, instead of a - subscription generator, typically as a result of validation errors. - """ - - def __init__(self, value: ExecutionResult) -> None: - self.value = value - - class Schema(BaseSchema): def __init__( self, @@ -301,7 +292,7 @@ async def subscribe( context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> AsyncGenerator[ExecutionResult, None]: + ) -> Union[ExecutionResult, AsyncGenerator[ExecutionResult, None]]: execution_context = ExecutionContext( query=query, schema=self, @@ -311,13 +302,12 @@ async def subscribe( provided_operation_name=operation_name, ) - async for result in subscribe( + return await subscribe( self._schema, extensions=self.get_extensions(), execution_context=execution_context, process_errors=self.process_errors, - ): - yield result + ) def _warn_for_federation_directives(self): """Raises a warning if the schema has any federation directives.""" diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 601642623f..b867ea99f4 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -9,7 +9,6 @@ from graphql import GraphQLError, GraphQLSyntaxError, parse from graphql.error.graphql_error import format_error as format_graphql_error -from strawberry.schema import SubscribeSingleResult from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, ConnectionAckMessage, @@ -21,6 +20,7 @@ SubscribeMessage, SubscribeMessagePayload, ) +from strawberry.types import ExecutionResult from strawberry.types.graphql import OperationType from strawberry.unset import UNSET from strawberry.utils.debug import pretty_print_graphql_operation @@ -33,7 +33,6 @@ from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( GraphQLTransportMessage, ) - from strawberry.types import ExecutionResult class BaseGraphQLTransportWSHandler(ABC): @@ -227,7 +226,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: # Get an AsyncGenerator yielding the results if operation_type == OperationType.SUBSCRIPTION: - result_source = self.schema.subscribe( + result_source = await self.schema.subscribe( query=message.payload.query, variable_values=message.payload.variables, operation_name=message.payload.operationName, @@ -237,22 +236,26 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: else: # create AsyncGenerator returning a single result async def get_result_source(): - raise SubscribeSingleResult( - await self.schema.execute( - query=message.payload.query, - variable_values=message.payload.variables, - context_value=context, - root_value=root_value, - operation_name=message.payload.operationName, - ) + yield await self.schema.execute( + query=message.payload.query, + variable_values=message.payload.variables, + context_value=context, + root_value=root_value, + operation_name=message.payload.operationName, ) - # need a yield here to turn this into an async generator - yield None # pragma: no cover result_source = get_result_source() operation = Operation(self, message.id) + # Handle initial validation errors + if isinstance(result_source, ExecutionResult): + assert result_source.errors + payload = [format_graphql_error(result_source.errors[0])] + await self.send_message(ErrorMessage(id=message.id, payload=payload)) + self.schema.process_errors(result_source.errors) + return + # Create task to handle this subscription, reserve the operation ID self.subscriptions[message.id] = result_source self.tasks[message.id] = asyncio.create_task( @@ -294,15 +297,19 @@ async def handle_async_results( operation: Operation, ) -> None: try: - try: - async for result in result_source: - await self.send_result(operation, result) - if result.errors: - return # terminate subscription - except SubscribeSingleResult as single_result: - await self.send_result(operation, single_result.value) - finally: - await result_source.aclose() + async for result in result_source: + if result.errors: + error_payload = [format_graphql_error(err) for err in result.errors] + error_message = ErrorMessage(id=operation.id, payload=error_payload) + await operation.send_message(error_message) + self.schema.process_errors(result.errors) + return + else: + next_payload = {"data": result.data} + if result.extensions: + next_payload["extensions"] = result.extensions + next_message = NextMessage(id=operation.id, payload=next_payload) + await operation.send_message(next_message) except asyncio.CancelledError: # CancelledErrors are expected during task cleanup. raise @@ -316,19 +323,6 @@ async def handle_async_results( self.schema.process_errors([error]) return - async def send_result(self, operation: Operation, result: ExecutionResult) -> None: - if result.errors: - error_payload = [format_graphql_error(err) for err in result.errors] - error_message = ErrorMessage(id=operation.id, payload=error_payload) - await operation.send_message(error_message) - self.schema.process_errors(result.errors) - else: - next_payload = {"data": result.data} - if result.extensions: - next_payload["extensions"] = result.extensions - next_message = NextMessage(id=operation.id, payload=next_payload) - await operation.send_message(next_message) - def forget_id(self, id: str) -> None: # de-register the operation id making it immediately available # for re-use diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index d6666c8fc5..e125e6c211 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -8,7 +8,6 @@ from graphql import GraphQLError from graphql.error.graphql_error import format_error as format_graphql_error -from strawberry.schema import SubscribeSingleResult from strawberry.subscriptions.protocols.graphql_ws import ( GQL_COMPLETE, GQL_CONNECTION_ACK, @@ -21,6 +20,7 @@ GQL_START, GQL_STOP, ) +from strawberry.types import ExecutionResult from strawberry.utils.debug import pretty_print_graphql_operation if TYPE_CHECKING: @@ -124,13 +124,26 @@ async def handle_start(self, message: OperationMessage) -> None: if self.debug: pretty_print_graphql_operation(operation_name, query, variables) - result_source = self.schema.subscribe( - query=query, - variable_values=variables, - operation_name=operation_name, - context_value=context, - root_value=root_value, - ) + try: + result_source = await self.schema.subscribe( + query=query, + variable_values=variables, + operation_name=operation_name, + context_value=context, + root_value=root_value, + ) + except GraphQLError as error: + error_payload = format_graphql_error(error) + await self.send_message(GQL_ERROR, operation_id, error_payload) + self.schema.process_errors([error]) + return + + if isinstance(result_source, ExecutionResult): + assert result_source.errors + error_payload = format_graphql_error(result_source.errors[0]) + await self.send_message(GQL_ERROR, operation_id, error_payload) + self.schema.process_errors(result_source.errors) + return self.subscriptions[operation_id] = result_source result_handler = self.handle_async_results(result_source, operation_id) @@ -152,29 +165,19 @@ async def handle_async_results( operation_id: str, ) -> None: try: - try: - async for result in result_source: - payload = {"data": result.data} - if result.errors: - payload["errors"] = [ - format_graphql_error(err) for err in result.errors - ] - if result.extensions: - payload["extensions"] = result.extensions - await self.send_message(GQL_DATA, operation_id, payload) - # log errors after send_message to prevent potential - # slowdown of sending result - if result.errors: - self.schema.process_errors(result.errors) - except SubscribeSingleResult as single_result: - result = single_result.value - assert result.errors - error_payload = format_graphql_error(result.errors[0]) - await self.send_message(GQL_ERROR, operation_id, error_payload) - self.schema.process_errors(result.errors) - return - finally: - await result_source.aclose() + async for result in result_source: + payload = {"data": result.data} + if result.errors: + payload["errors"] = [ + format_graphql_error(err) for err in result.errors + ] + if result.extensions: + payload["extensions"] = result.extensions + await self.send_message(GQL_DATA, operation_id, payload) + # log errors after send_message to prevent potential + # slowdown of sending result + if result.errors: + self.schema.process_errors(result.errors) except asyncio.CancelledError: # CancelledErrors are expected during task cleanup. pass @@ -188,16 +191,17 @@ async def handle_async_results( {"data": None, "errors": [format_graphql_error(error)]}, ) self.schema.process_errors([error]) + finally: + await result_source.aclose() await self.send_message(GQL_COMPLETE, operation_id, None) async def cleanup_operation(self, operation_id: str) -> None: - iterator = self.subscriptions.pop(operation_id) + self.subscriptions.pop(operation_id) task = self.tasks.pop(operation_id) task.cancel() with suppress(BaseException): await task - await iterator.aclose() async def send_message( self, diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index 99a58ef1ed..aade03ec3f 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -4,7 +4,6 @@ import strawberry from strawberry.permission import BasePermission -from strawberry.schema import SubscribeSingleResult from strawberry.types import Info @@ -76,10 +75,7 @@ async def user(self, info) -> typing.AsyncGenerator[str, None]: query = "subscription { user }" - with pytest.raises(SubscribeSingleResult) as err: - async for result in schema.subscribe(query): - pass - result = err.value.value + result = await schema.subscribe(query) assert result.errors[0].message == "You are not authorized" diff --git a/tests/schema/test_subscription.py b/tests/schema/test_subscription.py index 09b25918bb..95191d0192 100644 --- a/tests/schema/test_subscription.py +++ b/tests/schema/test_subscription.py @@ -25,7 +25,7 @@ async def example(self) -> typing.AsyncGenerator[str, None]: query = "subscription { example }" - async for result in schema.subscribe(query): + async for result in await schema.subscribe(query): assert not result.errors assert result.data["example"] == "Hi" @@ -46,7 +46,7 @@ async def example(self, name: str) -> typing.AsyncGenerator[str, None]: query = 'subscription { example(name: "Nina") }' - async for result in schema.subscribe(query): + async for result in await schema.subscribe(query): assert not result.errors assert result.data["example"] == "Hi Nina" @@ -87,6 +87,6 @@ class Subscription: query = "subscription { example }" - async for result in schema.subscribe(query): + async for result in await schema.subscribe(query): assert not result.errors assert result.data["example"] == "Hi"