Skip to content

Commit

Permalink
Support subscriptions extensions (#3554)
Browse files Browse the repository at this point in the history
* Support subscriptions extensions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add RELEASE.md

* mypy fixes.

* remove positional only / (py3.7)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reverts

* intial subscription test pass

* wip. pass manager to graphql core

* wip: migrate to new graphql core

* don't use supports resolve

* revert

* revert unneeded changes

* restore tests; inject execution_context on operation

* don't use class; depecate

* wip: ensure execution context is distinct.

* wip: first inital successful run after refactor

* wip: more tests pass

* better release.md

* all previous extensions tests pass

* update release.md

* improve tests readability

* test_subscription_success_many_fields pass

* test_subscription_first_yields_error

* test_extensino_results_are_cleared_between_yields

* test_extensino_results_are_cleared_between_yields

* fix extensions tests

* ai lints

* refactor; fix more tests

* ensure `on_execute` and `get_result` hooks are deterministically ordered.

* handle `on_execute` exceptions.

* move missing query error before parsing phase.

* wip: remove unneeded exception handler for `sync_execute`

* fix more tests

* refactor: separate subscription tests from normal tests.

* fix websocket tests

* fix mypy

* move to graphql-core origin@main

* nit

* fix: handle not awaitable result of `original_subscribe`

* add th `assert_next`

* nits

* nit

* nit

* nits

* reorder execute.py

* docs.

* use `.aclose`

* fix mypy issues.

* fix unused `else`

* move execution context injection upwards.

* nit

* add test for when extensios not return anything.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* wip: remove implements_get_result check

* working on tests

* schema tests pass

* lints

* fix more tests

* wip: working on tests.

* fix graphql-transport-ws

* graphql-ws tests pass

* tests were always running on 3.3 XD

* pass middleware manager only in 3.3

* fix some mypy issues

* fix more tests.

* fix graphql-ws-transport protocol behaviour on (pre)execution errors

* resolve optimization todos.

* add `long_runnning` subscription benchmark; update release.md; improve coverage.

* fix contextvar issue + few redundant changes.

* feat: add lazy loading for images in test_subscriptions benchmark

* typos in readme

* fix tests.

* recify reviews; nits & move `_implements_resolve` to `SchemaExtension`

* refactor: update GraphQL version check logic in utils/__init__.py

* rectify review comments

* rectify @DoctorJohn review

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rectify @bellini666 comments

* Improve protocol handling for subscriptions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Lint

* Restore tests

* Update wrong return type

* Update doc

* Temporarily skip on_execute, see #3613

* Update multipart tests

* Fix bad merge

* Fix type

* remove unneeded async gen wrapper since we removed support for on_execute hook

* Update release notes

* Add tweet.md

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Patrick Arminio <patrick.arminio@gmail.com>
  • Loading branch information
3 people authored Sep 10, 2024
1 parent fc25e04 commit 0dcf23d
Show file tree
Hide file tree
Showing 32 changed files with 1,268 additions and 631 deletions.
36 changes: 36 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
Release type: minor

This release adds support for schema-extensions in subscriptions.

Here's a small example of how to use them (they work the same way as query and
mutation extensions):

```python
import asyncio
from typing import AsyncIterator

import strawberry
from strawberry.extensions.base_extension import SchemaExtension


@strawberry.type
class Subscription:
@strawberry.subscription
async def notifications(self, info: strawberry.Info) -> AsyncIterator[str]:
for _ in range(3):
yield "Hello"


class MyExtension(SchemaExtension):
async def on_operation(self):
# This would run when the subscription starts
print("Subscription started")
yield
# The subscription has ended
print("Subscription ended")


schema = strawberry.Schema(
query=Query, subscription=Subscription, extensions=[MyExtension]
)
```
5 changes: 5 additions & 0 deletions TWEET.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
🆕 Release $version is out! Thanks to $contributor for the PR 👏

This release adds supports for schema extensions to subscriptions!

Get it here 👉 $release_url
1 change: 1 addition & 0 deletions docs/breaking-changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ title: List of breaking changes and deprecations

# List of breaking changes and deprecations

- [Version 0.240.0 - 10 September 2024](./breaking-changes/0.240.0.md)
- [Version 0.236.0 - 17 July 2024](./breaking-changes/0.236.0.md)
- [Version 0.233.0 - 29 May 2024](./breaking-changes/0.233.0.md)
- [Version 0.217.0 - 18 December 2023](./breaking-changes/0.217.0.md)
Expand Down
36 changes: 36 additions & 0 deletions docs/breaking-changes/0.240.0.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
---
title: 0.240.0 Breaking Changes
slug: breaking-changes/0.240.0
---

