diff --git a/tests/test_connector.py b/tests/test_connector.py index 63dff162f4..3293c6bfc8 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -21,7 +21,6 @@ Optional, Sequence, Tuple, - Union, ) from unittest import mock @@ -1064,6 +1063,13 @@ async def create_connection( established_connection.close() +@pytest.mark.parametrize( + ("request_url"), + [ + ("http://mocked.host"), + ("https://mocked.host"), + ], +) @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, @@ -1072,6 +1078,7 @@ async def create_connection( async def test_tcp_connector_multiple_hosts_one_timeout( start_connection: mock.Mock, loop: asyncio.AbstractEventLoop, + request_url: str, ) -> None: conn = aiohttp.TCPConnector() @@ -1079,38 +1086,43 @@ async def test_tcp_connector_multiple_hosts_one_timeout( ip2 = "192.168.1.2" ips = [ip1, ip2] ips_tried = [] + ips_success = [] + timeout_error = False + connected = False req = ClientRequest( "GET", - URL("https://mocked.host"), + URL(request_url), loop=loop, ) async def _resolve_host( - host: str, port: int, traces: Optional[Sequence[Trace]] = None + host: str, port: int, traces: object = None ) -> List[ResolveResult]: return [ { "hostname": host, "host": ip, "port": port, - "family": socket.AF_INET, + "family": socket.AF_INET6 if ":" in ip else socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } for ip in ips ] - timeout_error = False - connected = False - - async def create_connection( - *args: object, **kwargs: object - ) -> Tuple[mock.Mock, mock.Mock]: - nonlocal timeout_error, connected + async def start_connection( + addr_infos: Sequence[AddrInfoType], + *, + interleave: Optional[int] = None, + **kwargs: object, + ) -> socket.socket: + nonlocal timeout_error - ip = args[1] + addr_info = addr_infos[0] + addr_info_addr = addr_info[-1] + ip = addr_info_addr[0] ips_tried.append(ip) if ip == ip1: @@ -1118,35 +1130,52 @@ async def create_connection( raise asyncio.TimeoutError if ip == ip2: - connected = True - tr = create_mocked_conn(loop) - pr = create_mocked_conn(loop) + mock_socket = mock.create_autospec( + socket.socket, spec_set=True, instance=True + ) + mock_socket.getpeername.return_value = addr_info_addr + return mock_socket # type: ignore[no-any-return] - def get_extra_info(param: str) -> Union[bool, mock.Mock]: - if param == "sslcontext": - return True + assert False - if param == "ssl_object": - s = create_mocked_conn(loop) - s.getpeercert.return_value = b"foo" - return s + async def create_connection( + *args: object, sock: Optional[socket.socket] = None, **kwargs: object + ) -> Tuple[ResponseHandler, ResponseHandler]: + nonlocal connected - assert False + assert isinstance(sock, socket.socket) + addr_info = sock.getpeername() + ip = addr_info[0] + ips_success.append(ip) + connected = True - tr.get_extra_info = get_extra_info - return tr, pr + # Close the socket since we are not actually connecting + # and we don't want to leak it. + sock.close() + tr = create_mocked_conn(loop) + pr = create_mocked_conn(loop) + return tr, pr - assert False + with mock.patch.object( + conn, "_resolve_host", autospec=True, spec_set=True, side_effect=_resolve_host + ), mock.patch.object( + conn._loop, + "create_connection", + autospec=True, + spec_set=True, + side_effect=create_connection, + ), mock.patch( + "aiohttp.connector.aiohappyeyeballs.start_connection", start_connection + ): + established_connection = await conn.connect(req, [], ClientTimeout()) - with mock.patch.object(conn, "_resolve_host", _resolve_host): - with mock.patch.object(conn._loop, "create_connection", create_connection): - established_connection = await conn.connect(req, [], ClientTimeout()) - assert ips == ips_tried + assert ips_tried == ips + assert ips_success == [ip2] - assert timeout_error - assert connected + assert timeout_error + assert connected - established_connection.close() + established_connection.close() async def test_tcp_connector_resolve_host(loop: asyncio.AbstractEventLoop) -> None: