Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add database close methods for Node.js and Python APIs #3289

Merged
merged 5 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tools/nodejs_api/src_cpp/include/node_database.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class NodeDatabase : public Napi::ObjectWrap<NodeDatabase> {
private:
Napi::Value InitAsync(const Napi::CallbackInfo& info);
void InitCppDatabase();
void setLoggingLevel(const Napi::CallbackInfo& info);
void SetLoggingLevel(const Napi::CallbackInfo& info);
void Close(const Napi::CallbackInfo& info);
static Napi::Value GetVersion(const Napi::CallbackInfo& info);
static Napi::Value GetStorageVersion(const Napi::CallbackInfo& info);

Expand Down
11 changes: 9 additions & 2 deletions tools/nodejs_api/src_cpp/node_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Napi::Object NodeDatabase::Init(Napi::Env env, Napi::Object exports) {
Napi::Function t = DefineClass(env, "NodeDatabase",
{
InstanceMethod("initAsync", &NodeDatabase::InitAsync),
InstanceMethod("setLoggingLevel", &NodeDatabase::setLoggingLevel),
InstanceMethod("setLoggingLevel", &NodeDatabase::SetLoggingLevel),
InstanceMethod("close", &NodeDatabase::Close),
StaticMethod("getVersion", &NodeDatabase::GetVersion),
StaticMethod("getStorageVersion", &NodeDatabase::GetStorageVersion),
});
Expand Down Expand Up @@ -51,14 +52,20 @@ void NodeDatabase::InitCppDatabase() {
this->database = std::make_shared<Database>(databasePath, systemConfig);
}

void NodeDatabase::setLoggingLevel(const Napi::CallbackInfo& info) {
void NodeDatabase::SetLoggingLevel(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

auto loggingLevel = info[0].As<Napi::String>().Utf8Value();
database->setLoggingLevel(std::move(loggingLevel));
}

void NodeDatabase::Close(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);
database.reset();
}

Napi::Value NodeDatabase::GetVersion(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);
Expand Down
29 changes: 29 additions & 0 deletions tools/nodejs_api/src_js/database.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Database {
);
this._isInitialized = false;
this._initPromise = null;
this._isClosed = false;
}

/**
Expand Down Expand Up @@ -99,6 +100,9 @@ class Database {
* @returns {KuzuNative.NodeDatabase} the underlying native database.
*/
async _getDatabase() {
if (this._isClosed) {
throw new Error("Database is closed.");
}
await this.init();
return this._database;
}
Expand Down Expand Up @@ -126,6 +130,31 @@ class Database {
}
this._loggingLevel = loggingLevel;
}

/**
* Close the database.
*/
async close() {
if (this._isClosed) {
return;
}
if (!this._isInitialized) {
if (this._initPromise) {
// Database is initializing, wait for it to finish first.
await this._initPromise;
} else {
// Database is not initialized, simply mark it as closed and initialized.
this._isInitialized = true;
this._isClosed = true;
delete this._database;
return;
}
}
// Database is initialized, close it.
this._database.close();
delete this._database;
this._isClosed = true;
}
}

module.exports = Database;
6 changes: 4 additions & 2 deletions tools/nodejs_api/test/common.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ chai.config.includeStack = true;

const TEST_INSTALLED = process.env.TEST_INSTALLED || false;
if (TEST_INSTALLED) {
console.log("Testing installed package...");
global.kuzu = require("kuzu");
global.kuzuPath = require.resolve("kuzu");
console.log("Testing installed version @", kuzuPath);
} else {
console.log("Testing locally built version...");
global.kuzu = require("../build/");
global.kuzuPath = require.resolve("../build/");
console.log("Testing locally built version @", kuzuPath);
}

const tmp = require("tmp");
Expand Down
149 changes: 147 additions & 2 deletions tools/nodejs_api/test/test_database.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,34 @@ const { assert } = require("chai");
const tmp = require("tmp");
const process = require("process");

const spwan = require("child_process").spawn;

const openDatabaseOnSubprocess = (dbPath) => {
return new Promise((resolve, _) => {
const node = process.argv[0];
const code = `
(async() => {
const kuzu = require("${kuzuPath}");
const db = new kuzu.Database("${dbPath}", 1 << 28);
await db.init();
console.log("Database initialized.");
})();
`;
const child = spwan(node, ["-e", code]);
let stdout = "";
let stderr = "";
child.stdout.on("data", (data) => {
stdout += data;
});
child.stderr.on("data", (data) => {
stderr += data;
});
child.on("close", (code) => {
resolve({ code, stdout, stderr });
});
});
};

