Skip to content

Commit

Permalink
Add database close methods for Node.js and Python APIs (#3289)
Browse files Browse the repository at this point in the history
* Add database close method to Python API

* Fix test_database_close on Windows

* Add database close functionality to Node.js API

* Fix Python linter errors

* Fix database close exception message in Python API
  • Loading branch information
mewim committed Apr 17, 2024
1 parent 12127d3 commit c95edc7
Show file tree
Hide file tree
Showing 15 changed files with 665 additions and 294 deletions.
3 changes: 2 additions & 1 deletion tools/nodejs_api/src_cpp/include/node_database.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class NodeDatabase : public Napi::ObjectWrap<NodeDatabase> {
private:
Napi::Value InitAsync(const Napi::CallbackInfo& info);
void InitCppDatabase();
void setLoggingLevel(const Napi::CallbackInfo& info);
void SetLoggingLevel(const Napi::CallbackInfo& info);
void Close(const Napi::CallbackInfo& info);
static Napi::Value GetVersion(const Napi::CallbackInfo& info);
static Napi::Value GetStorageVersion(const Napi::CallbackInfo& info);

Expand Down
11 changes: 9 additions & 2 deletions tools/nodejs_api/src_cpp/node_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Napi::Object NodeDatabase::Init(Napi::Env env, Napi::Object exports) {
Napi::Function t = DefineClass(env, "NodeDatabase",
{
InstanceMethod("initAsync", &NodeDatabase::InitAsync),
InstanceMethod("setLoggingLevel", &NodeDatabase::setLoggingLevel),
InstanceMethod("setLoggingLevel", &NodeDatabase::SetLoggingLevel),
InstanceMethod("close", &NodeDatabase::Close),
StaticMethod("getVersion", &NodeDatabase::GetVersion),
StaticMethod("getStorageVersion", &NodeDatabase::GetStorageVersion),
});
Expand Down Expand Up @@ -51,14 +52,20 @@ void NodeDatabase::InitCppDatabase() {
this->database = std::make_shared<Database>(databasePath, systemConfig);
}

void NodeDatabase::setLoggingLevel(const Napi::CallbackInfo& info) {
void NodeDatabase::SetLoggingLevel(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

auto loggingLevel = info[0].As<Napi::String>().Utf8Value();
database->setLoggingLevel(std::move(loggingLevel));
}

void NodeDatabase::Close(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);
database.reset();
}

Napi::Value NodeDatabase::GetVersion(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);
Expand Down
29 changes: 29 additions & 0 deletions tools/nodejs_api/src_js/database.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Database {
);
this._isInitialized = false;
this._initPromise = null;
this._isClosed = false;
}

/**
Expand Down Expand Up @@ -99,6 +100,9 @@ class Database {
* @returns {KuzuNative.NodeDatabase} the underlying native database.
*/
async _getDatabase() {
if (this._isClosed) {
throw new Error("Database is closed.");
}
await this.init();
return this._database;
}
Expand Down Expand Up @@ -126,6 +130,31 @@ class Database {
}
this._loggingLevel = loggingLevel;
}

/**
* Close the database.
*/
async close() {
if (this._isClosed) {
return;
}
if (!this._isInitialized) {
if (this._initPromise) {
// Database is initializing, wait for it to finish first.
await this._initPromise;
} else {
// Database is not initialized, simply mark it as closed and initialized.
this._isInitialized = true;
this._isClosed = true;
delete this._database;
return;
}
}
// Database is initialized, close it.
this._database.close();
delete this._database;
this._isClosed = true;
}
}

module.exports = Database;
6 changes: 4 additions & 2 deletions tools/nodejs_api/test/common.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ chai.config.includeStack = true;

const TEST_INSTALLED = process.env.TEST_INSTALLED || false;
if (TEST_INSTALLED) {
console.log("Testing installed package...");
global.kuzu = require("kuzu");
global.kuzuPath = require.resolve("kuzu");
console.log("Testing installed version @", kuzuPath);
} else {
console.log("Testing locally built version...");
global.kuzu = require("../build/");
global.kuzuPath = require.resolve("../build/");
console.log("Testing locally built version @", kuzuPath);
}

const tmp = require("tmp");
Expand Down
149 changes: 147 additions & 2 deletions tools/nodejs_api/test/test_database.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,34 @@ const { assert } = require("chai");
const tmp = require("tmp");
const process = require("process");

