Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix graphql ws did not ignore parsing errors #3670

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Release type: minor

This release fixes a regression in the legacy GraphQL over WebSocket protocol.
Legacy protocol implementations should ignore client message parsing errors.
During a recent refactor, Strawberry changed this behavior to match the new protocol, where parsing errors must close the WebSocket connection.
The expected behavior is restored and adequately tested in this release.
15 changes: 11 additions & 4 deletions strawberry/aiohttp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
AsyncHTTPRequestAdapter,
AsyncWebSocketAdapter,
)
from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
from strawberry.http.exceptions import (
HTTPException,
NonJsonMessageReceived,
NonTextMessageReceived,
)
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import (
Context,
Expand Down Expand Up @@ -86,16 +90,19 @@ def __init__(self, request: web.Request, ws: web.WebSocketResponse) -> None:
self.request = request
self.ws = ws

async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
async def iter_json(
self, ignore_parsing_errors: bool
DoctorJohn marked this conversation as resolved.
Show resolved Hide resolved
) -> AsyncGenerator[Dict[str, object], None]:
async for ws_message in self.ws:
if ws_message.type == http.WSMsgType.TEXT:
try:
yield ws_message.json()
except JSONDecodeError:
raise NonJsonMessageReceived()
if not ignore_parsing_errors:
raise NonJsonMessageReceived()

elif ws_message.type == http.WSMsgType.BINARY:
raise NonJsonMessageReceived()
raise NonTextMessageReceived()

async def send_json(self, message: Mapping[str, object]) -> None:
await self.ws.send_json(message)
Expand Down
21 changes: 15 additions & 6 deletions strawberry/asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
AsyncHTTPRequestAdapter,
AsyncWebSocketAdapter,
)
from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
from strawberry.http.exceptions import (
HTTPException,
NonJsonMessageReceived,
NonTextMessageReceived,
)
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import (
Context,
Expand Down Expand Up @@ -85,13 +89,18 @@ class ASGIWebSocketAdapter(AsyncWebSocketAdapter):
def __init__(self, request: WebSocket, response: WebSocket) -> None:
self.ws = response

async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
async def iter_json(
self, ignore_parsing_errors: bool
) -> AsyncGenerator[Dict[str, object], None]:
try:
try:
while self.ws.application_state != WebSocketState.DISCONNECTED:
while self.ws.application_state != WebSocketState.DISCONNECTED:
try:
yield await self.ws.receive_json()
except (KeyError, JSONDecodeError):
raise NonJsonMessageReceived()
except JSONDecodeError: # noqa: PERF203
if not ignore_parsing_errors:
raise NonJsonMessageReceived()
except KeyError:
raise NonTextMessageReceived()
except WebSocketDisconnect: # pragma: no cover
pass

Expand Down
11 changes: 7 additions & 4 deletions strawberry/channels/handlers/ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing_extensions import TypeGuard

from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter
from strawberry.http.exceptions import NonJsonMessageReceived
from strawberry.http.exceptions import NonJsonMessageReceived, NonTextMessageReceived
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL

Expand All @@ -31,20 +31,23 @@ class ChannelsWebSocketAdapter(AsyncWebSocketAdapter):
def __init__(self, request: GraphQLWSConsumer, response: GraphQLWSConsumer) -> None:
self.ws_consumer = response

async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
async def iter_json(
self, ignore_parsing_errors: bool
DoctorJohn marked this conversation as resolved.
Show resolved Hide resolved
) -> AsyncGenerator[Dict[str, object], None]:
while True:
message = await self.ws_consumer.message_queue.get()

if message["disconnected"]:
break

if message["message"] is None:
raise NonJsonMessageReceived()
raise NonTextMessageReceived()

try:
yield json.loads(message["message"])
except json.JSONDecodeError:
raise NonJsonMessageReceived()
if not ignore_parsing_errors:
raise NonJsonMessageReceived()

async def send_json(self, message: Mapping[str, object]) -> None:
serialized_message = json.dumps(message)
Expand Down
4 changes: 3 additions & 1 deletion strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ async def get_form_data(self) -> FormData: ...

class AsyncWebSocketAdapter(abc.ABC):
@abc.abstractmethod
def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: ...
def iter_json(
self, ignore_parsing_errors: bool
DoctorJohn marked this conversation as resolved.
Show resolved Hide resolved
) -> AsyncGenerator[Dict[str, object], None]: ...

@abc.abstractmethod
async def send_json(self, message: Mapping[str, object]) -> None: ...
Expand Down
4 changes: 4 additions & 0 deletions strawberry/http/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ def __init__(self, status_code: int, reason: str) -> None:
self.reason = reason


class NonTextMessageReceived(Exception):
pass


class NonJsonMessageReceived(Exception):
pass

Expand Down
29 changes: 21 additions & 8 deletions strawberry/litestar/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import json
import warnings
from datetime import timedelta
from typing import (
Expand Down Expand Up @@ -37,7 +38,6 @@
from litestar.di import Provide
from litestar.exceptions import (
NotFoundException,
SerializationException,
ValidationException,
WebSocketDisconnect,
)
Expand All @@ -49,7 +49,11 @@
AsyncHTTPRequestAdapter,
AsyncWebSocketAdapter,
)
from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
from strawberry.http.exceptions import (
HTTPException,
NonJsonMessageReceived,
NonTextMessageReceived,
)
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
Expand Down Expand Up @@ -192,13 +196,22 @@ class LitestarWebSocketAdapter(AsyncWebSocketAdapter):
def __init__(self, request: WebSocket, response: WebSocket) -> None:
self.ws = response

async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
async def iter_json(
self, ignore_parsing_errors: bool
DoctorJohn marked this conversation as resolved.
Show resolved Hide resolved
) -> AsyncGenerator[Dict[str, object], None]:
try:
try:
while self.ws.connection_state != "disconnect":
yield await self.ws.receive_json()
except (SerializationException, ValueError):
raise NonJsonMessageReceived()
while self.ws.connection_state != "disconnect":
text = await self.ws.receive_text()

# Litestar internally defaults to an empty string for non-text messages
if text == "":
raise NonTextMessageReceived()

try:
yield json.loads(text)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR (so nothing to change here), but this got me thinking if we should allow for different json libs to be used in places like this, like orjson.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good point, I'll create a PR for it!

except json.JSONDecodeError:
if not ignore_parsing_errors:
raise NonJsonMessageReceived()
except WebSocketDisconnect:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from graphql import GraphQLError, GraphQLSyntaxError, parse

from strawberry.http.exceptions import NonJsonMessageReceived
from strawberry.http.exceptions import NonJsonMessageReceived, NonTextMessageReceived
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
CompleteMessage,
ConnectionAckMessage,
Expand Down Expand Up @@ -76,10 +76,12 @@ async def handle(self) -> Any:
self.on_request_accepted()

try:
async for message in self.websocket.iter_json():
async for message in self.websocket.iter_json(ignore_parsing_errors=False):
await self.handle_message(message)
except NonJsonMessageReceived:
except NonTextMessageReceived:
await self.handle_invalid_message("WebSocket message type must be text")
except NonJsonMessageReceived:
await self.handle_invalid_message("WebSocket message must be valid JSON")
finally:
await self.shutdown()

Expand Down
6 changes: 3 additions & 3 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
cast,
)

