diff --git a/CHANGES/9137.bugfix.rst b/CHANGES/9137.bugfix.rst new file mode 100644 index 00000000000..d99802095bd --- /dev/null +++ b/CHANGES/9137.bugfix.rst @@ -0,0 +1,2 @@ +Added :exc:`aiohttp.ClientConnectionResetError`. Client code that previously threw :exc:`ConnectionResetError` +will now throw this -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index 0dc4484388d..af5243106c0 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -6,6 +6,7 @@ from .client import ( BaseConnector, ClientConnectionError, + ClientConnectionResetError, ClientConnectorCertificateError, ClientConnectorError, ClientConnectorSSLError, @@ -117,6 +118,7 @@ # client "BaseConnector", "ClientConnectionError", + "ClientConnectionResetError", "ClientConnectorCertificateError", "ClientConnectorError", "ClientConnectorSSLError", diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index dc1f24f99cd..2fc2fa65885 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -1,6 +1,7 @@ import asyncio from typing import Optional, cast +from .client_exceptions import ClientConnectionResetError from .helpers import set_exception from .tcp_helpers import tcp_nodelay @@ -85,7 +86,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: async def _drain_helper(self) -> None: if not self.connected: - raise ConnectionResetError("Connection lost") + raise ClientConnectionResetError("Connection lost") if not self._paused: return waiter = self._drain_waiter diff --git a/aiohttp/client.py b/aiohttp/client.py index abeac1d228c..9c2fd8073a3 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -42,6 +42,7 @@ from .abc import AbstractCookieJar from .client_exceptions import ( ClientConnectionError, + ClientConnectionResetError, ClientConnectorCertificateError, ClientConnectorError, ClientConnectorSSLError, @@ -107,6 +108,7 @@ __all__ = ( # client_exceptions "ClientConnectionError", + "ClientConnectionResetError", "ClientConnectorCertificateError", "ClientConnectorError", "ClientConnectorSSLError", diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index eb5e1b09692..f9711bc2e71 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -5,7 +5,6 @@ from multidict import MultiMapping -from .http_parser import RawResponseMessage from .typedefs import StrOrURL try: @@ -18,12 +17,14 @@ if TYPE_CHECKING: from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo + from .http_parser import RawResponseMessage else: - RequestInfo = ClientResponse = ConnectionKey = None + RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None __all__ = ( "ClientError", "ClientConnectionError", + "ClientConnectionResetError", "ClientOSError", "ClientConnectorError", "ClientProxyConnectionError", @@ -126,6 +127,10 @@ class ClientConnectionError(ClientError): """Base class for client socket errors.""" +class ClientConnectionResetError(ClientConnectionError, ConnectionResetError): + """ConnectionResetError""" + + class ClientOSError(ClientConnectionError, OSError): """OSError error.""" diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index 78aac0cc26e..dd17b7675de 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -25,6 +25,7 @@ ) from .base_protocol import BaseProtocol +from .client_exceptions import ClientConnectionResetError from .compression_utils import ZLibCompressor, ZLibDecompressor from .helpers import NO_EXTENSIONS, set_exception from .streams import DataQueue @@ -609,7 +610,7 @@ async def _send_frame( ) -> None: """Send a frame over the websocket with message as its payload.""" if self._closing and not (opcode & WSMsgType.CLOSE): - raise ConnectionResetError("Cannot write to closing transport") + raise ClientConnectionResetError("Cannot write to closing transport") # RSV are the reserved bits in the frame header. They are used to # indicate that the frame is using an extension. @@ -704,7 +705,7 @@ def _make_compress_obj(self, compress: int) -> ZLibCompressor: def _write(self, data: bytes) -> None: if self.transport.is_closing(): - raise ConnectionResetError("Cannot write to closing transport") + raise ClientConnectionResetError("Cannot write to closing transport") self.transport.write(data) async def pong(self, message: Union[bytes, str] = b"") -> None: diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index d6b02e6f566..f54fa0f0774 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -8,6 +8,7 @@ from .abc import AbstractStreamWriter from .base_protocol import BaseProtocol +from .client_exceptions import ClientConnectionResetError from .compression_utils import ZLibCompressor from .helpers import NO_EXTENSIONS @@ -72,7 +73,7 @@ def _write(self, chunk: bytes) -> None: self.output_size += size transport = self.transport if not self._protocol.connected or transport is None or transport.is_closing(): - raise ConnectionResetError("Cannot write to closing transport") + raise ClientConnectionResetError("Cannot write to closing transport") transport.write(chunk) async def write( diff --git a/docs/client_reference.rst b/docs/client_reference.rst index d3768e581d2..89cf923ee74 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -2193,6 +2193,10 @@ Connection errors Derived from :exc:`ClientError` +.. class:: ClientConnectionResetError + + Derived from :exc:`ClientConnectionError` and :exc:`ConnectionResetError` + .. class:: ClientOSError Subset of connection errors that are initiated by an :exc:`OSError` @@ -2279,6 +2283,8 @@ Hierarchy of exceptions * :exc:`ClientConnectionError` + * :exc:`ClientConnectionResetError` + * :exc:`ClientOSError` * :exc:`ClientConnectorError` diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 7521bd24ac5..b4c3f6820be 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -2,15 +2,19 @@ import base64 import hashlib import os -from typing import Mapping +from typing import Mapping, Type from unittest import mock import pytest import aiohttp -from aiohttp import client, hdrs -from aiohttp.client_exceptions import ServerDisconnectedError -from aiohttp.client_ws import ClientWSTimeout +from aiohttp import ( + ClientConnectionResetError, + ClientWSTimeout, + ServerDisconnectedError, + client, + hdrs, +) from aiohttp.http import WS_KEY from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro @@ -535,8 +539,12 @@ async def test_close_exc2( await resp.close() +@pytest.mark.parametrize("exc", (ClientConnectionResetError, ConnectionResetError)) async def test_send_data_after_close( - ws_key: bytes, key_data: bytes, loop: asyncio.AbstractEventLoop + exc: Type[Exception], + ws_key: bytes, + key_data: bytes, + loop: asyncio.AbstractEventLoop, ) -> None: mresp = mock.Mock() mresp.status = 101 @@ -562,7 +570,7 @@ async def test_send_data_after_close( (resp.send_bytes, (b"b",)), (resp.send_json, ({},)), ): - with pytest.raises(ConnectionResetError): + with pytest.raises(exc): # Verify exc can be caught with both classes await meth(*args) diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 6659859369f..dc86691bb1c 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -6,7 +6,7 @@ import pytest import aiohttp -from aiohttp import ServerTimeoutError, WSMsgType, hdrs, web +from aiohttp import ClientConnectionResetError, ServerTimeoutError, WSMsgType, hdrs, web from aiohttp.client_ws import ClientWSTimeout from aiohttp.http import WSCloseCode from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer @@ -681,7 +681,7 @@ async def handler(request: web.Request) -> NoReturn: # would cancel the heartbeat task and we wouldn't get a ping assert resp._conn is not None with mock.patch.object( - resp._conn.transport, "write", side_effect=ConnectionResetError + resp._conn.transport, "write", side_effect=ClientConnectionResetError ), mock.patch.object(resp._writer, "ping", wraps=resp._writer.ping) as ping: await resp.receive() ping_count = ping.call_count diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index d1db6bf40ea..825e858457a 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -7,7 +7,7 @@ import pytest from multidict import CIMultiDict -from aiohttp import http +from aiohttp import ClientConnectionResetError, http from aiohttp.base_protocol import BaseProtocol from aiohttp.test_utils import make_mocked_coro @@ -301,7 +301,7 @@ async def test_write_to_closing_transport( await msg.write(b"Before closing") transport.is_closing.return_value = True # type: ignore[attr-defined] - with pytest.raises(ConnectionResetError): + with pytest.raises(ClientConnectionResetError): await msg.write(b"After closing") @@ -310,7 +310,7 @@ async def test_write_to_closed_transport( transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: - """Test that writing to a closed transport raises ConnectionResetError. + """Test that writing to a closed transport raises ClientConnectionResetError. The StreamWriter checks to see if protocol.transport is None before writing to the transport. If it is None, it raises ConnectionResetError. @@ -320,7 +320,9 @@ async def test_write_to_closed_transport( await msg.write(b"Before transport close") protocol.transport = None - with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"): + with pytest.raises( + ClientConnectionResetError, match="Cannot write to closing transport" + ): await msg.write(b"After transport closed")