Skip to content

Commit

Permalink
Use a separate strategy for process and thread based file readers (#1073
Browse files Browse the repository at this point in the history
)

This allows the thread file reader strategy to keep the file open as long as the context manager is active
  • Loading branch information
FasterSpeeding committed Mar 13, 2022
1 parent c4ed995 commit 35cfca3
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 87 deletions.
1 change: 1 addition & 0 deletions changes/1073.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The threaded file reader now persists the open file pointer while the context manager is active.
116 changes: 70 additions & 46 deletions hikari/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import asyncio
import base64
import concurrent.futures
import errno
import inspect
import io
import mimetypes
Expand Down Expand Up @@ -771,52 +772,52 @@ class FileReader(AsyncReader, abc.ABC):
"""The path to the resource to read."""


def _stat(path: pathlib.Path) -> os.stat_result:
def _stat(path: pathlib.Path) -> pathlib.Path:
# While paths will be implicitly resolved, we still need to explicitly
# call expanduser to deal with a ~ base.
try:
path = path.expanduser()
except RuntimeError:
pass # A home directory couldn't be resolved, so we'll just use the path as-is.

return path.stat()
# path.stat() will raise FileNotFoundError if the file doesn't exist
# (unlike is_dir) which is what we want here.
if stat.S_ISDIR(path.stat().st_mode):
raise IsADirectoryError(errno.EISDIR, "Cannot open the path specified as it is a directory", str(path))

return path


@attr.define(weakref_slot=False)
@typing.final
class _FileAsyncReaderContextManagerImpl(AsyncReaderContextManager[FileReader]):
impl: FileReader = attr.field()

async def __aenter__(self) -> FileReader:
loop = asyncio.get_running_loop()
class _ThreadedFileReaderContextManagerImpl(AsyncReaderContextManager[FileReader]):
executor: typing.Optional[concurrent.futures.ThreadPoolExecutor] = attr.field()
file: typing.Optional[typing.BinaryIO] = attr.field(default=None, init=False)
filename: str = attr.field()
path: pathlib.Path = attr.field()

# Will raise FileNotFoundError if the file doesn't exist (unlike is_dir),
# which is what we want here.
file_stats = await loop.run_in_executor(self.impl.executor, _stat, self.impl.path)
async def __aenter__(self) -> ThreadedFileReader:
if self.file:
raise RuntimeError("File is already open")

if stat.S_ISDIR(file_stats.st_mode):
raise IsADirectoryError(self.impl.path)

return self.impl
loop = asyncio.get_running_loop()
self.path = await loop.run_in_executor(self.executor, _stat, self.path)
self.file = typing.cast(io.BufferedReader, await loop.run_in_executor(self.executor, self.path.open, "rb"))
return ThreadedFileReader(self.filename, None, self.executor, self.path, self.file)

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc: typing.Optional[BaseException],
exc_tb: typing.Optional[types.TracebackType],
) -> None:
pass

if not self.file:
raise RuntimeError("File isn't open")

def _open_file(path: pathlib.Path) -> typing.BinaryIO:
# While paths will be implicitly resolved, we still need to explicitly
# call expanduser to deal with a ~ base.
try:
path = path.expanduser()
except RuntimeError:
pass # A home directory couldn't be resolved, so we'll just use the path as-is.

return path.open("rb")
loop = asyncio.get_running_loop()
file = self.file
self.file = None
await loop.run_in_executor(self.executor, file.close)


@attr.define(weakref_slot=False)
Expand All @@ -828,32 +829,44 @@ class ThreadedFileReader(FileReader):
do not need to be pickled to be communicated.
"""

_pointer: typing.BinaryIO = attr.field()

async def __aiter__(self) -> typing.AsyncGenerator[typing.Any, bytes]:
loop = asyncio.get_running_loop()

fp = await loop.run_in_executor(self.executor, _open_file, self.path)
while True:
chunk = await loop.run_in_executor(self.executor, self._pointer.read, _MAGIC)
yield chunk
if len(chunk) < _MAGIC:
break

try:
while True:
chunk = await loop.run_in_executor(self.executor, fp.read, _MAGIC)
yield chunk
if len(chunk) < _MAGIC:
break

finally:
await loop.run_in_executor(self.executor, fp.close)
@attr.define(weakref_slot=False)
@typing.final
class _MultiProcessingFileReaderContextManagerImpl(AsyncReaderContextManager[FileReader]):
executor: concurrent.futures.ProcessPoolExecutor = attr.field()
file: typing.Optional[typing.BinaryIO] = attr.field(default=None, init=False)
filename: str = attr.field()
path: pathlib.Path = attr.field()

async def __aenter__(self) -> MultiprocessingFileReader:
loop = asyncio.get_running_loop()

path = await loop.run_in_executor(self.executor, _stat, self.path)
return MultiprocessingFileReader(self.filename, None, self.executor, path)

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc: typing.Optional[BaseException],
exc_tb: typing.Optional[types.TracebackType],
) -> None:
pass

def _read_all(path: pathlib.Path) -> bytes:
# While paths will be implicitly resolved, we still need to explicitly
# call expanduser to deal with a ~ base.
try:
path = path.expanduser()
except RuntimeError:
pass # A home directory couldn't be resolved, so we'll just use the path as-is.

with path.open("rb") as fp:
return fp.read()
def _read_all(path: pathlib.Path) -> bytes:
with path.open("rb") as file:
return file.read()


@attr.define(slots=False, weakref_slot=False) # Do not slot (pickle)
Expand Down Expand Up @@ -952,12 +965,23 @@ def stream(
AsyncReaderContextManager[FileReader]
An async context manager that when entered, produces the
data stream.
Raises
------
IsADirectoryError
If the file's path leads to a directory.
FileNotFoundError
If the file doesn't exist.
"""
# asyncio forces the default executor when this is None to always be a thread pool executor anyway,
# so this is safe enough to do.:
is_threaded = executor is None or isinstance(executor, concurrent.futures.ThreadPoolExecutor)
impl = ThreadedFileReader if is_threaded else MultiprocessingFileReader
return _FileAsyncReaderContextManagerImpl(impl(self.filename, None, executor, self.path))
if executor is None or isinstance(executor, concurrent.futures.ThreadPoolExecutor):
return _ThreadedFileReaderContextManagerImpl(executor, self.filename, self.path)

