diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..7f194a4b85 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: minor + +Starting with this release, WebSocket logic now lives in the base class shared between all HTTP integrations. +This makes the behaviour of WebSockets much more consistent between integrations and easier to maintain. diff --git a/TWEET.md b/TWEET.md new file mode 100644 index 0000000000..0437a68073 --- /dev/null +++ b/TWEET.md @@ -0,0 +1,7 @@ +🚀 Starting with Strawberry $version, WebSocket logic now lives in the base +class shared across all HTTP integrations. More consistent behavior and easier +maintenance for WebSockets across integrations. 🎉 + +Thanks to $contributor for the PR 👏 + +$release_url diff --git a/strawberry/aiohttp/handlers/__init__.py b/strawberry/aiohttp/handlers/__init__.py deleted file mode 100644 index c769c4eec7..0000000000 --- a/strawberry/aiohttp/handlers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from strawberry.aiohttp.handlers.graphql_transport_ws_handler import ( - GraphQLTransportWSHandler, -) -from strawberry.aiohttp.handlers.graphql_ws_handler import GraphQLWSHandler - -__all__ = ["GraphQLTransportWSHandler", "GraphQLWSHandler"] diff --git a/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py b/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py deleted file mode 100644 index 52350199f7..0000000000 --- a/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Dict - -from aiohttp import http, web -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( - BaseGraphQLTransportWSHandler, -) - -if TYPE_CHECKING: - from datetime import timedelta - - from strawberry.schema import BaseSchema - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - connection_init_wait_timeout: timedelta, - get_context: Callable[..., Dict[str, Any]], - get_root_value: Any, - request: web.Request, - ) -> None: - super().__init__(schema, debug, connection_init_wait_timeout) - self._get_context = get_context - self._get_root_value = get_root_value - self._request = request - self._ws = web.WebSocketResponse(protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) - - async def get_context(self) -> Any: - return await self._get_context(request=self._request, response=self._ws) # type: ignore - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._request) - - async def send_json(self, data: dict) -> None: - await self._ws.send_json(data) - - async def close(self, code: int, reason: str) -> None: - await self._ws.close(code=code, message=reason.encode()) - - async def handle_request(self) -> web.StreamResponse: - await self._ws.prepare(self._request) - self.on_request_accepted() - - try: - async for ws_message in self._ws: # type: http.WSMessage - if ws_message.type == http.WSMsgType.TEXT: - await self.handle_message(ws_message.json()) - else: - error_message = "WebSocket message type must be text" - await self.handle_invalid_message(error_message) - finally: - await self.shutdown() - - return self._ws - - -__all__ = ["GraphQLTransportWSHandler"] diff --git a/strawberry/aiohttp/handlers/graphql_ws_handler.py b/strawberry/aiohttp/handlers/graphql_ws_handler.py deleted file mode 100644 index 677dd34884..0000000000 --- a/strawberry/aiohttp/handlers/graphql_ws_handler.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations - -from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Optional - -from aiohttp import http, web -from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler - -if TYPE_CHECKING: - from strawberry.schema import BaseSchema - from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage - - -class GraphQLWSHandler(BaseGraphQLWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - keep_alive: bool, - keep_alive_interval: float, - get_context: Callable, - get_root_value: Callable, - request: web.Request, - ) -> None: - super().__init__(schema, debug, keep_alive, keep_alive_interval) - self._get_context = get_context - self._get_root_value = get_root_value - self._request = request - self._ws = web.WebSocketResponse(protocols=[GRAPHQL_WS_PROTOCOL]) - - async def get_context(self) -> Any: - return await self._get_context(request=self._request, response=self._ws) - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._request) - - async def send_json(self, data: OperationMessage) -> None: - await self._ws.send_json(data) - - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - message = reason.encode() if reason else b"" - await self._ws.close(code=code, message=message) - - async def handle_request(self) -> Any: - await self._ws.prepare(self._request) - - try: - async for ws_message in self._ws: # type: http.WSMessage - if ws_message.type == http.WSMsgType.TEXT: - message: OperationMessage = ws_message.json() - await self.handle_message(message) - else: - await self.close( - code=1002, reason="WebSocket message type must be text" - ) - finally: - if self.keep_alive_task: - self.keep_alive_task.cancel() - with suppress(BaseException): - await self.keep_alive_task - - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - - return self._ws - - -__all__ = ["GraphQLWSHandler"] diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index f2154309f6..884264dcb0 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -4,6 +4,7 @@ import warnings from datetime import timedelta from io import BytesIO +from json.decoder import JSONDecodeError from typing import ( TYPE_CHECKING, Any, @@ -16,15 +17,16 @@ Union, cast, ) +from typing_extensions import TypeGuard -from aiohttp import web +from aiohttp import http, web from aiohttp.multipart import BodyPartReader -from strawberry.aiohttp.handlers import ( - GraphQLTransportWSHandler, - GraphQLWSHandler, +from strawberry.http.async_base_view import ( + AsyncBaseHTTPView, + AsyncHTTPRequestAdapter, + AsyncWebSocketAdapter, ) -from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter -from strawberry.http.exceptions import HTTPException +from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import ( Context, @@ -79,11 +81,36 @@ def content_type(self) -> Optional[str]: return self.headers.get("content-type") +class AioHTTPWebSocketAdapter(AsyncWebSocketAdapter): + def __init__(self, request: web.Request, ws: web.WebSocketResponse) -> None: + self.request = request + self.ws = ws + + async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + async for ws_message in self.ws: + if ws_message.type == http.WSMsgType.TEXT: + try: + yield ws_message.json() + except JSONDecodeError: + raise NonJsonMessageReceived() + + elif ws_message.type == http.WSMsgType.BINARY: + raise NonJsonMessageReceived() + + async def send_json(self, message: Mapping[str, object]) -> None: + await self.ws.send_json(message) + + async def close(self, code: int, reason: str) -> None: + await self.ws.close(code=code, message=reason.encode()) + + class GraphQLView( AsyncBaseHTTPView[ web.Request, Union[web.Response, web.StreamResponse], web.Response, + web.Request, + web.WebSocketResponse, Context, RootValue, ] @@ -92,10 +119,9 @@ class GraphQLView( # bare handler function. _is_coroutine = asyncio.coroutines._is_coroutine # type: ignore[attr-defined] - graphql_transport_ws_handler_class = GraphQLTransportWSHandler - graphql_ws_handler_class = GraphQLWSHandler allow_queries_via_get = True request_adapter_class = AioHTTPRequestAdapter + websocket_adapter_class = AioHTTPWebSocketAdapter def __init__( self, @@ -138,48 +164,36 @@ async def render_graphql_ide(self, request: web.Request) -> web.Response: async def get_sub_response(self, request: web.Request) -> web.Response: return web.Response() - async def __call__(self, request: web.Request) -> web.StreamResponse: + def is_websocket_request(self, request: web.Request) -> TypeGuard[web.Request]: ws = web.WebSocketResponse(protocols=self.subscription_protocols) - ws_test = ws.can_prepare(request) - - if not ws_test.ok: - try: - return await self.run(request=request) - except HTTPException as e: - return web.Response( - body=e.reason, - status=e.status_code, - ) - - if ws_test.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: - return await self.graphql_transport_ws_handler_class( - schema=self.schema, - debug=self.debug, - connection_init_wait_timeout=self.connection_init_wait_timeout, - get_context=self.get_context, # type: ignore - get_root_value=self.get_root_value, - request=request, - ).handle() - elif ws_test.protocol == GRAPHQL_WS_PROTOCOL: - return await self.graphql_ws_handler_class( - schema=self.schema, - debug=self.debug, - keep_alive=self.keep_alive, - keep_alive_interval=self.keep_alive_interval, - get_context=self.get_context, - get_root_value=self.get_root_value, - request=request, - ).handle() - else: - await ws.prepare(request) - await ws.close(code=4406, message=b"Subprotocol not acceptable") - return ws + return ws.can_prepare(request).ok + + async def pick_websocket_subprotocol(self, request: web.Request) -> Optional[str]: + ws = web.WebSocketResponse(protocols=self.subscription_protocols) + return ws.can_prepare(request).protocol + + async def create_websocket_response( + self, request: web.Request, subprotocol: Optional[str] + ) -> web.WebSocketResponse: + protocols = [subprotocol] if subprotocol else [] + ws = web.WebSocketResponse(protocols=protocols) + await ws.prepare(request) + return ws + + async def __call__(self, request: web.Request) -> web.StreamResponse: + try: + return await self.run(request=request) + except HTTPException as e: + return web.Response( + body=e.reason, + status=e.status_code, + ) async def get_root_value(self, request: web.Request) -> Optional[RootValue]: return None async def get_context( - self, request: web.Request, response: web.Response + self, request: web.Request, response: Union[web.Response, web.WebSocketResponse] ) -> Context: return {"request": request, "response": response} # type: ignore diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index d2647c6ee6..d10d207987 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -2,9 +2,11 @@ import warnings from datetime import timedelta +from json import JSONDecodeError from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, AsyncIterator, Callable, Dict, @@ -14,6 +16,7 @@ Union, cast, ) +from typing_extensions import TypeGuard from starlette import status from starlette.requests import Request @@ -23,14 +26,14 @@ Response, StreamingResponse, ) -from starlette.websockets import WebSocket +from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState -from strawberry.asgi.handlers import ( - GraphQLTransportWSHandler, - GraphQLWSHandler, +from strawberry.http.async_base_view import ( + AsyncBaseHTTPView, + AsyncHTTPRequestAdapter, + AsyncWebSocketAdapter, ) -from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter -from strawberry.http.exceptions import HTTPException +from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import ( Context, @@ -78,19 +81,41 @@ async def get_form_data(self) -> FormData: ) +class ASGIWebSocketAdapter(AsyncWebSocketAdapter): + def __init__(self, request: WebSocket, response: WebSocket) -> None: + self.ws = response + + async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + try: + try: + while self.ws.application_state != WebSocketState.DISCONNECTED: + yield await self.ws.receive_json() + except (KeyError, JSONDecodeError): + raise NonJsonMessageReceived() + except WebSocketDisconnect: # pragma: no cover + pass + + async def send_json(self, message: Mapping[str, object]) -> None: + await self.ws.send_json(message) + + async def close(self, code: int, reason: str) -> None: + await self.ws.close(code=code, reason=reason) + + class GraphQL( AsyncBaseHTTPView[ - Union[Request, WebSocket], + Request, Response, Response, + WebSocket, + WebSocket, Context, RootValue, ] ): - graphql_transport_ws_handler_class = GraphQLTransportWSHandler - graphql_ws_handler_class = GraphQLWSHandler allow_queries_via_get = True - request_adapter_class = ASGIRequestAdapter # pyright: ignore + request_adapter_class = ASGIRequestAdapter + websocket_adapter_class = ASGIWebSocketAdapter def __init__( self, @@ -129,51 +154,25 @@ def __init__( async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http": - return await self.handle_http(scope, receive, send) + http_request = Request(scope=scope, receive=receive) - elif scope["type"] == "websocket": - ws = WebSocket(scope, receive=receive, send=send) - preferred_protocol = self.pick_preferred_protocol(ws) - - if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: - await self.graphql_transport_ws_handler_class( - schema=self.schema, - debug=self.debug, - connection_init_wait_timeout=self.connection_init_wait_timeout, - get_context=self.get_context, - get_root_value=self.get_root_value, - ws=ws, - ).handle() - - elif preferred_protocol == GRAPHQL_WS_PROTOCOL: - await self.graphql_ws_handler_class( - schema=self.schema, - debug=self.debug, - keep_alive=self.keep_alive, - keep_alive_interval=self.keep_alive_interval, - get_context=self.get_context, - get_root_value=self.get_root_value, - ws=ws, - ).handle() - - else: - # Subprotocol not acceptable - await ws.close(code=4406) + try: + response = await self.run(http_request) + except HTTPException as e: + response = PlainTextResponse(e.reason, status_code=e.status_code) + await response(scope, receive, send) + elif scope["type"] == "websocket": + ws_request = WebSocket(scope, receive=receive, send=send) + await self.run(ws_request) else: # pragma: no cover raise ValueError("Unknown scope type: {!r}".format(scope["type"])) - def pick_preferred_protocol(self, ws: WebSocket) -> Optional[str]: - protocols = ws["subprotocols"] - intersection = set(protocols) & set(self.protocols) - sorted_intersection = sorted(intersection, key=protocols.index) - return next(iter(sorted_intersection), None) - async def get_root_value(self, request: Union[Request, WebSocket]) -> Optional[Any]: return None async def get_context( - self, request: Union[Request, WebSocket], response: Response + self, request: Union[Request, WebSocket], response: Union[Response, WebSocket] ) -> Context: return {"request": request, "response": response} # type: ignore @@ -187,21 +186,6 @@ async def get_sub_response( return sub_response - async def handle_http( - self, - scope: Scope, - receive: Receive, - send: Send, - ) -> None: - request = Request(scope=scope, receive=receive) - - try: - response = await self.run(request) - except HTTPException as e: - response = PlainTextResponse(e.reason, status_code=e.status_code) # pyright: ignore - - await response(scope, receive, send) - async def render_graphql_ide(self, request: Union[Request, WebSocket]) -> Response: return HTMLResponse(self.graphql_ide_html) @@ -239,3 +223,20 @@ async def create_streaming_response( **headers, }, ) + + def is_websocket_request( + self, request: Union[Request, WebSocket] + ) -> TypeGuard[WebSocket]: + return request.scope["type"] == "websocket" + + async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]: + protocols = request["subprotocols"] + intersection = set(protocols) & set(self.protocols) + sorted_intersection = sorted(intersection, key=protocols.index) + return next(iter(sorted_intersection), None) + + async def create_websocket_response( + self, request: WebSocket, subprotocol: Optional[str] + ) -> WebSocket: + await request.accept(subprotocol=subprotocol) + return request diff --git a/strawberry/asgi/handlers/__init__.py b/strawberry/asgi/handlers/__init__.py deleted file mode 100644 index 1891a06ad0..0000000000 --- a/strawberry/asgi/handlers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from strawberry.asgi.handlers.graphql_transport_ws_handler import ( - GraphQLTransportWSHandler, -) -from strawberry.asgi.handlers.graphql_ws_handler import GraphQLWSHandler - -__all__ = ["GraphQLTransportWSHandler", "GraphQLWSHandler"] diff --git a/strawberry/asgi/handlers/graphql_transport_ws_handler.py b/strawberry/asgi/handlers/graphql_transport_ws_handler.py deleted file mode 100644 index 7cec132ffd..0000000000 --- a/strawberry/asgi/handlers/graphql_transport_ws_handler.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable - -from starlette.websockets import WebSocketDisconnect, WebSocketState - -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( - BaseGraphQLTransportWSHandler, -) - -if TYPE_CHECKING: - from datetime import timedelta - - from starlette.websockets import WebSocket - - from strawberry.schema import BaseSchema - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - connection_init_wait_timeout: timedelta, - get_context: Callable, - get_root_value: Callable, - ws: WebSocket, - ) -> None: - super().__init__(schema, debug, connection_init_wait_timeout) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context(request=self._ws, response=None) - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._ws) - - async def send_json(self, data: dict) -> None: - await self._ws.send_json(data) - - async def close(self, code: int, reason: str) -> None: - await self._ws.close(code=code, reason=reason) - - async def handle_request(self) -> None: - await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL) - self.on_request_accepted() - - try: - while self._ws.application_state != WebSocketState.DISCONNECTED: - try: - message = await self._ws.receive_json() - except KeyError: # noqa: PERF203 - error_message = "WebSocket message type must be text" - await self.handle_invalid_message(error_message) - else: - await self.handle_message(message) - except WebSocketDisconnect: # pragma: no cover - pass - finally: - await self.shutdown() - - -__all__ = ["GraphQLTransportWSHandler"] diff --git a/strawberry/asgi/handlers/graphql_ws_handler.py b/strawberry/asgi/handlers/graphql_ws_handler.py deleted file mode 100644 index 00a314bbd0..0000000000 --- a/strawberry/asgi/handlers/graphql_ws_handler.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Optional - -from starlette.websockets import WebSocketDisconnect, WebSocketState - -from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler - -if TYPE_CHECKING: - from starlette.websockets import WebSocket - - from strawberry.schema import BaseSchema - from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage - - -class GraphQLWSHandler(BaseGraphQLWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - keep_alive: bool, - keep_alive_interval: float, - get_context: Callable, - get_root_value: Callable, - ws: WebSocket, - ) -> None: - super().__init__(schema, debug, keep_alive, keep_alive_interval) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context(request=self._ws, response=None) - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._ws) - - async def send_json(self, data: OperationMessage) -> None: - await self._ws.send_json(data) - - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - await self._ws.close(code=code, reason=reason) - - async def handle_request(self) -> Any: - await self._ws.accept(subprotocol=GRAPHQL_WS_PROTOCOL) - - try: - while self._ws.application_state != WebSocketState.DISCONNECTED: - try: - message = await self._ws.receive_json() - except KeyError: # noqa: PERF203 - await self.close( - code=1002, reason="WebSocket message type must be text" - ) - else: - await self.handle_message(message) - except WebSocketDisconnect: # pragma: no cover - pass - finally: - if self.keep_alive_task: - self.keep_alive_task.cancel() - with suppress(BaseException): - await self.keep_alive_task - - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - - -__all__ = ["GraphQLWSHandler"] diff --git a/strawberry/channels/__init__.py b/strawberry/channels/__init__.py index 455513babb..f67fb25a82 100644 --- a/strawberry/channels/__init__.py +++ b/strawberry/channels/__init__.py @@ -1,6 +1,4 @@ -from .handlers.base import ChannelsConsumer, ChannelsWSConsumer -from .handlers.graphql_transport_ws_handler import GraphQLTransportWSHandler -from .handlers.graphql_ws_handler import GraphQLWSHandler +from .handlers.base import ChannelsConsumer from .handlers.http_handler import ( ChannelsRequest, GraphQLHTTPConsumer, @@ -12,10 +10,7 @@ __all__ = [ "ChannelsConsumer", "ChannelsRequest", - "ChannelsWSConsumer", "GraphQLProtocolTypeRouter", - "GraphQLWSHandler", - "GraphQLTransportWSHandler", "GraphQLHTTPConsumer", "GraphQLWSConsumer", "SyncGraphQLHTTPConsumer", diff --git a/strawberry/channels/handlers/base.py b/strawberry/channels/handlers/base.py index ec2ffe6b2c..769ec569e5 100644 --- a/strawberry/channels/handlers/base.py +++ b/strawberry/channels/handlers/base.py @@ -16,7 +16,7 @@ from weakref import WeakSet from channels.consumer import AsyncConsumer -from channels.generic.websocket import AsyncJsonWebsocketConsumer +from channels.generic.websocket import AsyncWebsocketConsumer class ChannelsMessage(TypedDict, total=False): @@ -210,7 +210,7 @@ async def _listen_to_channel_generator( return -class ChannelsWSConsumer(ChannelsConsumer, AsyncJsonWebsocketConsumer): +class ChannelsWSConsumer(ChannelsConsumer, AsyncWebsocketConsumer): """Base channels websocket async consumer.""" diff --git a/strawberry/channels/handlers/graphql_transport_ws_handler.py b/strawberry/channels/handlers/graphql_transport_ws_handler.py deleted file mode 100644 index db290f4ef8..0000000000 --- a/strawberry/channels/handlers/graphql_transport_ws_handler.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Optional - -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( - BaseGraphQLTransportWSHandler, -) - -if TYPE_CHECKING: - from datetime import timedelta - - from strawberry.channels.handlers.base import ChannelsWSConsumer - from strawberry.schema import BaseSchema - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - connection_init_wait_timeout: timedelta, - get_context: Callable, - get_root_value: Callable, - ws: ChannelsWSConsumer, - ) -> None: - super().__init__(schema, debug, connection_init_wait_timeout) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context( - request=self._ws, connection_params=self.connection_params - ) - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._ws) - - async def send_json(self, data: dict) -> None: - await self._ws.send_json(data) - - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - # TODO: We are using `self._ws.base_send` directly instead of `self._ws.close` - # because the later doesn't accept the `reason` argument. - await self._ws.base_send( - { - "type": "websocket.close", - "code": code, - "reason": reason or "", - } - ) - - async def handle_request(self) -> Any: - await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL) - self.on_request_accepted() - - async def handle_disconnect(self, code: int) -> None: - await self.shutdown() - - -__all__ = ["GraphQLTransportWSHandler"] diff --git a/strawberry/channels/handlers/graphql_ws_handler.py b/strawberry/channels/handlers/graphql_ws_handler.py deleted file mode 100644 index 6d967a1d15..0000000000 --- a/strawberry/channels/handlers/graphql_ws_handler.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Optional - -from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler - -if TYPE_CHECKING: - from strawberry.channels.handlers.base import ChannelsWSConsumer - from strawberry.schema import BaseSchema - from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage - - -class GraphQLWSHandler(BaseGraphQLWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - keep_alive: bool, - keep_alive_interval: float, - get_context: Callable, - get_root_value: Callable, - ws: ChannelsWSConsumer, - ) -> None: - super().__init__(schema, debug, keep_alive, keep_alive_interval) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context( - request=self._ws, connection_params=self.connection_params - ) - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._ws) - - async def send_json(self, data: OperationMessage) -> None: - await self._ws.send_json(data) - - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - # TODO: We are using `self._ws.base_send` directly instead of `self._ws.close` - # because the latler doesn't accept the `reason` argument. - await self._ws.base_send( - { - "type": "websocket.close", - "code": code, - "reason": reason or "", - } - ) - - async def handle_request(self) -> Any: - await self._ws.accept(subprotocol=GRAPHQL_WS_PROTOCOL) - - async def handle_disconnect(self, code: int) -> None: - if self.keep_alive_task: - self.keep_alive_task.cancel() - with suppress(BaseException): - await self.keep_alive_task - - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - - async def handle_invalid_message(self, error_message: str) -> None: - # This is not part of the BaseGraphQLWSHandler's interface, but the - # channels integration is a high level wrapper that forwards this to - # both us and the BaseGraphQLTransportWSHandler. - await self.close(code=1002, reason=error_message) - - -__all__ = ["GraphQLWSHandler"] diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index c7264a45b3..8d682eea74 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -15,7 +15,7 @@ Optional, Union, ) -from typing_extensions import assert_never +from typing_extensions import TypeGuard, assert_never from urllib.parse import parse_qs from django.conf import settings @@ -233,6 +233,8 @@ class GraphQLHTTPConsumer( ChannelsRequest, Union[ChannelsResponse, MultipartChannelsResponse], TemporalResponse, + ChannelsRequest, + TemporalResponse, Context, RootValue, ], @@ -298,6 +300,21 @@ async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse content=self.graphql_ide_html.encode(), content_type="text/html" ) + def is_websocket_request( + self, request: ChannelsRequest + ) -> TypeGuard[ChannelsRequest]: + return False + + async def pick_websocket_subprotocol( + self, request: ChannelsRequest + ) -> Optional[str]: + return None + + async def create_websocket_response( + self, request: ChannelsRequest, subprotocol: Optional[str] + ) -> TemporalResponse: + raise NotImplementedError + class SyncGraphQLHTTPConsumer( BaseGraphQLHTTPConsumer, diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py index 2991059afd..b267f7ea9b 100644 --- a/strawberry/channels/handlers/ws_handler.py +++ b/strawberry/channels/handlers/ws_handler.py @@ -1,20 +1,76 @@ from __future__ import annotations +import asyncio import datetime -from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Union - +import json +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + Dict, + Mapping, + Optional, + Tuple, + TypedDict, + Union, +) +from typing_extensions import TypeGuard + +from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter +from strawberry.http.exceptions import NonJsonMessageReceived +from strawberry.http.typevars import Context, RootValue from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL -from .base import ChannelsConsumer, ChannelsWSConsumer -from .graphql_transport_ws_handler import GraphQLTransportWSHandler -from .graphql_ws_handler import GraphQLWSHandler +from .base import ChannelsWSConsumer if TYPE_CHECKING: - from strawberry.http.typevars import Context, RootValue + from strawberry.http import GraphQLHTTPResponse from strawberry.schema import BaseSchema -class GraphQLWSConsumer(ChannelsWSConsumer): +class ChannelsWebSocketAdapter(AsyncWebSocketAdapter): + def __init__(self, request: GraphQLWSConsumer, response: GraphQLWSConsumer) -> None: + self.ws_consumer = response + + async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + while True: + message = await self.ws_consumer.message_queue.get() + + if message["disconnected"]: + break + + if message["message"] is None: + raise NonJsonMessageReceived() + + try: + yield json.loads(message["message"]) + except json.JSONDecodeError: + raise NonJsonMessageReceived() + + async def send_json(self, message: Mapping[str, object]) -> None: + serialized_message = json.dumps(message) + await self.ws_consumer.send(serialized_message) + + async def close(self, code: int, reason: str) -> None: + await self.ws_consumer.close(code=code, reason=reason) + + +class MessageQueueData(TypedDict): + message: Union[str, None] + disconnected: bool + + +class GraphQLWSConsumer( + ChannelsWSConsumer, + AsyncBaseHTTPView[ + "GraphQLWSConsumer", + "GraphQLWSConsumer", + "GraphQLWSConsumer", + "GraphQLWSConsumer", + "GraphQLWSConsumer", + Context, + RootValue, + ], +): """A channels websocket consumer for GraphQL. This handles the connections, then hands off to the appropriate @@ -39,9 +95,7 @@ class GraphQLWSConsumer(ChannelsWSConsumer): ``` """ - graphql_transport_ws_handler_class = GraphQLTransportWSHandler - graphql_ws_handler_class = GraphQLWSHandler - _handler: Union[GraphQLWSHandler, GraphQLTransportWSHandler] + websocket_adapter_class = ChannelsWebSocketAdapter def __init__( self, @@ -63,70 +117,71 @@ def __init__( self.keep_alive_interval = keep_alive_interval self.debug = debug self.protocols = subscription_protocols + self.message_queue: asyncio.Queue[MessageQueueData] = asyncio.Queue() + self.run_task: Optional[asyncio.Task] = None super().__init__() - def pick_preferred_protocol( - self, accepted_subprotocols: Sequence[str] - ) -> Optional[str]: - intersection = set(accepted_subprotocols) & set(self.protocols) - sorted_intersection = sorted(intersection, key=accepted_subprotocols.index) - return next(iter(sorted_intersection), None) - async def connect(self) -> None: - preferred_protocol = self.pick_preferred_protocol(self.scope["subprotocols"]) - - if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: - self._handler = self.graphql_transport_ws_handler_class( - schema=self.schema, - debug=self.debug, - connection_init_wait_timeout=self.connection_init_wait_timeout, - get_context=self.get_context, - get_root_value=self.get_root_value, - ws=self, - ) - elif preferred_protocol == GRAPHQL_WS_PROTOCOL: - self._handler = self.graphql_ws_handler_class( - schema=self.schema, - debug=self.debug, - keep_alive=self.keep_alive, - keep_alive_interval=self.keep_alive_interval, - get_context=self.get_context, - get_root_value=self.get_root_value, - ws=self, - ) - else: - # Subprotocol not acceptable - return await self.close(code=4406) + self.run_task = asyncio.create_task(self.run(self)) - await self._handler.handle() - return None - - async def receive(self, *args: str, **kwargs: Any) -> None: - # Overriding this so that we can pass the errors to handle_invalid_message - try: - await super().receive(*args, **kwargs) - except ValueError: - reason = "WebSocket message type must be text" - await self._handler.handle_invalid_message(reason) - - async def receive_json(self, content: Any, **kwargs: Any) -> None: - await self._handler.handle_message(content) + async def receive( + self, text_data: Optional[str] = None, bytes_data: Optional[bytes] = None + ) -> None: + if text_data: + self.message_queue.put_nowait({"message": text_data, "disconnected": False}) + else: + self.message_queue.put_nowait({"message": None, "disconnected": False}) async def disconnect(self, code: int) -> None: - await self._handler.handle_disconnect(code) + self.message_queue.put_nowait({"message": None, "disconnected": True}) + assert self.run_task + await self.run_task - async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue]: + async def get_root_value(self, request: GraphQLWSConsumer) -> Optional[RootValue]: return None async def get_context( - self, request: ChannelsConsumer, connection_params: Any + self, request: GraphQLWSConsumer, response: GraphQLWSConsumer ) -> Context: return { "request": request, - "connection_params": connection_params, "ws": request, } # type: ignore + @property + def allow_queries_via_get(self) -> bool: + return False + + async def get_sub_response(self, request: GraphQLWSConsumer) -> GraphQLWSConsumer: + raise NotImplementedError + + def create_response( + self, response_data: GraphQLHTTPResponse, sub_response: GraphQLWSConsumer + ) -> GraphQLWSConsumer: + raise NotImplementedError + + async def render_graphql_ide(self, request: GraphQLWSConsumer) -> GraphQLWSConsumer: + raise NotImplementedError + + def is_websocket_request( + self, request: GraphQLWSConsumer + ) -> TypeGuard[GraphQLWSConsumer]: + return True + + async def pick_websocket_subprotocol( + self, request: GraphQLWSConsumer + ) -> Optional[str]: + protocols = request.scope["subprotocols"] + intersection = set(protocols) & set(self.protocols) + sorted_intersection = sorted(intersection, key=protocols.index) + return next(iter(sorted_intersection), None) + + async def create_websocket_response( + self, request: GraphQLWSConsumer, subprotocol: Optional[str] + ) -> GraphQLWSConsumer: + await request.accept(subprotocol=subprotocol) + return request + __all__ = ["GraphQLWSConsumer"] diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 132c822f78..457314d93b 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -13,6 +13,7 @@ Union, cast, ) +from typing_extensions import TypeGuard from asgiref.sync import markcoroutinefunction from django.core.serializers.json import DjangoJSONEncoder @@ -258,7 +259,13 @@ def render_graphql_ide(self, request: HttpRequest) -> HttpResponse: class AsyncGraphQLView( BaseView, AsyncBaseHTTPView[ - HttpRequest, HttpResponseBase, TemporalHttpResponse, Context, RootValue + HttpRequest, + HttpResponseBase, + TemporalHttpResponse, + HttpRequest, + TemporalHttpResponse, + Context, + RootValue, ], View, ): @@ -312,5 +319,16 @@ async def render_graphql_ide(self, request: HttpRequest) -> HttpResponse: return response + def is_websocket_request(self, request: HttpRequest) -> TypeGuard[HttpRequest]: + return False + + async def pick_websocket_subprotocol(self, request: HttpRequest) -> Optional[str]: + raise NotImplementedError + + async def create_websocket_response( + self, request: HttpRequest, subprotocol: Optional[str] + ) -> TemporalHttpResponse: + raise NotImplementedError + __all__ = ["GraphQLView", "AsyncGraphQLView"] diff --git a/strawberry/fastapi/handlers/__init__.py b/strawberry/fastapi/handlers/__init__.py deleted file mode 100644 index 20f336f5ff..0000000000 --- a/strawberry/fastapi/handlers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from strawberry.fastapi.handlers.graphql_transport_ws_handler import ( - GraphQLTransportWSHandler, -) -from strawberry.fastapi.handlers.graphql_ws_handler import GraphQLWSHandler - -__all__ = ["GraphQLTransportWSHandler", "GraphQLWSHandler"] diff --git a/strawberry/fastapi/handlers/graphql_transport_ws_handler.py b/strawberry/fastapi/handlers/graphql_transport_ws_handler.py deleted file mode 100644 index 817f6996ac..0000000000 --- a/strawberry/fastapi/handlers/graphql_transport_ws_handler.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any - -from strawberry.asgi.handlers import ( - GraphQLTransportWSHandler as BaseGraphQLTransportWSHandler, -) -from strawberry.fastapi.context import BaseContext - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - async def get_context(self) -> Any: - context = await self._get_context() - if isinstance(context, BaseContext): - context.connection_params = self.connection_params - return context - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - -__all__ = ["GraphQLTransportWSHandler"] diff --git a/strawberry/fastapi/handlers/graphql_ws_handler.py b/strawberry/fastapi/handlers/graphql_ws_handler.py deleted file mode 100644 index 0c43bbbd6e..0000000000 --- a/strawberry/fastapi/handlers/graphql_ws_handler.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Any - -from strawberry.asgi.handlers import GraphQLWSHandler as BaseGraphQLWSHandler -from strawberry.fastapi.context import BaseContext - - -class GraphQLWSHandler(BaseGraphQLWSHandler): - async def get_context(self) -> Any: - context = await self._get_context() - if isinstance(context, BaseContext): - context.connection_params = self.connection_params - return context - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - -__all__ = ["GraphQLWSHandler"] diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index badcfa33e0..3ed8e6a4a0 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -17,6 +17,7 @@ Union, cast, ) +from typing_extensions import TypeGuard from starlette import status from starlette.background import BackgroundTasks # noqa: TCH002 @@ -34,10 +35,9 @@ from fastapi.datastructures import Default from fastapi.routing import APIRoute from fastapi.utils import generate_unique_id -from strawberry.asgi import ASGIRequestAdapter +from strawberry.asgi import ASGIRequestAdapter, ASGIWebSocketAdapter from strawberry.exceptions import InvalidCustomContext from strawberry.fastapi.context import BaseContext, CustomContext -from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.http import process_result from strawberry.http.async_base_view import AsyncBaseHTTPView from strawberry.http.exceptions import HTTPException @@ -58,12 +58,14 @@ class GraphQLRouter( - AsyncBaseHTTPView[Request, Response, Response, Context, RootValue], APIRouter + AsyncBaseHTTPView[ + Request, Response, Response, WebSocket, WebSocket, Context, RootValue + ], + APIRouter, ): - graphql_ws_handler_class = GraphQLWSHandler - graphql_transport_ws_handler_class = GraphQLTransportWSHandler allow_queries_via_get = True request_adapter_class = ASGIRequestAdapter + websocket_adapter_class = ASGIWebSocketAdapter @staticmethod async def __get_root_value() -> None: @@ -261,44 +263,7 @@ async def websocket_endpoint( # pyright: ignore context: Context = Depends(self.context_getter), root_value: RootValue = Depends(self.root_value_getter), ) -> None: - async def _get_context() -> Context: - return context - - async def _get_root_value() -> RootValue: - return root_value - - preferred_protocol = self.pick_preferred_protocol(websocket) - if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: - await self.graphql_transport_ws_handler_class( - schema=self.schema, - debug=self.debug, - connection_init_wait_timeout=self.connection_init_wait_timeout, - get_context=_get_context, - get_root_value=_get_root_value, - ws=websocket, - ).handle() - elif preferred_protocol == GRAPHQL_WS_PROTOCOL: - await self.graphql_ws_handler_class( - schema=self.schema, - debug=self.debug, - keep_alive=self.keep_alive, - keep_alive_interval=self.keep_alive_interval, - get_context=_get_context, - get_root_value=_get_root_value, - ws=websocket, - ).handle() - else: - # Code 4406 is "Subprotocol not acceptable" - await websocket.close(code=4406) - - def pick_preferred_protocol(self, ws: WebSocket) -> Optional[str]: - protocols = ws["subprotocols"] - intersection = set(protocols) & set(self.protocols) - return min( - intersection, - key=lambda i: protocols.index(i), - default=None, - ) + await self.run(request=websocket, context=context, root_value=root_value) async def render_graphql_ide(self, request: Request) -> HTMLResponse: return HTMLResponse(self.graphql_ide_html) @@ -309,12 +274,12 @@ async def process_result( return process_result(result) async def get_context( - self, request: Request, response: Response + self, request: Union[Request, WebSocket], response: Union[Response, WebSocket] ) -> Context: # pragma: no cover raise ValueError("`get_context` is not used by FastAPI GraphQL Router") async def get_root_value( - self, request: Request + self, request: Union[Request, WebSocket] ) -> Optional[RootValue]: # pragma: no cover raise ValueError("`get_root_value` is not used by FastAPI GraphQL Router") @@ -350,5 +315,22 @@ async def create_streaming_response( }, ) + def is_websocket_request( + self, request: Union[Request, WebSocket] + ) -> TypeGuard[WebSocket]: + return request.scope["type"] == "websocket" + + async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]: + protocols = request["subprotocols"] + intersection = set(protocols) & set(self.protocols) + sorted_intersection = sorted(intersection, key=protocols.index) + return next(iter(sorted_intersection), None) + + async def create_websocket_response( + self, request: WebSocket, subprotocol: Optional[str] + ) -> WebSocket: + await request.accept(subprotocol=subprotocol) + return request + __all__ = ["GraphQLRouter"] diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index d952eb6aa9..2dc15d6d6c 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -9,6 +9,7 @@ Union, cast, ) +from typing_extensions import TypeGuard from flask import Request, Response, render_template_string, request from flask.views import View @@ -159,7 +160,9 @@ async def get_form_data(self) -> FormData: class AsyncGraphQLView( BaseGraphQLView, - AsyncBaseHTTPView[Request, Response, Response, Context, RootValue], + AsyncBaseHTTPView[ + Request, Response, Response, Request, Response, Context, RootValue + ], View, ): methods = ["GET", "POST"] @@ -187,6 +190,17 @@ async def dispatch_request(self) -> ResponseReturnValue: # type: ignore async def render_graphql_ide(self, request: Request) -> Response: return render_template_string(self.graphql_ide_html) # type: ignore + def is_websocket_request(self, request: Request) -> TypeGuard[Request]: + return False + + async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: + raise NotImplementedError + + async def create_websocket_response( + self, request: Request, subprotocol: Optional[str] + ) -> Response: + raise NotImplementedError + __all__ = [ "GraphQLView", diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index c9f1e6ae49..a7666018ef 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -2,6 +2,7 @@ import asyncio import contextlib import json +from datetime import timedelta from typing import ( Any, AsyncGenerator, @@ -13,8 +14,10 @@ Optional, Tuple, Union, + cast, + overload, ) -from typing_extensions import Literal +from typing_extensions import Literal, TypeGuard from graphql import GraphQLError @@ -29,6 +32,11 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.schema.base import BaseSchema from strawberry.schema.exceptions import InvalidOperationTypeError +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL +from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( + BaseGraphQLTransportWSHandler, +) +from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler from strawberry.types import ExecutionResult, SubscriptionExecutionResult from strawberry.types.graphql import OperationType @@ -36,7 +44,15 @@ from .exceptions import HTTPException from .parse_content_type import parse_content_type from .types import FormData, HTTPMethod, QueryParams -from .typevars import Context, Request, Response, RootValue, SubResponse +from .typevars import ( + Context, + Request, + Response, + RootValue, + SubResponse, + WebSocketRequest, + WebSocketResponse, +) class AsyncHTTPRequestAdapter(abc.ABC): @@ -63,14 +79,42 @@ async def get_body(self) -> Union[str, bytes]: ... async def get_form_data(self) -> FormData: ... +class AsyncWebSocketAdapter(abc.ABC): + @abc.abstractmethod + def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: ... + + @abc.abstractmethod + async def send_json(self, message: Mapping[str, object]) -> None: ... + + @abc.abstractmethod + async def close(self, code: int, reason: str) -> None: ... + + class AsyncBaseHTTPView( abc.ABC, BaseView[Request], - Generic[Request, Response, SubResponse, Context, RootValue], + Generic[ + Request, + Response, + SubResponse, + WebSocketRequest, + WebSocketResponse, + Context, + RootValue, + ], ): schema: BaseSchema graphql_ide: Optional[GraphQL_IDE] + debug: bool + keep_alive = False + keep_alive_interval: Optional[float] = None + connection_init_wait_timeout: timedelta = timedelta(minutes=1) request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter] + websocket_adapter_class: Callable[ + [WebSocketRequest, WebSocketResponse], AsyncWebSocketAdapter + ] + graphql_transport_ws_handler_class = BaseGraphQLTransportWSHandler + graphql_ws_handler_class = BaseGraphQLWSHandler @property @abc.abstractmethod @@ -80,10 +124,16 @@ def allow_queries_via_get(self) -> bool: ... async def get_sub_response(self, request: Request) -> SubResponse: ... @abc.abstractmethod - async def get_context(self, request: Request, response: SubResponse) -> Context: ... + async def get_context( + self, + request: Union[Request, WebSocketRequest], + response: Union[SubResponse, WebSocketResponse], + ) -> Context: ... @abc.abstractmethod - async def get_root_value(self, request: Request) -> Optional[RootValue]: ... + async def get_root_value( + self, request: Union[Request, WebSocketRequest] + ) -> Optional[RootValue]: ... @abc.abstractmethod def create_response( @@ -102,6 +152,21 @@ async def create_streaming_response( ) -> Response: raise ValueError("Multipart responses are not supported") + @abc.abstractmethod + def is_websocket_request( + self, request: Union[Request, WebSocketRequest] + ) -> TypeGuard[WebSocketRequest]: ... + + @abc.abstractmethod + async def pick_websocket_subprotocol( + self, request: WebSocketRequest + ) -> Optional[str]: ... + + @abc.abstractmethod + async def create_websocket_response( + self, request: WebSocketRequest, subprotocol: Optional[str] + ) -> WebSocketResponse: ... + async def execute_operation( self, request: Request, context: Context, root_value: Optional[RootValue] ) -> Union[ExecutionResult, SubscriptionExecutionResult]: @@ -167,35 +232,90 @@ def _handle_errors( ) -> None: """Hook to allow custom handling of errors, used by the Sentry Integration.""" + @overload async def run( self, request: Request, context: Optional[Context] = UNSET, root_value: Optional[RootValue] = UNSET, - ) -> Response: - request_adapter = self.request_adapter_class(request) + ) -> Response: ... - if not self.is_request_allowed(request_adapter): - raise HTTPException(405, "GraphQL only supports GET and POST requests.") + @overload + async def run( + self, + request: WebSocketRequest, + context: Optional[Context] = UNSET, + root_value: Optional[RootValue] = UNSET, + ) -> WebSocketResponse: ... - if self.should_render_graphql_ide(request_adapter): - if self.graphql_ide: - return await self.render_graphql_ide(request) + async def run( + self, + request: Union[Request, WebSocketRequest], + context: Optional[Context] = UNSET, + root_value: Optional[RootValue] = UNSET, + ) -> Union[Response, WebSocketResponse]: + root_value = ( + await self.get_root_value(request) if root_value is UNSET else root_value + ) + + if self.is_websocket_request(request): + websocket_subprotocol = await self.pick_websocket_subprotocol(request) + websocket_response = await self.create_websocket_response( + request, websocket_subprotocol + ) + websocket = self.websocket_adapter_class(request, websocket_response) + + context = ( + await self.get_context(request, response=websocket_response) + if context is UNSET + else context + ) + + if websocket_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: + await self.graphql_transport_ws_handler_class( + websocket=websocket, + context=context, + root_value=root_value, + schema=self.schema, + debug=self.debug, + connection_init_wait_timeout=self.connection_init_wait_timeout, + ).handle() + elif websocket_subprotocol == GRAPHQL_WS_PROTOCOL: + await self.graphql_ws_handler_class( + websocket=websocket, + context=context, + root_value=root_value, + schema=self.schema, + debug=self.debug, + keep_alive=self.keep_alive, + keep_alive_interval=self.keep_alive_interval, + ).handle() else: - raise HTTPException(404, "Not Found") + await websocket.close(4406, "Subprotocol not acceptable") + + return websocket_response + else: + request = cast(Request, request) + request_adapter = self.request_adapter_class(request) sub_response = await self.get_sub_response(request) context = ( await self.get_context(request, response=sub_response) if context is UNSET else context ) - root_value = ( - await self.get_root_value(request) if root_value is UNSET else root_value - ) assert context + if not self.is_request_allowed(request_adapter): + raise HTTPException(405, "GraphQL only supports GET and POST requests.") + + if self.should_render_graphql_ide(request_adapter): + if self.graphql_ide: + return await self.render_graphql_ide(request) + else: + raise HTTPException(404, "Not Found") + try: result = await self.execute_operation( request=request, context=context, root_value=root_value diff --git a/strawberry/http/exceptions.py b/strawberry/http/exceptions.py index d934696806..feddf77631 100644 --- a/strawberry/http/exceptions.py +++ b/strawberry/http/exceptions.py @@ -4,4 +4,8 @@ def __init__(self, status_code: int, reason: str) -> None: self.reason = reason +class NonJsonMessageReceived(Exception): + pass + + __all__ = ["HTTPException"] diff --git a/strawberry/http/typevars.py b/strawberry/http/typevars.py index a48cba848e..53a5d5ac33 100644 --- a/strawberry/http/typevars.py +++ b/strawberry/http/typevars.py @@ -3,8 +3,18 @@ Request = TypeVar("Request", contravariant=True) Response = TypeVar("Response") SubResponse = TypeVar("SubResponse") +WebSocketRequest = TypeVar("WebSocketRequest") +WebSocketResponse = TypeVar("WebSocketResponse") Context = TypeVar("Context") RootValue = TypeVar("RootValue") -__all__ = ["Request", "Response", "SubResponse", "Context", "RootValue"] +__all__ = [ + "Request", + "Response", + "SubResponse", + "WebSocketRequest", + "WebSocketResponse", + "Context", + "RootValue", +] diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index e5c27ffe87..fed3d2d45f 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -7,19 +7,19 @@ from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, AsyncIterator, Callable, Dict, FrozenSet, - List, Optional, - Set, Tuple, Type, TypedDict, Union, cast, ) +from typing_extensions import TypeGuard from msgspec import Struct @@ -35,23 +35,24 @@ ) from litestar.background_tasks import BackgroundTasks from litestar.di import Provide -from litestar.exceptions import NotFoundException, ValidationException +from litestar.exceptions import ( + NotFoundException, + SerializationException, + ValidationException, + WebSocketDisconnect, +) from litestar.response.streaming import Stream from litestar.status_codes import HTTP_200_OK from strawberry.exceptions import InvalidCustomContext -from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter -from strawberry.http.exceptions import HTTPException +from strawberry.http.async_base_view import ( + AsyncBaseHTTPView, + AsyncHTTPRequestAdapter, + AsyncWebSocketAdapter, +) +from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import Context, RootValue from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws import ( - WS_4406_PROTOCOL_NOT_ACCEPTABLE, -) - -from .handlers.graphql_transport_ws_handler import ( - GraphQLTransportWSHandler as BaseGraphQLTransportWSHandler, -) -from .handlers.graphql_ws_handler import GraphQLWSHandler as BaseGraphQLWSHandler if TYPE_CHECKING: from collections.abc import Mapping @@ -152,22 +153,6 @@ class GraphQLResource(Struct): extensions: Optional[dict[str, object]] -class GraphQLWSHandler(BaseGraphQLWSHandler): - async def get_context(self) -> Any: - return await self._get_context() - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - async def get_context(self) -> Any: - return await self._get_context() - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - class LitestarRequestAdapter(AsyncHTTPRequestAdapter): def __init__(self, request: Request[Any, Any, Any]) -> None: self.request = request @@ -203,10 +188,37 @@ async def get_form_data(self) -> FormData: return FormData(form=multipart_data, files=multipart_data) +class LitestarWebSocketAdapter(AsyncWebSocketAdapter): + def __init__(self, request: WebSocket, response: WebSocket) -> None: + self.ws = response + + async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + try: + try: + while self.ws.connection_state != "disconnect": + yield await self.ws.receive_json() + except (SerializationException, ValueError): + raise NonJsonMessageReceived() + except WebSocketDisconnect: + pass + + async def send_json(self, message: Mapping[str, object]) -> None: + await self.ws.send_json(message) + + async def close(self, code: int, reason: str) -> None: + await self.ws.close(code=code, reason=reason) + + class GraphQLController( Controller, AsyncBaseHTTPView[ - Request[Any, Any, Any], Response[Any], Response[Any], Context, RootValue + Request[Any, Any, Any], + Response[Any], + Response[Any], + WebSocket, + WebSocket, + Context, + RootValue, ], ): path: str = "" @@ -219,10 +231,7 @@ class GraphQLController( } request_adapter_class = LitestarRequestAdapter - graphql_ws_handler_class: Type[GraphQLWSHandler] = GraphQLWSHandler - graphql_transport_ws_handler_class: Type[GraphQLTransportWSHandler] = ( - GraphQLTransportWSHandler - ) + websocket_adapter_class = LitestarWebSocketAdapter allow_queries_via_get: bool = True graphiql_allowed_accept: FrozenSet[str] = frozenset({"text/html", "*/*"}) @@ -236,6 +245,23 @@ class GraphQLController( keep_alive: bool = False keep_alive_interval: float = 1 + def is_websocket_request( + self, request: Union[Request, WebSocket] + ) -> TypeGuard[WebSocket]: + return isinstance(request, WebSocket) + + async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]: + subprotocols = request.scope["subprotocols"] + intersection = set(subprotocols) & set(self.protocols) + sorted_intersection = sorted(intersection, key=subprotocols.index) + return next(iter(sorted_intersection), None) + + async def create_websocket_response( + self, request: WebSocket, subprotocol: Optional[str] + ) -> WebSocket: + await request.accept(subprotocols=subprotocol) + return request + async def execute_request( self, request: Request[Any, Any, Any], @@ -245,8 +271,6 @@ async def execute_request( try: return await self.run( request, - # TODO: check the dependency, above, can we make it so that - # we don't need to type ignore here? context=context, root_value=root_value, ) @@ -328,14 +352,29 @@ async def handle_http_post( root_value=root_value, ) + @websocket() + async def websocket_endpoint( + self, + socket: WebSocket, + context_ws: Any, + root_value: Any, + ) -> None: + await self.run( + request=socket, + context=context_ws, + root_value=root_value, + ) + async def get_context( - self, request: Request[Any, Any, Any], response: Response + self, + request: Union[Request[Any, Any, Any], WebSocket], + response: Union[Response, WebSocket], ) -> Context: # pragma: no cover msg = "`get_context` is not used by Litestar's controller" raise ValueError(msg) async def get_root_value( - self, request: Request[Any, Any, Any] + self, request: Union[Request[Any, Any, Any], WebSocket] ) -> RootValue | None: # pragma: no cover msg = "`get_root_value` is not used by Litestar's controller" raise ValueError(msg) @@ -343,54 +382,6 @@ async def get_root_value( async def get_sub_response(self, request: Request[Any, Any, Any]) -> Response: return self.temporal_response - @websocket() - async def websocket_endpoint( - self, - socket: WebSocket, - context_ws: Any, - root_value: Any, - ) -> None: - async def _get_context() -> Any: - return context_ws - - async def _get_root_value() -> Any: - return root_value - - preferred_protocol = self.pick_preferred_protocol(socket) - if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: - await self.graphql_transport_ws_handler_class( - schema=self.schema, - debug=self.debug, - connection_init_wait_timeout=self.connection_init_wait_timeout, - get_context=_get_context, - get_root_value=_get_root_value, - ws=socket, - ).handle() - elif preferred_protocol == GRAPHQL_WS_PROTOCOL: - await self.graphql_ws_handler_class( - schema=self.schema, - debug=self.debug, - keep_alive=self.keep_alive, - keep_alive_interval=self.keep_alive_interval, - get_context=_get_context, - get_root_value=_get_root_value, - ws=socket, - ).handle() - else: - await socket.close(code=WS_4406_PROTOCOL_NOT_ACCEPTABLE) - - def pick_preferred_protocol(self, socket: WebSocket) -> str | None: - protocols: List[str] = socket.scope["subprotocols"] - intersection: Set[str] = set(protocols) & set(self.protocols) - return ( - min( - intersection, - key=lambda i: protocols.index(i) if i else "", - default=None, - ) - or None - ) - def make_graphql_controller( schema: BaseSchema, diff --git a/strawberry/litestar/handlers/__init__.py b/strawberry/litestar/handlers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/strawberry/litestar/handlers/graphql_transport_ws_handler.py b/strawberry/litestar/handlers/graphql_transport_ws_handler.py deleted file mode 100644 index b5aa915d08..0000000000 --- a/strawberry/litestar/handlers/graphql_transport_ws_handler.py +++ /dev/null @@ -1,60 +0,0 @@ -from collections.abc import Callable -from datetime import timedelta -from typing import Any - -from litestar import WebSocket -from litestar.exceptions import SerializationException, WebSocketDisconnect -from strawberry.schema import BaseSchema -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( - BaseGraphQLTransportWSHandler, -) - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - connection_init_wait_timeout: timedelta, - get_context: Callable, - get_root_value: Callable, - ws: WebSocket, - ) -> None: - super().__init__(schema, debug, connection_init_wait_timeout) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context() - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - async def send_json(self, data: dict) -> None: - await self._ws.send_json(data) - - async def close(self, code: int, reason: str) -> None: - await self._ws.close(code=code, reason=reason) - - async def handle_request(self) -> None: - await self._ws.accept(subprotocols=GRAPHQL_TRANSPORT_WS_PROTOCOL) - self.on_request_accepted() - - try: - while self._ws.connection_state != "disconnect": - try: - message = await self._ws.receive_json() - except (SerializationException, ValueError): # noqa: PERF203 - error_message = "WebSocket message type must be text" - await self.handle_invalid_message(error_message) - else: - await self.handle_message(message) - except WebSocketDisconnect: # pragma: no cover - pass - finally: - await self.shutdown() - - -__all__ = ["GraphQLTransportWSHandler"] diff --git a/strawberry/litestar/handlers/graphql_ws_handler.py b/strawberry/litestar/handlers/graphql_ws_handler.py deleted file mode 100644 index ada421922f..0000000000 --- a/strawberry/litestar/handlers/graphql_ws_handler.py +++ /dev/null @@ -1,66 +0,0 @@ -from collections.abc import Callable -from contextlib import suppress -from typing import Any, Optional - -from litestar import WebSocket -from litestar.exceptions import SerializationException, WebSocketDisconnect -from strawberry.schema import BaseSchema -from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler -from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage - - -class GraphQLWSHandler(BaseGraphQLWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - keep_alive: bool, - keep_alive_interval: float, - get_context: Callable, - get_root_value: Callable, - ws: WebSocket, - ) -> None: - super().__init__(schema, debug, keep_alive, keep_alive_interval) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context() - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - async def send_json(self, data: OperationMessage) -> None: - await self._ws.send_json(data) - - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - await self._ws.close(code=code, reason=reason) - - async def handle_request(self) -> Any: - await self._ws.accept(subprotocols=GRAPHQL_WS_PROTOCOL) - - try: - while self._ws.connection_state != "disconnect": - try: - message = await self._ws.receive_json() - except (SerializationException, ValueError): # noqa: PERF203 - await self.close( - code=1002, reason="WebSocket message type must be text" - ) - else: - await self.handle_message(message) - except WebSocketDisconnect: # pragma: no cover - pass - finally: - if self.keep_alive_task: - self.keep_alive_task.cancel() - with suppress(BaseException): - await self.keep_alive_task - - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - - -__all__ = ["GraphQLWSHandler"] diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 5aafcec514..c7dc1257fd 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,6 +1,7 @@ import warnings from collections.abc import Mapping from typing import TYPE_CHECKING, AsyncGenerator, Callable, Dict, Optional, cast +from typing_extensions import TypeGuard from quart import Request, Response, request from quart.views import View @@ -46,7 +47,9 @@ async def get_form_data(self) -> FormData: class GraphQLView( - AsyncBaseHTTPView[Request, Response, Response, Context, RootValue], + AsyncBaseHTTPView[ + Request, Response, Response, Request, Response, Context, RootValue + ], View, ): _ide_subscription_enabled = False @@ -121,5 +124,16 @@ async def create_streaming_response( }, ) + def is_websocket_request(self, request: Request) -> TypeGuard[Request]: + return False + + async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: + raise NotImplementedError + + async def create_websocket_response( + self, request: Request, subprotocol: Optional[str] + ) -> Response: + raise NotImplementedError + __all__ = ["GraphQLView"] diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index b62a63ba65..ee76d2e946 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -13,6 +13,7 @@ Type, cast, ) +from typing_extensions import TypeGuard from sanic.request import Request from sanic.response import HTTPResponse, html @@ -71,7 +72,15 @@ async def get_form_data(self) -> FormData: class GraphQLView( - AsyncBaseHTTPView[Request, HTTPResponse, TemporalResponse, Context, RootValue], + AsyncBaseHTTPView[ + Request, + HTTPResponse, + TemporalResponse, + Request, + TemporalResponse, + Context, + RootValue, + ], HTTPMethodView, ): """Class based view to handle GraphQL HTTP Requests. @@ -206,5 +215,16 @@ async def create_streaming_response( # corner case return None # type: ignore + def is_websocket_request(self, request: Request) -> TypeGuard[Request]: + return False + + async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: + raise NotImplementedError + + async def create_websocket_response( + self, request: Request, subprotocol: Optional[str] + ) -> TemporalResponse: + raise NotImplementedError + __all__ = ["GraphQLView"] diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 7d19db8e98..f74e9def3d 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -2,7 +2,6 @@ import asyncio import logging -from abc import ABC, abstractmethod from contextlib import suppress from typing import ( TYPE_CHECKING, @@ -16,6 +15,7 @@ from graphql import GraphQLError, GraphQLSyntaxError, parse +from strawberry.http.exceptions import NonJsonMessageReceived from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, ConnectionAckMessage, @@ -38,6 +38,7 @@ if TYPE_CHECKING: from datetime import timedelta + from strawberry.http.async_base_view import AsyncWebSocketAdapter from strawberry.schema import BaseSchema from strawberry.schema.subscribe import SubscriptionResult from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( @@ -45,15 +46,21 @@ ) -class BaseGraphQLTransportWSHandler(ABC): +class BaseGraphQLTransportWSHandler: task_logger: logging.Logger = logging.getLogger("strawberry.ws.task") def __init__( self, + websocket: AsyncWebSocketAdapter, + context: object, + root_value: object, schema: BaseSchema, debug: bool, connection_init_wait_timeout: timedelta, ) -> None: + self.websocket = websocket + self.context = context + self.root_value = root_value self.schema = schema self.debug = debug self.connection_init_wait_timeout = connection_init_wait_timeout @@ -65,28 +72,16 @@ def __init__( self.completed_tasks: List[asyncio.Task] = [] self.connection_params: Optional[Dict[str, Any]] = None - @abstractmethod - async def get_context(self) -> Any: - """Return the operations context.""" - - @abstractmethod - async def get_root_value(self) -> Any: - """Return the schemas root value.""" - - @abstractmethod - async def send_json(self, data: dict) -> None: - """Send the data JSON encoded to the WebSocket client.""" - - @abstractmethod - async def close(self, code: int, reason: str) -> None: - """Close the WebSocket with the passed code and reason.""" - - @abstractmethod - async def handle_request(self) -> Any: - """Handle the request this instance was created for.""" - async def handle(self) -> Any: - return await self.handle_request() + self.on_request_accepted() + + try: + async for message in self.websocket.iter_json(): + await self.handle_message(message) + except NonJsonMessageReceived: + await self.handle_invalid_message("WebSocket message type must be text") + finally: + await self.shutdown() async def shutdown(self) -> None: if self.connection_init_timeout_task: @@ -118,7 +113,7 @@ async def handle_connection_init_timeout(self) -> None: self.connection_timed_out = True reason = "Connection initialisation timeout" - await self.close(code=4408, reason=reason) + await self.websocket.close(code=4408, reason=reason) except Exception as error: await self.handle_task_exception(error) # pragma: no cover finally: @@ -189,14 +184,16 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None: ) if not isinstance(payload, dict): - await self.close(code=4400, reason="Invalid connection init payload") + await self.websocket.close( + code=4400, reason="Invalid connection init payload" + ) return self.connection_params = payload if self.connection_init_received: reason = "Too many initialisation requests" - await self.close(code=4429, reason=reason) + await self.websocket.close(code=4429, reason=reason) return self.connection_init_received = True @@ -211,13 +208,13 @@ async def handle_pong(self, message: PongMessage) -> None: async def handle_subscribe(self, message: SubscribeMessage) -> None: if not self.connection_acknowledged: - await self.close(code=4401, reason="Unauthorized") + await self.websocket.close(code=4401, reason="Unauthorized") return try: graphql_document = parse(message.payload.query) except GraphQLSyntaxError as exc: - await self.close(code=4400, reason=exc.message) + await self.websocket.close(code=4400, reason=exc.message) return try: @@ -225,12 +222,14 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: graphql_document, message.payload.operationName ) except RuntimeError: - await self.close(code=4400, reason="Can't get GraphQL operation type") + await self.websocket.close( + code=4400, reason="Can't get GraphQL operation type" + ) return if message.id in self.operations: reason = f"Subscriber for {message.id} already exists" - await self.close(code=4409, reason=reason) + await self.websocket.close(code=4409, reason=reason) return if self.debug: # pragma: no cover @@ -240,26 +239,28 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: message.payload.variables, ) - context = await self.get_context() - if isinstance(context, dict): - context["connection_params"] = self.connection_params - root_value = await self.get_root_value() + if isinstance(self.context, dict): + self.context["connection_params"] = self.connection_params + elif hasattr(self.context, "connection_params"): + self.context.connection_params = self.connection_params + result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult] + # Get an AsyncGenerator yielding the results if operation_type == OperationType.SUBSCRIPTION: result_source = self.schema.subscribe( query=message.payload.query, variable_values=message.payload.variables, operation_name=message.payload.operationName, - context_value=context, - root_value=root_value, + context_value=self.context, + root_value=self.root_value, ) else: result_source = self.schema.execute( query=message.payload.query, variable_values=message.payload.variables, - context_value=context, - root_value=root_value, + context_value=self.context, + root_value=self.root_value, operation_name=message.payload.operationName, ) @@ -312,11 +313,11 @@ async def handle_complete(self, message: CompleteMessage) -> None: await self.cleanup_operation(operation_id=message.id) async def handle_invalid_message(self, error_message: str) -> None: - await self.close(code=4400, reason=error_message) + await self.websocket.close(code=4400, reason=error_message) async def send_message(self, message: GraphQLTransportMessage) -> None: data = message.as_dict() - await self.send_json(data) + await self.websocket.send_json(data) async def cleanup_operation(self, operation_id: str) -> None: if operation_id not in self.operations: diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 0451e0934b..fda3db829f 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -1,11 +1,9 @@ from __future__ import annotations import asyncio -from abc import ABC, abstractmethod from contextlib import suppress from typing import ( TYPE_CHECKING, - Any, AsyncGenerator, Awaitable, Dict, @@ -13,6 +11,7 @@ cast, ) +from strawberry.http.exceptions import NonJsonMessageReceived from strawberry.subscriptions.protocols.graphql_ws import ( GQL_COMPLETE, GQL_CONNECTION_ACK, @@ -25,29 +24,36 @@ GQL_START, GQL_STOP, ) +from strawberry.subscriptions.protocols.graphql_ws.types import ( + ConnectionInitPayload, + DataPayload, + OperationMessage, + OperationMessagePayload, + StartPayload, +) from strawberry.types.execution import ExecutionResult, PreExecutionError from strawberry.utils.debug import pretty_print_graphql_operation if TYPE_CHECKING: + from strawberry.http.async_base_view import AsyncWebSocketAdapter from strawberry.schema import BaseSchema from strawberry.schema.subscribe import SubscriptionResult - from strawberry.subscriptions.protocols.graphql_ws.types import ( - ConnectionInitPayload, - DataPayload, - OperationMessage, - OperationMessagePayload, - StartPayload, - ) -class BaseGraphQLWSHandler(ABC): +class BaseGraphQLWSHandler: def __init__( self, + websocket: AsyncWebSocketAdapter, + context: object, + root_value: object, schema: BaseSchema, debug: bool, keep_alive: bool, - keep_alive_interval: float, + keep_alive_interval: Optional[float], ) -> None: + self.websocket = websocket + self.context = context + self.root_value = root_value self.schema = schema self.debug = debug self.keep_alive = keep_alive @@ -57,28 +63,22 @@ def __init__( self.tasks: Dict[str, asyncio.Task] = {} self.connection_params: Optional[ConnectionInitPayload] = None - @abstractmethod - async def get_context(self) -> Any: - """Return the operations context.""" - - @abstractmethod - async def get_root_value(self) -> Any: - """Return the schemas root value.""" - - @abstractmethod - async def send_json(self, data: OperationMessage) -> None: - """Send the data JSON encoded to the WebSocket client.""" - - @abstractmethod - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - """Close the WebSocket with the passed code and reason.""" - - @abstractmethod - async def handle_request(self) -> Any: - """Handle the request this instance was created for.""" - - async def handle(self) -> Any: - return await self.handle_request() + async def handle(self) -> None: + try: + async for message in self.websocket.iter_json(): + await self.handle_message(cast(OperationMessage, message)) + except NonJsonMessageReceived: + await self.websocket.close( + code=1002, reason="WebSocket message type must be text" + ) + finally: + if self.keep_alive_task: + self.keep_alive_task.cancel() + with suppress(BaseException): + await self.keep_alive_task + + for operation_id in list(self.subscriptions.keys()): + await self.cleanup_operation(operation_id) async def handle_message( self, @@ -99,22 +99,22 @@ async def handle_connection_init(self, message: OperationMessage) -> None: payload = message.get("payload") if payload is not None and not isinstance(payload, dict): error_message: OperationMessage = {"type": GQL_CONNECTION_ERROR} - await self.send_json(error_message) - await self.close() + await self.websocket.send_json(error_message) + await self.websocket.close(code=1000, reason="") return payload = cast(Optional["ConnectionInitPayload"], payload) self.connection_params = payload acknowledge_message: OperationMessage = {"type": GQL_CONNECTION_ACK} - await self.send_json(acknowledge_message) + await self.websocket.send_json(acknowledge_message) if self.keep_alive: keep_alive_handler = self.handle_keep_alive() self.keep_alive_task = asyncio.create_task(keep_alive_handler) async def handle_connection_terminate(self, message: OperationMessage) -> None: - await self.close() + await self.websocket.close(code=1000, reason="") async def handle_start(self, message: OperationMessage) -> None: operation_id = message["id"] @@ -123,10 +123,10 @@ async def handle_start(self, message: OperationMessage) -> None: operation_name = payload.get("operationName") variables = payload.get("variables") - context = await self.get_context() - if isinstance(context, dict): - context["connection_params"] = self.connection_params - root_value = await self.get_root_value() + if isinstance(self.context, dict): + self.context["connection_params"] = self.connection_params + elif hasattr(self.context, "connection_params"): + self.context.connection_params = self.connection_params if self.debug: pretty_print_graphql_operation(operation_name, query, variables) @@ -135,8 +135,8 @@ async def handle_start(self, message: OperationMessage) -> None: query=query, variable_values=variables, operation_name=operation_name, - context_value=context, - root_value=root_value, + context_value=self.context, + root_value=self.root_value, ) result_handler = self.handle_async_results(result_source, operation_id) @@ -147,9 +147,10 @@ async def handle_stop(self, message: OperationMessage) -> None: await self.cleanup_operation(operation_id) async def handle_keep_alive(self) -> None: + assert self.keep_alive_interval while True: data: OperationMessage = {"type": GQL_CONNECTION_KEEP_ALIVE} - await self.send_json(data) + await self.websocket.send_json(data) await asyncio.sleep(self.keep_alive_interval) async def handle_async_results( @@ -191,7 +192,7 @@ async def send_message( data: OperationMessage = {"type": type_, "id": operation_id} if payload is not None: data["payload"] = payload - await self.send_json(data) + await self.websocket.send_json(data) async def send_data( self, execution_result: ExecutionResult, operation_id: str diff --git a/tests/aiohttp/__init__.py b/tests/aiohttp/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/aiohttp/app.py b/tests/aiohttp/app.py deleted file mode 100644 index ba43cea8fc..0000000000 --- a/tests/aiohttp/app.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any - -from aiohttp import web -from strawberry.aiohttp.handlers import GraphQLTransportWSHandler, GraphQLWSHandler -from strawberry.aiohttp.views import GraphQLView -from tests.views.schema import Query, schema - - -class DebuggableGraphQLTransportWSHandler(GraphQLTransportWSHandler): - def get_tasks(self) -> list: - return [op.task for op in self.operations.values()] - - async def get_context(self) -> object: - context = await super().get_context() - context["ws"] = self._ws - context["get_tasks"] = self.get_tasks - context["connectionInitTimeoutTask"] = self.connection_init_timeout_task - return context - - -class DebuggableGraphQLWSHandler(GraphQLWSHandler): - def get_tasks(self) -> list: - return list(self.tasks.values()) - - async def get_context(self) -> object: - context = await super().get_context() - context["ws"] = self._ws - context["get_tasks"] = self.get_tasks - context["connectionInitTimeoutTask"] = None - return context - - -class MyGraphQLView(GraphQLView): - graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler - graphql_ws_handler_class = DebuggableGraphQLWSHandler - - async def get_root_value(self, request: web.Request) -> Query: - await super().get_root_value(request) # for coverage - return Query() - - -def create_app(**kwargs: Any) -> web.Application: - app = web.Application() - app.router.add_route("*", "/graphql", MyGraphQLView(schema=schema, **kwargs)) - - return app diff --git a/tests/aiohttp/test_websockets.py b/tests/aiohttp/test_websockets.py deleted file mode 100644 index 82d6a5db00..0000000000 --- a/tests/aiohttp/test_websockets.py +++ /dev/null @@ -1,110 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Awaitable, Callable - -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL - -if TYPE_CHECKING: - from aiohttp.test_utils import TestClient - - -async def test_turning_off_graphql_ws( - aiohttp_client: Callable[..., Awaitable[TestClient]], -) -> None: - from .app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) - aiohttp_app_client = await aiohttp_client(app) - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] - ) as ws: - data = await ws.receive(timeout=2) - assert ws.protocol is None - assert ws.closed - assert ws.close_code == 4406 - assert data.extra == "Subprotocol not acceptable" - - -async def test_turning_off_graphql_transport_ws( - aiohttp_client: Callable[..., Awaitable[TestClient]], -): - from .app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL]) - aiohttp_app_client = await aiohttp_client(app) - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - data = await ws.receive(timeout=2) - assert ws.protocol is None - assert ws.closed - assert ws.close_code == 4406 - assert data.extra == "Subprotocol not acceptable" - - -async def test_turning_off_all_ws_protocols( - aiohttp_client: Callable[..., Awaitable[TestClient]], -): - from .app import create_app - - app = create_app(subscription_protocols=[]) - aiohttp_app_client = await aiohttp_client(app) - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - data = await ws.receive(timeout=2) - assert ws.protocol is None - assert ws.closed - assert ws.close_code == 4406 - assert data.extra == "Subprotocol not acceptable" - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] - ) as ws: - data = await ws.receive(timeout=2) - assert ws.protocol is None - assert ws.closed - assert ws.close_code == 4406 - assert data.extra == "Subprotocol not acceptable" - - -async def test_unsupported_ws_protocol( - aiohttp_client: Callable[..., Awaitable[TestClient]], -): - from .app import create_app - - app = create_app(subscription_protocols=[]) - aiohttp_app_client = await aiohttp_client(app) - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=["imaginary-protocol"] - ) as ws: - data = await ws.receive(timeout=2) - assert ws.protocol is None - assert ws.closed - assert ws.close_code == 4406 - assert data.extra == "Subprotocol not acceptable" - - -async def test_clients_can_prefer_protocols( - aiohttp_client: Callable[..., Awaitable[TestClient]], -): - from .app import create_app - - app = create_app( - subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) - aiohttp_app_client = await aiohttp_client(app) - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL] - ) as ws: - assert ws.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - assert ws.protocol == GRAPHQL_WS_PROTOCOL diff --git a/tests/asgi/app.py b/tests/asgi/app.py deleted file mode 100644 index a179a3210f..0000000000 --- a/tests/asgi/app.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Any, Dict, Optional, Union - -from starlette.requests import Request -from starlette.responses import Response -from starlette.websockets import WebSocket - -from strawberry.asgi import GraphQL as BaseGraphQL -from tests.views.schema import Query, schema - - -class GraphQL(BaseGraphQL): - async def get_root_value(self, request) -> Query: - return Query() - - async def get_context( - self, - request: Union[Request, WebSocket], - response: Optional[Response] = None, - ) -> Dict[str, Union[Request, WebSocket, Response, str, None]]: - return {"request": request, "response": response, "custom_value": "Hi"} - - -def create_app(**kwargs: Any) -> GraphQL: - return GraphQL(schema, **kwargs) diff --git a/tests/asgi/test_websockets.py b/tests/asgi/test_websockets.py deleted file mode 100644 index 511661358a..0000000000 --- a/tests/asgi/test_websockets.py +++ /dev/null @@ -1,94 +0,0 @@ -import pytest - -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL - - -def test_turning_off_graphql_ws(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.asgi.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_graphql_transport_ws(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.asgi.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_all_ws_protocols(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.asgi.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_unsupported_ws_protocol(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.asgi.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/", ["imaginary-protocol"]): - pass - - assert exc.value.code == 4406 - - -def test_clients_can_prefer_protocols(): - from starlette.testclient import TestClient - - from tests.asgi.app import create_app - - app = create_app( - subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) - test_client = TestClient(app) - - with test_client.websocket_connect( - "/", [GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL - - with test_client.websocket_connect( - "/", [GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL diff --git a/tests/channels/test_layers.py b/tests/channels/test_layers.py index 40a3eed7b6..8db1205fdf 100644 --- a/tests/channels/test_layers.py +++ b/tests/channels/test_layers.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING, AsyncGenerator import pytest @@ -21,7 +21,7 @@ @pytest.fixture -async def ws() -> Generator[WebsocketCommunicator, None, None]: +async def ws() -> AsyncGenerator[WebsocketCommunicator, None]: from channels.testing import WebsocketCommunicator from strawberry.channels import GraphQLWSConsumer diff --git a/tests/channels/test_testing.py b/tests/channels/test_testing.py index 921d8abf35..99aa9dd6c8 100644 --- a/tests/channels/test_testing.py +++ b/tests/channels/test_testing.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generator +from typing import TYPE_CHECKING, Any, AsyncGenerator import pytest @@ -14,7 +14,7 @@ @pytest.fixture(params=[GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL]) async def communicator( request: Any, -) -> Generator[GraphQLWebsocketCommunicator, None, None]: +) -> AsyncGenerator[GraphQLWebsocketCommunicator, None]: from strawberry.channels import GraphQLWSConsumer from strawberry.channels.testing import GraphQLWebsocketCommunicator diff --git a/tests/channels/test_ws_handler.py b/tests/channels/test_ws_handler.py deleted file mode 100644 index 88310bb617..0000000000 --- a/tests/channels/test_ws_handler.py +++ /dev/null @@ -1,54 +0,0 @@ -import pytest - -from tests.views.schema import schema - -try: - from channels.testing.websocket import WebsocketCommunicator - from strawberry.channels.handlers.graphql_transport_ws_handler import ( - GraphQLTransportWSHandler, - ) - from strawberry.channels.handlers.graphql_ws_handler import GraphQLWSHandler - from strawberry.channels.handlers.ws_handler import GraphQLWSConsumer -except ImportError: - pytestmark = pytest.mark.skip("Channels is not installed") - GraphQLWSHandler = None - GraphQLTransportWSHandler = None - -from strawberry.subscriptions import ( - GRAPHQL_TRANSPORT_WS_PROTOCOL, - GRAPHQL_WS_PROTOCOL, -) - - -async def test_wrong_protocol(): - GraphQLWSConsumer.as_asgi(schema=schema) - client = WebsocketCommunicator( - GraphQLWSConsumer.as_asgi(schema=schema), - "/graphql", - subprotocols=[ - "non-existing", - ], - ) - res = await client.connect() - assert res == (False, 4406) - - -@pytest.mark.parametrize( - ("protocol", "handler"), - [ - (GRAPHQL_TRANSPORT_WS_PROTOCOL, GraphQLTransportWSHandler), - (GRAPHQL_WS_PROTOCOL, GraphQLWSHandler), - ], -) -async def test_correct_protocol(protocol, handler): - consumer = GraphQLWSConsumer(schema=schema) - client = WebsocketCommunicator( - consumer, - "/graphql", - subprotocols=[ - protocol, - ], - ) - res = await client.connect() - assert res == (True, protocol) - assert isinstance(consumer._handler, handler) diff --git a/tests/fastapi/test_websockets.py b/tests/fastapi/test_websockets.py deleted file mode 100644 index de729c9f32..0000000000 --- a/tests/fastapi/test_websockets.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import Any - -import pytest - -import strawberry -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL - - -def test_turning_off_graphql_ws(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.fastapi.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_graphql_transport_ws(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.fastapi.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_all_ws_protocols(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.fastapi.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_unsupported_ws_protocol(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.fastapi.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", ["imaginary-protocol"]): - pass - - assert exc.value.code == 4406 - - -def test_clients_can_prefer_protocols(): - from starlette.testclient import TestClient - - from tests.fastapi.app import create_app - - app = create_app( - subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) - test_client = TestClient(app) - - with test_client.websocket_connect( - "/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL - - with test_client.websocket_connect( - "/graphql", [GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL - - -def test_with_custom_encode_json(): - from starlette.testclient import TestClient - - from fastapi import FastAPI - from strawberry.fastapi.router import GraphQLRouter - - @strawberry.type - class Query: - @strawberry.field - def abc(self) -> str: - return "abc" - - class MyRouter(GraphQLRouter[None, None]): - def encode_json(self, response_data: Any): - return '"custom"' - - app = FastAPI() - schema = strawberry.Schema(query=Query) - graphql_app = MyRouter(schema=schema) - app.include_router(graphql_app, prefix="/graphql") - - test_client = TestClient(app) - response = test_client.post("/graphql", json={"query": "{ abc }"}) - - assert response.status_code == 200 - assert response.json() == "custom" diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index 4979d7d480..dcf7abb5bd 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -10,7 +10,6 @@ from aiohttp.client_ws import ClientWebSocketResponse from aiohttp.http_websocket import WSMsgType from aiohttp.test_utils import TestClient, TestServer -from strawberry.aiohttp.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.aiohttp.views import GraphQLView as BaseGraphQLView from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE @@ -20,8 +19,8 @@ from ..context import get_context from .base import ( JSON, - DebuggableGraphQLTransportWSMixin, - DebuggableGraphQLWSMixin, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, HttpClient, Message, Response, @@ -30,16 +29,6 @@ ) -class DebuggableGraphQLTransportWSHandler( - DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler -): - pass - - -class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler): - pass - - class GraphQLView(BaseGraphQLView): result_override: ResultOverrideFunction = None graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler @@ -194,6 +183,9 @@ def __init__(self, ws: ClientWebSocketResponse): self.ws = ws self._reason: Optional[str] = None + async def send_text(self, payload: str) -> None: + await self.ws.send_str(payload) + async def send_json(self, payload: Dict[str, Any]) -> None: await self.ws.send_json(payload) @@ -213,6 +205,10 @@ async def receive_json(self, timeout: Optional[float] = None) -> Any: async def close(self) -> None: await self.ws.close() + @property + def accepted_subprotocol(self) -> Optional[str]: + return self.ws.protocol + @property def closed(self) -> bool: return self.ws.closed diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 5734c8df16..b8f51bbab1 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -8,11 +8,10 @@ from starlette.requests import Request from starlette.responses import Response as StarletteResponse -from starlette.testclient import TestClient +from starlette.testclient import TestClient, WebSocketTestSession from starlette.websockets import WebSocket, WebSocketDisconnect from strawberry.asgi import GraphQL as BaseGraphQLView -from strawberry.asgi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult @@ -21,8 +20,8 @@ from ..context import get_context from .base import ( JSON, - DebuggableGraphQLTransportWSMixin, - DebuggableGraphQLWSMixin, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, HttpClient, Message, Response, @@ -31,16 +30,6 @@ ) -class DebuggableGraphQLTransportWSHandler( - DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler -): - pass - - -class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler): - pass - - class GraphQLView(BaseGraphQLView): result_override: ResultOverrideFunction = None graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler @@ -181,7 +170,7 @@ async def ws_connect( class AsgiWebSocketClient(WebSocketClient): - def __init__(self, ws: Any): + def __init__(self, ws: WebSocketTestSession): self.ws = ws self._closed: bool = False self._close_code: Optional[int] = None @@ -192,6 +181,9 @@ def handle_disconnect(self, exc: WebSocketDisconnect) -> None: self._close_code = exc.code self._close_reason = exc.reason + async def send_text(self, payload: str) -> None: + self.ws.send_text(payload) + async def send_json(self, payload: Dict[str, Any]) -> None: self.ws.send_json(payload) @@ -224,6 +216,10 @@ async def close(self) -> None: self.ws.close() self._closed = True + @property + def accepted_subprotocol(self) -> Optional[str]: + return self.ws.accepted_subprotocol + @property def closed(self) -> bool: return self._closed diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index ff31e4111e..c1156fc3a6 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -21,6 +21,10 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE +from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( + BaseGraphQLTransportWSHandler, +) +from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler from strawberry.types import ExecutionResult logger = logging.getLogger("strawberry.test.http_client") @@ -237,7 +241,7 @@ def create_app(self, **kwargs: Any) -> None: """For use by websocket tests.""" raise NotImplementedError - async def ws_connect( + def ws_connect( self, url: str, *, @@ -260,6 +264,9 @@ class WebSocketClient(abc.ABC): def name(self) -> str: return "" + @abc.abstractmethod + async def send_text(self, payload: str) -> None: ... + @abc.abstractmethod async def send_json(self, payload: Dict[str, Any]) -> None: ... @@ -274,6 +281,10 @@ async def receive_json(self, timeout: Optional[float] = None) -> Any: ... @abc.abstractmethod async def close(self) -> None: ... + @property + @abc.abstractmethod + def accepted_subprotocol(self) -> Optional[str]: ... + @property @abc.abstractmethod def closed(self) -> bool: ... @@ -290,7 +301,7 @@ async def __aiter__(self) -> AsyncGenerator[Message, None]: yield await self.receive() -class DebuggableGraphQLTransportWSMixin: +class DebuggableGraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): def on_init(self) -> None: """This method can be patched by unit tests to get the instance of the transport handler when it is initialized. @@ -298,26 +309,41 @@ def on_init(self) -> None: def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) - DebuggableGraphQLTransportWSMixin.on_init(self) + self.original_context = kwargs.get("context", {}) + DebuggableGraphQLTransportWSHandler.on_init(self) def get_tasks(self) -> List: return [op.task for op in self.operations.values()] - async def get_context(self) -> object: - context = await super().get_context() - context["ws"] = self._ws - context["get_tasks"] = self.get_tasks - context["connectionInitTimeoutTask"] = self.connection_init_timeout_task - return context + @property + def context(self): + self.original_context["ws"] = self.websocket + self.original_context["get_tasks"] = self.get_tasks + self.original_context["connectionInitTimeoutTask"] = ( + self.connection_init_timeout_task + ) + return self.original_context + + @context.setter + def context(self, value): + self.original_context = value + +class DebuggableGraphQLWSHandler(BaseGraphQLWSHandler): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.original_context = self.context -class DebuggableGraphQLWSMixin: def get_tasks(self) -> List: return list(self.tasks.values()) - async def get_context(self) -> object: - context = await super().get_context() - context["ws"] = self._ws - context["get_tasks"] = self.get_tasks - context["connectionInitTimeoutTask"] = None - return context + @property + def context(self): + self.original_context["ws"] = self.websocket + self.original_context["get_tasks"] = self.get_tasks + self.original_context["connectionInitTimeoutTask"] = None + return self.original_context + + @context.setter + def context(self, value): + self.original_context = value diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 14abd5e4af..0e53ce8f82 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -23,6 +23,8 @@ from ..context import get_context from .base import ( JSON, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, HttpClient, Message, Response, @@ -63,25 +65,6 @@ def create_multipart_request_body( return headers, request_body -class DebuggableGraphQLTransportWSConsumer(GraphQLWSConsumer): - def get_tasks(self) -> List[Any]: - if hasattr(self._handler, "operations"): - return [op.task for op in self._handler.operations.values()] - else: - return list(self._handler.tasks.values()) - - async def get_context(self, *args: str, **kwargs: Any) -> object: - context = await super().get_context(*args, **kwargs) - context["ws"] = self._handler._ws - context["get_tasks"] = self.get_tasks - context["connectionInitTimeoutTask"] = getattr( - self._handler, "connection_init_timeout_task", None - ) - for key, val in get_context({}).items(): - context[key] = val - return context - - class DebuggableGraphQLHTTPConsumer(GraphQLHTTPConsumer): result_override: ResultOverrideFunction = None @@ -130,6 +113,16 @@ def process_result( return super().process_result(request, result) +class DebuggableGraphQLWSConsumer(GraphQLWSConsumer): + graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler + graphql_ws_handler_class = DebuggableGraphQLWSHandler + + async def get_context(self, request, response): + context = await super().get_context(request, response) + + return get_context(context) + + class ChannelsHttpClient(HttpClient): """A client to test websockets over channels.""" @@ -141,7 +134,7 @@ def __init__( result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, ): - self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( + self.ws_app = DebuggableGraphQLWSConsumer.as_asgi( schema=schema, keep_alive=False, ) @@ -156,9 +149,7 @@ def __init__( ) def create_app(self, **kwargs: Any) -> None: - self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( - schema=schema, **kwargs - ) + self.ws_app = DebuggableGraphQLWSConsumer.as_asgi(schema=schema, **kwargs) async def _graphql_request( self, @@ -247,10 +238,13 @@ async def ws_connect( ) -> AsyncGenerator[WebSocketClient, None]: client = WebsocketCommunicator(self.ws_app, url, subprotocols=protocols) - res = await client.connect() - assert res == (True, protocols[0]) + connected, subprotocol_or_close_code = await client.connect() + assert connected + try: - yield ChannelsWebSocketClient(client) + yield ChannelsWebSocketClient( + client, accepted_subprotocol=subprotocol_or_close_code + ) finally: await client.disconnect() @@ -275,15 +269,21 @@ def __init__( class ChannelsWebSocketClient(WebSocketClient): - def __init__(self, client: WebsocketCommunicator): + def __init__( + self, client: WebsocketCommunicator, accepted_subprotocol: Optional[str] + ): self.ws = client self._closed: bool = False self._close_code: Optional[int] = None self._close_reason: Optional[str] = None + self._accepted_subprotocol = accepted_subprotocol def name(self) -> str: return "channels" + async def send_text(self, payload: str) -> None: + await self.ws.send_to(text_data=payload) + async def send_json(self, payload: Dict[str, Any]) -> None: await self.ws.send_json_to(payload) @@ -311,6 +311,10 @@ async def close(self) -> None: await self.ws.disconnect() self._closed = True + @property + def accepted_subprotocol(self) -> Optional[str]: + return self._accepted_subprotocol + @property def closed(self) -> bool: return self._closed diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index cddc43032f..b1b80625fa 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -11,7 +11,6 @@ from fastapi import BackgroundTasks, Depends, FastAPI, Request, WebSocket from fastapi.testclient import TestClient from strawberry.fastapi import GraphQLRouter as BaseGraphQLRouter -from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult @@ -21,8 +20,8 @@ from .asgi import AsgiWebSocketClient from .base import ( JSON, - DebuggableGraphQLTransportWSMixin, - DebuggableGraphQLWSMixin, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, HttpClient, Response, ResultOverrideFunction, @@ -30,16 +29,6 @@ ) -class DebuggableGraphQLTransportWSHandler( - DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler -): - pass - - -class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler): - pass - - def custom_context_dependency() -> str: return "Hi!" diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index 0b99b43729..065b04f395 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -13,15 +13,14 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.litestar import make_graphql_controller -from strawberry.litestar.controller import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.types import ExecutionResult from tests.views.schema import Query, schema from ..context import get_context from .base import ( JSON, - DebuggableGraphQLTransportWSMixin, - DebuggableGraphQLWSMixin, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, HttpClient, Message, Response, @@ -42,16 +41,6 @@ async def get_root_value(request: Request = None): return Query() -class DebuggableGraphQLTransportWSHandler( - DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler -): - pass - - -class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler): - pass - - class LitestarHttpClient(HttpClient): def __init__( self, @@ -190,6 +179,9 @@ def handle_disconnect(self, exc: WebSocketDisconnect) -> None: self._closed = True self._close_code = exc.code + async def send_text(self, payload: str) -> None: + self.ws.send_text(payload) + async def send_json(self, payload: Dict[str, Any]) -> None: self.ws.send_json(payload) @@ -229,6 +221,10 @@ async def close(self) -> None: self.ws.close() self._closed = True + @property + def accepted_subprotocol(self) -> Optional[str]: + return self.ws.accepted_subprotocol + @property def closed(self) -> bool: return self._closed diff --git a/tests/litestar/test_websockets.py b/tests/litestar/test_websockets.py deleted file mode 100644 index b5554cb264..0000000000 --- a/tests/litestar/test_websockets.py +++ /dev/null @@ -1,89 +0,0 @@ -import pytest - -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL - - -def test_turning_off_graphql_ws(): - from litestar.exceptions import WebSocketDisconnect - from litestar.testing import TestClient - from tests.litestar.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_graphql_transport_ws(): - from litestar.exceptions import WebSocketDisconnect - from litestar.testing import TestClient - from tests.litestar.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_all_ws_protocols(): - from litestar.exceptions import WebSocketDisconnect - from litestar.testing import TestClient - from tests.litestar.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_unsupported_ws_protocol(): - from litestar.exceptions import WebSocketDisconnect - from litestar.testing import TestClient - from tests.litestar.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", ["imaginary-protocol"]): - pass - - assert exc.value.code == 4406 - - -def test_clients_can_prefer_protocols(): - from litestar.testing import TestClient - from tests.litestar.app import create_app - - app = create_app( - subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) - test_client = TestClient(app) - - with test_client.websocket_connect( - "/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL - - with test_client.websocket_connect( - "/graphql", [GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 02f8366852..49e9c7ce32 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -24,7 +24,7 @@ SubscribeMessage, SubscribeMessagePayload, ) -from tests.http.clients.base import DebuggableGraphQLTransportWSMixin +from tests.http.clients.base import DebuggableGraphQLTransportWSHandler from tests.views.schema import MyExtension, Schema if TYPE_CHECKING: @@ -123,6 +123,17 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): ws.assert_reason("WebSocket message type must be text") +async def test_ws_messages_must_be_json(ws_raw: WebSocketClient): + ws = ws_raw + + await ws.send_text("not valid json") + + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4400 + ws.assert_reason("WebSocket message type must be text") + + async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): ws = ws_raw @@ -879,7 +890,7 @@ def on_init(_handler): # cause an attribute error in the timeout task handler.connection_init_wait_timeout = None - with patch.object(DebuggableGraphQLTransportWSMixin, "on_init", on_init): + with patch.object(DebuggableGraphQLTransportWSHandler, "on_init", on_init): async with http_client.ws_connect( "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] ) as ws: diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 6752eaa7f8..1ad38c6d19 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -292,6 +292,17 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): ws.assert_reason("WebSocket message type must be text") +async def test_ws_messages_must_be_json(ws_raw: WebSocketClient): + ws = ws_raw + + await ws.send_text("not valid json") + + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 1002 + ws.assert_reason("WebSocket message type must be text") + + async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): ws = ws_raw diff --git a/tests/websockets/test_websockets.py b/tests/websockets/test_websockets.py new file mode 100644 index 0000000000..767617d727 --- /dev/null +++ b/tests/websockets/test_websockets.py @@ -0,0 +1,82 @@ +from typing import Type + +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL +from tests.http.clients.base import HttpClient + + +async def test_turning_off_graphql_ws(http_client_class: Type[HttpClient]): + http_client = http_client_class() + http_client.create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] + ) as ws: + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4406 + ws.assert_reason("Subprotocol not acceptable") + + +async def test_turning_off_graphql_transport_ws(http_client_class: Type[HttpClient]): + http_client = http_client_class() + http_client.create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL]) + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4406 + ws.assert_reason("Subprotocol not acceptable") + + +async def test_turning_off_all_subprotocols(http_client_class: Type[HttpClient]): + http_client = http_client_class() + http_client.create_app(subscription_protocols=[]) + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4406 + ws.assert_reason("Subprotocol not acceptable") + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] + ) as ws: + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4406 + ws.assert_reason("Subprotocol not acceptable") + + +async def test_generally_unsupported_subprotocols_are_rejected(http_client: HttpClient): + async with http_client.ws_connect( + "/graphql", protocols=["imaginary-protocol"] + ) as ws: + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4406 + ws.assert_reason("Subprotocol not acceptable") + + +async def test_clients_can_prefer_subprotocols(http_client_class: Type[HttpClient]): + http_client = http_client_class() + http_client.create_app( + subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL] + ) as ws: + assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL + await ws.close() + assert ws.closed + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL + await ws.close() + assert ws.closed