Skip to content

Commit

Permalink
Add new send_frame method to WebSockets (#9348)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Sep 29, 2024
1 parent 9314651 commit 2628256
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 47 deletions.
1 change: 1 addition & 0 deletions CHANGES/9348.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added :py:meth:`~aiohttp.ClientWebSocketResponse.send_frame` and :py:meth:`~aiohttp.web.WebSocketResponse.send_frame` for WebSockets -- by :user:`bdraco`.
12 changes: 10 additions & 2 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,23 @@ async def ping(self, message: bytes = b"") -> None:
async def pong(self, message: bytes = b"") -> None:
await self._writer.pong(message)

async def send_frame(
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket."""
await self._writer.send_frame(message, opcode, compress)

async def send_str(self, data: str, compress: Optional[int] = None) -> None:
if not isinstance(data, str):
raise TypeError("data argument must be str (%r)" % type(data))
await self._writer.send(data, binary=False, compress=compress)
await self._writer.send_frame(
data.encode("utf-8"), WSMsgType.TEXT, compress=compress
)

async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError("data argument must be byte-ish (%r)" % type(data))
await self._writer.send(data, binary=True, compress=compress)
await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)

async def send_json(
self,
Expand Down
22 changes: 4 additions & 18 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def __init__(
self._output_size = 0
self._compressobj: Any = None # actually compressobj

async def _send_frame(
async def send_frame(
self, message: bytes, opcode: int, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket with message as its payload."""
Expand Down Expand Up @@ -710,32 +710,18 @@ def _write(self, data: bytes) -> None:

async def pong(self, message: bytes = b"") -> None:
"""Send pong message."""
await self._send_frame(message, WSMsgType.PONG)
await self.send_frame(message, WSMsgType.PONG)

async def ping(self, message: bytes = b"") -> None:
"""Send ping message."""
await self._send_frame(message, WSMsgType.PING)

async def send(
self,
message: Union[str, bytes],
binary: bool = False,
compress: Optional[int] = None,
) -> None:
"""Send a frame over the websocket with message as its payload."""
if isinstance(message, str):
message = message.encode("utf-8")
if binary:
await self._send_frame(message, WSMsgType.BINARY, compress)
else:
await self._send_frame(message, WSMsgType.TEXT, compress)
await self.send_frame(message, WSMsgType.PING)

async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode("utf-8")
try:
await self._send_frame(
await self.send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
)
finally:
Expand Down
14 changes: 12 additions & 2 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,19 +404,29 @@ async def pong(self, message: bytes = b"") -> None:
raise RuntimeError("Call .prepare() first")
await self._writer.pong(message)

async def send_frame(
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket."""
if self._writer is None:
raise RuntimeError("Call .prepare() first")
await self._writer.send_frame(message, opcode, compress)

async def send_str(self, data: str, compress: Optional[int] = None) -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
if not isinstance(data, str):
raise TypeError("data argument must be str (%r)" % type(data))
await self._writer.send(data, binary=False, compress=compress)
await self._writer.send_frame(
data.encode("utf-8"), WSMsgType.TEXT, compress=compress
)

async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError("data argument must be byte-ish (%r)" % type(data))
await self._writer.send(data, binary=True, compress=compress)
await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)

async def send_json(
self,
Expand Down
26 changes: 26 additions & 0 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,32 @@ manually.
The method is converted into :term:`coroutine`,
*compress* parameter added.

.. method:: send_frame(message, opcode, compress=None)
:async:

Send a :const:`~aiohttp.WSMsgType` message *message* to peer.

This method is low-level and should be used with caution as it
only accepts bytes which must conform to the correct message type
for *message*.

It is recommended to use the :meth:`send_str`, :meth:`send_bytes`
or :meth:`send_json` methods instead of this method.

The primary use case for this method is to send bytes that are
have already been encoded without having to decode and
re-encode them.

:param bytes message: message to send.

:param ~aiohttp.WSMsgType opcode: opcode of the message.

:param int compress: sets specific level of compression for
single message,
``None`` for not overriding per-socket setting.

.. versionadded:: 3.11

.. method:: close(*, code=WSCloseCode.OK, message=b'')
:async:

Expand Down
30 changes: 28 additions & 2 deletions docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -934,8 +934,8 @@ and :ref:`aiohttp-web-signals` handlers::

To enable back-pressure from slow websocket clients treat methods
:meth:`ping`, :meth:`pong`, :meth:`send_str`,
:meth:`send_bytes`, :meth:`send_json` as coroutines. By
default write buffer size is set to 64k.
:meth:`send_bytes`, :meth:`send_json`, :meth:`send_frame` as coroutines.
By default write buffer size is set to 64k.

:param bool autoping: Automatically send
:const:`~aiohttp.WSMsgType.PONG` on
Expand Down Expand Up @@ -1149,6 +1149,32 @@ and :ref:`aiohttp-web-signals` handlers::
The method is converted into :term:`coroutine`,
*compress* parameter added.

.. method:: send_frame(message, opcode, compress=None)
:async:

Send a :const:`~aiohttp.WSMsgType` message *message* to peer.

This method is low-level and should be used with caution as it
only accepts bytes which must conform to the correct message type
for *message*.

It is recommended to use the :meth:`send_str`, :meth:`send_bytes`
or :meth:`send_json` methods instead of this method.

The primary use case for this method is to send bytes that are
have already been encoded without having to decode and
re-encode them.

:param bytes message: message to send.

:param ~aiohttp.WSMsgType opcode: opcode of the message.

:param int compress: sets specific level of compression for
single message,
``None`` for not overriding per-socket setting.

.. versionadded:: 3.11

.. method:: close(*, code=WSCloseCode.OK, message=b'', drain=True)
:async:

Expand Down
18 changes: 14 additions & 4 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ async def test_send_data_after_close(
(resp.send_str, ("s",)),
(resp.send_bytes, (b"b",)),
(resp.send_json, ({},)),
(resp.send_frame, (b"", aiohttp.WSMsgType.BINARY)),
):
with pytest.raises(exc): # Verify exc can be caught with both classes
await meth(*args)
Expand Down Expand Up @@ -775,19 +776,28 @@ async def test_ws_connect_deflate_per_message(
m_req.return_value = loop.create_future()
m_req.return_value.set_result(mresp)
writer = WebSocketWriter.return_value = mock.Mock()
send = writer.send = make_mocked_coro()
send_frame = writer.send_frame = make_mocked_coro()

session = aiohttp.ClientSession()
resp = await session.ws_connect("http://test.org")

await resp.send_str("string", compress=-1)
send.assert_called_with("string", binary=False, compress=-1)
send_frame.assert_called_with(
b"string", aiohttp.WSMsgType.TEXT, compress=-1
)

await resp.send_bytes(b"bytes", compress=15)
send.assert_called_with(b"bytes", binary=True, compress=15)
send_frame.assert_called_with(
b"bytes", aiohttp.WSMsgType.BINARY, compress=15
)

await resp.send_json([{}], compress=-9)
send.assert_called_with("[{}]", binary=False, compress=-9)
send_frame.assert_called_with(
b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9
)

await resp.send_frame(b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9)
send_frame.assert_called_with(b"[{}]", aiohttp.WSMsgType.TEXT, -9)

await session.close()

Expand Down
22 changes: 22 additions & 0 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,28 @@ async def handler(request: web.Request) -> web.WebSocketResponse:
await resp.close()


async def test_send_recv_frame(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)

data = await ws.receive()
await ws.send_frame(data.data, data.type)
await ws.close()
return ws

app = web.Application()
app.router.add_route("GET", "/", handler)
client = await aiohttp_client(app)
resp = await client.ws_connect("/")
await resp.send_frame(b"test", WSMsgType.BINARY)

data = await resp.receive()
assert data.data == b"test"
assert data.type is WSMsgType.BINARY
await resp.close()


async def test_ping_pong(aiohttp_client: AiohttpClient) -> None:
loop = asyncio.get_event_loop()
closed = loop.create_future()
Expand Down
29 changes: 25 additions & 4 deletions tests/test_web_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ async def test_nonstarted_pong() -> None:
await ws.pong()


async def test_nonstarted_send_frame() -> None:
ws = web.WebSocketResponse()
with pytest.raises(RuntimeError):
await ws.send_frame(b"string", WSMsgType.TEXT)


async def test_nonstarted_send_str() -> None:
ws = web.WebSocketResponse()
with pytest.raises(RuntimeError):
Expand Down Expand Up @@ -268,6 +274,18 @@ async def test_send_json_closed(make_request: _RequestMaker) -> None:
await ws.send_json({"type": "json"})


async def test_send_frame_closed(make_request: _RequestMaker) -> None:
req = make_request("GET", "/")
ws = web.WebSocketResponse()
await ws.prepare(req)
assert ws._reader is not None
ws._reader.feed_data(WS_CLOSED_MESSAGE)
await ws.close()

with pytest.raises(ConnectionError):
await ws.send_frame(b'{"type": "json"}', WSMsgType.TEXT)


async def test_ping_closed(make_request: _RequestMaker) -> None:
req = make_request("GET", "/")
ws = web.WebSocketResponse()
Expand Down Expand Up @@ -560,15 +578,18 @@ async def test_send_with_per_message_deflate(
req = make_request("GET", "/")
ws = web.WebSocketResponse()
await ws.prepare(req)
with mock.patch.object(ws._writer, "send", autospec=True, spec_set=True) as m:
with mock.patch.object(ws._writer, "send_frame", autospec=True, spec_set=True) as m:
await ws.send_str("string", compress=15)
m.assert_called_with("string", binary=False, compress=15)
m.assert_called_with(b"string", WSMsgType.TEXT, compress=15)

await ws.send_bytes(b"bytes", compress=0)
m.assert_called_with(b"bytes", binary=True, compress=0)
m.assert_called_with(b"bytes", WSMsgType.BINARY, compress=0)

await ws.send_json("[{}]", compress=9)
m.assert_called_with('"[{}]"', binary=False, compress=9)
m.assert_called_with(b'"[{}]"', WSMsgType.TEXT, compress=9)

await ws.send_frame(b"[{}]", WSMsgType.TEXT, compress=9)
m.assert_called_with(b"[{}]", WSMsgType.TEXT, compress=9)


async def test_no_transfer_encoding_header(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ async def handler(request: web.Request) -> web.WebSocketResponse:

ws = await client.ws_connect("/", protocols=("eggs", "bar"))

await ws._writer._send_frame(b"", WSMsgType.CLOSE)
await ws._writer.send_frame(b"", WSMsgType.CLOSE)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSE
Expand Down
Loading

0 comments on commit 2628256

Please sign in to comment.