From 6cc936403f8e4223e203cddf2bafa6aa744bc043 Mon Sep 17 00:00:00 2001 From: Benjamin Winger Date: Tue, 19 Sep 2023 12:46:54 -0400 Subject: [PATCH] Use a different schema for each RecordBatch in python getAsArrow --- tools/python_api/requirements_dev.txt | 2 +- .../src_cpp/include/py_query_result.h | 8 ++++++-- tools/python_api/src_cpp/py_query_result.cpp | 19 +++++++------------ 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/tools/python_api/requirements_dev.txt b/tools/python_api/requirements_dev.txt index b1ee48444f..8a1f28b83d 100644 --- a/tools/python_api/requirements_dev.txt +++ b/tools/python_api/requirements_dev.txt @@ -3,7 +3,7 @@ pytest pandas networkx~=3.0.0 numpy -pyarrow~=10.0.0 +pyarrow~=13.0.0 torch-cluster~=1.6.0 torch-geometric~=2.2.0 torch-scatter~=2.1.0 diff --git a/tools/python_api/src_cpp/include/py_query_result.h b/tools/python_api/src_cpp/include/py_query_result.h index 00ccbcba00..1db3872a21 100644 --- a/tools/python_api/src_cpp/include/py_query_result.h +++ b/tools/python_api/src_cpp/include/py_query_result.h @@ -1,5 +1,8 @@ #pragma once +#include +#include + #include "arrow_array.h" #include "common/arrow/arrow.h" #include "common/types/internal_id_t.h" @@ -49,8 +52,9 @@ class PyQueryResult { private: static py::dict convertNodeIdToPyDict(const kuzu::common::nodeID_t& nodeId); - bool getNextArrowChunk(const ArrowSchema& schema, py::list& batches, std::int64_t chunk_size); - py::object getArrowChunks(const ArrowSchema& schema, std::int64_t chunkSize); + bool getNextArrowChunk(const std::vector>& typesInfo, py::list& batches, std::int64_t chunk_size); + py::object getArrowChunks( + const std::vector>& typesInfo, std::int64_t chunkSize); private: std::unique_ptr queryResult; diff --git a/tools/python_api/src_cpp/py_query_result.cpp b/tools/python_api/src_cpp/py_query_result.cpp index 2dedafd4b3..53e0ff78a7 100644 --- a/tools/python_api/src_cpp/py_query_result.cpp +++ b/tools/python_api/src_cpp/py_query_result.cpp @@ -191,7 +191,7 @@ py::object PyQueryResult::getAsDF() { } bool PyQueryResult::getNextArrowChunk( - const ArrowSchema& schema, py::list& batches, std::int64_t chunkSize) { + const std::vector>& typesInfo, py::list& batches, std::int64_t chunkSize) { if (!queryResult->hasNext()) { return false; } @@ -200,14 +200,16 @@ bool PyQueryResult::getNextArrowChunk( auto pyarrowLibModule = py::module::import("pyarrow").attr("lib"); auto batchImportFunc = pyarrowLibModule.attr("RecordBatch").attr("_import_from_c"); - batches.append(batchImportFunc((std::uint64_t)&data, (std::uint64_t)&schema)); + auto schema = ArrowConverter::toArrowSchema(typesInfo); + batches.append(batchImportFunc((std::uint64_t)&data, (std::uint64_t)schema.get())); return true; } -py::object PyQueryResult::getArrowChunks(const ArrowSchema& schema, std::int64_t chunkSize) { +py::object PyQueryResult::getArrowChunks( + const std::vector>& typesInfo, std::int64_t chunkSize) { auto pyarrowLibModule = py::module::import("pyarrow").attr("lib"); py::list batches; - while (getNextArrowChunk(schema, batches, chunkSize)) {} + while (getNextArrowChunk(typesInfo, batches, chunkSize)) {} return std::move(batches); } @@ -219,16 +221,9 @@ kuzu::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize) { auto schemaImportFunc = pyarrowLibModule.attr("Schema").attr("_import_from_c"); auto typesInfo = queryResult->getColumnTypesInfo(); + py::list batches = getArrowChunks(typesInfo, chunkSize); auto schema = ArrowConverter::toArrowSchema(typesInfo); - // Prevent arrow from releasing the schema until it gets passed to the table - // It seems like you are expected to pass a new schema for each RecordBatch - auto release = schema->release; - schema->release = [](ArrowSchema*) {}; - - py::list batches = getArrowChunks(*schema, chunkSize); auto schemaObj = schemaImportFunc((std::uint64_t)schema.get()); - - schema->release = release; return py::cast(fromBatchesFunc(batches, schemaObj)); }