Skip to content

Commit

Permalink
Merge pull request #1418 from kuzudb/pyg-remote-backend
Browse files Browse the repository at this point in the history
Remote backend for PyG
  • Loading branch information
mewim committed Apr 2, 2023
2 parents 8a5a826 + 6ee6705 commit 9cd0309
Show file tree
Hide file tree
Showing 26 changed files with 938 additions and 26 deletions.
2 changes: 1 addition & 1 deletion scripts/pip-package/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run(self):
install_requires=[],
ext_modules=[CMakeExtension(
name="kuzu", sourcedir=base_dir)],
description='KuzuDB Python API',
description='An in-process property graph database management system built for query speed and scalability.',
license='MIT',
long_description=open(os.path.join(base_dir, "README.md"), 'r').read(),
long_description_content_type="text/markdown",
Expand Down
8 changes: 6 additions & 2 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,15 @@ DataTypeID Types::dataTypeIDFromString(const std::string& dataTypeIDString) {

std::string Types::dataTypeToString(const DataType& dataType) {
switch (dataType.typeID) {
case VAR_LIST:
case FIXED_LIST: {
case VAR_LIST: {
assert(dataType.childType);
return dataTypeToString(*dataType.childType) + "[]";
}
case FIXED_LIST: {
assert(dataType.childType);
return dataTypeToString(*dataType.childType) + "[" +
std::to_string(dataType.fixedNumElementsInList) + "]";
}
case ANY:
case NODE:
case REL:
Expand Down
1 change: 1 addition & 0 deletions src/include/main/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ KUZU_API struct SystemConfig {
class Database {
friend class EmbeddedShell;
friend class Connection;
friend class StorageDriver;
friend class kuzu::testing::BaseGraphTest;

public:
Expand Down
2 changes: 2 additions & 0 deletions src/include/main/query_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class QueryResult {
*/
KUZU_API void resetIterator();

processor::FactorizedTable* getTable() { return factorizedTable.get(); }

private:
void initResultTableAndIterator(std::shared_ptr<processor::FactorizedTable> factorizedTable_,
const std::vector<std::shared_ptr<binder::Expression>>& columns,
Expand Down
37 changes: 37 additions & 0 deletions src/include/main/storage_driver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include <queue>

#include "common/types/types_include.h"
#include "database.h"

namespace kuzu {
namespace storage {
class Column;
}

namespace main {

class StorageDriver {
public:
explicit StorageDriver(Database* database);

~StorageDriver();

void scan(const std::string& nodeName, const std::string& propertyName,
common::offset_t* offsets, size_t size, uint8_t* result, size_t numThreads);

uint64_t getNumNodes(const std::string& nodeName);
uint64_t getNumRels(const std::string& relName);

private:
void scanColumn(
storage::Column* column, common::offset_t* offsets, size_t size, uint8_t* result);

private:
catalog::Catalog* catalog;
storage::StorageManager* storageManager;
};

} // namespace main
} // namespace kuzu
3 changes: 3 additions & 0 deletions src/include/storage/storage_structure/column.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class Column : public BaseColumnOrList {
: Column(structureIDAndFName, dataType, common::Types::getDataTypeSize(dataType),
bufferManager, wal){};

// Expose for feature store
void scan(const common::offset_t* nodeOffsets, size_t size, uint8_t* result);

virtual void read(transaction::Transaction* transaction, common::ValueVector* nodeIDVector,
common::ValueVector* resultVector);

Expand Down
3 changes: 2 additions & 1 deletion src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ add_library(kuzu_main
plan_printer.cpp
prepared_statement.cpp
query_result.cpp
query_summary.cpp)
query_summary.cpp
storage_driver.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_main>
Expand Down
64 changes: 64 additions & 0 deletions src/main/storage_driver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "main/storage_driver.h"

#include "storage/storage_manager.h"

using namespace kuzu::common;

namespace kuzu {
namespace main {

StorageDriver::StorageDriver(kuzu::main::Database* database)
: catalog{database->catalog.get()}, storageManager{database->storageManager.get()} {}

StorageDriver::~StorageDriver() = default;

void StorageDriver::scan(const std::string& nodeName, const std::string& propertyName,
common::offset_t* offsets, size_t size, uint8_t* result, size_t numThreads) {
// Resolve files to read from
auto catalogContent = catalog->getReadOnlyVersion();
auto nodeTableID = catalogContent->getTableID(nodeName);
auto propertyID = catalogContent->getTableSchema(nodeTableID)->getPropertyID(propertyName);
auto nodeTable = storageManager->getNodesStore().getNodeTable(nodeTableID);
auto column = nodeTable->getPropertyColumn(propertyID);

auto current_buffer = result;
std::vector<std::thread> threads;
auto numElementsPerThread = size / numThreads + 1;
auto sizeLeft = size;
while (sizeLeft > 0) {
uint64_t sizeToRead = std::min(numElementsPerThread, sizeLeft);
threads.emplace_back(
&StorageDriver::scanColumn, this, column, offsets, sizeToRead, current_buffer);
offsets += sizeToRead;
current_buffer += sizeToRead * column->elementSize;
sizeLeft -= sizeToRead;
}
for (auto& thread : threads) {
thread.join();
}
}

uint64_t StorageDriver::getNumNodes(const std::string& nodeName) {
auto catalogContent = catalog->getReadOnlyVersion();
auto nodeTableID = catalogContent->getTableID(nodeName);
auto nodeStatistics = storageManager->getNodesStore()
.getNodesStatisticsAndDeletedIDs()
.getNodeStatisticsAndDeletedIDs(nodeTableID);
return nodeStatistics->getNumTuples();
}

uint64_t StorageDriver::getNumRels(const std::string& relName) {
auto catalogContent = catalog->getReadOnlyVersion();
auto relTableID = catalogContent->getTableID(relName);
auto relStatistics =
storageManager->getRelsStore().getRelsStatistics().getRelStatistics(relTableID);
return relStatistics->getNumTuples();
}

void StorageDriver::scanColumn(
storage::Column* column, common::offset_t* offsets, size_t size, uint8_t* result) {
column->scan(offsets, size, result);
}

} // namespace main
} // namespace kuzu
14 changes: 14 additions & 0 deletions src/storage/storage_structure/column.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ using namespace kuzu::transaction;
namespace kuzu {
namespace storage {

void Column::scan(const common::offset_t* nodeOffsets, size_t size, uint8_t* result) {
for (auto i = 0u; i < size; ++i) {
auto nodeOffset = nodeOffsets[i];
auto cursor = PageUtils::getPageElementCursorForPos(nodeOffset, numElementsPerPage);
auto [fileHandleToPin, pageIdxToPin] =
StorageStructureUtils::getFileHandleAndPhysicalPageIdxToPin(
*fileHandle, cursor.pageIdx, *wal, TransactionType::READ_ONLY);
bufferManager.optimisticRead(*fileHandleToPin, pageIdxToPin, [&](uint8_t* frame) -> void {
auto frameBytesOffset = getElemByteOffset(cursor.elemPosInPage);
memcpy(result + i * elementSize, frame + frameBytesOffset, elementSize);
});
}
}

void Column::read(Transaction* transaction, common::ValueVector* nodeIDVector,
common::ValueVector* resultVector) {
if (nodeIDVector->state->isFlat()) {
Expand Down
3 changes: 2 additions & 1 deletion test/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ add_kuzu_test(main_test
csv_output_test.cpp
exception_test.cpp
prepare_test.cpp
result_value_test.cpp)
result_value_test.cpp
storage_driver_test.cpp)
28 changes: 28 additions & 0 deletions test/main/storage_driver_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "main/storage_driver.h"
#include "main_test_helper/main_test_helper.h"

using namespace kuzu::testing;
using namespace kuzu::common;

TEST_F(ApiTest, StorageDriverScan) {
auto storageDriver = std::make_unique<StorageDriver>(database.get());
auto size = 6;
auto nodeOffsetsBuffer = std::make_unique<uint8_t[]>(sizeof(offset_t) * size);
auto nodeOffsets = (offset_t*)nodeOffsetsBuffer.get();
nodeOffsets[0] = 7;
nodeOffsets[1] = 0;
nodeOffsets[2] = 3;
nodeOffsets[3] = 1;
nodeOffsets[4] = 2;
nodeOffsets[5] = 6;
auto result = std::make_unique<uint8_t[]>(sizeof(int64_t) * size);
auto resultBuffer = (uint8_t*)result.get();
storageDriver->scan("person", "ID", nodeOffsets, size, resultBuffer, 3);
auto ids = (int64_t*)resultBuffer;
ASSERT_EQ(ids[0], 10);
ASSERT_EQ(ids[1], 0);
ASSERT_EQ(ids[2], 5);
ASSERT_EQ(ids[3], 2);
ASSERT_EQ(ids[4], 3);
ASSERT_EQ(ids[5], 9);
}
16 changes: 16 additions & 0 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "py_database.h"
#include "py_prepared_statement.h"
#include "py_query_result.h"
#include "main/storage_driver.h"

class PyConnection {

Expand All @@ -21,8 +22,22 @@ class PyConnection {

py::str getNodePropertyNames(const std::string& tableName);

py::str getNodeTableNames();

py::str getRelPropertyNames(const std::string& tableName);

py::str getRelTableNames();

PyPreparedStatement prepare(const std::string& query);

uint64_t getNumNodes(const std::string& nodeName);

uint64_t getNumRels(const std::string& relName);

void getAllEdgesForTorchGeometric(py::array_t<int64_t>& npArray,
const std::string& srcTableName, const std::string& relName,
const std::string& dstTableName, size_t queryBatchSize);

private:
std::unordered_map<std::string, std::shared_ptr<kuzu::common::Value>> transformPythonParameters(
py::list params);
Expand All @@ -33,5 +48,6 @@ class PyConnection {
kuzu::common::Value transformPythonValue(py::handle val);

private:
std::unique_ptr<StorageDriver> storageDriver;
std::unique_ptr<Connection> conn;
};
10 changes: 9 additions & 1 deletion tools/python_api/src_cpp/include/py_database.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#pragma once

#include "main/kuzu.h"
#include "pybind_include.h"

#include "pybind_include.h"
#include "main/storage_driver.h"
#define PYBIND11_DETAILED_ERROR_MESSAGES
using namespace kuzu::main;


class PyDatabase {
friend class PyConnection;

Expand All @@ -19,6 +22,11 @@ class PyDatabase {

~PyDatabase() = default;

template<class T>
void scanNodeTable(const std::string& tableName, const std::string& propName,
const py::array_t<uint64_t>& indices, py::array_t<T>& result, int numThreads);

private:
std::unique_ptr<Database> database;
std::unique_ptr<StorageDriver> storageDriver;
};
Loading

0 comments on commit 9cd0309

Please sign in to comment.