Skip to content

Commit

Permalink
Add a test for the task error handler
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Jun 21, 2023
1 parent c5c2211 commit a8f064c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,14 @@ async def operation_task(self, operation: Operation) -> None:
except asyncio.CancelledError:
raise
except Exception as error:
# Log any unhandled exceptions in the operation task
await self.handle_task_exception(error)
# cleanup in case of something really unexpected
finally:
# add this task to a list to be reaped later
# Clenaup. Remove the operation from the list of active operations
if operation.id in self.operations:
del self.operations[operation.id]
# TODO: Stop collecting background tasks, not necessary.
# Add this task to a list to be reaped later
self.completed_tasks.append(task)

async def handle_operation(
Expand Down
39 changes: 38 additions & 1 deletion tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import time
from datetime import timedelta
from typing import AsyncGenerator, Type
from typing import Any, AsyncGenerator, Type
from unittest.mock import patch

try:
Expand All @@ -16,6 +16,9 @@
from pytest_mock import MockerFixture

from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import (
BaseGraphQLTransportWSHandler,
)
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
CompleteMessage,
ConnectionAckMessage,
Expand Down Expand Up @@ -1068,3 +1071,37 @@ async def slow_get_context(ctxt):
id="sub1", payload={"data": {"valueFromContext": "slow"}}
).as_dict()
)


async def test_task_error_handler(ws: WebSocketClient):
"""
Test that error handling works
"""
# can't use a simple Event here, because the handler may run
# on a different thread
wakeup = False

# a replacement method which causes an error in th eTask
async def op(*args: Any, **kwargs: Any):
nonlocal wakeup
wakeup = True
raise ZeroDivisionError("test")

with patch.object(BaseGraphQLTransportWSHandler, "task_logger") as logger:
with patch.object(BaseGraphQLTransportWSHandler, "handle_operation", op):
# send any old subscription request. It will raise an error
await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query="subscription { conditionalFail(sleep:0) }"
),
).as_dict()
)

# wait for the error to be logged
while not wakeup:
await asyncio.sleep(0.01)
# and another little bit, for the thread to finish
await asyncio.sleep(0.01)
assert logger.exception.called

0 comments on commit a8f064c

Please sign in to comment.