Skip to content

Commit

Permalink
Adds tests. Fixes bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
payneio committed Oct 8, 2024
1 parent f343633 commit 7b3f9af
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 31 deletions.
34 changes: 20 additions & 14 deletions libraries/python/assistant-drive/assistant_drive/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# from pydantic_settings import BaseSettings

logger = logging.getLogger(__name__)
mime = magic.Magic(mime=True)


class DriveConfig(BaseModel):
Expand Down Expand Up @@ -57,7 +58,6 @@ def __str__(self) -> str:

@staticmethod
def from_bytes(content: BinaryIO, filename: str, dir: str | None = None) -> "FileMetadata":
mime = magic.Magic(mime=True)
content_type = mime.from_buffer(content.read())
content.seek(0)
size = len(content.read())
Expand Down Expand Up @@ -105,10 +105,10 @@ def _metadata_path_for(self, filename: str | None = None, dir: str | None = None

def _unique_filename(self, filename: str, dir: str | None) -> str:
"""Ensure filename is unique in the namespace by appending a counter."""
root_name, extension = os.path.splitext(filename)
base, extension = os.path.splitext(filename)
counter = 1
while self.file_exists(dir, filename):
filename = f"{root_name}({counter}){extension}"
while self.file_exists(filename, dir):
filename = f"{base}({counter}){extension}"
counter += 1

return filename
Expand All @@ -120,9 +120,12 @@ def add_bytes(
dir: str | None = None,
overwrite: bool = False,
) -> FileMetadata:
# Find index of stream.
idx = content.tell()

# If file exists, and asked to overwrite, use the same file id.
# If not asked to overwrite, generate a new file id and modified filename.
if self.file_exists(dir, filename):
if self.file_exists(filename, dir):
if not overwrite:
filename = self._unique_filename(filename, dir)

Expand All @@ -139,24 +142,27 @@ def add_bytes(
with open(metadata_path, "w") as f:
f.write(json.dumps(file.metadata.to_dict(), indent=2))

# Return the stream to the original index.
content.seek(idx)

return file.metadata

def delete(self, dir: str | None = None, filename: str | None = None) -> None:
file_path = self._path_for(dir, filename)
file_path = self._path_for(filename, dir)
if file_path.is_file():
file_path.unlink()
metadata_path = self._metadata_path_for(dir)
metadata_path = self._metadata_path_for(filename, dir)
metadata_path.unlink()
else:
dir_path = self._path_for(dir)
dir_path = self._path_for(dir=dir)
if dir_path.is_dir():
rmtree(dir_path)
metadata_path = self._metadata_path_for(dir)
metadata_path = self._metadata_path_for(dir=dir)
if metadata_path.is_dir():
rmtree(metadata_path)

def read_all_files(self, dir: str) -> Iterator[BinaryIO]:
dir_path = self._path_for(dir, "")
dir_path = self._path_for("", dir)
if not dir_path.is_dir():
return

Expand All @@ -165,7 +171,7 @@ def read_all_files(self, dir: str) -> Iterator[BinaryIO]:
yield f

def list(self, dir: str = "") -> Iterator[str]:
dir_path = self._path_for(dir, "")
dir_path = self._path_for("", dir)
if not dir_path.is_dir():
return

Expand All @@ -176,11 +182,11 @@ def list(self, dir: str = "") -> Iterator[str]:
yield file_path.name

@contextmanager
def read_file(self, dir: str, filename: str) -> Iterator[BinaryIO]:
file_path = self._path_for(dir, filename)
def read_file(self, filename: str, dir: str | None = None) -> Iterator[BinaryIO]:
file_path = self._path_for(filename, dir)
with open(file_path, "rb") as f:
yield f

def file_exists(self, dir: str | None, filename: str) -> bool:
def file_exists(self, filename: str, dir: str | None = None) -> bool:
file_path = self._path_for(filename, dir)
return file_path.exists()
117 changes: 117 additions & 0 deletions libraries/python/assistant-drive/assistant_drive/tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from io import BytesIO # Import BytesIO

import pytest
from assistant_drive import Drive, DriveConfig
from context import Context

file_content = BytesIO(b"Hello, World!") # Convert byte string to BytesIO


@pytest.fixture
def drive():
context: Context = Context(session_id="test_session")
config = DriveConfig(context=context)
drive = Drive(config)
drive.delete()

yield drive

drive.delete()


def test_add_bytes_to_root(drive):
# Add a file with a directory.
metadata = drive.add_bytes(file_content, "test.txt")

assert metadata.filename == "test.txt"
assert metadata.dir is None
assert metadata.content_type == "text/plain"
assert metadata.size == 13
assert metadata.created_at is not None
assert metadata.updated_at is not None

assert list(drive.list()) == ["test.txt"]


def test_add_bytes_to_directory(drive):
metadata = drive.add_bytes(file_content, "test.txt", "summaries")

assert metadata.filename == "test.txt"
assert metadata.dir == "summaries"
assert metadata.content_type == "text/plain"
assert metadata.size == 13
assert metadata.created_at is not None
assert metadata.updated_at is not None

assert list(drive.list()) == ["summaries"]
assert list(drive.list(dir="summaries")) == ["test.txt"]


def test_add_bytes_to_nested_directory(drive):
metadata = drive.add_bytes(file_content, "test.txt", "abc/summaries")

assert metadata.filename == "test.txt"
assert metadata.dir == "abc/summaries"
assert metadata.content_type == "text/plain"
assert metadata.size == 13
assert metadata.created_at is not None
assert metadata.updated_at is not None

assert list(drive.list()) == ["abc"]
assert list(drive.list(dir="abc")) == ["summaries"]
assert list(drive.list(dir="abc/summaries")) == ["test.txt"]


def test_exists(drive):
assert not drive.file_exists("test.txt", "summaries")
drive.add_bytes(file_content, "test.txt", "summaries")
assert drive.file_exists("test.txt", "summaries")


def test_read(drive):
drive.add_bytes(file_content, "test.txt", "summaries")
with drive.read_file("test.txt", "summaries") as f:
assert f.read() == b"Hello, World!"


def test_list(drive):
drive.add_bytes(file_content, "test.txt", "summaries")
assert list(drive.list(dir="summaries")) == ["test.txt"]

drive.add_bytes(file_content, "test2.txt", "summaries")
assert sorted(list(drive.list(dir="summaries"))) == ["test.txt", "test2.txt"]


def test_read_non_existent_file(drive):
with pytest.raises(FileNotFoundError):
with drive.read_file("test.txt", "summaries") as f:
f.read()


def test_no_overwrite(drive):
drive.add_bytes(file_content, "test.txt", "summaries")
metadata = drive.add_bytes(file_content, "test.txt", "summaries")
assert metadata.filename == "test(1).txt"
assert sorted(list(drive.list(dir="summaries"))) == sorted(["test.txt", "test(1).txt"])


def test_overwrite(drive):
drive.add_bytes(file_content, "test.txt", "summaries")
metadata = drive.add_bytes(BytesIO(b"XXX"), "test.txt", "summaries", overwrite=True)
assert metadata.filename == "test.txt"
with drive.read_file("test.txt", "summaries") as f:
assert f.read() == b"XXX"
assert list(drive.list(dir="summaries")) == ["test.txt"]


def test_delete(drive):
drive.add_bytes(file_content, "test.txt", "summaries")
drive.delete(dir="summaries", filename="test.txt")
assert list(drive.list(dir="summaries")) == []

# Add a file with the same name but overwrite.
metadata = drive.add_bytes(file_content, "test.txt", "summaries", overwrite=True)
assert metadata.filename == "test.txt"
assert sorted(list(drive.list(dir="summaries"))) == sorted(["test.txt"])

drive.delete()
24 changes: 7 additions & 17 deletions libraries/python/assistant-drive/usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['.metadata']\n",
"d33057f5-9b5a-4b77-86a4-1cc6a4fefa9d\n"
]
}
],
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
Expand All @@ -26,14 +17,13 @@
"config = DriveConfig(context=context)\n",
"drive = Drive(config)\n",
"\n",
"print(list(drive.list_files()))\n",
"print(list(drive.list()))\n",
"\n",
"# Create BinaryIO from text.\n",
"file_content = BytesIO(b\"Hello, World!\") # Convert byte string to BytesIO\n",
"file_id = drive.add_bytes(file_content, \"test.txt\")\n",
"print(file_id)\n",
"\n",
"# drive.add_file(\"test.txt\", \"Hello, World!\")"
"metadata = drive.add_bytes(file_content, \"test.txt\")\n",
"print(metadata)\n",
"\n"
]
}
],
Expand All @@ -53,7 +43,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.11.10"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 7b3f9af

Please sign in to comment.