Skip to content

Commit

Permalink
Remote backend for PyG
Browse files Browse the repository at this point in the history
  • Loading branch information
mewim committed Mar 29, 2023
1 parent 0a882b3 commit 52bd939
Show file tree
Hide file tree
Showing 24 changed files with 846 additions and 17 deletions.
2 changes: 1 addition & 1 deletion scripts/pip-package/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run(self):
install_requires=[],
ext_modules=[CMakeExtension(
name="kuzu", sourcedir=base_dir)],
description='KuzuDB Python API',
description='An in-process property graph database management system built for query speed and scalability.',
license='MIT',
long_description=open(os.path.join(base_dir, "README.md"), 'r').read(),
long_description_content_type="text/markdown",
Expand Down
8 changes: 6 additions & 2 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,15 @@ DataTypeID Types::dataTypeIDFromString(const std::string& dataTypeIDString) {

std::string Types::dataTypeToString(const DataType& dataType) {
switch (dataType.typeID) {
case VAR_LIST:
case FIXED_LIST: {
case VAR_LIST: {
assert(dataType.childType);
return dataTypeToString(*dataType.childType) + "[]";
}
case FIXED_LIST: {
assert(dataType.childType);
return dataTypeToString(*dataType.childType) + "[" +
std::to_string(dataType.fixedNumElementsInList) + "]";
}
case ANY:
case NODE:
case REL:
Expand Down
1 change: 1 addition & 0 deletions src/include/main/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ KUZU_API struct SystemConfig {
class Database {
friend class EmbeddedShell;
friend class Connection;
friend class StorageDriver;
friend class kuzu::testing::BaseGraphTest;

public:
Expand Down
58 changes: 58 additions & 0 deletions src/include/main/storage_driver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#pragma once

#include <queue>

#include "common/types/types_include.h"
#include "database.h"

namespace kuzu {
namespace storage {
class Column;
}

namespace main {

//// Note: pooling is tricky to figure out "Busy state"
// class ThreadPool {
// public:
// ThreadPool(size_t numThreads);
// ~ThreadPool();
//
// template<typename F, typename... Args>
// void enqueue(F&& f, Args&&... args) {
// {
// std::unique_lock<std::mutex> lock(queueMtx);
// tasks.emplace([=] { std::invoke(f, args...); });
// }
// condition.notify_one();
// }
//
// private:
// std::vector<std::thread> threads;
// std::queue<std::function<void()>> tasks;
// std::mutex queueMtx;
// std::condition_variable condition;
// bool stop = false;
//};

class StorageDriver {
public:
explicit StorageDriver(Database* database, size_t numThreads = 1);

~StorageDriver();

std::pair<std::unique_ptr<uint8_t[]>, size_t> scan(const std::string& nodeName,
const std::string& propertyName, common::offset_t* offsets, size_t size);

private:
void scanColumn(
storage::Column* column, common::offset_t* offsets, size_t size, uint8_t* result);

private:
catalog::Catalog* catalog;
storage::StorageManager* storageManager;
size_t numThreads;
};

} // namespace main
} // namespace kuzu
3 changes: 3 additions & 0 deletions src/include/storage/storage_structure/column.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class Column : public BaseColumnOrList {
: Column(structureIDAndFName, dataType, common::Types::getDataTypeSize(dataType),
bufferManager, wal){};

// Expose for feature store
void scan(const common::offset_t* nodeOffsets, size_t size, uint8_t* result);

virtual void read(transaction::Transaction* transaction, common::ValueVector* nodeIDVector,
common::ValueVector* resultVector);

Expand Down
3 changes: 2 additions & 1 deletion src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ add_library(kuzu_main
plan_printer.cpp
prepared_statement.cpp
query_result.cpp
query_summary.cpp)
query_summary.cpp
storage_driver.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_main>
Expand Down
82 changes: 82 additions & 0 deletions src/main/storage_driver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "main/storage_driver.h"

#include "catalog/catalog.h"
#include "storage/storage_manager.h"

using namespace kuzu::common;

namespace kuzu {
namespace main {

// ThreadPool::ThreadPool(size_t numThreads) {
// for (auto i = 0u; i < numThreads; ++i) {
// threads.emplace_back([this] {
// while (true) {
// std::function<void()> task;
// {
// std::unique_lock<std::mutex> lck(queueMtx);
// condition.wait(lck, [this] { return stop || !tasks.empty(); });
// if (stop && tasks.empty()) {
// return;
// }
// task = std::move(tasks.front());
// tasks.pop();
// }
// task();
// }
// });
// }
//}
//
// ThreadPool::~ThreadPool() {
// {
// std::unique_lock<std::mutex> lck(queueMtx);
// stop = true;
// }
// condition.notify_all();
// for (auto& thread : threads) {
// thread.join();
// }
//}

StorageDriver::StorageDriver(kuzu::main::Database* database, size_t numThreads)
: catalog{database->catalog.get()}, storageManager{database->storageManager.get()},
numThreads{numThreads} {}

StorageDriver::~StorageDriver() = default;

std::pair<std::unique_ptr<uint8_t[]>, size_t> StorageDriver::scan(const std::string& nodeName,
const std::string& propertyName, common::offset_t* offsets, size_t size) {
// Resolve files to read from
auto catalogContent = catalog->getReadOnlyVersion();
auto nodeTableID = catalogContent->getTableID(nodeName);
auto propertyID = catalogContent->getTableSchema(nodeTableID)->getPropertyID(propertyName);
auto nodeTable = storageManager->getNodesStore().getNodeTable(nodeTableID);
auto column = nodeTable->getPropertyColumn(propertyID);

auto bufferSize = column->elementSize * size;
auto result = std::make_unique<uint8_t[]>(bufferSize);
auto buffer = result.get();
std::vector<std::thread> threads;
auto numElementsPerThread = size / numThreads + 1;
auto sizeLeft = size;
while (sizeLeft > 0) {
auto sizeToRead = std::min(numElementsPerThread, sizeLeft);
threads.emplace_back(&StorageDriver::scanColumn, this, column, offsets, sizeToRead, buffer);
offsets += sizeToRead;
buffer += sizeToRead * column->elementSize;
sizeLeft -= sizeToRead;
}
for (auto& thread : threads) {
thread.join();
}
return std::make_pair(std::move(result), bufferSize);
}

void StorageDriver::scanColumn(
storage::Column* column, common::offset_t* offsets, size_t size, uint8_t* result) {
column->scan(offsets, size, result);
}

} // namespace main
} // namespace kuzu
14 changes: 14 additions & 0 deletions src/storage/storage_structure/column.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ using namespace kuzu::transaction;
namespace kuzu {
namespace storage {

void Column::scan(const common::offset_t* nodeOffsets, size_t size, uint8_t* result) {
for (auto i = 0u; i < size; ++i) {
auto nodeOffset = nodeOffsets[i];
auto cursor = PageUtils::getPageElementCursorForPos(nodeOffset, numElementsPerPage);
auto [fileHandleToPin, pageIdxToPin] =
StorageStructureUtils::getFileHandleAndPhysicalPageIdxToPin(
*fileHandle, cursor.pageIdx, *wal, TransactionType::READ_ONLY);
auto frame = bufferManager.pin(*fileHandleToPin, pageIdxToPin);
auto frameBytesOffset = getElemByteOffset(cursor.elemPosInPage);
memcpy(result + i * elementSize, frame + frameBytesOffset, elementSize);
bufferManager.unpin(*fileHandleToPin, pageIdxToPin);
}
}

void Column::read(Transaction* transaction, common::ValueVector* nodeIDVector,
common::ValueVector* resultVector) {
if (nodeIDVector->state->isFlat()) {
Expand Down
3 changes: 2 additions & 1 deletion test/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ add_kuzu_test(main_test
csv_output_test.cpp
exception_test.cpp
prepare_test.cpp
result_value_test.cpp)
result_value_test.cpp
storage_driver_test.cpp)
26 changes: 26 additions & 0 deletions test/main/storage_driver_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "main/storage_driver.h"
#include "main_test_helper/main_test_helper.h"

using namespace kuzu::testing;
using namespace kuzu::common;

TEST_F(ApiTest, StorageDriverScan) {
auto storageDriver = std::make_unique<StorageDriver>(database.get(), 3);
auto size = 6;
auto nodeOffsetsBuffer = std::make_unique<uint8_t[]>(sizeof(offset_t) * size);
auto nodeOffsets = (offset_t*)nodeOffsetsBuffer.get();
nodeOffsets[0] = 7;
nodeOffsets[1] = 0;
nodeOffsets[2] = 3;
nodeOffsets[3] = 1;
nodeOffsets[4] = 2;
nodeOffsets[5] = 6;
auto result = storageDriver->scan("person", "ID", nodeOffsets, size);
auto ids = (int64_t*)result.first.get();
ASSERT_EQ(ids[0], 10);
ASSERT_EQ(ids[1], 0);
ASSERT_EQ(ids[2], 5);
ASSERT_EQ(ids[3], 2);
ASSERT_EQ(ids[4], 3);
ASSERT_EQ(ids[5], 9);
}
6 changes: 6 additions & 0 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ class PyConnection {

py::str getNodePropertyNames(const std::string& tableName);

py::str getNodeTableNames();

py::str getRelPropertyNames(const std::string& tableName);

py::str getRelTableNames();

PyPreparedStatement prepare(const std::string& query);

private:
Expand Down
8 changes: 7 additions & 1 deletion tools/python_api/src_cpp/include/py_database.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#pragma once

#include "main/kuzu.h"
#include "main/storage_driver.h"
#include "pybind_include.h"

#define PYBIND11_DETAILED_ERROR_MESSAGES
using namespace kuzu::main;

class PyDatabase {
Expand All @@ -17,8 +18,13 @@ class PyDatabase {

explicit PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize);

template<class T>
py::array_t<T> scanNodeTable(const std::string& tableName, const std::string& propName,
const py::array_t<uint64_t>& indices, int numThreads);

~PyDatabase() = default;

private:
std::unique_ptr<Database> database;
std::unique_ptr<StorageDriver> storageDriver;
};
15 changes: 15 additions & 0 deletions tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ void PyConnection::initialize(py::handle& m) {
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"))
.def("get_node_property_names", &PyConnection::getNodePropertyNames, py::arg("table_name"))
.def("get_node_table_names", &PyConnection::getNodeTableNames)
.def("get_rel_property_names", &PyConnection::getRelPropertyNames, py::arg("table_name"))
.def("get_rel_table_names", &PyConnection::getRelTableNames)
.def("prepare", &PyConnection::prepare, py::arg("query"));
PyDateTime_IMPORT;
}
Expand Down Expand Up @@ -51,6 +54,18 @@ py::str PyConnection::getNodePropertyNames(const std::string& tableName) {
return conn->getNodePropertyNames(tableName);
}

py::str PyConnection::getNodeTableNames() {
return conn->getNodeTableNames();
}

py::str PyConnection::getRelPropertyNames(const std::string& tableName) {
return conn->getRelPropertyNames(tableName);
}

py::str PyConnection::getRelTableNames() {
return conn->getRelTableNames();
}

PyPreparedStatement PyConnection::prepare(const std::string& query) {
auto preparedStatement = conn->prepare(query);
PyPreparedStatement pyPreparedStatement;
Expand Down
40 changes: 39 additions & 1 deletion tools/python_api/src_cpp/py_database.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
#include "include/py_database.h"

#include <memory>

using namespace kuzu::common;

void PyDatabase::initialize(py::handle& m) {
py::class_<PyDatabase>(m, "Database")
.def(py::init<const std::string&, uint64_t>(), py::arg("database_path"),
py::arg("buffer_pool_size") = 0)
.def("set_logging_level", &PyDatabase::setLoggingLevel, py::arg("logging_level"));
.def("set_logging_level", &PyDatabase::setLoggingLevel, py::arg("logging_level"))
.def("scan_node_table_as_int64", &PyDatabase::scanNodeTable<std::int64_t>,
py::return_value_policy::take_ownership, py::arg("table_name"), py::arg("prop_name"),
py::arg("indices"), py::arg("num_threads"))
.def("scan_node_table_as_int32", &PyDatabase::scanNodeTable<std::int32_t>,
py::return_value_policy::take_ownership, py::arg("table_name"), py::arg("prop_name"),
py::arg("indices"), py::arg("num_threads"))
.def("scan_node_table_as_int16", &PyDatabase::scanNodeTable<std::int16_t>,
py::return_value_policy::take_ownership, py::arg("table_name"), py::arg("prop_name"),
py::arg("indices"), py::arg("num_threads"))
.def("scan_node_table_as_double", &PyDatabase::scanNodeTable<std::double_t>,
py::return_value_policy::take_ownership, py::arg("table_name"), py::arg("prop_name"),
py::arg("indices"), py::arg("num_threads"))
.def("scan_node_table_as_float", &PyDatabase::scanNodeTable<std::float_t>,
py::return_value_policy::take_ownership, py::arg("table_name"), py::arg("prop_name"),
py::arg("indices"), py::arg("num_threads"))
.def("scan_node_table_as_bool", &PyDatabase::scanNodeTable<bool>,
py::return_value_policy::take_ownership, py::arg("table_name"), py::arg("prop_name"),
py::arg("indices"), py::arg("num_threads"));
}

PyDatabase::PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize) {
Expand All @@ -16,3 +36,21 @@ PyDatabase::PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize)
}
database = std::make_unique<Database>(databasePath, systemConfig);
}

template<class T>
py::array_t<T> PyDatabase::scanNodeTable(const std::string& tableName, const std::string& propName,
const py::array_t<uint64_t>& indices, int numThreads) {
auto buf = indices.request(false);
auto ptr = static_cast<uint64_t*>(buf.ptr);
auto size = indices.size();
auto nodeOffsets = (offset_t*)ptr;
if (!storageDriver) {
storageDriver = std::make_unique<StorageDriver>(database.get());
}
auto scanResult = storageDriver->scan(tableName, propName, nodeOffsets, size);
auto buffer = (T*)(scanResult.first).get();
auto bufferSize = scanResult.second;
auto numberOfItems = bufferSize / sizeof(T);
return py::array_t<T>(py::buffer_info(buffer, sizeof(T), py::format_descriptor<T>::format(), 1,
std::vector<size_t>{numberOfItems}, std::vector<size_t>{sizeof(T)}));
}
3 changes: 2 additions & 1 deletion tools/python_api/src_cpp/py_query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) {
.attr("timedelta")(py::arg("days") = days,
py::arg("microseconds") = intervalVal.micros));
}
case VAR_LIST: {
case VAR_LIST:
case FIXED_LIST: {
auto& listVal = value.getListValReference();
py::list list;
for (auto i = 0u; i < listVal.size(); ++i) {
Expand Down
4 changes: 1 addition & 3 deletions tools/python_api/src_py/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from ._kuzu import *
# The following imports will override C++ implementations with Python
# implementations.
from .database import *
from .connection import *
from .query_result import *
Loading

0 comments on commit 52bd939

Please sign in to comment.