From b59cdaa55509852d53bfda4ff99b73561dcc2be8 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Sat, 12 Oct 2024 19:02:44 +0200 Subject: [PATCH 1/4] Fix graphql-ws did not ignore parsing errors --- strawberry/aiohttp/views.py | 15 +++++++--- strawberry/asgi/__init__.py | 21 ++++++++++---- strawberry/channels/handlers/ws_handler.py | 11 ++++--- strawberry/http/async_base_view.py | 4 ++- strawberry/http/exceptions.py | 4 +++ strawberry/litestar/controller.py | 29 ++++++++++++++----- .../graphql_transport_ws/handlers.py | 8 +++-- .../protocols/graphql_ws/handlers.py | 6 ++-- tests/websockets/test_graphql_transport_ws.py | 6 ++-- tests/websockets/test_graphql_ws.py | 13 +++++---- 10 files changed, 80 insertions(+), 37 deletions(-) diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index 884264dcb0..4ae165a104 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -26,7 +26,11 @@ AsyncHTTPRequestAdapter, AsyncWebSocketAdapter, ) -from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived +from strawberry.http.exceptions import ( + HTTPException, + NonJsonMessageReceived, + NonTextMessageReceived, +) from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import ( Context, @@ -86,16 +90,19 @@ def __init__(self, request: web.Request, ws: web.WebSocketResponse) -> None: self.request = request self.ws = ws - async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + async def iter_json( + self, ignore_parsing_errors: bool + ) -> AsyncGenerator[Dict[str, object], None]: async for ws_message in self.ws: if ws_message.type == http.WSMsgType.TEXT: try: yield ws_message.json() except JSONDecodeError: - raise NonJsonMessageReceived() + if not ignore_parsing_errors: + raise NonJsonMessageReceived() elif ws_message.type == http.WSMsgType.BINARY: - raise NonJsonMessageReceived() + raise NonTextMessageReceived() async def send_json(self, message: Mapping[str, object]) -> None: await self.ws.send_json(message) diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index 5a3f01203d..2c519b41cd 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -33,7 +33,11 @@ AsyncHTTPRequestAdapter, AsyncWebSocketAdapter, ) -from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived +from strawberry.http.exceptions import ( + HTTPException, + NonJsonMessageReceived, + NonTextMessageReceived, +) from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import ( Context, @@ -85,13 +89,18 @@ class ASGIWebSocketAdapter(AsyncWebSocketAdapter): def __init__(self, request: WebSocket, response: WebSocket) -> None: self.ws = response - async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + async def iter_json( + self, ignore_parsing_errors: bool + ) -> AsyncGenerator[Dict[str, object], None]: try: - try: - while self.ws.application_state != WebSocketState.DISCONNECTED: + while self.ws.application_state != WebSocketState.DISCONNECTED: + try: yield await self.ws.receive_json() - except (KeyError, JSONDecodeError): - raise NonJsonMessageReceived() + except JSONDecodeError: # noqa: PERF203 + if not ignore_parsing_errors: + raise NonJsonMessageReceived() + except KeyError: + raise NonTextMessageReceived() except WebSocketDisconnect: # pragma: no cover pass diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py index b267f7ea9b..dce0dff6ea 100644 --- a/strawberry/channels/handlers/ws_handler.py +++ b/strawberry/channels/handlers/ws_handler.py @@ -16,7 +16,7 @@ from typing_extensions import TypeGuard from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter -from strawberry.http.exceptions import NonJsonMessageReceived +from strawberry.http.exceptions import NonJsonMessageReceived, NonTextMessageReceived from strawberry.http.typevars import Context, RootValue from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL @@ -31,7 +31,9 @@ class ChannelsWebSocketAdapter(AsyncWebSocketAdapter): def __init__(self, request: GraphQLWSConsumer, response: GraphQLWSConsumer) -> None: self.ws_consumer = response - async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + async def iter_json( + self, ignore_parsing_errors: bool + ) -> AsyncGenerator[Dict[str, object], None]: while True: message = await self.ws_consumer.message_queue.get() @@ -39,12 +41,13 @@ async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: break if message["message"] is None: - raise NonJsonMessageReceived() + raise NonTextMessageReceived() try: yield json.loads(message["message"]) except json.JSONDecodeError: - raise NonJsonMessageReceived() + if not ignore_parsing_errors: + raise NonJsonMessageReceived() async def send_json(self, message: Mapping[str, object]) -> None: serialized_message = json.dumps(message) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index a7666018ef..241b7d3837 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -81,7 +81,9 @@ async def get_form_data(self) -> FormData: ... class AsyncWebSocketAdapter(abc.ABC): @abc.abstractmethod - def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: ... + def iter_json( + self, ignore_parsing_errors: bool + ) -> AsyncGenerator[Dict[str, object], None]: ... @abc.abstractmethod async def send_json(self, message: Mapping[str, object]) -> None: ... diff --git a/strawberry/http/exceptions.py b/strawberry/http/exceptions.py index feddf77631..cef7ea135f 100644 --- a/strawberry/http/exceptions.py +++ b/strawberry/http/exceptions.py @@ -4,6 +4,10 @@ def __init__(self, status_code: int, reason: str) -> None: self.reason = reason +class NonTextMessageReceived(Exception): + pass + + class NonJsonMessageReceived(Exception): pass diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index fed3d2d45f..0d7a498f1e 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import warnings from datetime import timedelta from typing import ( @@ -37,7 +38,6 @@ from litestar.di import Provide from litestar.exceptions import ( NotFoundException, - SerializationException, ValidationException, WebSocketDisconnect, ) @@ -49,7 +49,11 @@ AsyncHTTPRequestAdapter, AsyncWebSocketAdapter, ) -from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived +from strawberry.http.exceptions import ( + HTTPException, + NonJsonMessageReceived, + NonTextMessageReceived, +) from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import Context, RootValue from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL @@ -192,13 +196,22 @@ class LitestarWebSocketAdapter(AsyncWebSocketAdapter): def __init__(self, request: WebSocket, response: WebSocket) -> None: self.ws = response - async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + async def iter_json( + self, ignore_parsing_errors: bool + ) -> AsyncGenerator[Dict[str, object], None]: try: - try: - while self.ws.connection_state != "disconnect": - yield await self.ws.receive_json() - except (SerializationException, ValueError): - raise NonJsonMessageReceived() + while self.ws.connection_state != "disconnect": + text = await self.ws.receive_text() + + # Litestar internally defaults to an empty string for non-text messages + if text == "": + raise NonTextMessageReceived() + + try: + yield json.loads(text) + except json.JSONDecodeError: + if not ignore_parsing_errors: + raise NonJsonMessageReceived() except WebSocketDisconnect: pass diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index f74e9def3d..a04dc4c892 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -15,7 +15,7 @@ from graphql import GraphQLError, GraphQLSyntaxError, parse -from strawberry.http.exceptions import NonJsonMessageReceived +from strawberry.http.exceptions import NonJsonMessageReceived, NonTextMessageReceived from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, ConnectionAckMessage, @@ -76,10 +76,12 @@ async def handle(self) -> Any: self.on_request_accepted() try: - async for message in self.websocket.iter_json(): + async for message in self.websocket.iter_json(ignore_parsing_errors=False): await self.handle_message(message) - except NonJsonMessageReceived: + except NonTextMessageReceived: await self.handle_invalid_message("WebSocket message type must be text") + except NonJsonMessageReceived: + await self.handle_invalid_message("WebSocket message must be valid JSON") finally: await self.shutdown() diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index fda3db829f..5a79d61de3 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -11,7 +11,7 @@ cast, ) -from strawberry.http.exceptions import NonJsonMessageReceived +from strawberry.http.exceptions import NonTextMessageReceived from strawberry.subscriptions.protocols.graphql_ws import ( GQL_COMPLETE, GQL_CONNECTION_ACK, @@ -65,9 +65,9 @@ def __init__( async def handle(self) -> None: try: - async for message in self.websocket.iter_json(): + async for message in self.websocket.iter_json(ignore_parsing_errors=True): await self.handle_message(cast(OperationMessage, message)) - except NonJsonMessageReceived: + except NonTextMessageReceived: await self.websocket.close( code=1002, reason="WebSocket message type must be text" ) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 4dbea524f4..6a24ce7b5a 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -112,7 +112,7 @@ async def test_parsing_an_invalid_payload(ws_raw: WebSocketClient): assert ws.close_reason == "Failed to parse message" -async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): +async def test_non_text_ws_messages_result_in_socket_closure(ws_raw: WebSocketClient): ws = ws_raw await ws.send_bytes(json.dumps(ConnectionInitMessage().as_dict()).encode()) @@ -123,7 +123,7 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): assert ws.close_reason == "WebSocket message type must be text" -async def test_ws_messages_must_be_json(ws_raw: WebSocketClient): +async def test_non_json_ws_messages_result_in_socket_closure(ws_raw: WebSocketClient): ws = ws_raw await ws.send_text("not valid json") @@ -131,7 +131,7 @@ async def test_ws_messages_must_be_json(ws_raw: WebSocketClient): await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - assert ws.close_reason == "WebSocket message type must be text" + assert ws.close_reason == "WebSocket message must be valid JSON" async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index daf56f90fb..b34394db48 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -281,7 +281,7 @@ async def test_subscription_syntax_error(ws: WebSocketClient): } -async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): +async def test_non_text_ws_messages_result_in_socket_closure(ws_raw: WebSocketClient): ws = ws_raw await ws.send_bytes(json.dumps({"type": GQL_CONNECTION_INIT}).encode()) @@ -292,15 +292,18 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): assert ws.close_reason == "WebSocket message type must be text" -async def test_ws_messages_must_be_json(ws_raw: WebSocketClient): +async def test_non_json_ws_messages_are_ignored(ws_raw: WebSocketClient): ws = ws_raw await ws.send_text("not valid json") + await ws.send_json({"type": GQL_CONNECTION_INIT}) - await ws.receive(timeout=2) + response = await ws.receive_json() + assert response["type"] == GQL_CONNECTION_ACK + + await ws.send_json({"type": GQL_CONNECTION_TERMINATE}) + await ws.receive(timeout=2) # receive close assert ws.closed - assert ws.close_code == 1002 - assert ws.close_reason == "WebSocket message type must be text" async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): From 0ad7d1d17ce6efed2e3c2e122afbf3e358bf08c4 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Sun, 13 Oct 2024 19:36:32 +0200 Subject: [PATCH 2/4] Add release file --- RELEASE.md | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..bd4c01aeae --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,6 @@ +Release type: minor + +This release fixes a regression in the legacy GraphQL over WebSocket protocol. +Legacy protocol implementations should ignore client message parsing errors. +During a recent refactor, Strawberry changed this behavior to match the new protocol, where parsing errors must close the WebSocket connection. +The expected behavior is restored and adequately tested in this release. From ab7cf219070942e0c3198e31c5b49dfd59533ef4 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Sun, 13 Oct 2024 19:54:56 +0200 Subject: [PATCH 3/4] Test on every stage of the protocol --- tests/websockets/test_graphql_ws.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index b34394db48..9c729c0f79 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -295,12 +295,36 @@ async def test_non_text_ws_messages_result_in_socket_closure(ws_raw: WebSocketCl async def test_non_json_ws_messages_are_ignored(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_text("not valid json") + await ws.send_text("NOT VALID JSON") await ws.send_json({"type": GQL_CONNECTION_INIT}) response = await ws.receive_json() assert response["type"] == GQL_CONNECTION_ACK + await ws.send_text("NOT VALID JSON") + await ws.send_json( + { + "type": GQL_START, + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } + ) + + response = await ws.receive_json() + assert response["type"] == GQL_DATA + assert response["id"] == "demo" + assert response["payload"]["data"] == {"echo": "Hi"} + + await ws.send_text("NOT VALID JSON") + await ws.send_json({"type": GQL_STOP, "id": "demo"}) + + response = await ws.receive_json() + assert response["type"] == GQL_COMPLETE + assert response["id"] == "demo" + + await ws.send_text("NOT VALID JSON") await ws.send_json({"type": GQL_CONNECTION_TERMINATE}) await ws.receive(timeout=2) # receive close assert ws.closed From 94b6e834ef30623374e0628be97d062bb2f961e1 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Mon, 21 Oct 2024 17:17:54 +0200 Subject: [PATCH 4/4] Make new arg keyword-only and add defaults --- strawberry/aiohttp/views.py | 2 +- strawberry/asgi/__init__.py | 2 +- strawberry/channels/handlers/ws_handler.py | 2 +- strawberry/http/async_base_view.py | 2 +- strawberry/litestar/controller.py | 2 +- .../subscriptions/protocols/graphql_transport_ws/handlers.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index 4ae165a104..09f0b15156 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -91,7 +91,7 @@ def __init__(self, request: web.Request, ws: web.WebSocketResponse) -> None: self.ws = ws async def iter_json( - self, ignore_parsing_errors: bool + self, *, ignore_parsing_errors: bool = False ) -> AsyncGenerator[Dict[str, object], None]: async for ws_message in self.ws: if ws_message.type == http.WSMsgType.TEXT: diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index 2c519b41cd..11f445ecef 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -90,7 +90,7 @@ def __init__(self, request: WebSocket, response: WebSocket) -> None: self.ws = response async def iter_json( - self, ignore_parsing_errors: bool + self, *, ignore_parsing_errors: bool = False ) -> AsyncGenerator[Dict[str, object], None]: try: while self.ws.application_state != WebSocketState.DISCONNECTED: diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py index dce0dff6ea..1ec39be447 100644 --- a/strawberry/channels/handlers/ws_handler.py +++ b/strawberry/channels/handlers/ws_handler.py @@ -32,7 +32,7 @@ def __init__(self, request: GraphQLWSConsumer, response: GraphQLWSConsumer) -> N self.ws_consumer = response async def iter_json( - self, ignore_parsing_errors: bool + self, *, ignore_parsing_errors: bool = False ) -> AsyncGenerator[Dict[str, object], None]: while True: message = await self.ws_consumer.message_queue.get() diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 241b7d3837..e688c3b667 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -82,7 +82,7 @@ async def get_form_data(self) -> FormData: ... class AsyncWebSocketAdapter(abc.ABC): @abc.abstractmethod def iter_json( - self, ignore_parsing_errors: bool + self, *, ignore_parsing_errors: bool = False ) -> AsyncGenerator[Dict[str, object], None]: ... @abc.abstractmethod diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index 0d7a498f1e..348168e7db 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -197,7 +197,7 @@ def __init__(self, request: WebSocket, response: WebSocket) -> None: self.ws = response async def iter_json( - self, ignore_parsing_errors: bool + self, *, ignore_parsing_errors: bool = False ) -> AsyncGenerator[Dict[str, object], None]: try: while self.ws.connection_state != "disconnect": diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index a04dc4c892..c8285564f9 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -76,7 +76,7 @@ async def handle(self) -> Any: self.on_request_accepted() try: - async for message in self.websocket.iter_json(ignore_parsing_errors=False): + async for message in self.websocket.iter_json(): await self.handle_message(message) except NonTextMessageReceived: await self.handle_invalid_message("WebSocket message type must be text")