From 4a5ab8c19353819b0986c711611952229a9a0379 Mon Sep 17 00:00:00 2001 From: mxwli Date: Tue, 9 Apr 2024 17:53:00 -0400 Subject: [PATCH 01/12] coverage --- src/common/arrow/arrow_converter.cpp | 6 +-- src/common/arrow/arrow_row_batch.cpp | 9 ++++ tools/python_api/test/ground_truth.py | 68 ++++++++++++++++++--------- tools/python_api/test/test_arrow.py | 58 +++++++++++++---------- 4 files changed, 92 insertions(+), 49 deletions(-) diff --git a/src/common/arrow/arrow_converter.cpp b/src/common/arrow/arrow_converter.cpp index fcff47e499..b34403701b 100644 --- a/src/common/arrow/arrow_converter.cpp +++ b/src/common/arrow/arrow_converter.cpp @@ -76,10 +76,10 @@ void ArrowConverter::setArrowFormatForInternalID(ArrowSchemaHolder& rootHolder, } child.children = &rootHolder.nestedChildrenPtr.back()[0]; initializeChild(*child.children[0]); - child.children[0]->name = copyName(rootHolder, "table"); + child.children[0]->name = copyName(rootHolder, "offset"); setArrowFormat(rootHolder, *child.children[0], *LogicalType::INT64()); initializeChild(*child.children[1]); - child.children[1]->name = copyName(rootHolder, "offset"); + child.children[1]->name = copyName(rootHolder, "table"); setArrowFormat(rootHolder, *child.children[1], *LogicalType::INT64()); } @@ -142,7 +142,7 @@ void ArrowConverter::setArrowFormat(ArrowSchemaHolder& rootHolder, ArrowSchema& child.format = "tsu:"; } break; case LogicalTypeID::INTERVAL: { - child.format = "tDm"; + child.format = "tDu"; } break; case LogicalTypeID::UUID: case LogicalTypeID::STRING: { diff --git a/src/common/arrow/arrow_row_batch.cpp b/src/common/arrow/arrow_row_batch.cpp index 5e8208e171..e156680a1a 100644 --- a/src/common/arrow/arrow_row_batch.cpp +++ b/src/common/arrow/arrow_row_batch.cpp @@ -239,6 +239,15 @@ void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, const LogicalT std::memcpy(vector->data.data() + pos * valSize, &value->val, valSize); } +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const LogicalType& /*type*/, Value* value, std::int64_t pos) { + auto destAddr = (int64_t*)(vector->data.data() + pos * sizeof(std::int64_t)); + auto intervalVal = value->val.intervalVal; + *destAddr = intervalVal.micros + intervalVal.days * Interval::MICROS_PER_DAY + + intervalVal.months * Interval::MICROS_PER_MONTH; +} + template<> void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, const LogicalType& /*type*/, Value* value, std::int64_t pos) { diff --git a/tools/python_api/test/ground_truth.py b/tools/python_api/test/ground_truth.py index a587b288c8..ea2270995b 100644 --- a/tools/python_api/test/ground_truth.py +++ b/tools/python_api/test/ground_truth.py @@ -15,8 +15,11 @@ "workedHours": [10, 5], "usedNames": ["Aida"], "courseScoresPerTerm": [[10, 8], [6, 7, 8]], - "_label": "person", - "_id": {"offset": 0, "table": 0}, + "grades": [96, 54, 86, 92], + "height": 1.731, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", + "_LABEL": "person", + "_ID": {"offset": 0, "table": 0}, }, 2: { "ID": 2, @@ -32,8 +35,11 @@ "workedHours": [12, 8], "usedNames": ["Bobby"], "courseScoresPerTerm": [[8, 9], [9, 10]], - "_label": "person", - "_id": {"offset": 1, "table": 0}, + "grades": [98, 42, 93, 88], + "height": 0.99, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a12", + "_LABEL": "person", + "_ID": {"offset": 1, "table": 0}, }, 3: { "ID": 3, @@ -49,8 +55,11 @@ "workedHours": [4, 5], "usedNames": ["Carmen", "Fred"], "courseScoresPerTerm": [[8, 10]], - "_label": "person", - "_id": {"offset": 2, "table": 0}, + "grades": [91, 75, 21, 95], + "height": 1.00, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a13", + "_LABEL": "person", + "_ID": {"offset": 2, "table": 0}, }, 5: { "ID": 5, @@ -66,8 +75,11 @@ "workedHours": [1, 9], "usedNames": ["Wolfeschlegelstein", "Daniel"], "courseScoresPerTerm": [[7, 4], [8, 8], [9]], - "_label": "person", - "_id": {"offset": 3, "table": 0}, + "grades": [76, 88, 99, 89], + "height": 1.30, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a14", + "_LABEL": "person", + "_ID": {"offset": 3, "table": 0}, }, 7: { "ID": 7, @@ -83,8 +95,11 @@ "workedHours": [2], "usedNames": ["Ein"], "courseScoresPerTerm": [[6], [7], [8]], - "_label": "person", - "_id": {"offset": 4, "table": 0}, + "grades": [96, 59, 65, 88], + "height": 1.463, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a15", + "_LABEL": "person", + "_ID": {"offset": 4, "table": 0}, }, 8: { "ID": 8, @@ -100,8 +115,11 @@ "workedHours": [3, 4, 5, 6, 7], "usedNames": ["Fesdwe"], "courseScoresPerTerm": [[8]], - "_label": "person", - "_id": {"offset": 5, "table": 0}, + "grades": [80, 78, 34, 83], + "height": 1.51, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a16", + "_LABEL": "person", + "_ID": {"offset": 5, "table": 0}, }, 9: { "ID": 9, @@ -117,8 +135,11 @@ "workedHours": [1], "usedNames": ["Grad"], "courseScoresPerTerm": [[10]], - "_label": "person", - "_id": {"offset": 6, "table": 0}, + "grades": [43, 83, 67, 43], + "height": 1.6, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a17", + "_LABEL": "person", + "_ID": {"offset": 6, "table": 0}, }, 10: { "ID": 10, @@ -134,8 +155,11 @@ "workedHours": [10, 11, 12, 3, 4, 5, 6, 7], "usedNames": ["Ad", "De", "Hi", "Kye", "Orlan"], "courseScoresPerTerm": [[7], [10], [6, 7]], - "_label": "person", - "_id": {"offset": 7, "table": 0}, + "grades": [77, 64, 100, 54], + "height": 1.323, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a18", + "_LABEL": "person", + "_ID": {"offset": 7, "table": 0}, }, } @@ -149,8 +173,8 @@ "history": "10 years 5 months 13 hours 24 us", "licenseValidInterval": timedelta(days=1085), "rating": 1.0, - "_label": "organisation", - "_id": {"offset": 0, "table": 2}, + "_LABEL": "organisation", + "_ID": {"offset": 0, "table": 2}, }, 4: { "ID": 4, @@ -161,8 +185,8 @@ "history": "2 years 4 days 10 hours", "licenseValidInterval": timedelta(days=9414), "rating": 0.78, - "_label": "organisation", - "_id": {"offset": 1, "table": 2}, + "_LABEL": "organisation", + "_ID": {"offset": 1, "table": 2}, }, 6: { "ID": 6, @@ -173,8 +197,8 @@ "history": "2 years 4 hours 22 us 34 minutes", "licenseValidInterval": timedelta(days=3, seconds=36000, microseconds=100000), "rating": 0.52, - "_label": "organisation", - "_id": {"offset": 2, "table": 2}, + "_LABEL": "organisation", + "_ID": {"offset": 2, "table": 2}, }, } diff --git a/tools/python_api/test/test_arrow.py b/tools/python_api/test/test_arrow.py index e579c034f4..6164be8238 100644 --- a/tools/python_api/test/test_arrow.py +++ b/tools/python_api/test/test_arrow.py @@ -3,6 +3,7 @@ from datetime import date, datetime, timedelta from decimal import Decimal from typing import Any +import math import ground_truth import kuzu @@ -29,7 +30,7 @@ "a.eyeSight": {"arrow": pa.float64(), "pl": pl.Float64}, "a.birthdate": {"arrow": pa.date32(), "pl": pl.Date}, "a.registerTime": {"arrow": pa.timestamp("us"), "pl": pl.Datetime("us")}, - "a.lastJobDuration": {"arrow": pa.duration("ms"), "pl": pl.Duration("ms")}, + "a.lastJobDuration": {"arrow": pa.duration("us"), "pl": pl.Duration("us")}, "a.workedHours": {"arrow": pa.list_(pa.int64()), "pl": pl.List(pl.Int64)}, "a.usedNames": {"arrow": pa.list_(pa.string()), "pl": pl.List(pl.String)}, "a.courseScoresPerTerm": {"arrow": pa.list_(pa.list_(pa.int64())), "pl": pl.List(pl.List(pl.Int64))}, @@ -209,14 +210,14 @@ def _test_person_table(_conn: kuzu.Connection, return_type: str, chunk_size: int col_name="a.lastJobDuration", return_type=return_type, expected_values=[ - timedelta(days=99, seconds=36334, microseconds=628000), - timedelta(days=543, seconds=4800), - timedelta(microseconds=125000), - timedelta(days=541, seconds=57600, microseconds=24000), - timedelta(0), - timedelta(days=2016, seconds=68600), - timedelta(microseconds=125000), - timedelta(days=541, seconds=57600, microseconds=24000), + timedelta(days=1082, seconds=46920), + timedelta(days=3750, seconds=46800, microseconds=24), + timedelta(days=2, seconds=1451), + timedelta(days=3750, seconds=46800, microseconds=24), + timedelta(days=2, seconds=1451), + timedelta(seconds=1080, microseconds=24000), + timedelta(days=3750, seconds=46800, microseconds=24), + timedelta(days=1082, seconds=46920), ], ) @@ -470,13 +471,22 @@ def _test_with_nulls(_conn: kuzu.Connection, return_type: str, chunk_size: int | def test_to_arrow_complex(conn_db_readonly: ConnDB) -> None: conn, db = conn_db_readonly + def _test_node_helper(srcStruct, dstStruct): + assert srcStruct.keys() == dstStruct.keys() + for key in srcStruct.keys(): + if type(srcStruct[key]) is float: + assert math.fabs(srcStruct[key] - dstStruct[key]) < 1e-5 + else: + assert srcStruct[key] == dstStruct[key] + + def _test_node(_conn: kuzu.Connection) -> None: query = "MATCH (p:person) RETURN p ORDER BY p.ID" query_result = _conn.execute(query) arrow_tbl = query_result.get_as_arrow() p_col = arrow_tbl.column(0) - assert p_col.to_pylist() == [ + for a, b in zip(p_col.to_pylist(), [ ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[0], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[2], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[3], @@ -484,8 +494,8 @@ def _test_node(_conn: kuzu.Connection) -> None: ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[7], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[8], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[9], - ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[10], - ] + ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[10]]): + _test_node_helper(a, b) def _test_node_rel(_conn: kuzu.Connection) -> None: query = "MATCH (a:person)-[e:workAt]->(b:organisation) RETURN a, e, b;" @@ -498,11 +508,11 @@ def _test_node_rel(_conn: kuzu.Connection) -> None: assert len(a_col) == 3 b_col = arrow_tbl.column(2) assert len(a_col) == 3 - assert a_col.to_pylist() == [ + for a, b in zip(p_col.to_pylist(), [ ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[3], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[5], - ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[7], - ] + ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[7]]): + _test_node_helper(a, b) assert e_col.to_pylist() == [ { "_src": {"offset": 2, "tableID": 0}, @@ -523,11 +533,11 @@ def _test_node_rel(_conn: kuzu.Connection) -> None: "year": 2015, }, ] - assert b_col.to_pylist() == [ - ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[4], - ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[6], - ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[6], - ] + for a, b in zip(b_col.to_pylist(), [ + ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[4], + ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[6], + ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[6]]): + _test_node_helper(a, b) def _test_marries_table(_conn: kuzu.Connection) -> None: query = "MATCH (:person)-[e:marries]->(:person) RETURN e.*" @@ -549,12 +559,12 @@ def _test_marries_table(_conn: kuzu.Connection) -> None: assert len(used_addr_col) == 3 assert used_addr_col.to_pylist() == [None, "long long long string", "short str"] - # _test_node(conn) + _test_node(conn) # _test_node_rel(conn) # _test_marries_table(conn) - def test_to_arrow1(conn_db_readonly: ConnDB) -> None: - conn, db = conn_db_readonly + def test_to_arrow1(conn: kuzu.Connection) -> None: query = "MATCH (a:person)-[e:knows]->(:person) RETURN e.summary" + res = conn.execute(query) arrow_tbl = conn.execute(query).get_as_arrow(-1) - assert arrow_tbl == [] + assert arrow_tbl == [] \ No newline at end of file From 02544c88abc8a1002c58ab8020535055bbd97fa4 Mon Sep 17 00:00:00 2001 From: mxwli Date: Wed, 10 Apr 2024 15:00:05 -0400 Subject: [PATCH 02/12] progress --- Makefile | 4 +- src/common/arrow/arrow_converter.cpp | 22 ++++++ src/common/arrow/arrow_row_batch.cpp | 79 +++++++++++++++++++++- src/common/types/value/rel.cpp | 5 ++ src/include/common/arrow/arrow_converter.h | 2 + src/include/common/arrow/arrow_row_batch.h | 1 + src/include/common/types/value/rel.h | 4 ++ tools/python_api/test/test_arrow.py | 2 +- 8 files changed, 113 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 1eeae88d6a..eb975ad0d1 100644 --- a/Makefile +++ b/Makefile @@ -111,7 +111,7 @@ nodejs: $(call run-cmake-release, -DBUILD_NODEJS=TRUE) python: - $(call run-cmake-release, -DBUILD_PYTHON=TRUE) + $(call run-cmake-debug, -DBUILD_PYTHON=TRUE) rust: ifeq ($(OS),Windows_NT) @@ -142,7 +142,7 @@ nodejstest: nodejs cd tools/nodejs_api && npm test pytest: python - cmake -E env PYTHONPATH=tools/python_api/build python3 -m pytest -v tools/python_api/test + cmake -E env PYTHONPATH=tools/python_api/build python3 -m pytest -vv tools/python_api/test rusttest: rust cd tools/rust_api && cargo test --locked --all-features -- --test-threads=1 diff --git a/src/common/arrow/arrow_converter.cpp b/src/common/arrow/arrow_converter.cpp index b34403701b..3e55efd6e7 100644 --- a/src/common/arrow/arrow_converter.cpp +++ b/src/common/arrow/arrow_converter.cpp @@ -62,6 +62,25 @@ void ArrowConverter::setArrowFormatForStruct(ArrowSchemaHolder& rootHolder, Arro } } +void ArrowConverter::setArrowFormatForUnion(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& dataType) { + std::string formatStr = "+us"; + child.n_children = (std::int64_t)UnionType::getNumFields(&dataType); + rootHolder.nestedChildren.emplace_back(); + rootHolder.nestedChildren.back().resize(child.n_children); + rootHolder.nestedChildrenPtr.emplace_back(); + rootHolder.nestedChildrenPtr.back().resize(child.n_children); + for (auto i = 0u; i < child.n_children; i++) { + initializeChild(*child.children[i]); + auto unionFieldType = UnionType::getFieldType(&dataType, i); + auto unionFieldName = UnionType::getFieldName(&dataType, i); + child.children[i]->name = copyName(rootHolder, unionFieldName); + setArrowFormat(rootHolder, *child.children[i], *unionFieldType); + formatStr += std::string(child.children[i]->format); + } + child.format = copyName(rootHolder, formatStr); +} + void ArrowConverter::setArrowFormatForInternalID(ArrowSchemaHolder& rootHolder, ArrowSchema& child, const LogicalType& /*dataType*/) { child.format = "+s"; @@ -181,6 +200,9 @@ void ArrowConverter::setArrowFormat(ArrowSchemaHolder& rootHolder, ArrowSchema& case LogicalTypeID::INTERNAL_ID: setArrowFormatForInternalID(rootHolder, child, dataType); break; + case LogicalTypeID::UNION: + setArrowFormatForUnion(rootHolder, child, dataType); + break; default: KU_UNREACHABLE; } diff --git a/src/common/arrow/arrow_row_batch.cpp b/src/common/arrow/arrow_row_batch.cpp index e156680a1a..75d6712524 100644 --- a/src/common/arrow/arrow_row_batch.cpp +++ b/src/common/arrow/arrow_row_batch.cpp @@ -76,6 +76,19 @@ void ArrowRowBatch::templateInitializeVector(ArrowVector* initializeStructVector(vector, type, capacity); } +template<> +void ArrowRowBatch::templateInitializeVector(ArrowVector* vector, + const LogicalType& type, std::int64_t capacity) { + // Interestingly, unions don't have their own validity bitmap + // Initialize type buffer + vector->data.reserve((capacity) * sizeof(std::uint8_t)); + // Initialize children + for (auto i = 0u; i < UnionType::getNumFields(&type); i++) { + auto childVector = createVector(*UnionType::getFieldType(&type, i), capacity); + vector->childData.push_back(std::move(childVector)); + } +} + void ArrowRowBatch::initializeStructVector(ArrowVector* vector, const LogicalType& type, std::int64_t capacity) { initializeNullBits(vector->validity, capacity); @@ -187,6 +200,9 @@ std::unique_ptr ArrowRowBatch::createVector(const LogicalType& type case LogicalTypeID::STRUCT: { templateInitializeVector(result.get(), type, capacity); } break; + case LogicalTypeID::UNION: { + templateInitializeVector(result.get(), type, capacity); + } break; case LogicalTypeID::INTERNAL_ID: { templateInitializeVector(result.get(), type, capacity); } break; @@ -348,6 +364,16 @@ void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* } } +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const LogicalType& type, Value* value, std::int64_t pos) { + auto childType = value->children[0]->val.uint8Val; + auto typeBuffer = (std::uint8_t*)vector->data.data(); + typeBuffer[pos] = childType; + appendValue(vector->childData[childType].get(), *UnionType::getFieldType(&type, childType), + value->children[childType + 1].get()); +} + template<> void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, const LogicalType& /*type*/, Value* value, std::int64_t /*pos*/) { @@ -382,10 +408,15 @@ void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* ve RelVal::getSrcNodeIDVal(value)); appendValue(vector->childData[1].get(), *StructType::getFieldTypes(&type)[1], RelVal::getDstNodeIDVal(value)); - std::int64_t propertyId = 2; - auto numProperties = NodeVal::getNumProperties(value); + appendValue(vector->childData[2].get(), *StructType::getFieldTypes(&type)[2], + RelVal::getLabelVal(value)); + appendValue(vector->childData[3].get(), *StructType::getFieldTypes(&type)[3], + RelVal::getIDVal(value)); + std::int64_t propertyId = 4; + auto numProperties = RelVal::getNumProperties(value); for (auto i = 0u; i < numProperties; i++) { - auto val = NodeVal::getPropertyVal(value, i); + auto name = RelVal::getPropertyName(value, i); + auto val = RelVal::getPropertyVal(value, i); appendValue(vector->childData[propertyId].get(), *StructType::getFieldTypes(&type)[propertyId], val); propertyId++; @@ -467,6 +498,9 @@ void ArrowRowBatch::copyNonNullValue(ArrowVector* vector, const LogicalType& typ case LogicalTypeID::STRUCT: { templateCopyNonNullValue(vector, type, value, pos); } break; + case LogicalTypeID::UNION: { + templateCopyNonNullValue(vector, type, value, pos); + } break; case LogicalTypeID::INTERNAL_ID: { templateCopyNonNullValue(vector, type, value, pos); } break; @@ -521,6 +555,14 @@ void ArrowRowBatch::templateCopyNullValue(ArrowVector* ve vector->numNulls++; } +void ArrowRowBatch::copyNullValueUnion(ArrowVector* vector, Value* value, + std::int64_t pos) { + auto typeBuffer = (std::uint8_t*)vector->data.data(); + typeBuffer[pos] = 0; + copyNullValue(vector->childData[0].get(), value->children[0].get(), pos); + vector->numNulls++; +} + void ArrowRowBatch::copyNullValue(ArrowVector* vector, Value* value, std::int64_t pos) { switch (value->dataType->getLogicalTypeID()) { case LogicalTypeID::BOOL: { @@ -596,6 +638,9 @@ void ArrowRowBatch::copyNullValue(ArrowVector* vector, Value* value, std::int64_ case LogicalTypeID::STRUCT: { templateCopyNullValue(vector, pos); } break; + case LogicalTypeID::UNION: { + copyNullValueUnion(vector, value, pos); + } break; case LogicalTypeID::NODE: { templateCopyNullValue(vector, pos); } break; @@ -716,6 +761,31 @@ ArrowArray* ArrowRowBatch::convertInternalIDVectorToArray(ArrowVector& vector, return vector.array.get(); } +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type) { + //since union is a special case, we make the ArrowArray ourselves instead of using createArrayFromVector + auto nChildren = UnionType::getNumFields(&type); + vector.array = std::make_unique(); + vector.array->private_data = nullptr; + vector.array->release = releaseArrowVector; + vector.array->n_children = nChildren; + vector.array->children = vector.childPointers.data(); + vector.array->offset = 0; + vector.array->dictionary = nullptr; + vector.array->buffers = vector.buffers.data(); + vector.array->null_count = vector.numNulls; + vector.array->length = vector.numValues; + vector.array->n_buffers = 1; + vector.array->buffers[0] = vector.data.data(); + vector.childPointers.resize(nChildren); + for (auto i = 0u; i < nChildren; i++) { + auto childType = UnionType::getFieldType(&type, i); + vector.childPointers[i] = convertVectorToArray(*vector.childData[i], *childType); + } + return vector.array.get(); +} + template<> ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, const LogicalType& type) { @@ -806,6 +876,9 @@ ArrowArray* ArrowRowBatch::convertVectorToArray(ArrowVector& vector, const Logic case LogicalTypeID::STRUCT: { return templateCreateArray(vector, type); } + case LogicalTypeID::UNION: { + return templateCreateArray(vector, type); + } case LogicalTypeID::INTERNAL_ID: { return templateCreateArray(vector, type); } diff --git a/src/common/types/value/rel.cpp b/src/common/types/value/rel.cpp index e7b3bfefaf..89e1d0a8a3 100644 --- a/src/common/types/value/rel.cpp +++ b/src/common/types/value/rel.cpp @@ -48,6 +48,11 @@ Value* RelVal::getPropertyVal(const Value* val, uint64_t index) { return val->children[index + OFFSET].get(); } +Value* RelVal::getIDVal(const Value* val) { + auto fieldIdx = StructType::getFieldIdx(val->dataType.get(), InternalKeyword::ID); + return val->children[fieldIdx].get(); +} + Value* RelVal::getSrcNodeIDVal(const Value* val) { auto fieldIdx = StructType::getFieldIdx(val->dataType.get(), InternalKeyword::SRC); return val->children[fieldIdx].get(); diff --git a/src/include/common/arrow/arrow_converter.h b/src/include/common/arrow/arrow_converter.h index 5e1126663d..fdbeb15b63 100644 --- a/src/include/common/arrow/arrow_converter.h +++ b/src/include/common/arrow/arrow_converter.h @@ -38,6 +38,8 @@ struct ArrowConverter { static void initializeChild(ArrowSchema& child, const std::string& name = ""); static void setArrowFormatForStruct(ArrowSchemaHolder& rootHolder, ArrowSchema& child, const LogicalType& dataType); + static void setArrowFormatForUnion(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& dataType); static void setArrowFormatForInternalID(ArrowSchemaHolder& rootHolder, ArrowSchema& child, const LogicalType& dataType); static void setArrowFormat(ArrowSchemaHolder& rootHolder, ArrowSchema& child, diff --git a/src/include/common/arrow/arrow_row_batch.h b/src/include/common/arrow/arrow_row_batch.h index 7e45278c32..116d1ad987 100644 --- a/src/include/common/arrow/arrow_row_batch.h +++ b/src/include/common/arrow/arrow_row_batch.h @@ -77,6 +77,7 @@ class ArrowRowBatch { std::int64_t pos); template static void templateCopyNullValue(ArrowVector* vector, std::int64_t pos); + static void copyNullValueUnion(ArrowVector* vector, Value* value, std::int64_t pos); template static ArrowArray* templateCreateArray(ArrowVector& vector, const LogicalType& type); diff --git a/src/include/common/types/value/rel.h b/src/include/common/types/value/rel.h index 9eaf0fda34..a13d9af78c 100644 --- a/src/include/common/types/value/rel.h +++ b/src/include/common/types/value/rel.h @@ -46,6 +46,10 @@ class RelVal { * @return the dst nodeID value of the RelVal in Value. */ KUZU_API static Value* getDstNodeIDVal(const Value* val); + /** + * @return the internal ID value of the RelVal in Value. + */ + KUZU_API static Value* getIDVal(const Value* val); /** * @return the label value of the RelVal. */ diff --git a/tools/python_api/test/test_arrow.py b/tools/python_api/test/test_arrow.py index 6164be8238..a9a4565133 100644 --- a/tools/python_api/test/test_arrow.py +++ b/tools/python_api/test/test_arrow.py @@ -560,7 +560,7 @@ def _test_marries_table(_conn: kuzu.Connection) -> None: assert used_addr_col.to_pylist() == [None, "long long long string", "short str"] _test_node(conn) - # _test_node_rel(conn) + _test_node_rel(conn) # _test_marries_table(conn) def test_to_arrow1(conn: kuzu.Connection) -> None: From 52f2c766558432a65f0581ba987a22b4b039cbc2 Mon Sep 17 00:00:00 2001 From: mxwli Date: Thu, 11 Apr 2024 14:53:37 -0400 Subject: [PATCH 03/12] it's done --- src/common/arrow/arrow_converter.cpp | 8 ++- src/common/arrow/arrow_row_batch.cpp | 23 ++++++--- tools/python_api/test/ground_truth.py | 12 +++-- tools/python_api/test/test_arrow.py | 73 +++++++++++++++------------ 4 files changed, 74 insertions(+), 42 deletions(-) diff --git a/src/common/arrow/arrow_converter.cpp b/src/common/arrow/arrow_converter.cpp index 3e55efd6e7..6211772b76 100644 --- a/src/common/arrow/arrow_converter.cpp +++ b/src/common/arrow/arrow_converter.cpp @@ -64,19 +64,23 @@ void ArrowConverter::setArrowFormatForStruct(ArrowSchemaHolder& rootHolder, Arro void ArrowConverter::setArrowFormatForUnion(ArrowSchemaHolder& rootHolder, ArrowSchema& child, const LogicalType& dataType) { - std::string formatStr = "+us"; + std::string formatStr = "+ud"; child.n_children = (std::int64_t)UnionType::getNumFields(&dataType); rootHolder.nestedChildren.emplace_back(); rootHolder.nestedChildren.back().resize(child.n_children); rootHolder.nestedChildrenPtr.emplace_back(); rootHolder.nestedChildrenPtr.back().resize(child.n_children); + for (auto i = 0u; i < child.n_children; i++) { + rootHolder.nestedChildrenPtr.back()[i] = &rootHolder.nestedChildren.back()[i]; + } + child.children = &rootHolder.nestedChildrenPtr.back()[0]; for (auto i = 0u; i < child.n_children; i++) { initializeChild(*child.children[i]); auto unionFieldType = UnionType::getFieldType(&dataType, i); auto unionFieldName = UnionType::getFieldName(&dataType, i); child.children[i]->name = copyName(rootHolder, unionFieldName); setArrowFormat(rootHolder, *child.children[i], *unionFieldType); - formatStr += std::string(child.children[i]->format); + formatStr += (i == 0u? ":" : ",") + std::to_string(i); } child.format = copyName(rootHolder, formatStr); } diff --git a/src/common/arrow/arrow_row_batch.cpp b/src/common/arrow/arrow_row_batch.cpp index 75d6712524..987e7c57d5 100644 --- a/src/common/arrow/arrow_row_batch.cpp +++ b/src/common/arrow/arrow_row_batch.cpp @@ -82,6 +82,8 @@ void ArrowRowBatch::templateInitializeVector(ArrowVector* // Interestingly, unions don't have their own validity bitmap // Initialize type buffer vector->data.reserve((capacity) * sizeof(std::uint8_t)); + // Initialize offsets buffer + vector->overflow.reserve((capacity) * sizeof(std::int32_t)); // Initialize children for (auto i = 0u; i < UnionType::getNumFields(&type); i++) { auto childVector = createVector(*UnionType::getFieldType(&type, i), capacity); @@ -367,11 +369,17 @@ void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* template<> void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, const LogicalType& type, Value* value, std::int64_t pos) { - auto childType = value->children[0]->val.uint8Val; auto typeBuffer = (std::uint8_t*)vector->data.data(); - typeBuffer[pos] = childType; - appendValue(vector->childData[childType].get(), *UnionType::getFieldType(&type, childType), - value->children[childType + 1].get()); + auto offsetsBuffer = (std::int32_t*)vector->overflow.data(); + for (auto i = 0u; i < UnionType::getNumFields(&type); i++) { + if (*UnionType::getFieldType(&type, i) == *value->children[0]->dataType) { + typeBuffer[pos] = i; + offsetsBuffer[pos] = vector->childData[i]->numValues; + return appendValue(vector->childData[i].get(), *UnionType::getFieldType(&type, i), + value->children[0].get()); + } + } + KU_UNREACHABLE; // We should always be able to find a matching type } template<> @@ -558,7 +566,9 @@ void ArrowRowBatch::templateCopyNullValue(ArrowVector* ve void ArrowRowBatch::copyNullValueUnion(ArrowVector* vector, Value* value, std::int64_t pos) { auto typeBuffer = (std::uint8_t*)vector->data.data(); + auto offsetsBuffer = (std::int32_t*)vector->overflow.data(); typeBuffer[pos] = 0; + offsetsBuffer[pos] = vector->childData[0]->numValues; copyNullValue(vector->childData[0].get(), value->children[0].get(), pos); vector->numNulls++; } @@ -770,15 +780,16 @@ ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector vector.array->private_data = nullptr; vector.array->release = releaseArrowVector; vector.array->n_children = nChildren; + vector.childPointers.resize(nChildren); vector.array->children = vector.childPointers.data(); vector.array->offset = 0; vector.array->dictionary = nullptr; vector.array->buffers = vector.buffers.data(); vector.array->null_count = vector.numNulls; vector.array->length = vector.numValues; - vector.array->n_buffers = 1; + vector.array->n_buffers = 2; vector.array->buffers[0] = vector.data.data(); - vector.childPointers.resize(nChildren); + vector.array->buffers[1] = vector.overflow.data(); for (auto i = 0u; i < nChildren; i++) { auto childType = UnionType::getFieldType(&type, i); vector.childPointers[i] = convertVectorToArray(*vector.childData[i], *childType); diff --git a/tools/python_api/test/ground_truth.py b/tools/python_api/test/ground_truth.py index ea2270995b..72d2fda911 100644 --- a/tools/python_api/test/ground_truth.py +++ b/tools/python_api/test/ground_truth.py @@ -173,8 +173,10 @@ "history": "10 years 5 months 13 hours 24 us", "licenseValidInterval": timedelta(days=1085), "rating": 1.0, + "state": {'revenue': 138, 'location': ['toronto', 'montr,eal'], 'stock': {'price': [96, 56], 'volume': 1000}}, + "info": 3.12, "_LABEL": "organisation", - "_ID": {"offset": 0, "table": 2}, + "_ID": {"offset": 0, "table": 1}, }, 4: { "ID": 4, @@ -185,8 +187,10 @@ "history": "2 years 4 days 10 hours", "licenseValidInterval": timedelta(days=9414), "rating": 0.78, + "state": {'revenue': 152, 'location': ['\"vanco,uver north area\"'], 'stock': {'price': [15, 78, 671], 'volume': 432}}, + "info": "abcd", "_LABEL": "organisation", - "_ID": {"offset": 1, "table": 2}, + "_ID": {"offset": 1, "table": 1}, }, 6: { "ID": 6, @@ -197,8 +201,10 @@ "history": "2 years 4 hours 22 us 34 minutes", "licenseValidInterval": timedelta(days=3, seconds=36000, microseconds=100000), "rating": 0.52, + "state": {'revenue': 558, 'location': ['\'very long city name\'', '\'new york\''], 'stock': {'price': [22], 'volume': 99}}, + "info": date(2023, 12, 15), "_LABEL": "organisation", - "_ID": {"offset": 2, "table": 2}, + "_ID": {"offset": 2, "table": 1}, }, } diff --git a/tools/python_api/test/test_arrow.py b/tools/python_api/test/test_arrow.py index a9a4565133..e9f702e113 100644 --- a/tools/python_api/test/test_arrow.py +++ b/tools/python_api/test/test_arrow.py @@ -472,7 +472,7 @@ def test_to_arrow_complex(conn_db_readonly: ConnDB) -> None: conn, db = conn_db_readonly def _test_node_helper(srcStruct, dstStruct): - assert srcStruct.keys() == dstStruct.keys() + assert set(srcStruct.keys()) == set(dstStruct.keys()) for key in srcStruct.keys(): if type(srcStruct[key]) is float: assert math.fabs(srcStruct[key] - dstStruct[key]) < 1e-5 @@ -500,43 +500,53 @@ def _test_node(_conn: kuzu.Connection) -> None: def _test_node_rel(_conn: kuzu.Connection) -> None: query = "MATCH (a:person)-[e:workAt]->(b:organisation) RETURN a, e, b;" query_result = _conn.execute(query) - arrow_tbl = query_result.get_as_arrow(0) + arrow_tbl = query_result.get_as_arrow(3) assert arrow_tbl.num_columns == 3 a_col = arrow_tbl.column(0) assert len(a_col) == 3 e_col = arrow_tbl.column(1) - assert len(a_col) == 3 + assert len(e_col) == 3 b_col = arrow_tbl.column(2) - assert len(a_col) == 3 - for a, b in zip(p_col.to_pylist(), [ + assert len(b_col) == 3 + for a, b in zip(a_col.to_pylist(), [ ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[3], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[5], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[7]]): _test_node_helper(a, b) - assert e_col.to_pylist() == [ + for a, b in zip(e_col.to_pylist(), [ { - "_src": {"offset": 2, "tableID": 0}, - "_dst": {"offset": 1, "tableID": 2}, - "_id": {"offset": 0, "tableID": 4}, + "_SRC": {"offset": 2, "table": 0}, + "_DST": {"offset": 1, "table": 1}, + "_ID": {"offset": 0, "table": 5}, + "_LABEL": "workAt", + "grading": [3.8, 2.5], + "rating": 8.2, "year": 2015, }, { - "_src": {"offset": 3, "tableID": 0}, - "_dst": {"offset": 2, "tableID": 2}, - "_id": {"offset": 1, "tableID": 4}, + "_SRC": {"offset": 3, "table": 0}, + "_DST": {"offset": 2, "table": 1}, + "_ID": {"offset": 1, "table": 5}, + "_LABEL": "workAt", + "grading": [2.1, 4.4], + "rating": 7.6, "year": 2010, }, { - "_src": {"offset": 4, "tableID": 0}, - "_dst": {"offset": 2, "tableID": 2}, - "_id": {"offset": 2, "tableID": 4}, + "_SRC": {"offset": 4, "table": 0}, + "_DST": {"offset": 2, "table": 1}, + "_ID": {"offset": 2, "table": 5}, + "_LABEL": "workAt", + "grading": [9.2, 3.1], + "rating": 9.2, "year": 2015, - }, - ] + }]): + _test_node_helper(a, b) + for a, b in zip(b_col.to_pylist(), [ - ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[4], - ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[6], - ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[6]]): + ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[4], + ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[6], + ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[6]]): _test_node_helper(a, b) def _test_marries_table(_conn: kuzu.Connection) -> None: @@ -549,22 +559,23 @@ def _test_marries_table(_conn: kuzu.Connection) -> None: assert len(used_addr_col) == 3 assert used_addr_col.to_pylist() == [["toronto"], None, []] - arrow_tbl.column(1) - assert used_addr_col.type == pa.list_(pa.int16(), 2) - assert len(used_addr_col) == 3 - assert used_addr_col.to_pylist() == [[4, 5], [2, 5], [3, 9]] + addr_col = arrow_tbl.column(1) + assert addr_col.type == pa.list_(pa.int16(), 2) + assert len(addr_col) == 3 + assert addr_col.to_pylist() == [[4, 5], [2, 5], [3, 9]] - arrow_tbl.column(2) - assert used_addr_col.type == pa.string() - assert len(used_addr_col) == 3 - assert used_addr_col.to_pylist() == [None, "long long long string", "short str"] + note_col = arrow_tbl.column(2) + assert note_col.type == pa.string() + assert len(note_col) == 3 + assert note_col.to_pylist() == [None, "long long long string", "short str"] _test_node(conn) _test_node_rel(conn) - # _test_marries_table(conn) + _test_marries_table(conn) def test_to_arrow1(conn: kuzu.Connection) -> None: query = "MATCH (a:person)-[e:knows]->(:person) RETURN e.summary" res = conn.execute(query) - arrow_tbl = conn.execute(query).get_as_arrow(-1) - assert arrow_tbl == [] \ No newline at end of file + arrow_tbl = conn.execute(query).get_as_arrow(-1) # what is a chunk size of -1 even supposed to mean? + assert arrow_tbl == [] + \ No newline at end of file From 8062a1dff10748b27fe0fc03ea0e6acf6bb94e43 Mon Sep 17 00:00:00 2001 From: mxwli Date: Thu, 11 Apr 2024 15:18:50 -0400 Subject: [PATCH 04/12] remove whitespace --- tools/python_api/test/test_arrow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/python_api/test/test_arrow.py b/tools/python_api/test/test_arrow.py index e9f702e113..3fd3e03fea 100644 --- a/tools/python_api/test/test_arrow.py +++ b/tools/python_api/test/test_arrow.py @@ -578,4 +578,3 @@ def test_to_arrow1(conn: kuzu.Connection) -> None: res = conn.execute(query) arrow_tbl = conn.execute(query).get_as_arrow(-1) # what is a chunk size of -1 even supposed to mean? assert arrow_tbl == [] - \ No newline at end of file From 5810360a29c8bce0a31b4418ac260694595eb28c Mon Sep 17 00:00:00 2001 From: CI Bot Date: Thu, 11 Apr 2024 19:21:26 +0000 Subject: [PATCH 05/12] Run clang-format --- src/common/arrow/arrow_converter.cpp | 4 ++-- src/common/arrow/arrow_row_batch.cpp | 6 +++--- src/include/common/types/value/rel.h | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/common/arrow/arrow_converter.cpp b/src/common/arrow/arrow_converter.cpp index 6211772b76..8c917eed31 100644 --- a/src/common/arrow/arrow_converter.cpp +++ b/src/common/arrow/arrow_converter.cpp @@ -80,7 +80,7 @@ void ArrowConverter::setArrowFormatForUnion(ArrowSchemaHolder& rootHolder, Arrow auto unionFieldName = UnionType::getFieldName(&dataType, i); child.children[i]->name = copyName(rootHolder, unionFieldName); setArrowFormat(rootHolder, *child.children[i], *unionFieldType); - formatStr += (i == 0u? ":" : ",") + std::to_string(i); + formatStr += (i == 0u ? ":" : ",") + std::to_string(i); } child.format = copyName(rootHolder, formatStr); } @@ -206,7 +206,7 @@ void ArrowConverter::setArrowFormat(ArrowSchemaHolder& rootHolder, ArrowSchema& break; case LogicalTypeID::UNION: setArrowFormatForUnion(rootHolder, child, dataType); - break; + break; default: KU_UNREACHABLE; } diff --git a/src/common/arrow/arrow_row_batch.cpp b/src/common/arrow/arrow_row_batch.cpp index 987e7c57d5..b319a882e9 100644 --- a/src/common/arrow/arrow_row_batch.cpp +++ b/src/common/arrow/arrow_row_batch.cpp @@ -563,8 +563,7 @@ void ArrowRowBatch::templateCopyNullValue(ArrowVector* ve vector->numNulls++; } -void ArrowRowBatch::copyNullValueUnion(ArrowVector* vector, Value* value, - std::int64_t pos) { +void ArrowRowBatch::copyNullValueUnion(ArrowVector* vector, Value* value, std::int64_t pos) { auto typeBuffer = (std::uint8_t*)vector->data.data(); auto offsetsBuffer = (std::int32_t*)vector->overflow.data(); typeBuffer[pos] = 0; @@ -774,7 +773,8 @@ ArrowArray* ArrowRowBatch::convertInternalIDVectorToArray(ArrowVector& vector, template<> ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, const LogicalType& type) { - //since union is a special case, we make the ArrowArray ourselves instead of using createArrayFromVector + // since union is a special case, we make the ArrowArray ourselves instead of using + // createArrayFromVector auto nChildren = UnionType::getNumFields(&type); vector.array = std::make_unique(); vector.array->private_data = nullptr; diff --git a/src/include/common/types/value/rel.h b/src/include/common/types/value/rel.h index a13d9af78c..73eff15cac 100644 --- a/src/include/common/types/value/rel.h +++ b/src/include/common/types/value/rel.h @@ -48,7 +48,7 @@ class RelVal { KUZU_API static Value* getDstNodeIDVal(const Value* val); /** * @return the internal ID value of the RelVal in Value. - */ + */ KUZU_API static Value* getIDVal(const Value* val); /** * @return the label value of the RelVal. From 2a2d82a59b902a1faf35c3bd80091dca89513f6c Mon Sep 17 00:00:00 2001 From: mxwli Date: Thu, 11 Apr 2024 15:22:19 -0400 Subject: [PATCH 06/12] python lint --- tools/python_api/test/test_arrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/python_api/test/test_arrow.py b/tools/python_api/test/test_arrow.py index 3fd3e03fea..ce4b4b36c3 100644 --- a/tools/python_api/test/test_arrow.py +++ b/tools/python_api/test/test_arrow.py @@ -473,7 +473,7 @@ def test_to_arrow_complex(conn_db_readonly: ConnDB) -> None: def _test_node_helper(srcStruct, dstStruct): assert set(srcStruct.keys()) == set(dstStruct.keys()) - for key in srcStruct.keys(): + for key in srcStruct: if type(srcStruct[key]) is float: assert math.fabs(srcStruct[key] - dstStruct[key]) < 1e-5 else: From ad06b45f56ee314648cb2e887ab5c0fbc68f1a18 Mon Sep 17 00:00:00 2001 From: mxwli Date: Thu, 11 Apr 2024 15:37:01 -0400 Subject: [PATCH 07/12] clang-tidy --- src/common/arrow/arrow_row_batch.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/common/arrow/arrow_row_batch.cpp b/src/common/arrow/arrow_row_batch.cpp index b319a882e9..d6d8d4d57e 100644 --- a/src/common/arrow/arrow_row_batch.cpp +++ b/src/common/arrow/arrow_row_batch.cpp @@ -423,7 +423,6 @@ void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* ve std::int64_t propertyId = 4; auto numProperties = RelVal::getNumProperties(value); for (auto i = 0u; i < numProperties; i++) { - auto name = RelVal::getPropertyName(value, i); auto val = RelVal::getPropertyVal(value, i); appendValue(vector->childData[propertyId].get(), *StructType::getFieldTypes(&type)[propertyId], val); From 9f526c090c791f9740558b43e63eda6b93a401a8 Mon Sep 17 00:00:00 2001 From: mxwli Date: Thu, 11 Apr 2024 16:08:02 -0400 Subject: [PATCH 08/12] makefile fix --- Makefile | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Makefile b/Makefile index eb975ad0d1..f1815d16de 100644 --- a/Makefile +++ b/Makefile @@ -111,6 +111,9 @@ nodejs: $(call run-cmake-release, -DBUILD_NODEJS=TRUE) python: + $(call run-cmake-release, -DBUILD_PYTHON=TRUE) + +python-debug: $(call run-cmake-debug, -DBUILD_PYTHON=TRUE) rust: @@ -144,6 +147,9 @@ nodejstest: nodejs pytest: python cmake -E env PYTHONPATH=tools/python_api/build python3 -m pytest -vv tools/python_api/test +pytest-debug: python-debug + cmake -E env PYTHONPATH=tools/python_api/build python3 -m pytest -vv tools/python_api/test + rusttest: rust cd tools/rust_api && cargo test --locked --all-features -- --test-threads=1 From b053263067c450fba0fe182d8767563c2e792f11 Mon Sep 17 00:00:00 2001 From: mxwli Date: Thu, 11 Apr 2024 16:59:57 -0400 Subject: [PATCH 09/12] add map export --- src/common/arrow/arrow_converter.cpp | 12 ++++++++++ src/common/arrow/arrow_row_batch.cpp | 36 ++++++++++++++++++++++++++++ tools/python_api/test/test_arrow.py | 4 ++++ 3 files changed, 52 insertions(+) diff --git a/src/common/arrow/arrow_converter.cpp b/src/common/arrow/arrow_converter.cpp index 8c917eed31..1dc76d4edb 100644 --- a/src/common/arrow/arrow_converter.cpp +++ b/src/common/arrow/arrow_converter.cpp @@ -196,6 +196,18 @@ void ArrowConverter::setArrowFormat(ArrowSchemaHolder& rootHolder, ArrowSchema& child.children[0]->name = "l"; setArrowFormat(rootHolder, **child.children, *ArrayType::getChildType(&dataType)); } break; + case LogicalTypeID::MAP: { + child.format = "+m"; + child.n_children = 1; + rootHolder.nestedChildren.emplace_back(); + rootHolder.nestedChildren.back().resize(1); + rootHolder.nestedChildrenPtr.emplace_back(); + rootHolder.nestedChildrenPtr.back().push_back(&rootHolder.nestedChildren.back()[0]); + initializeChild(rootHolder.nestedChildren.back()[0]); + child.children = &rootHolder.nestedChildrenPtr.back()[0]; + child.children[0]->name = "l"; + setArrowFormat(rootHolder, **child.children, *ListType::getChildType(&dataType)); + } break; case LogicalTypeID::STRUCT: case LogicalTypeID::NODE: case LogicalTypeID::REL: diff --git a/src/common/arrow/arrow_row_batch.cpp b/src/common/arrow/arrow_row_batch.cpp index d6d8d4d57e..5f24b9aed3 100644 --- a/src/common/arrow/arrow_row_batch.cpp +++ b/src/common/arrow/arrow_row_batch.cpp @@ -70,6 +70,12 @@ void ArrowRowBatch::templateInitializeVector(ArrowVector* vector->childData.push_back(std::move(childVector)); } +template<> +void ArrowRowBatch::templateInitializeVector(ArrowVector* vector, + const LogicalType& type, std::int64_t capacity) { + return templateInitializeVector(vector, type, capacity); +} + template<> void ArrowRowBatch::templateInitializeVector(ArrowVector* vector, const LogicalType& type, std::int64_t capacity) { @@ -199,6 +205,9 @@ std::unique_ptr ArrowRowBatch::createVector(const LogicalType& type case LogicalTypeID::ARRAY: { templateInitializeVector(result.get(), type, capacity); } break; + case LogicalTypeID::MAP: { + templateInitializeVector(result.get(), type, capacity); + } break; case LogicalTypeID::STRUCT: { templateInitializeVector(result.get(), type, capacity); } break; @@ -357,6 +366,12 @@ void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* } } +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const LogicalType& type, Value* value, std::int64_t pos) { + return templateCopyNonNullValue(vector, type, value, pos); +} + template<> void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, const LogicalType& type, Value* value, std::int64_t /*pos*/) { @@ -502,6 +517,9 @@ void ArrowRowBatch::copyNonNullValue(ArrowVector* vector, const LogicalType& typ case LogicalTypeID::ARRAY: { templateCopyNonNullValue(vector, type, value, pos); } break; + case LogicalTypeID::MAP: { + templateCopyNonNullValue(vector, type, value, pos); + } break; case LogicalTypeID::STRUCT: { templateCopyNonNullValue(vector, type, value, pos); } break; @@ -555,6 +573,12 @@ void ArrowRowBatch::templateCopyNullValue(ArrowVector* vec vector->numNulls++; } +template<> +void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, + std::int64_t pos) { + return templateCopyNullValue(vector, pos); +} + template<> void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, std::int64_t pos) { @@ -640,6 +664,9 @@ void ArrowRowBatch::copyNullValue(ArrowVector* vector, Value* value, std::int64_ case LogicalTypeID::ARRAY: { templateCopyNullValue(vector, pos); } break; + case LogicalTypeID::MAP: { + templateCopyNullValue(vector, pos); + } break; case LogicalTypeID::INTERNAL_ID: { templateCopyNullValue(vector, pos); } break; @@ -733,6 +760,12 @@ ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector return vector.array.get(); } +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type) { + return templateCreateArray(vector, type); +} + template<> ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, const LogicalType& type) { @@ -883,6 +916,9 @@ ArrowArray* ArrowRowBatch::convertVectorToArray(ArrowVector& vector, const Logic case LogicalTypeID::ARRAY: { return templateCreateArray(vector, type); } + case LogicalTypeID::MAP: { + return templateCreateArray(vector, type); + } case LogicalTypeID::STRUCT: { return templateCreateArray(vector, type); } diff --git a/tools/python_api/test/test_arrow.py b/tools/python_api/test/test_arrow.py index ce4b4b36c3..9e7829bbc1 100644 --- a/tools/python_api/test/test_arrow.py +++ b/tools/python_api/test/test_arrow.py @@ -467,6 +467,10 @@ def _test_with_nulls(_conn: kuzu.Connection, return_type: str, chunk_size: int | _test_with_nulls(conn, "arrow", 12) _test_with_nulls(conn, "pl") +def test_to_arrow_map(conn_db_readonly: ConnDB) -> None: + conn = conn_db_readonly[0] + results = conn.execute("RETURN map([1, 2, 3], [{a: 1, b: 2, c: '3'}, {a: 2, b: 3, c: '4'}, {a: 3, b: 4, c: '5'}])").get_as_arrow(8)[0].to_pylist() + assert results == [[(1, {'a': 1, 'b': 2, 'c': '3'}), (2, {'a': 2, 'b': 3, 'c': '4'}), (3, {'a': 3, 'b': 4, 'c': '5'})]] def test_to_arrow_complex(conn_db_readonly: ConnDB) -> None: conn, db = conn_db_readonly From c7c8a6b82def1b09ddc3b859406be9cc7d96fc55 Mon Sep 17 00:00:00 2001 From: mxwli Date: Fri, 12 Apr 2024 12:03:46 -0400 Subject: [PATCH 10/12] add doc comment --- src/common/arrow/arrow_row_batch.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/common/arrow/arrow_row_batch.cpp b/src/common/arrow/arrow_row_batch.cpp index 5f24b9aed3..9db2310117 100644 --- a/src/common/arrow/arrow_row_batch.cpp +++ b/src/common/arrow/arrow_row_batch.cpp @@ -86,6 +86,7 @@ template<> void ArrowRowBatch::templateInitializeVector(ArrowVector* vector, const LogicalType& type, std::int64_t capacity) { // Interestingly, unions don't have their own validity bitmap + // https://arrow.apache.org/docs/format/Columnar.html#union-layout // Initialize type buffer vector->data.reserve((capacity) * sizeof(std::uint8_t)); // Initialize offsets buffer From 1279988ba1c38f9c7fd44a1bc983fe2df7635b5b Mon Sep 17 00:00:00 2001 From: mxwli Date: Fri, 12 Apr 2024 13:48:54 -0400 Subject: [PATCH 11/12] ignore internal IDs in complex test --- tools/python_api/test/test_arrow.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/python_api/test/test_arrow.py b/tools/python_api/test/test_arrow.py index 9e7829bbc1..6061b5bd6d 100644 --- a/tools/python_api/test/test_arrow.py +++ b/tools/python_api/test/test_arrow.py @@ -478,6 +478,8 @@ def test_to_arrow_complex(conn_db_readonly: ConnDB) -> None: def _test_node_helper(srcStruct, dstStruct): assert set(srcStruct.keys()) == set(dstStruct.keys()) for key in srcStruct: + if key == "_ID": # there isn't any guarantee on the value of _ID, so ignore it + continue if type(srcStruct[key]) is float: assert math.fabs(srcStruct[key] - dstStruct[key]) < 1e-5 else: From 6ab3458544390c7d2340db4f85abc72041a539a5 Mon Sep 17 00:00:00 2001 From: mxwli Date: Fri, 12 Apr 2024 14:28:35 -0400 Subject: [PATCH 12/12] added ID order --- tools/python_api/test/test_arrow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/python_api/test/test_arrow.py b/tools/python_api/test/test_arrow.py index 6061b5bd6d..b87597129b 100644 --- a/tools/python_api/test/test_arrow.py +++ b/tools/python_api/test/test_arrow.py @@ -504,7 +504,7 @@ def _test_node(_conn: kuzu.Connection) -> None: _test_node_helper(a, b) def _test_node_rel(_conn: kuzu.Connection) -> None: - query = "MATCH (a:person)-[e:workAt]->(b:organisation) RETURN a, e, b;" + query = "MATCH (a:person)-[e:workAt]->(b:organisation) RETURN a, e, b ORDER BY a.ID, b.ID" query_result = _conn.execute(query) arrow_tbl = query_result.get_as_arrow(3) assert arrow_tbl.num_columns == 3 @@ -556,7 +556,7 @@ def _test_node_rel(_conn: kuzu.Connection) -> None: _test_node_helper(a, b) def _test_marries_table(_conn: kuzu.Connection) -> None: - query = "MATCH (:person)-[e:marries]->(:person) RETURN e.*" + query = "MATCH (a:person)-[e:marries]->(b:person) RETURN e.* ORDER BY a.ID, b.ID" arrow_tbl = _conn.execute(query).get_as_arrow(0) assert arrow_tbl.num_columns == 3