diff --git a/CHANGES/3383.bugfix b/CHANGES/3383.bugfix new file mode 100644 index 00000000000..98bfe3d8c60 --- /dev/null +++ b/CHANGES/3383.bugfix @@ -0,0 +1 @@ +Fix task cancellation when ``sendfile()`` syscall is used by static file handling. diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index ccae5dac710..be4dbc73bce 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -2,6 +2,7 @@ import mimetypes import os import pathlib +from functools import partial from typing import (IO, TYPE_CHECKING, Any, Awaitable, Callable, List, # noqa Optional, Union, cast) @@ -35,9 +36,15 @@ class SendfileStreamWriter(StreamWriter): def __init__(self, protocol: BaseProtocol, loop: asyncio.AbstractEventLoop, + fobj: IO[Any], + count: int, on_chunk_sent: _T_OnChunkSent=None) -> None: super().__init__(protocol, loop, on_chunk_sent) self._sendfile_buffer = [] # type: List[bytes] + self._fobj = fobj + self._count = count + self._offset = fobj.tell() + self._in_fd = fobj.fileno() def _write(self, chunk: bytes) -> None: # we overwrite StreamWriter._write, so nothing can be appended to @@ -46,54 +53,57 @@ def _write(self, chunk: bytes) -> None: self.output_size += len(chunk) self._sendfile_buffer.append(chunk) - def _sendfile_cb(self, fut: 'asyncio.Future[None]', - out_fd: int, in_fd: int, - offset: int, count: int, - loop: asyncio.AbstractEventLoop, - registered: bool) -> None: - if registered: - loop.remove_writer(out_fd) + def _sendfile_cb(self, fut: 'asyncio.Future[None]', out_fd: int) -> None: if fut.cancelled(): return + try: + if self._do_sendfile(out_fd): + set_result(fut, None) + except Exception as exc: + set_exception(fut, exc) + def _do_sendfile(self, out_fd: int) -> bool: try: - n = os.sendfile(out_fd, in_fd, offset, count) - if n == 0: # EOF reached - n = count + n = os.sendfile(out_fd, + self._in_fd, + self._offset, + self._count) + if n == 0: # in_fd EOF reached + n = self._count except (BlockingIOError, InterruptedError): n = 0 - except Exception as exc: - set_exception(fut, exc) - return + self.output_size += n + self._offset += n + self._count -= n + assert self._count >= 0 + return self._count == 0 - if n < count: - loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd, in_fd, - offset + n, count - n, loop, True) - else: - set_result(fut, None) + def _done_fut(self, out_fd: int, fut: 'asyncio.Future[None]') -> None: + self.loop.remove_writer(out_fd) - async def sendfile(self, fobj: IO[Any], count: int) -> None: + async def sendfile(self) -> None: assert self.transport is not None out_socket = self.transport.get_extra_info('socket').dup() out_socket.setblocking(False) out_fd = out_socket.fileno() - in_fd = fobj.fileno() - offset = fobj.tell() loop = self.loop data = b''.join(self._sendfile_buffer) try: await loop.sock_sendall(out_socket, data) - fut = loop.create_future() - self._sendfile_cb(fut, out_fd, in_fd, offset, count, loop, False) - await fut + if not self._do_sendfile(out_fd): + fut = loop.create_future() + fut.add_done_callback(partial(self._done_fut, out_fd)) + loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd) + await fut + except asyncio.CancelledError: + raise except Exception: server_logger.debug('Socket error') self.transport.close() finally: out_socket.close() - self.output_size += count await super().write_eof() async def write_eof(self, chunk: bytes=b'') -> None: @@ -139,12 +149,14 @@ async def _sendfile_system(self, request: 'BaseRequest', else: writer = SendfileStreamWriter( request.protocol, - request._loop + request._loop, + fobj, + count ) request._payload_writer = writer await super().prepare(request) - await writer.sendfile(fobj, count) + await writer.sendfile() return writer diff --git a/tests/test_web_sendfile.py b/tests/test_web_sendfile.py index 241e7db74df..f0849f15c8a 100644 --- a/tests/test_web_sendfile.py +++ b/tests/test_web_sendfile.py @@ -2,73 +2,7 @@ from aiohttp import hdrs from aiohttp.test_utils import make_mocked_coro, make_mocked_request -from aiohttp.web_fileresponse import FileResponse, SendfileStreamWriter - - -def test_static_handle_eof(loop) -> None: - fake_loop = mock.Mock() - with mock.patch('aiohttp.web_fileresponse.os') as m_os: - out_fd = 30 - in_fd = 31 - fut = loop.create_future() - m_os.sendfile.return_value = 0 - writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop) - writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) - m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) - assert fut.done() - assert fut.result() is None - assert not fake_loop.add_writer.called - assert not fake_loop.remove_writer.called - - -def test_static_handle_again(loop) -> None: - fake_loop = mock.Mock() - with mock.patch('aiohttp.web_fileresponse.os') as m_os: - out_fd = 30 - in_fd = 31 - fut = loop.create_future() - m_os.sendfile.side_effect = BlockingIOError() - writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop) - writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) - m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) - assert not fut.done() - fake_loop.add_writer.assert_called_with(out_fd, - writer._sendfile_cb, - fut, out_fd, in_fd, 0, 100, - fake_loop, True) - assert not fake_loop.remove_writer.called - - -def test_static_handle_exception(loop) -> None: - fake_loop = mock.Mock() - with mock.patch('aiohttp.web_fileresponse.os') as m_os: - out_fd = 30 - in_fd = 31 - fut = loop.create_future() - exc = OSError() - m_os.sendfile.side_effect = exc - writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop) - writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) - m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) - assert fut.done() - assert exc is fut.exception() - assert not fake_loop.add_writer.called - assert not fake_loop.remove_writer.called - - -def test__sendfile_cb_return_on_cancelling(loop) -> None: - fake_loop = mock.Mock() - with mock.patch('aiohttp.web_fileresponse.os') as m_os: - out_fd = 30 - in_fd = 31 - fut = loop.create_future() - fut.cancel() - writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop) - writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) - assert fut.done() - assert not fake_loop.add_writer.called - assert not fake_loop.remove_writer.called - assert not m_os.sendfile.called +from aiohttp.web_fileresponse import FileResponse def test_using_gzip_if_header_present_and_file_available(loop) -> None: diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 410878e9de9..3d71de99129 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -1,6 +1,7 @@ import asyncio import os import pathlib +import socket import zlib import pytest @@ -324,7 +325,7 @@ def test_static_route_path_existence_check() -> None: async def test_static_file_huge(aiohttp_client, tmpdir) -> None: filename = 'huge_data.unknown_mime_type' - # fill 100MB file + # fill 20MB file with tmpdir.join(filename).open('w') as f: for i in range(1024*20): f.write(chr(i % 64 + 0x20) * 1024) @@ -751,3 +752,68 @@ async def handler(request): assert 'application/octet-stream' == resp.headers['Content-Type'] assert resp.headers.get('Content-Encoding') == 'deflate' await resp.release() + + +async def test_static_file_huge_cancel(aiohttp_client, tmpdir) -> None: + filename = 'huge_data.unknown_mime_type' + + # fill 100MB file + with tmpdir.join(filename).open('w') as f: + for i in range(1024*20): + f.write(chr(i % 64 + 0x20) * 1024) + + task = None + + async def handler(request): + nonlocal task + task = request.task + # reduce send buffer size + tr = request.transport + sock = tr.get_extra_info('socket') + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + ret = web.FileResponse(pathlib.Path(str(tmpdir.join(filename)))) + return ret + + app = web.Application() + + app.router.add_get('/', handler) + client = await aiohttp_client(app) + + resp = await client.get('/') + assert resp.status == 200 + task.cancel() + await asyncio.sleep(0) + data = b'' + while True: + try: + data += await resp.content.read(1024) + except aiohttp.ClientPayloadError: + break + assert len(data) < 1024 * 1024 * 20 + + +async def test_static_file_huge_error(aiohttp_client, tmpdir) -> None: + filename = 'huge_data.unknown_mime_type' + + # fill 20MB file + with tmpdir.join(filename).open('wb') as f: + f.seek(20*1024*1024) + f.write(b'1') + + async def handler(request): + # reduce send buffer size + tr = request.transport + sock = tr.get_extra_info('socket') + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + ret = web.FileResponse(pathlib.Path(str(tmpdir.join(filename)))) + return ret + + app = web.Application() + + app.router.add_get('/', handler) + client = await aiohttp_client(app) + + resp = await client.get('/') + assert resp.status == 200 + # raise an exception on server side + resp.close()