Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote backend for PyG #1418

Merged
merged 1 commit into from
Apr 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/pip-package/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run(self):
install_requires=[],
ext_modules=[CMakeExtension(
name="kuzu", sourcedir=base_dir)],
description='KuzuDB Python API',
description='An in-process property graph database management system built for query speed and scalability.',
license='MIT',
long_description=open(os.path.join(base_dir, "README.md"), 'r').read(),
long_description_content_type="text/markdown",
Expand Down
8 changes: 6 additions & 2 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,15 @@ DataTypeID Types::dataTypeIDFromString(const std::string& dataTypeIDString) {

std::string Types::dataTypeToString(const DataType& dataType) {
switch (dataType.typeID) {
case VAR_LIST:
case FIXED_LIST: {
case VAR_LIST: {
assert(dataType.childType);
return dataTypeToString(*dataType.childType) + "[]";
}
case FIXED_LIST: {
assert(dataType.childType);
return dataTypeToString(*dataType.childType) + "[" +
std::to_string(dataType.fixedNumElementsInList) + "]";
}
case ANY:
case NODE:
case REL:
Expand Down
1 change: 1 addition & 0 deletions src/include/main/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ KUZU_API struct SystemConfig {
class Database {
friend class EmbeddedShell;
friend class Connection;
friend class StorageDriver;
friend class kuzu::testing::BaseGraphTest;

public:
Expand Down
2 changes: 2 additions & 0 deletions src/include/main/query_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class QueryResult {
*/
KUZU_API void resetIterator();

processor::FactorizedTable* getTable() { return factorizedTable.get(); }

private:
void initResultTableAndIterator(std::shared_ptr<processor::FactorizedTable> factorizedTable_,
const std::vector<std::shared_ptr<binder::Expression>>& columns,
Expand Down
37 changes: 37 additions & 0 deletions src/include/main/storage_driver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include <queue>

#include "common/types/types_include.h"
#include "database.h"

namespace kuzu {
namespace storage {
class Column;
}

namespace main {

class StorageDriver {
public:
explicit StorageDriver(Database* database);

~StorageDriver();

void scan(const std::string& nodeName, const std::string& propertyName,
common::offset_t* offsets, size_t size, uint8_t* result, size_t numThreads);

uint64_t getNumNodes(const std::string& nodeName);
uint64_t getNumRels(const std::string& relName);

private:
void scanColumn(
storage::Column* column, common::offset_t* offsets, size_t size, uint8_t* result);

private:
catalog::Catalog* catalog;
storage::StorageManager* storageManager;
};

} // namespace main
} // namespace kuzu
3 changes: 3 additions & 0 deletions src/include/storage/storage_structure/column.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class Column : public BaseColumnOrList {
: Column(structureIDAndFName, dataType, common::Types::getDataTypeSize(dataType),
bufferManager, wal){};

// Expose for feature store
void scan(const common::offset_t* nodeOffsets, size_t size, uint8_t* result);

virtual void read(transaction::Transaction* transaction, common::ValueVector* nodeIDVector,
common::ValueVector* resultVector);

Expand Down
3 changes: 2 additions & 1 deletion src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ add_library(kuzu_main
plan_printer.cpp
prepared_statement.cpp
query_result.cpp
query_summary.cpp)
query_summary.cpp
storage_driver.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_main>
Expand Down
64 changes: 64 additions & 0 deletions src/main/storage_driver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "main/storage_driver.h"

#include "storage/storage_manager.h"

using namespace kuzu::common;

namespace kuzu {
namespace main {

StorageDriver::StorageDriver(kuzu::main::Database* database)
: catalog{database->catalog.get()}, storageManager{database->storageManager.get()} {}

StorageDriver::~StorageDriver() = default;

void StorageDriver::scan(const std::string& nodeName, const std::string& propertyName,
common::offset_t* offsets, size_t size, uint8_t* result, size_t numThreads) {
// Resolve files to read from
auto catalogContent = catalog->getReadOnlyVersion();
auto nodeTableID = catalogContent->getTableID(nodeName);
auto propertyID = catalogContent->getTableSchema(nodeTableID)->getPropertyID(propertyName);
auto nodeTable = storageManager->getNodesStore().getNodeTable(nodeTableID);
auto column = nodeTable->getPropertyColumn(propertyID);

auto current_buffer = result;
std::vector<std::thread> threads;
auto numElementsPerThread = size / numThreads + 1;
auto sizeLeft = size;
while (sizeLeft > 0) {
uint64_t sizeToRead = std::min(numElementsPerThread, sizeLeft);
threads.emplace_back(
&StorageDriver::scanColumn, this, column, offsets, sizeToRead, current_buffer);
offsets += sizeToRead;
current_buffer += sizeToRead * column->elementSize;
sizeLeft -= sizeToRead;
}
for (auto& thread : threads) {
thread.join();
}
}

uint64_t StorageDriver::getNumNodes(const std::string& nodeName) {
auto catalogContent = catalog->getReadOnlyVersion();
auto nodeTableID = catalogContent->getTableID(nodeName);
auto nodeStatistics = storageManager->getNodesStore()
.getNodesStatisticsAndDeletedIDs()
.getNodeStatisticsAndDeletedIDs(nodeTableID);
return nodeStatistics->getNumTuples();
}

uint64_t StorageDriver::getNumRels(const std::string& relName) {
auto catalogContent = catalog->getReadOnlyVersion();
auto relTableID = catalogContent->getTableID(relName);
auto relStatistics =
storageManager->getRelsStore().getRelsStatistics().getRelStatistics(relTableID);
return relStatistics->getNumTuples();
}

void StorageDriver::scanColumn(
storage::Column* column, common::offset_t* offsets, size_t size, uint8_t* result) {
column->scan(offsets, size, result);
}

} // namespace main
} // namespace kuzu
14 changes: 14 additions & 0 deletions src/storage/storage_structure/column.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ using namespace kuzu::transaction;
namespace kuzu {
namespace storage {

void Column::scan(const common::offset_t* nodeOffsets, size_t size, uint8_t* result) {
for (auto i = 0u; i < size; ++i) {
auto nodeOffset = nodeOffsets[i];
auto cursor = PageUtils::getPageElementCursorForPos(nodeOffset, numElementsPerPage);
auto [fileHandleToPin, pageIdxToPin] =
StorageStructureUtils::getFileHandleAndPhysicalPageIdxToPin(
*fileHandle, cursor.pageIdx, *wal, TransactionType::READ_ONLY);
bufferManager.optimisticRead(*fileHandleToPin, pageIdxToPin, [&](uint8_t* frame) -> void {
auto frameBytesOffset = getElemByteOffset(cursor.elemPosInPage);
memcpy(result + i * elementSize, frame + frameBytesOffset, elementSize);
});
}
}

void Column::read(Transaction* transaction, common::ValueVector* nodeIDVector,
common::ValueVector* resultVector) {
if (nodeIDVector->state->isFlat()) {
Expand Down
3 changes: 2 additions & 1 deletion test/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ add_kuzu_test(main_test
csv_output_test.cpp
exception_test.cpp
prepare_test.cpp
result_value_test.cpp)
result_value_test.cpp
storage_driver_test.cpp)
28 changes: 28 additions & 0 deletions test/main/storage_driver_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "main/storage_driver.h"
#include "main_test_helper/main_test_helper.h"

using namespace kuzu::testing;
using namespace kuzu::common;

TEST_F(ApiTest, StorageDriverScan) {
auto storageDriver = std::make_unique<StorageDriver>(database.get());
auto size = 6;
auto nodeOffsetsBuffer = std::make_unique<uint8_t[]>(sizeof(offset_t) * size);
auto nodeOffsets = (offset_t*)nodeOffsetsBuffer.get();
nodeOffsets[0] = 7;
nodeOffsets[1] = 0;
nodeOffsets[2] = 3;
nodeOffsets[3] = 1;
nodeOffsets[4] = 2;
nodeOffsets[5] = 6;
auto result = std::make_unique<uint8_t[]>(sizeof(int64_t) * size);
auto resultBuffer = (uint8_t*)result.get();
storageDriver->scan("person", "ID", nodeOffsets, size, resultBuffer, 3);
auto ids = (int64_t*)resultBuffer;
ASSERT_EQ(ids[0], 10);
ASSERT_EQ(ids[1], 0);
ASSERT_EQ(ids[2], 5);
ASSERT_EQ(ids[3], 2);
ASSERT_EQ(ids[4], 3);
ASSERT_EQ(ids[5], 9);
}
16 changes: 16 additions & 0 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "py_database.h"
#include "py_prepared_statement.h"
#include "py_query_result.h"
#include "main/storage_driver.h"

class PyConnection {

Expand All @@ -21,8 +22,22 @@ class PyConnection {

py::str getNodePropertyNames(const std::string& tableName);

py::str getNodeTableNames();

py::str getRelPropertyNames(const std::string& tableName);

py::str getRelTableNames();

PyPreparedStatement prepare(const std::string& query);

uint64_t getNumNodes(const std::string& nodeName);

uint64_t getNumRels(const std::string& relName);

void getAllEdgesForTorchGeometric(py::array_t<int64_t>& npArray,
const std::string& srcTableName, const std::string& relName,
const std::string& dstTableName, size_t queryBatchSize);

private:
std::unordered_map<std::string, std::shared_ptr<kuzu::common::Value>> transformPythonParameters(
py::list params);
Expand All @@ -33,5 +48,6 @@ class PyConnection {
kuzu::common::Value transformPythonValue(py::handle val);

private:
std::unique_ptr<StorageDriver> storageDriver;
std::unique_ptr<Connection> conn;
};
10 changes: 9 additions & 1 deletion tools/python_api/src_cpp/include/py_database.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#pragma once

#include "main/kuzu.h"
#include "pybind_include.h"

#include "pybind_include.h"
#include "main/storage_driver.h"
#define PYBIND11_DETAILED_ERROR_MESSAGES
using namespace kuzu::main;


class PyDatabase {
friend class PyConnection;

Expand All @@ -19,6 +22,11 @@ class PyDatabase {

~PyDatabase() = default;

template<class T>
void scanNodeTable(const std::string& tableName, const std::string& propName,
const py::array_t<uint64_t>& indices, py::array_t<T>& result, int numThreads);

private:
std::unique_ptr<Database> database;
std::unique_ptr<StorageDriver> storageDriver;
};
Loading