if not isinstance(executor, concurrent.futures.ProcessPoolExecutor):
raise TypeError("The executor must be a ProcessPoolExecutor, ThreadPoolExecutor, or `builtins.None`.")

return _MultiProcessingFileReaderContextManagerImpl(executor, self.filename, self.path)


########################################################################
Expand Down
166 changes: 125 additions & 41 deletions tests/hikari/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@

import base64
import concurrent.futures
import contextlib
import pathlib
import random
import tempfile
import typing

import mock
import pytest
Expand All @@ -52,68 +50,154 @@ def test___exit__(self, reader):
pytest.fail(exc)


class Test_FileAsyncReaderContextManagerImpl:
@pytest.mark.parametrize(
"executor", [concurrent.futures.ThreadPoolExecutor, concurrent.futures.ProcessPoolExecutor]
)
class Test_ThreadedFileReaderContextManagerImpl:
@pytest.mark.asyncio()
async def test_context_manager(self, executor: typing.Callable[[], concurrent.futures.Executor]):
mock_reader = mock.Mock(executor=executor())
context_manager = files._FileAsyncReaderContextManagerImpl(mock_reader)
async def test_enter_dunder_method_when_already_open(self):
manager = files._ThreadedFileReaderContextManagerImpl(mock.Mock(), "ea", pathlib.Path("ea"))
manager.file = mock.Mock()
with pytest.raises(RuntimeError, match="File is already open"):
await manager.__aenter__()

with tempfile.NamedTemporaryFile() as file:
mock_reader.path = pathlib.Path(file.name)
@pytest.mark.asyncio()
async def test_exit_dunder_method_when_not_open(self):
manager = files._ThreadedFileReaderContextManagerImpl(mock.Mock(), "ea", pathlib.Path("ea"))

with pytest.raises(RuntimeError, match="File isn't open"):
await manager.__aexit__(None, None, None)

@pytest.mark.asyncio()
async def test_context_manager(self):
executor = concurrent.futures.ThreadPoolExecutor()
mock_data = b"meeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee" * 50

# A try, finally is used to delete the file rather than relying on delete=True behaviour
# as on Windows the file cannot be accessed by other processes if delete is True.
file = tempfile.NamedTemporaryFile("wb", delete=False)
path = pathlib.Path(file.name)
try:
with file:
file.write(mock_data)

context_manager = files._ThreadedFileReaderContextManagerImpl(executor, "meow.txt", path)

async with context_manager as reader:
assert reader is mock_reader
data = await reader.read()

assert reader.filename == "meow.txt"
assert reader.path == path
assert reader.executor is executor
assert data == mock_data

finally:
path.unlink()

@mock.patch.object(pathlib.Path, "expanduser", side_effect=RuntimeError)
@pytest.mark.asyncio()
async def test_context_manager_when_expandname_raises_runtime_error(self):
async def test_context_manager_when_expandname_raises_runtime_error(self, expanduser: mock.Mock):
# We can't mock patch stuff in other processes easily (if at all) so
# for this test we only run it threaded.
mock_reader = mock.Mock(executor=concurrent.futures.ThreadPoolExecutor())
context_manager = files._FileAsyncReaderContextManagerImpl(mock_reader)
# for this test we have to cheat and use a thread pool executor.
executor = concurrent.futures.ThreadPoolExecutor()

stack = contextlib.ExitStack()
file = stack.enter_context(tempfile.NamedTemporaryFile())
expandname = stack.enter_context(mock.patch.object(pathlib.Path, "expanduser", side_effect=RuntimeError))
# A try, finally is used to delete the file rather than relying on delete=True behaviour
# as on Windows the file cannot be accessed by other processes if delete is True.
with tempfile.NamedTemporaryFile(delete=False) as file:
pass

