Skip to content

Commit

Permalink
Fix support for connection Upgrade and CONNECT when some data in the …
Browse files Browse the repository at this point in the history
…stream has been read. (#882)

* Add a starting point for the work

* Add draft tests

* Support connection `Upgrade` and `CONNECT`.

* Update CHANGELOG.md

* Remove private state assertions

* Add Async prefix

* Update CHANGELOG.md

Co-authored-by: Tom Christie <tom@tomchristie.com>

* Update tests/_async/test_http11.py

Co-authored-by: T-256 <132141463+T-256@users.noreply.github.com>

---------

Co-authored-by: Tom Christie <tom@tomchristie.com>
Co-authored-by: T-256 <132141463+T-256@users.noreply.github.com>
Co-authored-by: Tom Christie <tom.christie@krakentechnologies.ltd>
  • Loading branch information
4 people committed Feb 20, 2024
1 parent c468024 commit accae7b
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## Unreleased

- Fix support for connection Upgrade and CONNECT when some data in the stream has been read. (#882)

## 1.0.3 (February 13th, 2024)

- Fix support for async cancellations. (#880)
Expand Down
50 changes: 47 additions & 3 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import enum
import logging
import ssl
import time
from types import TracebackType
from typing import (
Any,
AsyncIterable,
AsyncIterator,
List,
Expand Down Expand Up @@ -107,6 +109,7 @@ async def handle_async_request(self, request: Request) -> Response:
status,
reason_phrase,
headers,
trailing_data,
) = await self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
Expand All @@ -115,14 +118,22 @@ async def handle_async_request(self, request: Request) -> Response:
headers,
)

network_stream = self._network_stream

# CONNECT or Upgrade request
if (status == 101) or (
(request.method == b"CONNECT") and (200 <= status < 300)
):
network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data)

return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
"network_stream": network_stream,
},
)
except BaseException as exc:
Expand Down Expand Up @@ -167,7 +178,7 @@ async def _send_event(

async def _receive_response_headers(
self, request: Request
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

Expand All @@ -187,7 +198,9 @@ async def _receive_response_headers(
# raw header casing, rather than the enforced lowercase headers.
headers = event.headers.raw_items()

return http_version, event.status_code, event.reason, headers
trailing_data, _ = self._h11_state.trailing_data

return http_version, event.status_code, event.reason, headers, trailing_data

async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
timeouts = request.extensions.get("timeout", {})
Expand Down Expand Up @@ -340,3 +353,34 @@ async def aclose(self) -> None:
self._closed = True
async with Trace("response_closed", logger, self._request):
await self._connection._response_closed()


class AsyncHTTP11UpgradeStream(AsyncNetworkStream):
def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None:
self._stream = stream
self._leading_data = leading_data

async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
if self._leading_data:
buffer = self._leading_data[:max_bytes]
self._leading_data = self._leading_data[max_bytes:]
return buffer
else:
return await self._stream.read(max_bytes, timeout)

async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
await self._stream.write(buffer, timeout)

async def aclose(self) -> None:
await self._stream.aclose()

async def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
) -> AsyncNetworkStream:
return await self._stream.start_tls(ssl_context, server_hostname, timeout)

def get_extra_info(self, info: str) -> Any:
return self._stream.get_extra_info(info)
50 changes: 47 additions & 3 deletions httpcore/_sync/http11.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import enum
import logging
import ssl
import time
from types import TracebackType
from typing import (
Any,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -107,6 +109,7 @@ def handle_request(self, request: Request) -> Response:
status,
reason_phrase,
headers,
trailing_data,
) = self._receive_response_headers(**kwargs)
trace.return_value = (
http_version,
Expand All @@ -115,14 +118,22 @@ def handle_request(self, request: Request) -> Response:
headers,
)

network_stream = self._network_stream

# CONNECT or Upgrade request
if (status == 101) or (
(request.method == b"CONNECT") and (200 <= status < 300)
):
network_stream = HTTP11UpgradeStream(network_stream, trailing_data)

return Response(
status=status,
headers=headers,
content=HTTP11ConnectionByteStream(self, request),
extensions={
"http_version": http_version,
"reason_phrase": reason_phrase,
"network_stream": self._network_stream,
"network_stream": network_stream,
},
)
except BaseException as exc:
Expand Down Expand Up @@ -167,7 +178,7 @@ def _send_event(

def _receive_response_headers(
self, request: Request
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]:
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

Expand All @@ -187,7 +198,9 @@ def _receive_response_headers(
# raw header casing, rather than the enforced lowercase headers.
headers = event.headers.raw_items()

return http_version, event.status_code, event.reason, headers
trailing_data, _ = self._h11_state.trailing_data

return http_version, event.status_code, event.reason, headers, trailing_data

def _receive_response_body(self, request: Request) -> Iterator[bytes]:
timeouts = request.extensions.get("timeout", {})
Expand Down Expand Up @@ -340,3 +353,34 @@ def close(self) -> None:
self._closed = True
with Trace("response_closed", logger, self._request):
self._connection._response_closed()


class HTTP11UpgradeStream(NetworkStream):
def __init__(self, stream: NetworkStream, leading_data: bytes) -> None:
self._stream = stream
self._leading_data = leading_data

def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
if self._leading_data:
buffer = self._leading_data[:max_bytes]
self._leading_data = self._leading_data[max_bytes:]
return buffer
else:
return self._stream.read(max_bytes, timeout)

def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
self._stream.write(buffer, timeout)

def close(self) -> None:
self._stream.close()

def start_tls(
self,
ssl_context: ssl.SSLContext,
server_hostname: Optional[str] = None,
timeout: Optional[float] = None,
) -> NetworkStream:
return self._stream.start_tls(ssl_context, server_hostname, timeout)

def get_extra_info(self, info: str) -> Any:
return self._stream.get_extra_info(info)
51 changes: 51 additions & 0 deletions tests/_async/test_http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,57 @@ async def test_http11_upgrade_connection():
assert content == b"..."


@pytest.mark.anyio
async def test_http11_upgrade_with_trailing_data():
"""
HTTP "101 Switching Protocols" indicates an upgraded connection.
In `CONNECT` and `Upgrade:` requests, we need to handover the trailing data
in the h11.Connection object.
https://h11.readthedocs.io/en/latest/api.html#switching-protocols
"""
origin = httpcore.Origin(b"wss", b"example.com", 443)
stream = httpcore.AsyncMockStream(
# The first element of this mock network stream buffer simulates networking
# in which response headers and data are received at once.
# This means that "foobar" becomes trailing data.
[
(
b"HTTP/1.1 101 Switching Protocols\r\n"
b"Connection: upgrade\r\n"
b"Upgrade: custom\r\n"
b"\r\n"
b"foobar"
),
b"baz",
]
)
async with httpcore.AsyncHTTP11Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
async with conn.stream(
"GET",
"wss://example.com/",
headers={"Connection": "upgrade", "Upgrade": "custom"},
) as response:
assert response.status == 101
network_stream = response.extensions["network_stream"]

content = await network_stream.read(max_bytes=3)
assert content == b"foo"
content = await network_stream.read(max_bytes=3)
assert content == b"bar"
content = await network_stream.read(max_bytes=3)
assert content == b"baz"

# Lazy tests for AsyncHTTP11UpgradeStream
await network_stream.write(b"spam")
invalid = network_stream.get_extra_info("invalid")
assert invalid is None
await network_stream.aclose()


@pytest.mark.anyio
async def test_http11_early_hints():
"""
Expand Down
51 changes: 51 additions & 0 deletions tests/_sync/test_http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,57 @@ def test_http11_upgrade_connection():



def test_http11_upgrade_with_trailing_data():
"""
HTTP "101 Switching Protocols" indicates an upgraded connection.
In `CONNECT` and `Upgrade:` requests, we need to handover the trailing data
in the h11.Connection object.
https://h11.readthedocs.io/en/latest/api.html#switching-protocols
"""
origin = httpcore.Origin(b"wss", b"example.com", 443)
stream = httpcore.MockStream(
# The first element of this mock network stream buffer simulates networking
# in which response headers and data are received at once.
# This means that "foobar" becomes trailing data.
[
(
b"HTTP/1.1 101 Switching Protocols\r\n"
b"Connection: upgrade\r\n"
b"Upgrade: custom\r\n"
b"\r\n"
b"foobar"
),
b"baz",
]
)
with httpcore.HTTP11Connection(
origin=origin, stream=stream, keepalive_expiry=5.0
) as conn:
with conn.stream(
"GET",
"wss://example.com/",
headers={"Connection": "upgrade", "Upgrade": "custom"},
) as response:
assert response.status == 101
network_stream = response.extensions["network_stream"]

content = network_stream.read(max_bytes=3)
assert content == b"foo"
content = network_stream.read(max_bytes=3)
assert content == b"bar"
content = network_stream.read(max_bytes=3)
assert content == b"baz"

# Lazy tests for HTTP11UpgradeStream
network_stream.write(b"spam")
invalid = network_stream.get_extra_info("invalid")
assert invalid is None
network_stream.close()



def test_http11_early_hints():
"""
HTTP "103 Early Hints" is an interim response.
Expand Down

0 comments on commit accae7b

Please sign in to comment.