# v0.240.0 updates `Schema.subscribe`'s signature

In order to support schema extensions in subscriptions and errors that can be
raised before the execution of the subscription, we had to update the signature
of `Schema.subscribe`.

Previously it was:

```python
async def subscribe(
self,
query: str,
variable_values: Optional[Dict[str, Any]] = None,
context_value: Optional[Any] = None,
root_value: Optional[Any] = None,
operation_name: Optional[str] = None,
) -> Union[AsyncIterator[GraphQLExecutionResult], GraphQLExecutionResult]:
```

Now it is:

```python
async def subscribe(
self,
query: Optional[str],
variable_values: Optional[Dict[str, Any]] = None,
context_value: Optional[Any] = None,
root_value: Optional[Any] = None,
operation_name: Optional[str] = None,
) -> Union[AsyncGenerator[ExecutionResult, None], PreExecutionError]:
```
9 changes: 2 additions & 7 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
PYTHON_VERSIONS = ["3.12", "3.11", "3.10", "3.9", "3.8"]
GQL_CORE_VERSIONS = [
"3.2.3",
"3.3.0",
"3.3.0a6",
]

COMMON_PYTEST_OPTIONS = [
Expand Down Expand Up @@ -44,12 +44,7 @@


def _install_gql_core(session: Session, version: str) -> None:
# hack for better workflow names # noqa: FIX004
if version == "3.2.3":
session._session.install(f"graphql-core=={version}") # type: ignore
session._session.install(
"https://github.com/graphql-python/graphql-core/archive/876aef67b6f1e1f21b3b5db94c7ff03726cb6bdf.zip"
) # type: ignore
session._session.install(f"graphql-core=={version}")


gql_core_parametrize = nox.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions strawberry/channels/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ async def subscribe(
message_type = response["type"]
if message_type == NextMessage.type:
payload = NextMessage(**response).payload
ret = ExecutionResult(payload["data"], None)
ret = ExecutionResult(payload.get("data"), None)
if "errors" in payload:
ret.errors = self.process_errors(payload["errors"])
ret.errors = self.process_errors(payload.get("errors") or [])
ret.extensions = payload.get("extensions", None)
yield ret
elif message_type == ErrorMessage.type:
Expand Down
13 changes: 10 additions & 3 deletions strawberry/extensions/base_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ class LifecycleStep(Enum):
class SchemaExtension:
execution_context: ExecutionContext

def __init__(self, *, execution_context: ExecutionContext) -> None:
self.execution_context = execution_context

# to support extensions that still use the old signature
# we have an optional argument here for ease of initialization.
def __init__(
self, *, execution_context: ExecutionContext | None = None
) -> None: ...
def on_operation( # type: ignore
self,
) -> AsyncIteratorOrIterator[None]: # pragma: no cover
Expand Down Expand Up @@ -61,6 +63,11 @@ def resolve(
def get_results(self) -> AwaitableOrValue[Dict[str, Any]]:
return {}

@classmethod
def _implements_resolve(cls) -> bool:
"""Whether the extension implements the resolve method."""
return cls.resolve is not SchemaExtension.resolve


Hook = Callable[[SchemaExtension], AsyncIteratorOrIterator[None]]

Expand Down
45 changes: 8 additions & 37 deletions strawberry/extensions/runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union

from graphql import MiddlewareManager
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from strawberry.extensions.context import (
ExecutingContextManager,
Expand All @@ -13,39 +11,22 @@
)
from strawberry.utils.await_maybe import await_maybe

from . import SchemaExtension

if TYPE_CHECKING:
from strawberry.types import ExecutionContext

from . import SchemaExtension


class SchemaExtensionsRunner:
extensions: List[SchemaExtension]

def __init__(
self,
execution_context: ExecutionContext,
extensions: Optional[
List[Union[Type[SchemaExtension], SchemaExtension]]
] = None,
extensions: Optional[List[SchemaExtension]] = None,
) -> None:
self.execution_context = execution_context

if not extensions:
extensions = []

init_extensions: List[SchemaExtension] = []

for extension in extensions:
# If the extension has already been instantiated then set the
# `execution_context` attribute
if isinstance(extension, SchemaExtension):
extension.execution_context = execution_context
init_extensions.append(extension)
else:
init_extensions.append(extension(execution_context=execution_context))

self.extensions = init_extensions
self.extensions = extensions or []

def operation(self) -> OperationContextManager:
return OperationContextManager(self.extensions)
Expand All @@ -61,29 +42,19 @@ def executing(self) -> ExecutingContextManager:

def get_extensions_results_sync(self) -> Dict[str, Any]:
data: Dict[str, Any] = {}

for extension in self.extensions:
if inspect.iscoroutinefunction(extension.get_results):
msg = "Cannot use async extension hook during sync execution"
raise RuntimeError(msg)

data.update(extension.get_results()) # type: ignore

return data

async def get_extensions_results(self) -> Dict[str, Any]:
async def get_extensions_results(self, ctx: ExecutionContext) -> Dict[str, Any]:
data: Dict[str, Any] = {}

for extension in self.extensions:
results = await await_maybe(extension.get_results())
data.update(results)
data.update(await await_maybe(extension.get_results()))

data.update(ctx.extensions_results)
return data

def as_middleware_manager(self, *additional_middlewares: Any) -> MiddlewareManager:
middlewares = tuple(self.extensions) + additional_middlewares

return MiddlewareManager(*middlewares)


__all__ = ["SchemaExtensionsRunner"]
3 changes: 2 additions & 1 deletion strawberry/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
from typing_extensions import TypedDict
from typing_extensions import Literal, TypedDict

if TYPE_CHECKING:
from strawberry.types import ExecutionResult
Expand Down Expand Up @@ -33,6 +33,7 @@ class GraphQLRequestData:
query: Optional[str]
variables: Optional[Dict[str, Any]]
operation_name: Optional[str]
protocol: Literal["http", "multipart-subscription"] = "http"


def parse_query_params(params: Dict[str, str]) -> Dict[str, Any]:
Expand Down
16 changes: 16 additions & 0 deletions strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Tuple,
Union,
)
from typing_extensions import Literal

from graphql import GraphQLError

Expand Down Expand Up @@ -121,6 +122,15 @@ async def execute_operation(

assert self.schema

if request_data.protocol == "multipart-subscription":
return await self.schema.subscribe(
request_data.query, # type: ignore
variable_values=request_data.variables,
context_value=context,
root_value=root_value,
operation_name=request_data.operation_name,
)

return await self.schema.execute(
request_data.query,
root_value=root_value,
Expand Down Expand Up @@ -312,21 +322,27 @@ async def parse_http_body(
) -> GraphQLRequestData:
content_type, params = parse_content_type(request.content_type or "")

protocol: Literal["http", "multipart-subscription"] = "http"

if request.method == "GET":
data = self.parse_query_params(request.query_params)
if self._is_multipart_subscriptions(content_type, params):
protocol = "multipart-subscription"
elif "application/json" in content_type:
data = self.parse_json(await request.get_body())
elif content_type == "multipart/form-data":
data = await self.parse_multipart(request)
elif self._is_multipart_subscriptions(content_type, params):
data = await self.parse_multipart_subscriptions(request)
protocol = "multipart-subscription"
else:
raise HTTPException(400, "Unsupported content type")

return GraphQLRequestData(
query=data.get("query"),
variables=data.get("variables"),
operation_name=data.get("operationName"),
protocol=protocol,
)

async def process_result(
Expand Down
6 changes: 3 additions & 3 deletions strawberry/schema/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from strawberry.types import (
ExecutionContext,
ExecutionResult,
SubscriptionExecutionResult,
)
from strawberry.types.base import StrawberryObjectDefinition
from strawberry.types.enum import EnumDefinition
Expand All @@ -24,6 +23,7 @@
from strawberry.types.union import StrawberryUnion

from .config import StrawberryConfig
from .subscribe import SubscriptionResult


class BaseSchema(Protocol):
Expand All @@ -43,7 +43,7 @@ async def execute(
root_value: Optional[Any] = None,
operation_name: Optional[str] = None,
allowed_operation_types: Optional[Iterable[OperationType]] = None,
) -> Union[ExecutionResult, SubscriptionExecutionResult]:
) -> ExecutionResult:
raise NotImplementedError

@abstractmethod
Expand All @@ -66,7 +66,7 @@ async def subscribe(
context_value: Optional[Any] = None,
root_value: Optional[Any] = None,
operation_name: Optional[str] = None,
) -> Any:
) -> SubscriptionResult:
raise NotImplementedError

@abstractmethod
Expand Down
Loading

0 comments on commit 0dcf23d

Please sign in to comment.