with file:
mock_reader.path = pathlib.Path(file.name)
path = pathlib.Path(file.name)
try:
context_manager = files._ThreadedFileReaderContextManagerImpl(executor, "filename.txt", path)

async with context_manager as reader:
assert reader is mock_reader
assert reader.path == path

expanduser.assert_called_once_with()

expandname.assert_called_once_with()
finally:
path.unlink()

@pytest.mark.parametrize(
"executor", [concurrent.futures.ThreadPoolExecutor, concurrent.futures.ProcessPoolExecutor]
)
@pytest.mark.asyncio()
async def test_context_manager_for_unknown_file(self, executor: typing.Callable[[], concurrent.futures.Executor]):
mock_reader = mock.Mock(executor=executor())
context_manager = files._FileAsyncReaderContextManagerImpl(mock_reader)
async def test_context_manager_for_unknown_file(self):
executor = concurrent.futures.ThreadPoolExecutor()
path = pathlib.Path(base64.urlsafe_b64encode(random.getrandbits(512).to_bytes(64, "little")).decode())
context_manager = files._ThreadedFileReaderContextManagerImpl(executor, "ea.txt", path)

mock_reader.path = pathlib.Path(
base64.urlsafe_b64encode(random.getrandbits(512).to_bytes(64, "little")).decode()
)
with pytest.raises(FileNotFoundError): # noqa: PT012 - raises block should contain a single statement
async with context_manager:
...

@pytest.mark.asyncio()
async def test_test_context_manager_when_target_is_dir(self):
executor = concurrent.futures.ThreadPoolExecutor()

with tempfile.TemporaryDirectory() as name:
path = pathlib.Path(name)
context_manager = files._ThreadedFileReaderContextManagerImpl(executor, "meow.txt", path)

with pytest.raises(IsADirectoryError): # noqa: PT012 - raises block should contain a single statement
async with context_manager:
...


class Test_MultiProcessingFileReaderContextManagerImpl:
@pytest.mark.asyncio()
async def test_context_manager(self):
executor = concurrent.futures.ProcessPoolExecutor()
mock_data = b"kon'nichiwa i am yellow and blue da be meow da bayeet" * 50

# A try, finally is used to delete the file rather than relying on delete=True behaviour
# as on Windows the file cannot be accessed by other processes if delete is True.
file = tempfile.NamedTemporaryFile("wb", delete=False)
path = pathlib.Path(file.name)
try:
with file:
file.write(mock_data)

context_manager = files._MultiProcessingFileReaderContextManagerImpl(executor, "filename.txt", path)

async with context_manager as reader:
data = await reader.read()

assert reader.filename == "filename.txt"
assert reader.path == path
assert reader.executor is executor
assert data == mock_data

finally:
path.unlink()

@mock.patch.object(pathlib.Path, "expanduser", side_effect=RuntimeError)
@pytest.mark.asyncio()
async def test_context_manager_when_expandname_raises_runtime_error(self, expanduser: mock.Mock):
# We can't mock patch stuff in other processes easily (if at all) so
# for this test we have to cheat and use a thread pool executor.
executor = concurrent.futures.ThreadPoolExecutor()

with tempfile.NamedTemporaryFile() as file:
path = pathlib.Path(file.name)
context_manager = files._MultiProcessingFileReaderContextManagerImpl(executor, "filename.txt", path)

async with context_manager as reader:
assert reader.path == path

expanduser.assert_called_once_with()

@pytest.mark.asyncio()
async def test_context_manager_for_unknown_file(self):
executor = concurrent.futures.ProcessPoolExecutor()
path = pathlib.Path(base64.urlsafe_b64encode(random.getrandbits(512).to_bytes(64, "little")).decode())
context_manager = files._MultiProcessingFileReaderContextManagerImpl(executor, "ea.txt", path)

with pytest.raises(FileNotFoundError): # noqa: PT012 - raises block should contain a single statement
async with context_manager:
...

@pytest.mark.parametrize(
"executor", [concurrent.futures.ThreadPoolExecutor, concurrent.futures.ProcessPoolExecutor]
)
@pytest.mark.asyncio()
async def test_test_context_manager_when_target_is_dir(
self, executor: typing.Callable[[], concurrent.futures.Executor]
):
mock_reader = mock.Mock(executor=executor())
context_manager = files._FileAsyncReaderContextManagerImpl(mock_reader)
async def test_test_context_manager_when_target_is_dir(self):
executor = concurrent.futures.ProcessPoolExecutor()

with tempfile.TemporaryDirectory() as name:
mock_reader.path = pathlib.Path(name)
path = pathlib.Path(name)
context_manager = files._MultiProcessingFileReaderContextManagerImpl(executor, "meow.txt", path)

with pytest.raises(IsADirectoryError): # noqa: PT012 - raises block should contain a single statement
async with context_manager:
Expand Down

0 comments on commit 35cfca3

Please sign in to comment.