Skip to content

Commit

Permalink
Use async schema extensions only for integrations which support it
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed May 31, 2023
1 parent cf08906 commit 87fbff0
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 10 deletions.
3 changes: 2 additions & 1 deletion tests/http/clients/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from strawberry.aiohttp.views import GraphQLView as BaseGraphQLView
from strawberry.http import GraphQLHTTPResponse
from strawberry.types import ExecutionResult
from tests.views.schema import Query, schema
from tests.views.schema import Query
from tests.views.schema import async_schema as schema

from ..context import get_context
from .base import (
Expand Down
3 changes: 2 additions & 1 deletion tests/http/clients/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from strawberry.asgi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
from strawberry.http import GraphQLHTTPResponse
from strawberry.types import ExecutionResult
from tests.views.schema import Query, schema
from tests.views.schema import Query
from tests.views.schema import async_schema as schema

from ..context import get_context
from .base import (
Expand Down
2 changes: 1 addition & 1 deletion tests/http/clients/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from channels.testing import WebsocketCommunicator
from strawberry.channels import GraphQLWSConsumer
from tests.views.schema import schema
from tests.views.schema import async_schema as schema

from ..context import get_context
from .base import (
Expand Down
3 changes: 2 additions & 1 deletion tests/http/clients/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
from strawberry.http import GraphQLHTTPResponse
from strawberry.types import ExecutionResult
from tests.views.schema import Query, schema
from tests.views.schema import Query
from tests.views.schema import async_schema as schema

from ..context import get_context
from .asgi import AsgiWebSocketClient
Expand Down
3 changes: 2 additions & 1 deletion tests/http/clients/starlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from strawberry.starlite import make_graphql_controller
from strawberry.starlite.controller import GraphQLTransportWSHandler, GraphQLWSHandler
from strawberry.types import ExecutionResult
from tests.views.schema import Query, schema
from tests.views.schema import Query
from tests.views.schema import async_schema as schema

from ..context import get_context
from .base import (
Expand Down
46 changes: 44 additions & 2 deletions tests/views/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,41 @@ class MyExtension(SchemaExtension):
def get_results(self) -> Dict[str, str]:
return {"example": "example"}

def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any):
self.resolve_called()
return _next(root, info, *args, **kwargs)

def resolve_called(self):
pass

def lifecycle_called(self, event, phase):
pass

def on_operation(self):
self.lifecycle_called("operation", "before")
yield
self.lifecycle_called("operation", "after")

def on_validate(self):
self.lifecycle_called("validate", "before")
yield
self.lifecycle_called("validate", "after")

def on_parse(self):
self.lifecycle_called("parse", "before")
yield
self.lifecycle_called("parse", "after")

def on_execute(self):
self.lifecycle_called("execute", "before")
yield
self.lifecycle_called("execute", "after")


class MyAsyncExtension(SchemaExtension):
def get_results(self) -> Dict[str, str]:
return {"example": "example"}

async def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any):
self.resolve_called()
result = _next(root, info, *args, **kwargs)
Expand All @@ -62,12 +97,12 @@ def on_validate(self):
yield
self.lifecycle_called("validate", "after")

async def on_parse(self):
def on_parse(self):
self.lifecycle_called("parse", "before")
yield
self.lifecycle_called("parse", "after")

def on_execute(self):
async def on_execute(self):
self.lifecycle_called("execute", "before")
yield
self.lifecycle_called("execute", "after")
Expand Down Expand Up @@ -301,3 +336,10 @@ async def conditional_fail(
subscription=Subscription,
extensions=[MyExtension],
)

async_schema = strawberry.Schema(
query=Query,
mutation=Mutation,
subscription=Subscription,
extensions=[MyAsyncExtension],
)
6 changes: 3 additions & 3 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from tests.http.clients.base import DebuggableGraphQLTransportWSMixin

from ..http.clients import HttpClient, WebSocketClient
from ..views.schema import MyExtension
from ..views.schema import MyAsyncExtension


@pytest_asyncio.fixture
Expand Down Expand Up @@ -856,8 +856,8 @@ async def test_extensions(ws: WebSocketClient):
resolve_called = Mock()
lifecycle_called = Mock()

with patch.object(MyExtension, "resolve_called", resolve_called):
with patch.object(MyExtension, "lifecycle_called", lifecycle_called):
with patch.object(MyAsyncExtension, "resolve_called", resolve_called):
with patch.object(MyAsyncExtension, "lifecycle_called", lifecycle_called):
await ws.send_json(
SubscribeMessage(
id="sub1",
Expand Down

0 comments on commit 87fbff0

Please sign in to comment.