describe("Database constructor", function () {
it("should create a database with a valid path and buffer size", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
Expand Down Expand Up @@ -104,8 +132,7 @@ describe("Database constructor", function () {
}
return resolve(path);
});
}
);
});
const testDb = new kuzu.Database(
tmpDbPath,
1 << 28 /* 256MB */,
Expand Down Expand Up @@ -212,3 +239,121 @@ describe("Set logging level", function () {
}
});
});

describe("Database close", function () {
it("should allow initializing a new database after closing", async function () {
if (process.platform === "win32") {
this._runnable.title += " (skipped: not implemented on Windows)";
this.skip();
}
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
await testDb.init();
let subProcessResult = await openDatabaseOnSubprocess(tmpDbPath);
assert.notEqual(subProcessResult.code, 0);
assert.include(
subProcessResult.stderr,
"Error: IO exception: Could not set lock on file"
);
await testDb.close();
subProcessResult = await openDatabaseOnSubprocess(tmpDbPath);
assert.equal(subProcessResult.code, 0);
assert.isEmpty(subProcessResult.stderr);
assert.include(subProcessResult.stdout, "Database initialized.");
});

it("should throw error if the database is closed", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
await testDb.init();
await testDb.close();
try {
await testDb._getDatabase();
assert.fail("No error thrown when the database is closed.");
} catch (e) {
assert.equal(e.message, "Database is closed.");
}
});

it("should close the database if it is initialized", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
await testDb.init();
assert.isTrue(testDb._isInitialized);
assert.exists(testDb._database);
await testDb.close();
assert.notExists(testDb._database);
assert.isTrue(testDb._isClosed);
});

it("should close the database if it is not initialized", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
assert.isFalse(testDb._isInitialized);
await testDb.close();
assert.notExists(testDb._database);
assert.isTrue(testDb._isClosed);
assert.isTrue(testDb._isInitialized);
});

it("should close a initializing database", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
await Promise.all([testDb.init(), testDb.close()]);
assert.notExists(testDb._database);
assert.isTrue(testDb._isClosed);
assert.isTrue(testDb._isInitialized);
});

it("should gracefully close a database multiple times", async function () {
const tmpDbPath = await new Promise((resolve, reject) => {
tmp.dir({ unsafeCleanup: true }, (err, path, _) => {
if (err) {
return reject(err);
}
return resolve(path);
});
});
const testDb = new kuzu.Database(tmpDbPath, 1 << 28 /* 256MB */);
await testDb.init();
await Promise.all([testDb.close(), testDb.close(), testDb.close()]);
assert.notExists(testDb._database);
assert.isTrue(testDb._isClosed);
assert.isTrue(testDb._isInitialized);
});
});
2 changes: 2 additions & 0 deletions tools/python_api/src_cpp/include/py_database.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class PyDatabase {

~PyDatabase();

void close();

template<class T>
void scanNodeTable(const std::string& tableName, const std::string& propName,
const py::array_t<uint64_t>& indices, py::array_t<T>& result, int numThreads);
Expand Down
5 changes: 5 additions & 0 deletions tools/python_api/src_cpp/py_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ void PyDatabase::initialize(py::handle& m) {
py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads"))
.def("scan_node_table_as_bool", &PyDatabase::scanNodeTable<bool>, py::arg("table_name"),
py::arg("prop_name"), py::arg("indices"), py::arg("np_array"), py::arg("num_threads"))
.def("close", &PyDatabase::close)
.def_static("get_version", &PyDatabase::getVersion)
.def_static("get_storage_version", &PyDatabase::getStorageVersion);
}
Expand Down Expand Up @@ -59,6 +60,10 @@ PyDatabase::PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize,

PyDatabase::~PyDatabase() {}

void PyDatabase::close() {
database.reset();
}

template<class T>
void PyDatabase::scanNodeTable(const std::string& tableName, const std::string& propName,
const py::array_t<uint64_t>& indices, py::array_t<T>& result, int numThreads) {
Expand Down
Loading
Loading