Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR #9340/8a97e03 backport][3.10] Use dunder writer internally in ClientResponse #9341

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ async def send(self, conn: "Connection") -> "ClientResponse":
self.response = response_class(
self.method,
self.original_url,
writer=self._writer,
writer=task,
continue100=self._continue,
timer=self._timer,
request_info=self.request_info,
Expand All @@ -773,9 +773,9 @@ async def send(self, conn: "Connection") -> "ClientResponse":
return self.response

async def close(self) -> None:
if self._writer is not None:
if self.__writer is not None:
try:
await self._writer
await self.__writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
Expand All @@ -785,11 +785,11 @@ async def close(self) -> None:
raise

def terminate(self) -> None:
if self._writer is not None:
if self.__writer is not None:
if not self.loop.is_closed():
self._writer.cancel()
self._writer.remove_done_callback(self.__reset_writer)
self._writer = None
self.__writer.cancel()
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = None

async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
for trace in self._traces:
Expand Down Expand Up @@ -845,8 +845,8 @@ def __init__(

self._real_url = url
self._url = url.with_fragment(None)
self._body: Any = None
self._writer: Optional[asyncio.Task[None]] = writer
self._body: Optional[bytes] = None
self._writer = writer
self._continue = continue100 # None by default
self._closed = True
self._history: Tuple[ClientResponse, ...] = ()
Expand Down Expand Up @@ -874,10 +874,16 @@ def __reset_writer(self, _: object = None) -> None:

@property
def _writer(self) -> Optional["asyncio.Task[None]"]:
"""The writer task for streaming data.

_writer is only provided for backwards compatibility
for subclasses that may need to access it.
"""
return self.__writer

@_writer.setter
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
"""Set the writer task for streaming data."""
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = writer
Expand Down Expand Up @@ -1128,16 +1134,16 @@ def raise_for_status(self) -> None:

def _release_connection(self) -> None:
if self._connection is not None:
if self._writer is None:
if self.__writer is None:
self._connection.release()
self._connection = None
else:
self._writer.add_done_callback(lambda f: self._release_connection())
self.__writer.add_done_callback(lambda f: self._release_connection())

async def _wait_released(self) -> None:
if self._writer is not None:
if self.__writer is not None:
try:
await self._writer
await self.__writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
Expand All @@ -1148,8 +1154,8 @@ async def _wait_released(self) -> None:
self._release_connection()

def _cleanup_writer(self) -> None:
if self._writer is not None:
self._writer.cancel()
if self.__writer is not None:
self.__writer.cancel()
self._session = None

def _notify_content(self) -> None:
Expand All @@ -1159,9 +1165,9 @@ def _notify_content(self) -> None:
self._released = True

async def wait_for_close(self) -> None:
if self._writer is not None:
if self.__writer is not None:
try:
await self._writer
await self.__writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
Expand Down Expand Up @@ -1189,7 +1195,7 @@ async def read(self) -> bytes:
protocol = self._connection and self._connection.protocol
if protocol is None or not protocol.upgraded:
await self._wait_released() # Underlying connection released
return self._body # type: ignore[no-any-return]
return self._body

def get_encoding(self) -> str:
ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
Expand Down Expand Up @@ -1222,9 +1228,7 @@ async def text(self, encoding: Optional[str] = None, errors: str = "strict") ->
if encoding is None:
encoding = self.get_encoding()

return self._body.decode( # type: ignore[no-any-return,union-attr]
encoding, errors=errors
)
return self._body.decode(encoding, errors=errors) # type: ignore[union-attr]

async def json(
self,
Expand Down
Loading