from strawberry.http.exceptions import NonJsonMessageReceived
from strawberry.http.exceptions import NonTextMessageReceived
from strawberry.subscriptions.protocols.graphql_ws import (
GQL_COMPLETE,
GQL_CONNECTION_ACK,
Expand Down Expand Up @@ -65,9 +65,9 @@ def __init__(

async def handle(self) -> None:
try:
async for message in self.websocket.iter_json():
async for message in self.websocket.iter_json(ignore_parsing_errors=True):
await self.handle_message(cast(OperationMessage, message))
except NonJsonMessageReceived:
except NonTextMessageReceived:
await self.websocket.close(
code=1002, reason="WebSocket message type must be text"
)
Expand Down
6 changes: 3 additions & 3 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def test_parsing_an_invalid_payload(ws_raw: WebSocketClient):
assert ws.close_reason == "Failed to parse message"


async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
async def test_non_text_ws_messages_result_in_socket_closure(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_bytes(json.dumps(ConnectionInitMessage().as_dict()).encode())
Expand All @@ -123,15 +123,15 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
assert ws.close_reason == "WebSocket message type must be text"


async def test_ws_messages_must_be_json(ws_raw: WebSocketClient):
async def test_non_json_ws_messages_result_in_socket_closure(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_text("not valid json")

await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
assert ws.close_reason == "WebSocket message type must be text"
assert ws.close_reason == "WebSocket message must be valid JSON"


async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient):
Expand Down
39 changes: 33 additions & 6 deletions tests/websockets/test_graphql_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ async def test_subscription_syntax_error(ws: WebSocketClient):
}


async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
async def test_non_text_ws_messages_result_in_socket_closure(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_bytes(json.dumps({"type": GQL_CONNECTION_INIT}).encode())
Expand All @@ -292,15 +292,42 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
assert ws.close_reason == "WebSocket message type must be text"


async def test_ws_messages_must_be_json(ws_raw: WebSocketClient):
async def test_non_json_ws_messages_are_ignored(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_text("not valid json")
await ws.send_text("NOT VALID JSON")
await ws.send_json({"type": GQL_CONNECTION_INIT})

await ws.receive(timeout=2)
response = await ws.receive_json()
assert response["type"] == GQL_CONNECTION_ACK

await ws.send_text("NOT VALID JSON")
await ws.send_json(
{
"type": GQL_START,
"id": "demo",
"payload": {
"query": 'subscription { echo(message: "Hi") }',
},
}
)

response = await ws.receive_json()
assert response["type"] == GQL_DATA
assert response["id"] == "demo"
assert response["payload"]["data"] == {"echo": "Hi"}

await ws.send_text("NOT VALID JSON")
await ws.send_json({"type": GQL_STOP, "id": "demo"})

response = await ws.receive_json()
assert response["type"] == GQL_COMPLETE
assert response["id"] == "demo"

await ws.send_text("NOT VALID JSON")
await ws.send_json({"type": GQL_CONNECTION_TERMINATE})
await ws.receive(timeout=2) # receive close
assert ws.closed
assert ws.close_code == 1002
assert ws.close_reason == "WebSocket message type must be text"


async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient):
Expand Down
Loading