Skip to content

Commit

Permalink
Fix close race that prevented the close code from reaching the client (
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Aug 12, 2024
1 parent 3de518a commit 4f41d05
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGES/8680.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a race closing the server-side WebSocket where the close code would not reach the client. -- by :user:`bdraco`.
28 changes: 12 additions & 16 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,23 +431,10 @@ async def close(
if self._writer is None:
raise RuntimeError("Call .prepare() first")

self._cancel_heartbeat()
reader = self._reader
assert reader is not None

# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting and not self._closed:
if not self._close_wait:
assert self._loop is not None
self._close_wait = self._loop.create_future()
reader.feed_data(WS_CLOSING_MESSAGE)
await self._close_wait

if self._closed:
return False

self._set_closed()

try:
await self._writer.close(code, message)
writer = self._payload_writer
Expand All @@ -462,12 +449,21 @@ async def close(
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True

reader = self._reader
assert reader is not None
# we need to break `receive()` cycle before we can call
# `reader.read()` as `close()` may be called from different task
if self._waiting:
assert self._loop is not None
assert self._close_wait is None
self._close_wait = self._loop.create_future()
reader.feed_data(WS_CLOSING_MESSAGE)
await self._close_wait

if self._closing:
self._close_transport()
return True

reader = self._reader
assert reader is not None
try:
async with async_timeout.timeout(self._timeout):
msg = await reader.read()
Expand Down
60 changes: 60 additions & 0 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import asyncio
import contextlib
import sys
import weakref
from typing import Any, Optional

import pytest

import aiohttp
from aiohttp import WSServerHandshakeError, web
from aiohttp.http import WSCloseCode, WSMsgType
from aiohttp.pytest_plugin import AiohttpClient


async def test_websocket_can_prepare(loop: Any, aiohttp_client: Any) -> None:
Expand Down Expand Up @@ -1019,3 +1021,61 @@ async def handler(request):
await ws.close(code=WSCloseCode.OK, message="exit message")

await closed


async def test_websocket_shutdown(aiohttp_client: AiohttpClient) -> None:
"""Test that the client websocket gets the close message when the server is shutting down."""
url = "/ws"
app = web.Application()
websockets = web.AppKey("websockets", weakref.WeakSet)
app[websockets] = weakref.WeakSet()

# need for send signal shutdown server
shutdown_websockets = web.AppKey("shutdown_websockets", weakref.WeakSet)
app[shutdown_websockets] = weakref.WeakSet()

async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
websocket = web.WebSocketResponse()
await websocket.prepare(request)
request.app[websockets].add(websocket)
request.app[shutdown_websockets].add(websocket)

try:
async for message in websocket:
await websocket.send_json({"ok": True, "message": message.json()})
finally:
request.app[websockets].discard(websocket)

return websocket

async def on_shutdown(app: web.Application) -> None:
while app[shutdown_websockets]:
websocket = app[shutdown_websockets].pop()
await websocket.close(
code=aiohttp.WSCloseCode.GOING_AWAY,
message="Server shutdown",
)

app.router.add_get(url, websocket_handler)
app.on_shutdown.append(on_shutdown)

client = await aiohttp_client(app)

websocket = await client.ws_connect(url)

message = {"message": "hi"}
await websocket.send_json(message)
reply = await websocket.receive_json()
assert reply == {"ok": True, "message": message}

await app.shutdown()

assert websocket.closed is False

reply = await websocket.receive()

assert reply.type is aiohttp.http.WSMsgType.CLOSE
assert reply.data == aiohttp.WSCloseCode.GOING_AWAY
assert reply.extra == "Server shutdown"

assert websocket.closed is True

0 comments on commit 4f41d05

Please sign in to comment.