Skip to content

Commit

Permalink
Bind manual database close methods to Python APIs (#3435)
Browse files Browse the repository at this point in the history
* Bind manual database close methods to Python APIs

* Add new line

* Run clang-format

* Fix linter

---------

Co-authored-by: CI Bot <mewim@users.noreply.github.com>
  • Loading branch information
mewim and mewim committed May 2, 2024
1 parent a5f3627 commit 1ec7425
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 12 deletions.
2 changes: 2 additions & 0 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class PyConnection {

explicit PyConnection(PyDatabase* pyDatabase, uint64_t numThreads);

void close();

~PyConnection() = default;

void setQueryTimeout(uint64_t timeoutInMS);
Expand Down
5 changes: 5 additions & 0 deletions tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using namespace kuzu;
void PyConnection::initialize(py::handle& m) {
py::class_<PyConnection>(m, "Connection")
.def(py::init<PyDatabase*, uint64_t>(), py::arg("database"), py::arg("num_threads") = 0)
.def("close", &PyConnection::close)
.def("execute", &PyConnection::execute, py::arg("prepared_statement"),
py::arg("parameters") = py::dict())
.def("query", &PyConnection::query, py::arg("statement"))
Expand All @@ -44,6 +45,10 @@ PyConnection::PyConnection(PyDatabase* pyDatabase, uint64_t numThreads) {
}
}

void PyConnection::close() {
conn.reset();
}

void PyConnection::setQueryTimeout(uint64_t timeoutInMS) {
conn->setQueryTimeOut(timeoutInMS);
}
Expand Down
35 changes: 35 additions & 0 deletions tools/python_api/src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@
from .query_result import QueryResult

if TYPE_CHECKING:
import sys
from types import TracebackType

from .database import Database

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class Connection:
"""Connection to a database."""
Expand All @@ -29,6 +37,7 @@ def __init__(self, database: Database, num_threads: int = 0):
self._connection: Any = None # (type: _kuzu.Connection from pybind11)
self.database = database
self.num_threads = num_threads
self.is_closed = False
self.init_connection()

def __getstate__(self) -> dict[str, Any]:
Expand All @@ -41,6 +50,9 @@ def __getstate__(self) -> dict[str, Any]:

def init_connection(self) -> None:
"""Establish a connection to the database, if not already initalised."""
if self.is_closed:
error_msg = "Connection is closed."
raise RuntimeError(error_msg)
self.database.init_database()
if self._connection is None:
self._connection = _kuzu.Connection(self.database._database, self.num_threads) # type: ignore[union-attr]
Expand All @@ -58,6 +70,29 @@ def set_max_threads_for_exec(self, num_threads: int) -> None:
self.init_connection()
self._connection.set_max_threads_for_exec(num_threads)

def close(self) -> None:
"""
Close the connection.
Note: Call to this method is optional. The connection will be closed
automatically when the object goes out of scope.
"""
if self._connection is not None:
self._connection.close()
self._connection = None
self.is_closed = True

def __enter__(self) -> Self:
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
self.close()

def execute(
self,
query: str | PreparedStatement,
Expand Down
13 changes: 11 additions & 2 deletions tools/python_api/src_py/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,16 @@ def _scan_node_table(
raise ValueError(msg)

def close(self) -> None:
"""Close the database."""
"""
Close the database. Once the database is closed, the lock on the database
files is released and the database can be opened in another process.
Note: Call to this method is not required. The Python garbage collector
will automatically close the database when no references to the database
object exist. It is recommended not to call this method explicitly. If you
decide to manually close the database, make sure that all the QueryResult
and Connection objects are closed before calling this method.
"""
if self.is_closed:
return
self.is_closed = True
Expand All @@ -275,5 +284,5 @@ def check_for_database_close(self) -> None:
"""
if not self.is_closed:
return
msg = "Query result is closed"
msg = "Database is closed"
raise RuntimeError(msg)
29 changes: 29 additions & 0 deletions tools/python_api/test/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import kuzu
import pytest

if TYPE_CHECKING:
from pathlib import Path


def test_connection_close(tmp_path: Path, build_dir: Path) -> None:
db_path = tmp_path / "test_connection_close.kuzu"
db = kuzu.Database(database_path=db_path, read_only=False)
conn = kuzu.Connection(db)
conn.close()
assert conn.is_closed
pytest.raises(RuntimeError, conn.execute, "RETURN 1")
db.close()


def test_connection_close_context_manager(tmp_path: Path, build_dir: Path) -> None:
db_path = tmp_path / "test_connection_close_context_manager.kuzu"
with kuzu.Database(database_path=db_path, read_only=False) as db:
with kuzu.Connection(db) as conn:
pass
assert conn.is_closed
pytest.raises(RuntimeError, conn.execute, "RETURN 1")
assert db.is_closed
6 changes: 4 additions & 2 deletions tools/python_api/test/test_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ def test_exception(conn_db_readonly: ConnDB) -> None:

def test_db_path_exception() -> None:
path = ""
error_message = ("IO exception: Failed to create directory due to: IO exception: Directory cannot be created. "
"Check if it exists and remove it.")
error_message = (
"IO exception: Failed to create directory due to: IO exception: Directory cannot be created. "
"Check if it exists and remove it."
)
with pytest.raises(RuntimeError, match=error_message):
kuzu.Database(path)

Expand Down
12 changes: 4 additions & 8 deletions tools/python_api/test/test_scan_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,7 @@ def test_scan_all_null(tmp_path: Path) -> None:
def test_copy_from_scan_pandas_result(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({
"name": ["Adam", "Karissa", "Zhang", "Noura"],
"age": [30, 40, 50, 25]
})
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]})
conn.execute("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY (name));")
conn.execute("COPY Person FROM (LOAD FROM df WHERE age < 30 RETURN *);")
result = conn.execute("match (p:Person) return p.*")
Expand All @@ -353,10 +350,9 @@ def test_copy_from_scan_pandas_result(tmp_path: Path) -> None:
def test_scan_from_py_arrow_pandas(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({
"name": ["Adam", "Karissa", "Zhang", "Noura"],
"age": [30, 40, 50, 25]
}).convert_dtypes(dtype_backend="pyarrow")
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]}).convert_dtypes(
dtype_backend="pyarrow"
)
result = conn.execute("LOAD FROM df RETURN *;")
assert result.get_next() == ["Adam", 30]
assert result.get_next() == ["Karissa", 40]
Expand Down

0 comments on commit 1ec7425

Please sign in to comment.