From 02b0e99d6d0a4130ab031a2cbb24eb7c994fa45c Mon Sep 17 00:00:00 2001 From: Maxwell <49460053+mxwli@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:57:00 -0400 Subject: [PATCH] Fix problematic to_arrow tests (#3257) * coverage * progress * it's done * remove whitespace * Run clang-format * python lint * clang-tidy * makefile fix * add map export * add doc comment * ignore internal IDs in complex test * added ID order --------- Co-authored-by: CI Bot --- Makefile | 8 +- src/common/arrow/arrow_converter.cpp | 44 ++++++- src/common/arrow/arrow_row_batch.cpp | 135 ++++++++++++++++++++- src/common/types/value/rel.cpp | 5 + src/include/common/arrow/arrow_converter.h | 2 + src/include/common/arrow/arrow_row_batch.h | 1 + src/include/common/types/value/rel.h | 4 + tools/python_api/test/ground_truth.py | 74 +++++++---- tools/python_api/test/test_arrow.py | 124 +++++++++++-------- 9 files changed, 319 insertions(+), 78 deletions(-) diff --git a/Makefile b/Makefile index 1eeae88d6a..f1815d16de 100644 --- a/Makefile +++ b/Makefile @@ -113,6 +113,9 @@ nodejs: python: $(call run-cmake-release, -DBUILD_PYTHON=TRUE) +python-debug: + $(call run-cmake-debug, -DBUILD_PYTHON=TRUE) + rust: ifeq ($(OS),Windows_NT) set KUZU_TESTING=1 @@ -142,7 +145,10 @@ nodejstest: nodejs cd tools/nodejs_api && npm test pytest: python - cmake -E env PYTHONPATH=tools/python_api/build python3 -m pytest -v tools/python_api/test + cmake -E env PYTHONPATH=tools/python_api/build python3 -m pytest -vv tools/python_api/test + +pytest-debug: python-debug + cmake -E env PYTHONPATH=tools/python_api/build python3 -m pytest -vv tools/python_api/test rusttest: rust cd tools/rust_api && cargo test --locked --all-features -- --test-threads=1 diff --git a/src/common/arrow/arrow_converter.cpp b/src/common/arrow/arrow_converter.cpp index fcff47e499..1dc76d4edb 100644 --- a/src/common/arrow/arrow_converter.cpp +++ b/src/common/arrow/arrow_converter.cpp @@ -62,6 +62,29 @@ void ArrowConverter::setArrowFormatForStruct(ArrowSchemaHolder& rootHolder, Arro } } +void ArrowConverter::setArrowFormatForUnion(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& dataType) { + std::string formatStr = "+ud"; + child.n_children = (std::int64_t)UnionType::getNumFields(&dataType); + rootHolder.nestedChildren.emplace_back(); + rootHolder.nestedChildren.back().resize(child.n_children); + rootHolder.nestedChildrenPtr.emplace_back(); + rootHolder.nestedChildrenPtr.back().resize(child.n_children); + for (auto i = 0u; i < child.n_children; i++) { + rootHolder.nestedChildrenPtr.back()[i] = &rootHolder.nestedChildren.back()[i]; + } + child.children = &rootHolder.nestedChildrenPtr.back()[0]; + for (auto i = 0u; i < child.n_children; i++) { + initializeChild(*child.children[i]); + auto unionFieldType = UnionType::getFieldType(&dataType, i); + auto unionFieldName = UnionType::getFieldName(&dataType, i); + child.children[i]->name = copyName(rootHolder, unionFieldName); + setArrowFormat(rootHolder, *child.children[i], *unionFieldType); + formatStr += (i == 0u ? ":" : ",") + std::to_string(i); + } + child.format = copyName(rootHolder, formatStr); +} + void ArrowConverter::setArrowFormatForInternalID(ArrowSchemaHolder& rootHolder, ArrowSchema& child, const LogicalType& /*dataType*/) { child.format = "+s"; @@ -76,10 +99,10 @@ void ArrowConverter::setArrowFormatForInternalID(ArrowSchemaHolder& rootHolder, } child.children = &rootHolder.nestedChildrenPtr.back()[0]; initializeChild(*child.children[0]); - child.children[0]->name = copyName(rootHolder, "table"); + child.children[0]->name = copyName(rootHolder, "offset"); setArrowFormat(rootHolder, *child.children[0], *LogicalType::INT64()); initializeChild(*child.children[1]); - child.children[1]->name = copyName(rootHolder, "offset"); + child.children[1]->name = copyName(rootHolder, "table"); setArrowFormat(rootHolder, *child.children[1], *LogicalType::INT64()); } @@ -142,7 +165,7 @@ void ArrowConverter::setArrowFormat(ArrowSchemaHolder& rootHolder, ArrowSchema& child.format = "tsu:"; } break; case LogicalTypeID::INTERVAL: { - child.format = "tDm"; + child.format = "tDu"; } break; case LogicalTypeID::UUID: case LogicalTypeID::STRING: { @@ -173,6 +196,18 @@ void ArrowConverter::setArrowFormat(ArrowSchemaHolder& rootHolder, ArrowSchema& child.children[0]->name = "l"; setArrowFormat(rootHolder, **child.children, *ArrayType::getChildType(&dataType)); } break; + case LogicalTypeID::MAP: { + child.format = "+m"; + child.n_children = 1; + rootHolder.nestedChildren.emplace_back(); + rootHolder.nestedChildren.back().resize(1); + rootHolder.nestedChildrenPtr.emplace_back(); + rootHolder.nestedChildrenPtr.back().push_back(&rootHolder.nestedChildren.back()[0]); + initializeChild(rootHolder.nestedChildren.back()[0]); + child.children = &rootHolder.nestedChildrenPtr.back()[0]; + child.children[0]->name = "l"; + setArrowFormat(rootHolder, **child.children, *ListType::getChildType(&dataType)); + } break; case LogicalTypeID::STRUCT: case LogicalTypeID::NODE: case LogicalTypeID::REL: @@ -181,6 +216,9 @@ void ArrowConverter::setArrowFormat(ArrowSchemaHolder& rootHolder, ArrowSchema& case LogicalTypeID::INTERNAL_ID: setArrowFormatForInternalID(rootHolder, child, dataType); break; + case LogicalTypeID::UNION: + setArrowFormatForUnion(rootHolder, child, dataType); + break; default: KU_UNREACHABLE; } diff --git a/src/common/arrow/arrow_row_batch.cpp b/src/common/arrow/arrow_row_batch.cpp index 5e8208e171..9db2310117 100644 --- a/src/common/arrow/arrow_row_batch.cpp +++ b/src/common/arrow/arrow_row_batch.cpp @@ -70,12 +70,34 @@ void ArrowRowBatch::templateInitializeVector(ArrowVector* vector->childData.push_back(std::move(childVector)); } +template<> +void ArrowRowBatch::templateInitializeVector(ArrowVector* vector, + const LogicalType& type, std::int64_t capacity) { + return templateInitializeVector(vector, type, capacity); +} + template<> void ArrowRowBatch::templateInitializeVector(ArrowVector* vector, const LogicalType& type, std::int64_t capacity) { initializeStructVector(vector, type, capacity); } +template<> +void ArrowRowBatch::templateInitializeVector(ArrowVector* vector, + const LogicalType& type, std::int64_t capacity) { + // Interestingly, unions don't have their own validity bitmap + // https://arrow.apache.org/docs/format/Columnar.html#union-layout + // Initialize type buffer + vector->data.reserve((capacity) * sizeof(std::uint8_t)); + // Initialize offsets buffer + vector->overflow.reserve((capacity) * sizeof(std::int32_t)); + // Initialize children + for (auto i = 0u; i < UnionType::getNumFields(&type); i++) { + auto childVector = createVector(*UnionType::getFieldType(&type, i), capacity); + vector->childData.push_back(std::move(childVector)); + } +} + void ArrowRowBatch::initializeStructVector(ArrowVector* vector, const LogicalType& type, std::int64_t capacity) { initializeNullBits(vector->validity, capacity); @@ -184,9 +206,15 @@ std::unique_ptr ArrowRowBatch::createVector(const LogicalType& type case LogicalTypeID::ARRAY: { templateInitializeVector(result.get(), type, capacity); } break; + case LogicalTypeID::MAP: { + templateInitializeVector(result.get(), type, capacity); + } break; case LogicalTypeID::STRUCT: { templateInitializeVector(result.get(), type, capacity); } break; + case LogicalTypeID::UNION: { + templateInitializeVector(result.get(), type, capacity); + } break; case LogicalTypeID::INTERNAL_ID: { templateInitializeVector(result.get(), type, capacity); } break; @@ -239,6 +267,15 @@ void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, const LogicalT std::memcpy(vector->data.data() + pos * valSize, &value->val, valSize); } +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const LogicalType& /*type*/, Value* value, std::int64_t pos) { + auto destAddr = (int64_t*)(vector->data.data() + pos * sizeof(std::int64_t)); + auto intervalVal = value->val.intervalVal; + *destAddr = intervalVal.micros + intervalVal.days * Interval::MICROS_PER_DAY + + intervalVal.months * Interval::MICROS_PER_MONTH; +} + template<> void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, const LogicalType& /*type*/, Value* value, std::int64_t pos) { @@ -330,6 +367,12 @@ void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* } } +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const LogicalType& type, Value* value, std::int64_t pos) { + return templateCopyNonNullValue(vector, type, value, pos); +} + template<> void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, const LogicalType& type, Value* value, std::int64_t /*pos*/) { @@ -339,6 +382,22 @@ void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* } } +template<> +void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, + const LogicalType& type, Value* value, std::int64_t pos) { + auto typeBuffer = (std::uint8_t*)vector->data.data(); + auto offsetsBuffer = (std::int32_t*)vector->overflow.data(); + for (auto i = 0u; i < UnionType::getNumFields(&type); i++) { + if (*UnionType::getFieldType(&type, i) == *value->children[0]->dataType) { + typeBuffer[pos] = i; + offsetsBuffer[pos] = vector->childData[i]->numValues; + return appendValue(vector->childData[i].get(), *UnionType::getFieldType(&type, i), + value->children[0].get()); + } + } + KU_UNREACHABLE; // We should always be able to find a matching type +} + template<> void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* vector, const LogicalType& /*type*/, Value* value, std::int64_t /*pos*/) { @@ -373,10 +432,14 @@ void ArrowRowBatch::templateCopyNonNullValue(ArrowVector* ve RelVal::getSrcNodeIDVal(value)); appendValue(vector->childData[1].get(), *StructType::getFieldTypes(&type)[1], RelVal::getDstNodeIDVal(value)); - std::int64_t propertyId = 2; - auto numProperties = NodeVal::getNumProperties(value); + appendValue(vector->childData[2].get(), *StructType::getFieldTypes(&type)[2], + RelVal::getLabelVal(value)); + appendValue(vector->childData[3].get(), *StructType::getFieldTypes(&type)[3], + RelVal::getIDVal(value)); + std::int64_t propertyId = 4; + auto numProperties = RelVal::getNumProperties(value); for (auto i = 0u; i < numProperties; i++) { - auto val = NodeVal::getPropertyVal(value, i); + auto val = RelVal::getPropertyVal(value, i); appendValue(vector->childData[propertyId].get(), *StructType::getFieldTypes(&type)[propertyId], val); propertyId++; @@ -455,9 +518,15 @@ void ArrowRowBatch::copyNonNullValue(ArrowVector* vector, const LogicalType& typ case LogicalTypeID::ARRAY: { templateCopyNonNullValue(vector, type, value, pos); } break; + case LogicalTypeID::MAP: { + templateCopyNonNullValue(vector, type, value, pos); + } break; case LogicalTypeID::STRUCT: { templateCopyNonNullValue(vector, type, value, pos); } break; + case LogicalTypeID::UNION: { + templateCopyNonNullValue(vector, type, value, pos); + } break; case LogicalTypeID::INTERNAL_ID: { templateCopyNonNullValue(vector, type, value, pos); } break; @@ -505,6 +574,12 @@ void ArrowRowBatch::templateCopyNullValue(ArrowVector* vec vector->numNulls++; } +template<> +void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, + std::int64_t pos) { + return templateCopyNullValue(vector, pos); +} + template<> void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, std::int64_t pos) { @@ -512,6 +587,15 @@ void ArrowRowBatch::templateCopyNullValue(ArrowVector* ve vector->numNulls++; } +void ArrowRowBatch::copyNullValueUnion(ArrowVector* vector, Value* value, std::int64_t pos) { + auto typeBuffer = (std::uint8_t*)vector->data.data(); + auto offsetsBuffer = (std::int32_t*)vector->overflow.data(); + typeBuffer[pos] = 0; + offsetsBuffer[pos] = vector->childData[0]->numValues; + copyNullValue(vector->childData[0].get(), value->children[0].get(), pos); + vector->numNulls++; +} + void ArrowRowBatch::copyNullValue(ArrowVector* vector, Value* value, std::int64_t pos) { switch (value->dataType->getLogicalTypeID()) { case LogicalTypeID::BOOL: { @@ -581,12 +665,18 @@ void ArrowRowBatch::copyNullValue(ArrowVector* vector, Value* value, std::int64_ case LogicalTypeID::ARRAY: { templateCopyNullValue(vector, pos); } break; + case LogicalTypeID::MAP: { + templateCopyNullValue(vector, pos); + } break; case LogicalTypeID::INTERNAL_ID: { templateCopyNullValue(vector, pos); } break; case LogicalTypeID::STRUCT: { templateCopyNullValue(vector, pos); } break; + case LogicalTypeID::UNION: { + copyNullValueUnion(vector, value, pos); + } break; case LogicalTypeID::NODE: { templateCopyNullValue(vector, pos); } break; @@ -671,6 +761,12 @@ ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector return vector.array.get(); } +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type) { + return templateCreateArray(vector, type); +} + template<> ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, const LogicalType& type) { @@ -707,6 +803,33 @@ ArrowArray* ArrowRowBatch::convertInternalIDVectorToArray(ArrowVector& vector, return vector.array.get(); } +template<> +ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, + const LogicalType& type) { + // since union is a special case, we make the ArrowArray ourselves instead of using + // createArrayFromVector + auto nChildren = UnionType::getNumFields(&type); + vector.array = std::make_unique(); + vector.array->private_data = nullptr; + vector.array->release = releaseArrowVector; + vector.array->n_children = nChildren; + vector.childPointers.resize(nChildren); + vector.array->children = vector.childPointers.data(); + vector.array->offset = 0; + vector.array->dictionary = nullptr; + vector.array->buffers = vector.buffers.data(); + vector.array->null_count = vector.numNulls; + vector.array->length = vector.numValues; + vector.array->n_buffers = 2; + vector.array->buffers[0] = vector.data.data(); + vector.array->buffers[1] = vector.overflow.data(); + for (auto i = 0u; i < nChildren; i++) { + auto childType = UnionType::getFieldType(&type, i); + vector.childPointers[i] = convertVectorToArray(*vector.childData[i], *childType); + } + return vector.array.get(); +} + template<> ArrowArray* ArrowRowBatch::templateCreateArray(ArrowVector& vector, const LogicalType& type) { @@ -794,9 +917,15 @@ ArrowArray* ArrowRowBatch::convertVectorToArray(ArrowVector& vector, const Logic case LogicalTypeID::ARRAY: { return templateCreateArray(vector, type); } + case LogicalTypeID::MAP: { + return templateCreateArray(vector, type); + } case LogicalTypeID::STRUCT: { return templateCreateArray(vector, type); } + case LogicalTypeID::UNION: { + return templateCreateArray(vector, type); + } case LogicalTypeID::INTERNAL_ID: { return templateCreateArray(vector, type); } diff --git a/src/common/types/value/rel.cpp b/src/common/types/value/rel.cpp index e7b3bfefaf..89e1d0a8a3 100644 --- a/src/common/types/value/rel.cpp +++ b/src/common/types/value/rel.cpp @@ -48,6 +48,11 @@ Value* RelVal::getPropertyVal(const Value* val, uint64_t index) { return val->children[index + OFFSET].get(); } +Value* RelVal::getIDVal(const Value* val) { + auto fieldIdx = StructType::getFieldIdx(val->dataType.get(), InternalKeyword::ID); + return val->children[fieldIdx].get(); +} + Value* RelVal::getSrcNodeIDVal(const Value* val) { auto fieldIdx = StructType::getFieldIdx(val->dataType.get(), InternalKeyword::SRC); return val->children[fieldIdx].get(); diff --git a/src/include/common/arrow/arrow_converter.h b/src/include/common/arrow/arrow_converter.h index 5e1126663d..fdbeb15b63 100644 --- a/src/include/common/arrow/arrow_converter.h +++ b/src/include/common/arrow/arrow_converter.h @@ -38,6 +38,8 @@ struct ArrowConverter { static void initializeChild(ArrowSchema& child, const std::string& name = ""); static void setArrowFormatForStruct(ArrowSchemaHolder& rootHolder, ArrowSchema& child, const LogicalType& dataType); + static void setArrowFormatForUnion(ArrowSchemaHolder& rootHolder, ArrowSchema& child, + const LogicalType& dataType); static void setArrowFormatForInternalID(ArrowSchemaHolder& rootHolder, ArrowSchema& child, const LogicalType& dataType); static void setArrowFormat(ArrowSchemaHolder& rootHolder, ArrowSchema& child, diff --git a/src/include/common/arrow/arrow_row_batch.h b/src/include/common/arrow/arrow_row_batch.h index 7e45278c32..116d1ad987 100644 --- a/src/include/common/arrow/arrow_row_batch.h +++ b/src/include/common/arrow/arrow_row_batch.h @@ -77,6 +77,7 @@ class ArrowRowBatch { std::int64_t pos); template static void templateCopyNullValue(ArrowVector* vector, std::int64_t pos); + static void copyNullValueUnion(ArrowVector* vector, Value* value, std::int64_t pos); template static ArrowArray* templateCreateArray(ArrowVector& vector, const LogicalType& type); diff --git a/src/include/common/types/value/rel.h b/src/include/common/types/value/rel.h index 9eaf0fda34..73eff15cac 100644 --- a/src/include/common/types/value/rel.h +++ b/src/include/common/types/value/rel.h @@ -46,6 +46,10 @@ class RelVal { * @return the dst nodeID value of the RelVal in Value. */ KUZU_API static Value* getDstNodeIDVal(const Value* val); + /** + * @return the internal ID value of the RelVal in Value. + */ + KUZU_API static Value* getIDVal(const Value* val); /** * @return the label value of the RelVal. */ diff --git a/tools/python_api/test/ground_truth.py b/tools/python_api/test/ground_truth.py index a587b288c8..72d2fda911 100644 --- a/tools/python_api/test/ground_truth.py +++ b/tools/python_api/test/ground_truth.py @@ -15,8 +15,11 @@ "workedHours": [10, 5], "usedNames": ["Aida"], "courseScoresPerTerm": [[10, 8], [6, 7, 8]], - "_label": "person", - "_id": {"offset": 0, "table": 0}, + "grades": [96, 54, 86, 92], + "height": 1.731, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", + "_LABEL": "person", + "_ID": {"offset": 0, "table": 0}, }, 2: { "ID": 2, @@ -32,8 +35,11 @@ "workedHours": [12, 8], "usedNames": ["Bobby"], "courseScoresPerTerm": [[8, 9], [9, 10]], - "_label": "person", - "_id": {"offset": 1, "table": 0}, + "grades": [98, 42, 93, 88], + "height": 0.99, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a12", + "_LABEL": "person", + "_ID": {"offset": 1, "table": 0}, }, 3: { "ID": 3, @@ -49,8 +55,11 @@ "workedHours": [4, 5], "usedNames": ["Carmen", "Fred"], "courseScoresPerTerm": [[8, 10]], - "_label": "person", - "_id": {"offset": 2, "table": 0}, + "grades": [91, 75, 21, 95], + "height": 1.00, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a13", + "_LABEL": "person", + "_ID": {"offset": 2, "table": 0}, }, 5: { "ID": 5, @@ -66,8 +75,11 @@ "workedHours": [1, 9], "usedNames": ["Wolfeschlegelstein", "Daniel"], "courseScoresPerTerm": [[7, 4], [8, 8], [9]], - "_label": "person", - "_id": {"offset": 3, "table": 0}, + "grades": [76, 88, 99, 89], + "height": 1.30, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a14", + "_LABEL": "person", + "_ID": {"offset": 3, "table": 0}, }, 7: { "ID": 7, @@ -83,8 +95,11 @@ "workedHours": [2], "usedNames": ["Ein"], "courseScoresPerTerm": [[6], [7], [8]], - "_label": "person", - "_id": {"offset": 4, "table": 0}, + "grades": [96, 59, 65, 88], + "height": 1.463, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a15", + "_LABEL": "person", + "_ID": {"offset": 4, "table": 0}, }, 8: { "ID": 8, @@ -100,8 +115,11 @@ "workedHours": [3, 4, 5, 6, 7], "usedNames": ["Fesdwe"], "courseScoresPerTerm": [[8]], - "_label": "person", - "_id": {"offset": 5, "table": 0}, + "grades": [80, 78, 34, 83], + "height": 1.51, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a16", + "_LABEL": "person", + "_ID": {"offset": 5, "table": 0}, }, 9: { "ID": 9, @@ -117,8 +135,11 @@ "workedHours": [1], "usedNames": ["Grad"], "courseScoresPerTerm": [[10]], - "_label": "person", - "_id": {"offset": 6, "table": 0}, + "grades": [43, 83, 67, 43], + "height": 1.6, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a17", + "_LABEL": "person", + "_ID": {"offset": 6, "table": 0}, }, 10: { "ID": 10, @@ -134,8 +155,11 @@ "workedHours": [10, 11, 12, 3, 4, 5, 6, 7], "usedNames": ["Ad", "De", "Hi", "Kye", "Orlan"], "courseScoresPerTerm": [[7], [10], [6, 7]], - "_label": "person", - "_id": {"offset": 7, "table": 0}, + "grades": [77, 64, 100, 54], + "height": 1.323, + "u": "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a18", + "_LABEL": "person", + "_ID": {"offset": 7, "table": 0}, }, } @@ -149,8 +173,10 @@ "history": "10 years 5 months 13 hours 24 us", "licenseValidInterval": timedelta(days=1085), "rating": 1.0, - "_label": "organisation", - "_id": {"offset": 0, "table": 2}, + "state": {'revenue': 138, 'location': ['toronto', 'montr,eal'], 'stock': {'price': [96, 56], 'volume': 1000}}, + "info": 3.12, + "_LABEL": "organisation", + "_ID": {"offset": 0, "table": 1}, }, 4: { "ID": 4, @@ -161,8 +187,10 @@ "history": "2 years 4 days 10 hours", "licenseValidInterval": timedelta(days=9414), "rating": 0.78, - "_label": "organisation", - "_id": {"offset": 1, "table": 2}, + "state": {'revenue': 152, 'location': ['\"vanco,uver north area\"'], 'stock': {'price': [15, 78, 671], 'volume': 432}}, + "info": "abcd", + "_LABEL": "organisation", + "_ID": {"offset": 1, "table": 1}, }, 6: { "ID": 6, @@ -173,8 +201,10 @@ "history": "2 years 4 hours 22 us 34 minutes", "licenseValidInterval": timedelta(days=3, seconds=36000, microseconds=100000), "rating": 0.52, - "_label": "organisation", - "_id": {"offset": 2, "table": 2}, + "state": {'revenue': 558, 'location': ['\'very long city name\'', '\'new york\''], 'stock': {'price': [22], 'volume': 99}}, + "info": date(2023, 12, 15), + "_LABEL": "organisation", + "_ID": {"offset": 2, "table": 1}, }, } diff --git a/tools/python_api/test/test_arrow.py b/tools/python_api/test/test_arrow.py index e579c034f4..b87597129b 100644 --- a/tools/python_api/test/test_arrow.py +++ b/tools/python_api/test/test_arrow.py @@ -3,6 +3,7 @@ from datetime import date, datetime, timedelta from decimal import Decimal from typing import Any +import math import ground_truth import kuzu @@ -29,7 +30,7 @@ "a.eyeSight": {"arrow": pa.float64(), "pl": pl.Float64}, "a.birthdate": {"arrow": pa.date32(), "pl": pl.Date}, "a.registerTime": {"arrow": pa.timestamp("us"), "pl": pl.Datetime("us")}, - "a.lastJobDuration": {"arrow": pa.duration("ms"), "pl": pl.Duration("ms")}, + "a.lastJobDuration": {"arrow": pa.duration("us"), "pl": pl.Duration("us")}, "a.workedHours": {"arrow": pa.list_(pa.int64()), "pl": pl.List(pl.Int64)}, "a.usedNames": {"arrow": pa.list_(pa.string()), "pl": pl.List(pl.String)}, "a.courseScoresPerTerm": {"arrow": pa.list_(pa.list_(pa.int64())), "pl": pl.List(pl.List(pl.Int64))}, @@ -209,14 +210,14 @@ def _test_person_table(_conn: kuzu.Connection, return_type: str, chunk_size: int col_name="a.lastJobDuration", return_type=return_type, expected_values=[ - timedelta(days=99, seconds=36334, microseconds=628000), - timedelta(days=543, seconds=4800), - timedelta(microseconds=125000), - timedelta(days=541, seconds=57600, microseconds=24000), - timedelta(0), - timedelta(days=2016, seconds=68600), - timedelta(microseconds=125000), - timedelta(days=541, seconds=57600, microseconds=24000), + timedelta(days=1082, seconds=46920), + timedelta(days=3750, seconds=46800, microseconds=24), + timedelta(days=2, seconds=1451), + timedelta(days=3750, seconds=46800, microseconds=24), + timedelta(days=2, seconds=1451), + timedelta(seconds=1080, microseconds=24000), + timedelta(days=3750, seconds=46800, microseconds=24), + timedelta(days=1082, seconds=46920), ], ) @@ -466,17 +467,32 @@ def _test_with_nulls(_conn: kuzu.Connection, return_type: str, chunk_size: int | _test_with_nulls(conn, "arrow", 12) _test_with_nulls(conn, "pl") +def test_to_arrow_map(conn_db_readonly: ConnDB) -> None: + conn = conn_db_readonly[0] + results = conn.execute("RETURN map([1, 2, 3], [{a: 1, b: 2, c: '3'}, {a: 2, b: 3, c: '4'}, {a: 3, b: 4, c: '5'}])").get_as_arrow(8)[0].to_pylist() + assert results == [[(1, {'a': 1, 'b': 2, 'c': '3'}), (2, {'a': 2, 'b': 3, 'c': '4'}), (3, {'a': 3, 'b': 4, 'c': '5'})]] def test_to_arrow_complex(conn_db_readonly: ConnDB) -> None: conn, db = conn_db_readonly + def _test_node_helper(srcStruct, dstStruct): + assert set(srcStruct.keys()) == set(dstStruct.keys()) + for key in srcStruct: + if key == "_ID": # there isn't any guarantee on the value of _ID, so ignore it + continue + if type(srcStruct[key]) is float: + assert math.fabs(srcStruct[key] - dstStruct[key]) < 1e-5 + else: + assert srcStruct[key] == dstStruct[key] + + def _test_node(_conn: kuzu.Connection) -> None: query = "MATCH (p:person) RETURN p ORDER BY p.ID" query_result = _conn.execute(query) arrow_tbl = query_result.get_as_arrow() p_col = arrow_tbl.column(0) - assert p_col.to_pylist() == [ + for a, b in zip(p_col.to_pylist(), [ ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[0], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[2], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[3], @@ -484,53 +500,63 @@ def _test_node(_conn: kuzu.Connection) -> None: ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[7], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[8], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[9], - ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[10], - ] + ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[10]]): + _test_node_helper(a, b) def _test_node_rel(_conn: kuzu.Connection) -> None: - query = "MATCH (a:person)-[e:workAt]->(b:organisation) RETURN a, e, b;" + query = "MATCH (a:person)-[e:workAt]->(b:organisation) RETURN a, e, b ORDER BY a.ID, b.ID" query_result = _conn.execute(query) - arrow_tbl = query_result.get_as_arrow(0) + arrow_tbl = query_result.get_as_arrow(3) assert arrow_tbl.num_columns == 3 a_col = arrow_tbl.column(0) assert len(a_col) == 3 e_col = arrow_tbl.column(1) - assert len(a_col) == 3 + assert len(e_col) == 3 b_col = arrow_tbl.column(2) - assert len(a_col) == 3 - assert a_col.to_pylist() == [ + assert len(b_col) == 3 + for a, b in zip(a_col.to_pylist(), [ ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[3], ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[5], - ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[7], - ] - assert e_col.to_pylist() == [ + ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[7]]): + _test_node_helper(a, b) + for a, b in zip(e_col.to_pylist(), [ { - "_src": {"offset": 2, "tableID": 0}, - "_dst": {"offset": 1, "tableID": 2}, - "_id": {"offset": 0, "tableID": 4}, + "_SRC": {"offset": 2, "table": 0}, + "_DST": {"offset": 1, "table": 1}, + "_ID": {"offset": 0, "table": 5}, + "_LABEL": "workAt", + "grading": [3.8, 2.5], + "rating": 8.2, "year": 2015, }, { - "_src": {"offset": 3, "tableID": 0}, - "_dst": {"offset": 2, "tableID": 2}, - "_id": {"offset": 1, "tableID": 4}, + "_SRC": {"offset": 3, "table": 0}, + "_DST": {"offset": 2, "table": 1}, + "_ID": {"offset": 1, "table": 5}, + "_LABEL": "workAt", + "grading": [2.1, 4.4], + "rating": 7.6, "year": 2010, }, { - "_src": {"offset": 4, "tableID": 0}, - "_dst": {"offset": 2, "tableID": 2}, - "_id": {"offset": 2, "tableID": 4}, + "_SRC": {"offset": 4, "table": 0}, + "_DST": {"offset": 2, "table": 1}, + "_ID": {"offset": 2, "table": 5}, + "_LABEL": "workAt", + "grading": [9.2, 3.1], + "rating": 9.2, "year": 2015, - }, - ] - assert b_col.to_pylist() == [ + }]): + _test_node_helper(a, b) + + for a, b in zip(b_col.to_pylist(), [ ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[4], ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[6], - ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[6], - ] + ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[6]]): + _test_node_helper(a, b) def _test_marries_table(_conn: kuzu.Connection) -> None: - query = "MATCH (:person)-[e:marries]->(:person) RETURN e.*" + query = "MATCH (a:person)-[e:marries]->(b:person) RETURN e.* ORDER BY a.ID, b.ID" arrow_tbl = _conn.execute(query).get_as_arrow(0) assert arrow_tbl.num_columns == 3 @@ -539,22 +565,22 @@ def _test_marries_table(_conn: kuzu.Connection) -> None: assert len(used_addr_col) == 3 assert used_addr_col.to_pylist() == [["toronto"], None, []] - arrow_tbl.column(1) - assert used_addr_col.type == pa.list_(pa.int16(), 2) - assert len(used_addr_col) == 3 - assert used_addr_col.to_pylist() == [[4, 5], [2, 5], [3, 9]] + addr_col = arrow_tbl.column(1) + assert addr_col.type == pa.list_(pa.int16(), 2) + assert len(addr_col) == 3 + assert addr_col.to_pylist() == [[4, 5], [2, 5], [3, 9]] - arrow_tbl.column(2) - assert used_addr_col.type == pa.string() - assert len(used_addr_col) == 3 - assert used_addr_col.to_pylist() == [None, "long long long string", "short str"] + note_col = arrow_tbl.column(2) + assert note_col.type == pa.string() + assert len(note_col) == 3 + assert note_col.to_pylist() == [None, "long long long string", "short str"] - # _test_node(conn) - # _test_node_rel(conn) - # _test_marries_table(conn) + _test_node(conn) + _test_node_rel(conn) + _test_marries_table(conn) - def test_to_arrow1(conn_db_readonly: ConnDB) -> None: - conn, db = conn_db_readonly + def test_to_arrow1(conn: kuzu.Connection) -> None: query = "MATCH (a:person)-[e:knows]->(:person) RETURN e.summary" - arrow_tbl = conn.execute(query).get_as_arrow(-1) + res = conn.execute(query) + arrow_tbl = conn.execute(query).get_as_arrow(-1) # what is a chunk size of -1 even supposed to mean? assert arrow_tbl == []