Skip to content

Commit

Permalink
Merge pull request #2019 from kuzudb/arrow-schema-fix
Browse files Browse the repository at this point in the history
Arrow schema fixes
  • Loading branch information
andyfengHKU committed Sep 12, 2023
2 parents c578381 + 2dcc60c commit 32f38fb
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 10 deletions.
24 changes: 15 additions & 9 deletions src/common/arrow/arrow_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,25 @@ namespace kuzu {
namespace common {

static void releaseArrowSchema(ArrowSchema* schema) {
if (!schema || schema->release) {
if (!schema || !schema->release) {
return;
}
schema->release = nullptr;
auto holder = static_cast<ArrowSchemaHolder*>(schema->private_data);
delete holder;
}

// Copies the given string into the arrow holder's owned names and returns a pointer to the owned
// version
static const char* copyName(ArrowSchemaHolder& rootHolder, const std::string& name) {
auto length = name.length();
std::unique_ptr<char[]> namePtr = std::make_unique<char[]>(length + 1);
std::memcpy(namePtr.get(), name.c_str(), length);
namePtr[length] = '\0';
rootHolder.ownedTypeNames.push_back(std::move(namePtr));
return rootHolder.ownedTypeNames.back().get();
}

void ArrowConverter::initializeChild(ArrowSchema& child, const std::string& name) {
//! Child is cleaned up by parent
child.private_data = nullptr;
Expand Down Expand Up @@ -44,13 +55,7 @@ void ArrowConverter::setArrowFormatForStruct(
child.children = &rootHolder.nestedChildrenPtr.back()[0];
for (auto i = 0u; i < child.n_children; i++) {
initializeChild(*child.children[i]);
auto structFieldName = childrenTypesInfo[i]->name;
auto structFieldNameLength = structFieldName.length();
std::unique_ptr<char[]> namePtr = std::make_unique<char[]>(structFieldNameLength + 1);
std::memcpy(namePtr.get(), structFieldName.c_str(), structFieldNameLength);
namePtr[structFieldNameLength] = '\0';
rootHolder.ownedTypeNames.push_back(std::move(namePtr));
child.children[i]->name = rootHolder.ownedTypeNames.back().get();
child.children[i]->name = copyName(rootHolder, childrenTypesInfo[i]->name);
setArrowFormat(rootHolder, *child.children[i], *childrenTypesInfo[i]);
}
}
Expand Down Expand Up @@ -130,7 +135,8 @@ std::unique_ptr<ArrowSchema> ArrowConverter::toArrowSchema(

for (auto i = 0u; i < columnCount; i++) {
auto& child = rootHolder->children[i];
initializeChild(child, typesInfo[i]->name);
initializeChild(child);
child.name = copyName(*rootHolder, typesInfo[i]->name);
setArrowFormat(*rootHolder, child, *typesInfo[i]);
}

Expand Down
1 change: 1 addition & 0 deletions test/main/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_kuzu_test(main_test
arrow_test.cpp
config_test.cpp
connection_test.cpp
csv_output_test.cpp
Expand Down
27 changes: 27 additions & 0 deletions test/main/arrow_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "main_test_helper/main_test_helper.h"

using namespace kuzu::common;
using namespace kuzu::testing;

class ArrowTest : public ApiTest {};

TEST_F(ArrowTest, getArrowResult) {
auto query = "MATCH (a:person) WHERE a.fName = 'Bob' RETURN a.fName";
auto result = conn->query(query);
auto arrowArray = result->getNextArrowChunk(1);
ASSERT_EQ(arrowArray->length, 1);
ASSERT_EQ(arrowArray->null_count, 0);
ASSERT_EQ(arrowArray->n_children, 1);
// FIXME: Not sure where the length of the string is stored
ASSERT_EQ(std::string((const char*)arrowArray->children[0]->buffers[2], 3), "Bob");
arrowArray->release(arrowArray.get());
}

TEST_F(ArrowTest, getArrowSchema) {
auto query = "MATCH (a:person) RETURN a.fName as NAME";
auto result = conn->query(query);
auto schema = result->getArrowSchema();
ASSERT_EQ(schema->n_children, 1);
ASSERT_EQ(std::string(schema->children[0]->name), "NAME");
schema->release(schema.get());
}
9 changes: 8 additions & 1 deletion tools/python_api/src_cpp/py_query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,15 @@ kuzu::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize) {

auto typesInfo = queryResult->getColumnTypesInfo();
auto schema = ArrowConverter::toArrowSchema(typesInfo);
auto schemaObj = schemaImportFunc((std::uint64_t)schema.get());
// Prevent arrow from releasing the schema until it gets passed to the table
// It seems like you are expected to pass a new schema for each RecordBatch
auto release = schema->release;
schema->release = [](ArrowSchema*) {};

py::list batches = getArrowChunks(*schema, chunkSize);
auto schemaObj = schemaImportFunc((std::uint64_t)schema.get());

schema->release = release;
return py::cast<kuzu::pyarrow::Table>(fromBatchesFunc(batches, schemaObj));
}

Expand Down

0 comments on commit 32f38fb

Please sign in to comment.