const spwan = require("child_process").spawn;

const openDatabaseOnSubprocess = (dbPath) => {
return new Promise((resolve, _) => {
const node = process.argv[0];
const code = `
(async() => {
const kuzu = require("${kuzuPath}");
const db = new kuzu.Database("${dbPath}", 1 << 28);
await db.init();
console.log("Database initialized.");
})();
`;
const child = spwan(node, ["-e", code]);
let stdout = "";
let stderr = "";
child.stdout.on("data", (data) => {
stdout += data;
});
child.stderr.on("data", (data) => {
stderr += data;
});
child.on("close", (code) => {
resolve({ code, stdout, stderr });
});
});
};

describe("Database constructor", function () {
it("should create a database with a valid path and buffer size", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
Expand Down Expand Up @@ -104,8 +132,7 @@ describe("Database constructor", function () {
}
return resolve(path);
});
}
);
});
const testDb = new kuzu.Database(
tmpDbPath,
1 << 28 /* 256MB */,
Expand Down Expand Up @@ -212,3 +239,121 @@ describe("Set logging level", function () {
}
});
});

describe("Database close", function () {
it("should allow initializing a new database after closing", async function () {
if (process.platform === "win32") {
this._runnable.title += " (skipped: not implemented on Windows)";
this.skip();
}
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
await testDb.init();
let subProcessResult = await openDatabaseOnSubprocess(tmpDbPath);
assert.notEqual(subProcessResult.code, 0);
assert.include(
subProcessResult.stderr,
"Error: IO exception: Could not set lock on file"
);
await testDb.close();
subProcessResult = await openDatabaseOnSubprocess(tmpDbPath);
assert.equal(subProcessResult.code, 0);
assert.isEmpty(subProcessResult.stderr);
assert.include(subProcessResult.stdout, "Database initialized.");
});

it("should throw error if the database is closed", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
await testDb.init();
await testDb.close();
try {
await testDb._getDatabase();
assert.fail("No error thrown when the database is closed.");
} catch (e) {
assert.equal(e.message, "Database is closed.");
}
});

it("should close the database if it is initialized", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
await testDb.init();
assert.isTrue(testDb._isInitialized);
assert.exists(testDb._database);
await testDb.close();
assert.notExists(testDb._database);
assert.isTrue(testDb._isClosed);
});

it("should close the database if it is not initialized", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
assert.isFalse(testDb._isInitialized);
await testDb.close();
assert.notExists(testDb._database);
assert.isTrue(testDb._isClosed);
assert.isTrue(testDb._isInitialized);
});

it("should close a initializing database", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
await Promise.all([testDb.init(), testDb.close()]);
assert.notExists(testDb._database);
assert.isTrue(testDb._isClosed);
assert.isTrue(testDb._isInitialized);
});

it("should gracefully close a database multiple times", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
await testDb.init();
await Promise.all([testDb.close(), testDb.close(), testDb.close()]);
assert.notExists(testDb._database);
assert.isTrue(testDb._isClosed);
assert.isTrue(testDb._isInitialized);
});
});
2 changes: 2 additions & 0 deletions tools/python_api/src_cpp/include/py_database.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class PyDatabase {

~PyDatabase();

void close();

template<class T>
void scanNodeTable(const std::string& tableName, const std::string& propName,
const py::array_t<uint64_t>& indices, py::array_t<T>& result, int numThreads);
Expand Down
5 changes: 5 additions & 0 deletions tools/python_api/src_cpp/py_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ void PyDatabase::initialize(py::handle& m) {
py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads"))
.def("scan_node_table_as_bool", &PyDatabase::scanNodeTable<bool>, py::arg("table_name"),
py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads"))
.def("close", &PyDatabase::close)
.def_static("get_version", &PyDatabase::getVersion)
.def_static("get_storage_version", &PyDatabase::getStorageVersion);
}
Expand Down Expand Up @@ -59,6 +60,10 @@ PyDatabase::PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize,

PyDatabase::~PyDatabase() {}

void PyDatabase::close() {
database.reset();
}

template<class T>
void PyDatabase::scanNodeTable(const std::string& tableName, const std::string& propName,
const py::array_t<uint64_t>& indices, py::array_t<T>& result, int numThreads) {
Expand Down
Loading

0 comments on commit c95edc7

Please sign in to comment.