Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Oct 1, 2024
2 parents 6029cd6 + 803d818 commit b630750
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGES/9365.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Changed ``ClientRequest.connection_key`` to be a `NamedTuple` to improve client performance -- by :user:`bdraco`.
3 changes: 3 additions & 0 deletions CHANGES/9368.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed proxy headers being used in the ``ConnectionKey`` hash when a proxy was not being used -- by :user:`bdraco`.

If default headers are used, they are also used for proxy headers. This could have led to creating connections that were not needed when one was already available.
6 changes: 4 additions & 2 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,6 @@ async def _request(

# Merge with default headers and transform to CIMultiDict
headers = self._prepare_headers(headers)
proxy_headers = self._prepare_headers(proxy_headers)

try:
url = self._build_url(str_or_url)
Expand All @@ -492,7 +491,10 @@ async def _request(
if proxy_auth is None:
proxy_auth = self._default_proxy_auth

if proxy is not None:
if proxy is None:
proxy_headers = None
else:
proxy_headers = self._prepare_headers(proxy_headers)
try:
proxy = URL(proxy)
except ValueError as e:
Expand Down
104 changes: 69 additions & 35 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Tuple,
Type,
Expand Down Expand Up @@ -150,8 +151,13 @@ def check(self, transport: asyncio.Transport) -> None:
SSL_ALLOWED_TYPES = (bool,)


@dataclasses.dataclass(frozen=True)
class ConnectionKey:
_SSL_SCHEMES = frozenset(("https", "wss"))


# ConnectionKey is a NamedTuple because it is used as a key in a dict
# and a set in the connector. Since a NamedTuple is a tuple it uses
# the fast native tuple __hash__ and __eq__ implementation in CPython.
class ConnectionKey(NamedTuple):
# the key should contain an information about used proxy / TLS
# to prevent reusing wrong connections from a pool
host: str
Expand Down Expand Up @@ -232,7 +238,7 @@ def __init__(
if params:
url = url.extend_query(params)
self.original_url = url
self.url = url.with_fragment(None)
self.url = url.with_fragment(None) if url.raw_fragment else url
self.method = method.upper()
self.chunked = chunked
self.loop = loop
Expand Down Expand Up @@ -287,24 +293,24 @@ def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
writer.add_done_callback(self.__reset_writer)

def is_ssl(self) -> bool:
return self.url.scheme in ("https", "wss")
return self.url.scheme in _SSL_SCHEMES

@property
def ssl(self) -> Union["SSLContext", bool, Fingerprint]:
return self._ssl

@property
def connection_key(self) -> ConnectionKey:
proxy_headers = self.proxy_headers
if proxy_headers:
if proxy_headers := self.proxy_headers:
h: Optional[int] = hash(tuple(proxy_headers.items()))
else:
h = None
url = self.url
return ConnectionKey(
self.host,
self.port,
self.is_ssl(),
self.ssl,
url.raw_host or "",
url.port,
url.scheme in _SSL_SCHEMES,
self._ssl,
self.proxy,
self.proxy_auth,
h,
Expand Down Expand Up @@ -332,9 +338,8 @@ def update_host(self, url: URL) -> None:
raise InvalidURL(url)

# basic auth info
username, password = url.user, url.password
if username or password:
self.auth = helpers.BasicAuth(username or "", password or "")
if url.raw_user or url.raw_password:
self.auth = helpers.BasicAuth(url.user or "", url.password or "")

def update_version(self, version: Union[http.HttpVersion, str]) -> None:
"""Convert request version to two elements tuple.
Expand All @@ -355,25 +360,45 @@ def update_headers(self, headers: Optional[LooseHeaders]) -> None:
"""Update request headers."""
self.headers: CIMultiDict[str] = CIMultiDict()

# add host
netloc = self.url.host_subcomponent
assert netloc is not None
# See https://github.com/aio-libs/aiohttp/issues/3636.
netloc = netloc.rstrip(".")
if self.url.port is not None and not self.url.is_default_port():
netloc += ":" + str(self.url.port)
self.headers[hdrs.HOST] = netloc

if headers:
if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
headers = headers.items()

for key, value in headers: # type: ignore[misc]
# A special case for Host header
if key.lower() == "host":
self.headers[key] = value
else:
self.headers.add(key, value)
# Build the host header
host = self.url.host_subcomponent

# host_subcomponent is None when the URL is a relative URL.
# but we know we do not have a relative URL here.
assert host is not None

if host[-1] == ".":
# Remove all trailing dots from the netloc as while
# they are valid FQDNs in DNS, TLS validation fails.
# See https://github.com/aio-libs/aiohttp/issues/3636.
# To avoid string manipulation we only call rstrip if
# the last character is a dot.
host = host.rstrip(".")

# If explicit port is not None, it means that the port was
# explicitly specified in the URL. In this case we check
# if its not the default port for the scheme and add it to
# the host header. We check explicit_port first because
# yarl caches explicit_port and its likely to already be
# in the cache and non-default port URLs are far less common.
explicit_port = self.url.explicit_port
if explicit_port is not None and not self.url.is_default_port():
host = f"{host}:{explicit_port}"

self.headers[hdrs.HOST] = host

if not headers:
return

if isinstance(headers, (dict, MultiDictProxy, MultiDict)):
headers = headers.items()

for key, value in headers: # type: ignore[misc]
# A special case for Host header
if key.lower() == "host":
self.headers[key] = value
else:
self.headers.add(key, value)

def update_auto_headers(self, skip_auto_headers: Optional[Iterable[str]]) -> None:
if skip_auto_headers is not None:
Expand Down Expand Up @@ -514,7 +539,10 @@ def update_body_from_data(self, body: Any) -> None:
def update_expect_continue(self, expect: bool = False) -> None:
if expect:
self.headers[hdrs.EXPECT] = "100-continue"
elif self.headers.get(hdrs.EXPECT, "").lower() == "100-continue":
elif (
hdrs.EXPECT in self.headers
and self.headers[hdrs.EXPECT].lower() == "100-continue"
):
expect = True

if expect:
Expand All @@ -526,10 +554,16 @@ def update_proxy(
proxy_auth: Optional[BasicAuth],
proxy_headers: Optional[LooseHeaders],
) -> None:
self.proxy = proxy
if proxy is None:
self.proxy_auth = None
self.proxy_headers = None
return

if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
raise ValueError("proxy_auth must be None or BasicAuth() tuple")
self.proxy = proxy
self.proxy_auth = proxy_auth

if proxy_headers is not None and not isinstance(
proxy_headers, (MultiDict, MultiDictProxy)
):
Expand Down Expand Up @@ -759,7 +793,7 @@ def __init__(
self.cookies = SimpleCookie()

self._real_url = url
self._url = url.with_fragment(None)
self._url = url.with_fragment(None) if url.raw_fragment else url
self._body: Optional[bytes] = None
self._writer = writer
self._continue = continue100 # None by default
Expand Down
5 changes: 2 additions & 3 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import dataclasses
import functools
import logging
import random
Expand Down Expand Up @@ -1312,8 +1311,8 @@ async def _create_proxy_connection(
# asyncio handles this perfectly
proxy_req.method = hdrs.METH_CONNECT
proxy_req.url = req.url
key = dataclasses.replace(
req.connection_key, proxy=None, proxy_auth=None, proxy_headers_hash=None
key = req.connection_key._replace(
proxy=None, proxy_auth=None, proxy_headers_hash=None
)
conn = Connection(self, key, proto, self._loop)
proxy_resp = await proxy_req.send(conn)
Expand Down
13 changes: 8 additions & 5 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"
"""Create BasicAuth from url."""
if not isinstance(url, URL):
raise TypeError("url should be yarl.URL instance")
if url.user is None and url.password is None:
# Check raw_user and raw_password first as yarl is likely
# to already have these values parsed from the netloc in the cache.
if url.raw_user is None and url.raw_password is None:
return None
return cls(url.user or "", url.password or "", encoding=encoding)

Expand All @@ -172,11 +174,12 @@ def encode(self) -> str:


def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
auth = BasicAuth.from_url(url)
if auth is None:
"""Remove user and password from URL if present and return BasicAuth object."""
# Check raw_user and raw_password first as yarl is likely
# to already have these values parsed from the netloc in the cache.
if url.raw_user is None and url.raw_password is None:
return url, None
else:
return url.with_user(None), auth
return url.with_user(None), BasicAuth(url.user or "", url.password or "")


def netrc_from_env() -> Optional[netrc.netrc]:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,3 +1474,30 @@ def test_basicauth_from_empty_netrc(
"""Test that no Authorization header is sent when netrc is empty"""
req = make_request("get", "http://example.com", trust_env=True)
assert hdrs.AUTHORIZATION not in req.headers


async def test_connection_key_with_proxy() -> None:
"""Verify the proxy headers are included in the ConnectionKey when a proxy is used."""
proxy = URL("http://proxy.example.com")
req = ClientRequest(
"GET",
URL("http://example.com"),
proxy=proxy,
proxy_headers={"X-Proxy": "true"},
loop=asyncio.get_running_loop(),
)
assert req.connection_key.proxy_headers_hash is not None
await req.close()


async def test_connection_key_without_proxy() -> None:
"""Verify the proxy headers are not included in the ConnectionKey when a proxy is used."""
# If proxy is unspecified, proxy_headers should be ignored
req = ClientRequest(
"GET",
URL("http://example.com"),
proxy_headers={"X-Proxy": "true"},
loop=asyncio.get_running_loop(),
)
assert req.connection_key.proxy_headers_hash is None
await req.close()
6 changes: 6 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ def test_basic_auth_no_user_from_url() -> None:
assert auth.password == "pass"


def test_basic_auth_no_auth_from_url() -> None:
url = URL("http://example.com")
auth = helpers.BasicAuth.from_url(url)
assert auth is None


def test_basic_auth_from_not_url() -> None:
with pytest.raises(TypeError):
helpers.BasicAuth.from_url("http://user:pass@example.com") # type: ignore[arg-type]
Expand Down

0 comments on commit b630750

Please sign in to comment.