Skip to content

Commit

Permalink
Merge pull request #2055 from kuzudb/python-arrow-fix
Browse files Browse the repository at this point in the history
Use a different schema for each RecordBatch in python getAsArrow
  • Loading branch information
benjaminwinger committed Sep 19, 2023
2 parents f822c56 + 6cc9364 commit 6943574
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 15 deletions.
2 changes: 1 addition & 1 deletion tools/python_api/requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pytest
pandas
networkx~=3.0.0
numpy
pyarrow~=10.0.0
pyarrow~=13.0.0
torch-cluster~=1.6.0
torch-geometric~=2.2.0
torch-scatter~=2.1.0
Expand Down
8 changes: 6 additions & 2 deletions tools/python_api/src_cpp/include/py_query_result.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

#include <vector>
#include <memory>

#include "arrow_array.h"
#include "common/arrow/arrow.h"
#include "common/types/internal_id_t.h"
Expand Down Expand Up @@ -49,8 +52,9 @@ class PyQueryResult {
private:
static py::dict convertNodeIdToPyDict(const kuzu::common::nodeID_t& nodeId);

bool getNextArrowChunk(const ArrowSchema& schema, py::list& batches, std::int64_t chunk_size);
py::object getArrowChunks(const ArrowSchema& schema, std::int64_t chunkSize);
bool getNextArrowChunk(const std::vector<std::unique_ptr<DataTypeInfo>>& typesInfo, py::list& batches, std::int64_t chunk_size);
py::object getArrowChunks(
const std::vector<std::unique_ptr<DataTypeInfo>>& typesInfo, std::int64_t chunkSize);

private:
std::unique_ptr<QueryResult> queryResult;
Expand Down
19 changes: 7 additions & 12 deletions tools/python_api/src_cpp/py_query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ py::object PyQueryResult::getAsDF() {
}

bool PyQueryResult::getNextArrowChunk(
const ArrowSchema& schema, py::list& batches, std::int64_t chunkSize) {
const std::vector<std::unique_ptr<DataTypeInfo>>& typesInfo, py::list& batches, std::int64_t chunkSize) {
if (!queryResult->hasNext()) {
return false;
}
Expand All @@ -200,14 +200,16 @@ bool PyQueryResult::getNextArrowChunk(

auto pyarrowLibModule = py::module::import("pyarrow").attr("lib");
auto batchImportFunc = pyarrowLibModule.attr("RecordBatch").attr("_import_from_c");
batches.append(batchImportFunc((std::uint64_t)&data, (std::uint64_t)&schema));
auto schema = ArrowConverter::toArrowSchema(typesInfo);
batches.append(batchImportFunc((std::uint64_t)&data, (std::uint64_t)schema.get()));
return true;
}

py::object PyQueryResult::getArrowChunks(const ArrowSchema& schema, std::int64_t chunkSize) {
py::object PyQueryResult::getArrowChunks(
const std::vector<std::unique_ptr<DataTypeInfo>>& typesInfo, std::int64_t chunkSize) {
auto pyarrowLibModule = py::module::import("pyarrow").attr("lib");
py::list batches;
while (getNextArrowChunk(schema, batches, chunkSize)) {}
while (getNextArrowChunk(typesInfo, batches, chunkSize)) {}
return std::move(batches);
}

Expand All @@ -219,16 +221,9 @@ kuzu::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize) {
auto schemaImportFunc = pyarrowLibModule.attr("Schema").attr("_import_from_c");

auto typesInfo = queryResult->getColumnTypesInfo();
py::list batches = getArrowChunks(typesInfo, chunkSize);
auto schema = ArrowConverter::toArrowSchema(typesInfo);
// Prevent arrow from releasing the schema until it gets passed to the table
// It seems like you are expected to pass a new schema for each RecordBatch
auto release = schema->release;
schema->release = [](ArrowSchema*) {};

py::list batches = getArrowChunks(*schema, chunkSize);
auto schemaObj = schemaImportFunc((std::uint64_t)schema.get());

schema->release = release;
return py::cast<kuzu::pyarrow::Table>(fromBatchesFunc(batches, schemaObj));
}

Expand Down

0 comments on commit 6943574

Please sign in to comment.