Skip to content

Commit

Permalink
schema.subscribe again, returns a Union[ExecutionResult, AsyncGen]
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Jun 8, 2023
1 parent c2fba2b commit a3028b0
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 123 deletions.
5 changes: 3 additions & 2 deletions docs/operations/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ async def test_subscription():
}
"""

sub = await schema.subscribe(query)

index = 0
async for ok, result in schema.subscribe(query):
assert ok
async for result in sub:
assert not result.errors
assert result.data == {"count": index}

Expand Down
4 changes: 2 additions & 2 deletions strawberry/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import BaseSchema, SubscribeSingleResult
from .base import BaseSchema
from .schema import Schema

__all__ = ["BaseSchema", "Schema", "SubscribeSingleResult"]
__all__ = ["BaseSchema", "Schema"]
13 changes: 2 additions & 11 deletions strawberry/schema/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@
from .config import StrawberryConfig


class SubscribeSingleResult(RuntimeError):
"""Raised when Schema.subscribe() returns a single execution result, instead of a
subscription generator, typically as a result of validation errors.
"""

def __init__(self, value: ExecutionResult) -> None:
self.value = value


class BaseSchema(Protocol):
config: StrawberryConfig
schema_converter: GraphQLCoreConverter
Expand Down Expand Up @@ -74,14 +65,14 @@ def execute_sync(
raise NotImplementedError

@abstractmethod
def subscribe(
async def subscribe(
self,
query: str,
variable_values: Optional[Dict[str, Any]] = None,
context_value: Optional[Any] = None,
root_value: Optional[Any] = None,
operation_name: Optional[str] = None,
) -> AsyncGenerator[ExecutionResult, None]:
) -> Union[ExecutionResult, AsyncGenerator[ExecutionResult, None]]:
raise NotImplementedError

@abstractmethod
Expand Down
62 changes: 43 additions & 19 deletions strawberry/schema/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from strawberry.extensions.runner import SchemaExtensionsRunner
from strawberry.types import ExecutionResult

from .base import SubscribeSingleResult
from .exceptions import InvalidOperationTypeError

if TYPE_CHECKING:
Expand Down Expand Up @@ -309,19 +308,43 @@ async def subscribe(
extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]],
execution_context: ExecutionContext,
process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None],
) -> AsyncGenerator[ExecutionResult, None]:
"""
The graphql-core subscribe function returns either an ExecutionResult or an
AsyncGenerator[ExecutionResult, None]. The former is returned in case of an error
during parsing or validation.
Because we need to maintain execution context, we cannot return an
async generator, we must _be_ an async generator. So we yield a
(bool, ExecutionResult) tuple, where the bool indicates whether the result is an
potentially multiple execution result or a single result.
A False value indicates an single result, most likely an intial
failure (and no more values will be yielded) whereas a True value indicates a
successful subscription, and more values may be yielded.
"""
) -> Union[ExecutionResult, AsyncGenerator[ExecutionResult, None]]:
# The graphql-core subscribe function returns either an ExecutionResult or an
# AsyncGenerator[ExecutionResult, None]. The former is returned in case of an error
# during parsing or validation.
# We repeat that pattern here, but to maintain the context of the extensions
# context manager, we must delegate to an inner async generator. The inner
# generator yields an initial result, either a None, or an ExecutionResult,
# to indicate the two different cases.

asyncgen = _subscribe(
schema,
extensions=extensions,
execution_context=execution_context,
process_errors=process_errors,
)
# start the generator
first = await asyncgen.__anext__()
if first is not None:
# Single result. Close the generator to exit any context managers
await asyncgen.aclose()
return first
else:
# return the started generator. Cast away the Optional[] type
return cast(AsyncGenerator[ExecutionResult, None], asyncgen)


async def _subscribe(
schema: GraphQLSchema,
*,
extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]],
execution_context: ExecutionContext,
process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None],
) -> AsyncGenerator[Optional[ExecutionResult], None]:
# This Async generator first yields either a single ExecutionResult or None.
# If None is yielded, then the subscription has failed and the generator should
# be closed.
# Otherwise, if None is yielded, the subscription can continue.

extensions_runner = SchemaExtensionsRunner(
execution_context=execution_context,
Expand All @@ -338,7 +361,8 @@ async def subscribe(
execution_context, process_errors, extensions_runner
)
if error_result is not None:
raise SubscribeSingleResult(error_result)
yield error_result
return # pragma: no cover

async with extensions_runner.executing():
# currently original_subscribe is an async function. A future release
Expand Down Expand Up @@ -373,12 +397,12 @@ async def subscribe(
)

if isinstance(result, GraphQLExecutionResult):
raise SubscribeSingleResult(
await process_subscribe_result(
execution_context, process_errors, extensions_runner, result
)
yield await process_subscribe_result(
execution_context, process_errors, extensions_runner, result
)
return # pragma: no cover

yield None # signal that we are returning an async generator
aiterator = result.__aiter__()
try:
async for result in aiterator:
Expand Down
16 changes: 3 additions & 13 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,6 @@
}


class SubscribeSingleResult(RuntimeError):
"""Raised when Schema.subscribe() returns a single execution result, instead of a
subscription generator, typically as a result of validation errors.
"""

def __init__(self, value: ExecutionResult) -> None:
self.value = value


class Schema(BaseSchema):
def __init__(
self,
Expand Down Expand Up @@ -301,7 +292,7 @@ async def subscribe(
context_value: Optional[Any] = None,
root_value: Optional[Any] = None,
operation_name: Optional[str] = None,
) -> AsyncGenerator[ExecutionResult, None]:
) -> Union[ExecutionResult, AsyncGenerator[ExecutionResult, None]]:
execution_context = ExecutionContext(
query=query,
schema=self,
Expand All @@ -311,13 +302,12 @@ async def subscribe(
provided_operation_name=operation_name,
)

async for result in subscribe(
return await subscribe(
self._schema,
extensions=self.get_extensions(),
execution_context=execution_context,
process_errors=self.process_errors,
):
yield result
)

def _warn_for_federation_directives(self):
"""Raises a warning if the schema has any federation directives."""
Expand Down
64 changes: 29 additions & 35 deletions strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from graphql import GraphQLError, GraphQLSyntaxError, parse
from graphql.error.graphql_error import format_error as format_graphql_error

from strawberry.schema import SubscribeSingleResult
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
CompleteMessage,
ConnectionAckMessage,
Expand All @@ -21,6 +20,7 @@
SubscribeMessage,
SubscribeMessagePayload,
)
from strawberry.types import ExecutionResult
from strawberry.types.graphql import OperationType
from strawberry.unset import UNSET
from strawberry.utils.debug import pretty_print_graphql_operation
Expand All @@ -33,7 +33,6 @@
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
GraphQLTransportMessage,
)
from strawberry.types import ExecutionResult


class BaseGraphQLTransportWSHandler(ABC):
Expand Down Expand Up @@ -227,7 +226,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:

# Get an AsyncGenerator yielding the results
if operation_type == OperationType.SUBSCRIPTION:
result_source = self.schema.subscribe(
result_source = await self.schema.subscribe(
query=message.payload.query,
variable_values=message.payload.variables,
operation_name=message.payload.operationName,
Expand All @@ -237,22 +236,26 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
else:
# create AsyncGenerator returning a single result
async def get_result_source():
raise SubscribeSingleResult(
await self.schema.execute(
query=message.payload.query,
variable_values=message.payload.variables,
context_value=context,
root_value=root_value,
operation_name=message.payload.operationName,
)
yield await self.schema.execute(
query=message.payload.query,
variable_values=message.payload.variables,
context_value=context,
root_value=root_value,
operation_name=message.payload.operationName,
)
# need a yield here to turn this into an async generator
yield None # pragma: no cover

result_source = get_result_source()

operation = Operation(self, message.id)

# Handle initial validation errors
if isinstance(result_source, ExecutionResult):
assert result_source.errors
payload = [format_graphql_error(result_source.errors[0])]
await self.send_message(ErrorMessage(id=message.id, payload=payload))
self.schema.process_errors(result_source.errors)
return

# Create task to handle this subscription, reserve the operation ID
self.subscriptions[message.id] = result_source
self.tasks[message.id] = asyncio.create_task(
Expand Down Expand Up @@ -294,15 +297,19 @@ async def handle_async_results(
operation: Operation,
) -> None:
try:
try:
async for result in result_source:
await self.send_result(operation, result)
if result.errors:
return # terminate subscription
except SubscribeSingleResult as single_result:
await self.send_result(operation, single_result.value)
finally:
await result_source.aclose()
async for result in result_source:
if result.errors:
error_payload = [format_graphql_error(err) for err in result.errors]
error_message = ErrorMessage(id=operation.id, payload=error_payload)
await operation.send_message(error_message)
self.schema.process_errors(result.errors)
return
else:
next_payload = {"data": result.data}
if result.extensions:
next_payload["extensions"] = result.extensions
next_message = NextMessage(id=operation.id, payload=next_payload)
await operation.send_message(next_message)
except asyncio.CancelledError:
# CancelledErrors are expected during task cleanup.
raise
Expand All @@ -316,19 +323,6 @@ async def handle_async_results(
self.schema.process_errors([error])
return

async def send_result(self, operation: Operation, result: ExecutionResult) -> None:
if result.errors:
error_payload = [format_graphql_error(err) for err in result.errors]
error_message = ErrorMessage(id=operation.id, payload=error_payload)
await operation.send_message(error_message)
self.schema.process_errors(result.errors)
else:
next_payload = {"data": result.data}
if result.extensions:
next_payload["extensions"] = result.extensions
next_message = NextMessage(id=operation.id, payload=next_payload)
await operation.send_message(next_message)

def forget_id(self, id: str) -> None:
# de-register the operation id making it immediately available
# for re-use
Expand Down
Loading

0 comments on commit a3028b0

Please sign in to comment.