Skip to content

Commit

Permalink
Add new send_frame method to WebSockets (#9348)
Browse files Browse the repository at this point in the history
(cherry picked from commit 2628256)
  • Loading branch information
bdraco committed Sep 29, 2024
1 parent 1ca7244 commit 6c8b526
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGES/9348.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added :py:meth:`~aiohttp.ClientWebSocketResponse.send_frame` and :py:meth:`~aiohttp.web.WebSocketResponse.send_frame` for WebSockets -- by :user:`bdraco`.
12 changes: 10 additions & 2 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,23 @@ async def ping(self, message: bytes = b"") -> None:
async def pong(self, message: bytes = b"") -> None:
await self._writer.pong(message)

async def send_frame(
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket."""
await self._writer.send_frame(message, opcode, compress)

async def send_str(self, data: str, compress: Optional[int] = None) -> None:
if not isinstance(data, str):
raise TypeError("data argument must be str (%r)" % type(data))
await self._writer.send(data, binary=False, compress=compress)
await self._writer.send_frame(
data.encode("utf-8"), WSMsgType.TEXT, compress=compress
)

async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError("data argument must be byte-ish (%r)" % type(data))
await self._writer.send(data, binary=True, compress=compress)
await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)

async def send_json(
self,
Expand Down
22 changes: 4 additions & 18 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def __init__(
self._output_size = 0
self._compressobj: Any = None # actually compressobj

async def _send_frame(
async def send_frame(
self, message: bytes, opcode: int, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket with message as its payload."""
Expand Down Expand Up @@ -727,34 +727,20 @@ async def pong(self, message: Union[bytes, str] = b"") -> None:
"""Send pong message."""
if isinstance(message, str):
message = message.encode("utf-8")
await self._send_frame(message, WSMsgType.PONG)
await self.send_frame(message, WSMsgType.PONG)

async def ping(self, message: Union[bytes, str] = b"") -> None:
"""Send ping message."""
if isinstance(message, str):
message = message.encode("utf-8")
await self._send_frame(message, WSMsgType.PING)

async def send(
self,
message: Union[str, bytes],
binary: bool = False,
compress: Optional[int] = None,
) -> None:
"""Send a frame over the websocket with message as its payload."""
if isinstance(message, str):
message = message.encode("utf-8")
if binary:
await self._send_frame(message, WSMsgType.BINARY, compress)
else:
await self._send_frame(message, WSMsgType.TEXT, compress)
await self.send_frame(message, WSMsgType.PING)

async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode("utf-8")
try:
await self._send_frame(
await self.send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
)
finally:
Expand Down
14 changes: 12 additions & 2 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,19 +379,29 @@ async def pong(self, message: bytes = b"") -> None:
raise RuntimeError("Call .prepare() first")
await self._writer.pong(message)

async def send_frame(
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket."""
if self._writer is None:
raise RuntimeError("Call .prepare() first")
await self._writer.send_frame(message, opcode, compress)

async def send_str(self, data: str, compress: Optional[int] = None) -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
if not isinstance(data, str):
raise TypeError("data argument must be str (%r)" % type(data))
await self._writer.send(data, binary=False, compress=compress)
await self._writer.send_frame(
data.encode("utf-8"), WSMsgType.TEXT, compress=compress
)

async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError("data argument must be byte-ish (%r)" % type(data))
await self._writer.send(data, binary=True, compress=compress)
await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)

async def send_json(
self,
Expand Down
26 changes: 26 additions & 0 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1610,6 +1610,32 @@ manually.
The method is converted into :term:`coroutine`,
*compress* parameter added.

.. method:: send_frame(message, opcode, compress=None)
:async:

Send a :const:`~aiohttp.WSMsgType` message *message* to peer.

This method is low-level and should be used with caution as it
only accepts bytes which must conform to the correct message type
for *message*.

It is recommended to use the :meth:`send_str`, :meth:`send_bytes`
or :meth:`send_json` methods instead of this method.

The primary use case for this method is to send bytes that are
have already been encoded without having to decode and
re-encode them.

:param bytes message: message to send.

:param ~aiohttp.WSMsgType opcode: opcode of the message.

:param int compress: sets specific level of compression for
single message,
``None`` for not overriding per-socket setting.

.. versionadded:: 3.11

.. method:: close(*, code=WSCloseCode.OK, message=b'')
:async:

Expand Down
30 changes: 28 additions & 2 deletions docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -966,8 +966,8 @@ and :ref:`aiohttp-web-signals` handlers::

To enable back-pressure from slow websocket clients treat methods
:meth:`ping`, :meth:`pong`, :meth:`send_str`,
:meth:`send_bytes`, :meth:`send_json` as coroutines. By
default write buffer size is set to 64k.
:meth:`send_bytes`, :meth:`send_json`, :meth:`send_frame` as coroutines.
By default write buffer size is set to 64k.

:param bool autoping: Automatically send
:const:`~aiohttp.WSMsgType.PONG` on
Expand Down Expand Up @@ -1181,6 +1181,32 @@ and :ref:`aiohttp-web-signals` handlers::
The method is converted into :term:`coroutine`,
*compress* parameter added.

.. method:: send_frame(message, opcode, compress=None)
:async:

Send a :const:`~aiohttp.WSMsgType` message *message* to peer.

This method is low-level and should be used with caution as it
only accepts bytes which must conform to the correct message type
for *message*.

It is recommended to use the :meth:`send_str`, :meth:`send_bytes`
or :meth:`send_json` methods instead of this method.

The primary use case for this method is to send bytes that are
have already been encoded without having to decode and
re-encode them.

:param bytes message: message to send.

:param ~aiohttp.WSMsgType opcode: opcode of the message.

:param int compress: sets specific level of compression for
single message,
``None`` for not overriding per-socket setting.

.. versionadded:: 3.11

.. method:: close(*, code=WSCloseCode.OK, message=b'', drain=True)
:async:

Expand Down
18 changes: 14 additions & 4 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ async def test_send_data_after_close(
(resp.send_str, ("s",)),
(resp.send_bytes, (b"b",)),
(resp.send_json, ({},)),
(resp.send_frame, (b"", aiohttp.WSMsgType.BINARY)),
):
with pytest.raises(exc): # Verify exc can be caught with both classes
await meth(*args)
Expand Down Expand Up @@ -725,19 +726,28 @@ async def test_ws_connect_deflate_per_message(loop, ws_key, key_data) -> None:
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)
writer = WebSocketWriter.return_value = mock.Mock()
send = writer.send = make_mocked_coro()
send_frame = writer.send_frame = make_mocked_coro()

session = aiohttp.ClientSession(loop=loop)
resp = await session.ws_connect("http://test.org")

await resp.send_str("string", compress=-1)
send.assert_called_with("string", binary=False, compress=-1)
send_frame.assert_called_with(
b"string", aiohttp.WSMsgType.TEXT, compress=-1
)

await resp.send_bytes(b"bytes", compress=15)
send.assert_called_with(b"bytes", binary=True, compress=15)
send_frame.assert_called_with(
b"bytes", aiohttp.WSMsgType.BINARY, compress=15
)

await resp.send_json([{}], compress=-9)
send.assert_called_with("[{}]", binary=False, compress=-9)
send_frame.assert_called_with(
b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9
)

await resp.send_frame(b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9)
send_frame.assert_called_with(b"[{}]", aiohttp.WSMsgType.TEXT, -9)

await session.close()

Expand Down
24 changes: 23 additions & 1 deletion tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,29 @@ async def handler(request):
await resp.close()


async def test_ping_pong(aiohttp_client) -> None:
async def test_send_recv_frame(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)

data = await ws.receive()
await ws.send_frame(data.data, data.type)
await ws.close()
return ws

app = web.Application()
app.router.add_route("GET", "/", handler)
client = await aiohttp_client(app)
resp = await client.ws_connect("/")
await resp.send_frame(b"test", WSMsgType.BINARY)

data = await resp.receive()
assert data.data == b"test"
assert data.type is WSMsgType.BINARY
await resp.close()


async def test_ping_pong(aiohttp_client: AiohttpClient) -> None:
loop = asyncio.get_event_loop()
closed = loop.create_future()

Expand Down
34 changes: 27 additions & 7 deletions tests/test_web_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ async def test_nonstarted_pong() -> None:
await ws.pong()


async def test_nonstarted_send_frame() -> None:
ws = WebSocketResponse()
with pytest.raises(RuntimeError):
await ws.send_frame(b"string", WSMsgType.TEXT)


async def test_nonstarted_send_str() -> None:
ws = WebSocketResponse()
with pytest.raises(RuntimeError):
Expand Down Expand Up @@ -277,6 +283,18 @@ async def test_send_json_closed(make_request) -> None:
await ws.send_json({"type": "json"})


async def test_send_frame_closed(make_request) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
assert ws._reader is not None
ws._reader.feed_data(WS_CLOSED_MESSAGE)
await ws.close()

with pytest.raises(ConnectionError):
await ws.send_frame(b'{"type": "json"}', WSMsgType.TEXT)


async def test_ping_closed(make_request) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
Expand Down Expand Up @@ -536,16 +554,18 @@ async def test_send_with_per_message_deflate(make_request, mocker) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
writer_send = ws._writer.send = make_mocked_coro()
with mock.patch.object(ws._writer, "send_frame", autospec=True, spec_set=True) as m:
await ws.send_str("string", compress=15)
m.assert_called_with(b"string", WSMsgType.TEXT, compress=15)

await ws.send_str("string", compress=15)
writer_send.assert_called_with("string", binary=False, compress=15)
await ws.send_bytes(b"bytes", compress=0)
m.assert_called_with(b"bytes", WSMsgType.BINARY, compress=0)

await ws.send_bytes(b"bytes", compress=0)
writer_send.assert_called_with(b"bytes", binary=True, compress=0)
await ws.send_json("[{}]", compress=9)
m.assert_called_with(b'"[{}]"', WSMsgType.TEXT, compress=9)

await ws.send_json("[{}]", compress=9)
writer_send.assert_called_with('"[{}]"', binary=False, compress=9)
await ws.send_frame(b"[{}]", WSMsgType.TEXT, compress=9)
m.assert_called_with(b"[{}]", WSMsgType.TEXT, compress=9)


async def test_no_transfer_encoding_header(make_request, mocker) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ async def handler(request):

ws: web.WebSocketResponse = await client.ws_connect("/", protocols=("eggs", "bar"))

await ws._writer._send_frame(b"", WSMsgType.CLOSE)
await ws._writer.send_frame(b"", WSMsgType.CLOSE)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSE
Expand Down
Loading

0 comments on commit 6c8b526

Please sign in to comment.