From fd3760d005702c4c4019b7df600fdf4aa67d6f79 Mon Sep 17 00:00:00 2001 From: xiyang Date: Mon, 22 Apr 2024 23:26:27 -0400 Subject: [PATCH] Resolve default any type --- src/binder/bind/bind_graph_pattern.cpp | 8 +- src/binder/bind/read/bind_unwind.cpp | 15 ++- src/binder/expression/CMakeLists.txt | 3 +- .../expression/parameter_expression.cpp | 2 +- src/binder/expression/variable_expression.cpp | 22 ++++ src/binder/expression_binder.cpp | 42 +++++++- src/common/types/types.cpp | 24 +++++ src/common/types/value/value.cpp | 2 +- src/include/binder/expression/expression.h | 3 +- .../binder/expression/node_rel_expression.h | 49 +++++---- .../binder/expression/variable_expression.h | 8 +- src/include/common/types/types.h | 49 +++++---- test/test_files/function/cast.test | 33 ++++++ tools/python_api/test/test_df_pyarrow.py | 42 ++++---- tools/python_api/test/test_issue.py | 102 ++++++++++++++++++ 15 files changed, 321 insertions(+), 83 deletions(-) create mode 100644 src/binder/expression/variable_expression.cpp create mode 100644 test/test_files/function/cast.test create mode 100644 tools/python_api/test/test_issue.py diff --git a/src/binder/bind/bind_graph_pattern.cpp b/src/binder/bind/bind_graph_pattern.cpp index 4390232599..f796b0f131 100644 --- a/src/binder/bind/bind_graph_pattern.cpp +++ b/src/binder/bind/bind_graph_pattern.cpp @@ -307,7 +307,7 @@ std::shared_ptr Binder::createNonRecursiveQueryRel(const std::str fields.emplace_back(property->getPropertyName(), property->getDataType().copy()); } auto extraInfo = std::make_unique(std::move(fields)); - RelType::setExtraTypeInfo(queryRel->getDataTypeReference(), std::move(extraInfo)); + queryRel->setExtraTypeInfo(std::move(extraInfo)); return queryRel; } @@ -357,7 +357,7 @@ std::shared_ptr Binder::createRecursiveQueryRel(const parser::Rel } bindRecursiveRelProjectionList(nodeProjectionList, nodeFields); auto nodeExtraInfo = std::make_unique(std::move(nodeFields)); - node->getDataTypeReference().setExtraTypeInfo(std::move(nodeExtraInfo)); + node->setExtraTypeInfo(std::move(nodeExtraInfo)); auto nodeCopy = createQueryNode(recursivePatternInfo->nodeName, std::vector{nodeTableIDs.begin(), nodeTableIDs.end()}); // Bind intermediate rel @@ -384,7 +384,7 @@ std::shared_ptr Binder::createRecursiveQueryRel(const parser::Rel relFields.emplace_back(InternalKeyword::ID, LogicalType::INTERNAL_ID()); bindRecursiveRelProjectionList(relProjectionList, relFields); auto relExtraInfo = std::make_unique(std::move(relFields)); - rel->getDataTypeReference().setExtraTypeInfo(std::move(relExtraInfo)); + rel->setExtraTypeInfo(std::move(relExtraInfo)); // Bind predicates in {}, e.g. [e* {date=1999-01-01}] std::shared_ptr relPredicate; for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) { @@ -569,7 +569,7 @@ std::shared_ptr Binder::createQueryNode(const std::string& parse fieldTypes.emplace_back(property->dataType.copy()); } auto extraInfo = std::make_unique(fieldNames, fieldTypes); - NodeType::setExtraTypeInfo(queryNode->getDataTypeReference(), std::move(extraInfo)); + queryNode->setExtraTypeInfo(std::move(extraInfo)); return queryNode; } diff --git a/src/binder/bind/read/bind_unwind.cpp b/src/binder/bind/read/bind_unwind.cpp index fec228019c..97f2138ab8 100644 --- a/src/binder/bind/read/bind_unwind.cpp +++ b/src/binder/bind/read/bind_unwind.cpp @@ -9,12 +9,23 @@ using namespace kuzu::common; namespace kuzu { namespace binder { +// E.g. UNWIND $1. We cannot validate $1 has data type LIST until we see the actual parameter. +static bool skipDataTypeValidation(const Expression& expr) { + return expr.expressionType == ExpressionType::PARAMETER && + expr.getDataType().getLogicalTypeID() == LogicalTypeID::ANY; +} + std::unique_ptr Binder::bindUnwindClause(const ReadingClause& readingClause) { auto& unwindClause = readingClause.constCast(); auto boundExpression = expressionBinder.bindExpression(*unwindClause.getExpression()); - ExpressionUtil::validateDataType(*boundExpression, LogicalTypeID::LIST); auto aliasName = unwindClause.getAlias(); - auto alias = createVariable(aliasName, *ListType::getChildType(&boundExpression->dataType)); + std::shared_ptr alias; + if (!skipDataTypeValidation(*boundExpression)) { + ExpressionUtil::validateDataType(*boundExpression, LogicalTypeID::LIST); + alias = createVariable(aliasName, *ListType::getChildType(&boundExpression->dataType)); + } else { + alias = createVariable(aliasName, *LogicalType::ANY()); + } std::shared_ptr idExpr = nullptr; if (scope.hasMemorizedTableIDs(boundExpression->getAlias())) { auto tableIDs = scope.getMemorizedTableIDs(boundExpression->getAlias()); diff --git a/src/binder/expression/CMakeLists.txt b/src/binder/expression/CMakeLists.txt index ee106d6723..dcb9d3e59b 100644 --- a/src/binder/expression/CMakeLists.txt +++ b/src/binder/expression/CMakeLists.txt @@ -6,7 +6,8 @@ add_library( expression_util.cpp function_expression.cpp literal_expression.cpp - parameter_expression.cpp) + parameter_expression.cpp + variable_expression.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/binder/expression/parameter_expression.cpp b/src/binder/expression/parameter_expression.cpp index c0a93cbf2e..5f7435e8a1 100644 --- a/src/binder/expression/parameter_expression.cpp +++ b/src/binder/expression/parameter_expression.cpp @@ -8,7 +8,7 @@ using namespace common; namespace binder { void ParameterExpression::cast(const LogicalType& type) { - if (dataType.getLogicalTypeID() != LogicalTypeID::ANY) { + if (!dataType.containsAny()) { // LCOV_EXCL_START throw BinderException( stringFormat("Cannot change parameter expression data type from {} to {}.", diff --git a/src/binder/expression/variable_expression.cpp b/src/binder/expression/variable_expression.cpp new file mode 100644 index 0000000000..cda127c8a6 --- /dev/null +++ b/src/binder/expression/variable_expression.cpp @@ -0,0 +1,22 @@ +#include "binder/expression/variable_expression.h" + +#include "common/exception/binder.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace binder { + +void VariableExpression::cast(const LogicalType& type) { + if (!dataType.containsAny()) { + // LCOV_EXCL_START + throw BinderException( + stringFormat("Cannot change variable expression data type from {} to {}.", + dataType.toString(), type.toString())); + // LCOV_EXCL_STOP + } + dataType = type; +} + +} // namespace binder +} // namespace kuzu diff --git a/src/binder/expression_binder.cpp b/src/binder/expression_binder.cpp index e5f6ccac12..fb248d70ff 100644 --- a/src/binder/expression_binder.cpp +++ b/src/binder/expression_binder.cpp @@ -81,12 +81,50 @@ static std::string unsupportedImplicitCastException(const Expression& expression expression.toString(), expression.dataType.toString(), targetTypeStr); } +static bool compatible(const LogicalType& type, const LogicalType& target) { + if (type.getLogicalTypeID() == LogicalTypeID::ANY) { + return true; + } + if (type.getLogicalTypeID() != target.getLogicalTypeID()) { + return false; + } + switch (type.getLogicalTypeID()) { + case LogicalTypeID::LIST: { + return compatible(*ListType::getChildType(&type), *ListType::getChildType(&target)); + } + case LogicalTypeID::ARRAY: { + return compatible(*ArrayType::getChildType(&type), *ArrayType::getChildType(&target)); + } + case LogicalTypeID::STRUCT: { + if (StructType::getNumFields(&type) != StructType::getNumFields(&target)) { + return false; + } + for (auto i = 0u; i < StructType::getNumFields(&type); ++i) { + if (!compatible(*StructType::getField(&type, i)->getType(), + *StructType::getField(&target, i)->getType())) { + return false; + } + } + return true; + } + case LogicalTypeID::RDF_VARIANT: + case LogicalTypeID::UNION: + case LogicalTypeID::MAP: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + return false; + default: + return true; + } +} + std::shared_ptr ExpressionBinder::implicitCastIfNecessary( const std::shared_ptr& expression, const LogicalType& targetType) { - if (targetType.getLogicalTypeID() == LogicalTypeID::ANY || expression->dataType == targetType) { + if (expression->dataType == targetType || targetType.containsAny()) { // No need to cast. return expression; } - if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) { + if (compatible(expression->getDataType(), targetType)) { expression->cast(targetType); return expression; } diff --git a/src/common/types/types.cpp b/src/common/types/types.cpp index 42e75c26c3..a902c6a7f7 100644 --- a/src/common/types/types.cpp +++ b/src/common/types/types.cpp @@ -102,6 +102,10 @@ uint32_t PhysicalTypeUtils::getFixedTypeSize(PhysicalTypeID physicalType) { } } +bool ListTypeInfo::containsAny() const { + return childType->containsAny(); +} + bool ListTypeInfo::operator==(const ExtraTypeInfo& other) const { auto otherListTypeInfo = ku_dynamic_cast(&other); if (otherListTypeInfo) { @@ -147,6 +151,10 @@ void ArrayTypeInfo::serializeInternal(Serializer& serializer) const { serializer.serializeValue(numElements); } +bool StructField::containsAny() const { + return type->containsAny(); +} + bool StructField::operator==(const StructField& other) const { return *type == *other.type; } @@ -238,6 +246,15 @@ std::vector StructTypeInfo::getStructFields() const { return structFields; } +bool StructTypeInfo::containsAny() const { + for (auto& field : fields) { + if (field.containsAny()) { + return true; + } + } + return false; +} + bool StructTypeInfo::operator==(const ExtraTypeInfo& other) const { auto otherStructTypeInfo = ku_dynamic_cast(&other); if (otherStructTypeInfo) { @@ -314,6 +331,13 @@ LogicalType::LogicalType(const LogicalType& other) { } } +bool LogicalType::containsAny() const { + if (extraTypeInfo != nullptr) { + return extraTypeInfo->containsAny(); + } + return typeID == LogicalTypeID::ANY; +} + LogicalType& LogicalType::operator=(const LogicalType& other) { // Reuse the copy constructor and move assignment operator. LogicalType copy(other); diff --git a/src/common/types/value/value.cpp b/src/common/types/value/value.cpp index c9eb92836e..f6ca3b3687 100644 --- a/src/common/types/value/value.cpp +++ b/src/common/types/value/value.cpp @@ -72,7 +72,7 @@ bool Value::operator==(const Value& rhs) const { } void Value::setDataType(const LogicalType& dataType_) { - KU_ASSERT(dataType->getLogicalTypeID() == LogicalTypeID::ANY); + KU_ASSERT(dataType->containsAny()); dataType = dataType_.copy(); } diff --git a/src/include/binder/expression/expression.h b/src/include/binder/expression/expression.h index 5caee6c11a..c46f4e1a5e 100644 --- a/src/include/binder/expression/expression.h +++ b/src/include/binder/expression/expression.h @@ -63,8 +63,9 @@ class Expression : public std::enable_shared_from_this { } virtual void cast(const common::LogicalType& type); + // NOTE: Avoid using the following unsafe getter. It is meant for resolving ANY data type only. + common::LogicalType& getDataTypeUnsafe() { return dataType; } common::LogicalType getDataType() const { return dataType; } - common::LogicalType& getDataTypeReference() { return dataType; } bool hasAlias() const { return !alias.empty(); } std::string getAlias() const { return alias; } diff --git a/src/include/binder/expression/node_rel_expression.h b/src/include/binder/expression/node_rel_expression.h index 3ef320aa93..6db09b49cc 100644 --- a/src/include/binder/expression/node_rel_expression.h +++ b/src/include/binder/expression/node_rel_expression.h @@ -15,12 +15,16 @@ class NodeOrRelExpression : public Expression { variableName(std::move(variableName)), tableIDs{std::move(tableIDs)} {} ~NodeOrRelExpression() override = default; - inline std::string getVariableName() const { return variableName; } - - inline void setTableIDs(common::table_id_vector_t tableIDs) { - this->tableIDs = std::move(tableIDs); + // Note: ideally I would try to remove this function. But for now, we have to create type + // after expression. + void setExtraTypeInfo(std::unique_ptr info) { + dataType.setExtraTypeInfo(std::move(info)); } - inline void addTableIDs(const common::table_id_vector_t& tableIDsToAdd) { + + std::string getVariableName() const { return variableName; } + + void setTableIDs(common::table_id_vector_t tableIDs_) { tableIDs = std::move(tableIDs_); } + void addTableIDs(const common::table_id_vector_t& tableIDsToAdd) { auto tableIDsSet = getTableIDsSet(); for (auto tableID : tableIDsToAdd) { if (!tableIDsSet.contains(tableID)) { @@ -29,35 +33,34 @@ class NodeOrRelExpression : public Expression { } } - inline bool isMultiLabeled() const { return tableIDs.size() > 1; } - inline uint32_t getNumTableIDs() const { return tableIDs.size(); } - inline std::vector getTableIDs() const { return tableIDs; } - inline std::unordered_set getTableIDsSet() const { + bool isMultiLabeled() const { return tableIDs.size() > 1; } + uint32_t getNumTableIDs() const { return tableIDs.size(); } + std::vector getTableIDs() const { return tableIDs; } + std::unordered_set getTableIDsSet() const { return {tableIDs.begin(), tableIDs.end()}; } - inline common::table_id_t getSingleTableID() const { + common::table_id_t getSingleTableID() const { KU_ASSERT(tableIDs.size() == 1); return tableIDs[0]; } - inline void addPropertyExpression(const std::string& propertyName, + void addPropertyExpression(const std::string& propertyName, std::unique_ptr property) { KU_ASSERT(!propertyNameToIdx.contains(propertyName)); propertyNameToIdx.insert({propertyName, propertyExprs.size()}); propertyExprs.push_back(std::move(property)); } - inline bool hasPropertyExpression(const std::string& propertyName) const { + bool hasPropertyExpression(const std::string& propertyName) const { return propertyNameToIdx.contains(propertyName); } - inline std::shared_ptr getPropertyExpression( - const std::string& propertyName) const { + std::shared_ptr getPropertyExpression(const std::string& propertyName) const { KU_ASSERT(propertyNameToIdx.contains(propertyName)); return propertyExprs[propertyNameToIdx.at(propertyName)]->copy(); } - inline const std::vector>& getPropertyExprsRef() const { + const std::vector>& getPropertyExprsRef() const { return propertyExprs; } - inline expression_vector getPropertyExprs() const { + expression_vector getPropertyExprs() const { expression_vector result; for (auto& expr : propertyExprs) { result.push_back(expr->copy()); @@ -65,27 +68,27 @@ class NodeOrRelExpression : public Expression { return result; } - inline void setLabelExpression(std::shared_ptr expression) { + void setLabelExpression(std::shared_ptr expression) { labelExpression = std::move(expression); } - inline std::shared_ptr getLabelExpression() const { return labelExpression; } + std::shared_ptr getLabelExpression() const { return labelExpression; } - inline void addPropertyDataExpr(std::string propertyName, std::shared_ptr expr) { + void addPropertyDataExpr(std::string propertyName, std::shared_ptr expr) { propertyDataExprs.insert({propertyName, expr}); } - inline const std::unordered_map>& + const std::unordered_map>& getPropertyDataExprRef() const { return propertyDataExprs; } - inline bool hasPropertyDataExpr(const std::string& propertyName) const { + bool hasPropertyDataExpr(const std::string& propertyName) const { return propertyDataExprs.contains(propertyName); } - inline std::shared_ptr getPropertyDataExpr(const std::string& propertyName) const { + std::shared_ptr getPropertyDataExpr(const std::string& propertyName) const { KU_ASSERT(propertyDataExprs.contains(propertyName)); return propertyDataExprs.at(propertyName); } - inline std::string toStringInternal() const final { return variableName; } + std::string toStringInternal() const final { return variableName; } protected: std::string variableName; diff --git a/src/include/binder/expression/variable_expression.h b/src/include/binder/expression/variable_expression.h index 45b7c5d89b..96cdedcfe4 100644 --- a/src/include/binder/expression/variable_expression.h +++ b/src/include/binder/expression/variable_expression.h @@ -12,11 +12,13 @@ class VariableExpression : public Expression { : Expression{common::ExpressionType::VARIABLE, std::move(dataType), std::move(uniqueName)}, variableName{std::move(variableName)} {} - inline std::string getVariableName() const { return variableName; } + std::string getVariableName() const { return variableName; } - inline std::string toStringInternal() const final { return variableName; } + void cast(const common::LogicalType& type) override; - inline std::unique_ptr copy() const final { + std::string toStringInternal() const final { return variableName; } + + std::unique_ptr copy() const final { return std::make_unique(*dataType.copy(), uniqueName, variableName); } diff --git a/src/include/common/types/types.h b/src/include/common/types/types.h index bed4497585..8f9eb5fa4f 100644 --- a/src/include/common/types/types.h +++ b/src/include/common/types/types.h @@ -173,7 +173,9 @@ class ExtraTypeInfo { public: virtual ~ExtraTypeInfo() = default; - inline void serialize(Serializer& serializer) const { serializeInternal(serializer); } + void serialize(Serializer& serializer) const { serializeInternal(serializer); } + + virtual bool containsAny() const = 0; virtual bool operator==(const ExtraTypeInfo& other) const = 0; @@ -188,8 +190,13 @@ class ListTypeInfo : public ExtraTypeInfo { ListTypeInfo() = default; explicit ListTypeInfo(std::unique_ptr childType) : childType{std::move(childType)} {} - inline LogicalType* getChildType() const { return childType.get(); } + + LogicalType* getChildType() const { return childType.get(); } + + bool containsAny() const override; + bool operator==(const ExtraTypeInfo& other) const override; + std::unique_ptr copy() const override; static std::unique_ptr deserialize(Deserializer& deserializer); @@ -206,9 +213,13 @@ class ArrayTypeInfo : public ListTypeInfo { ArrayTypeInfo() = default; explicit ArrayTypeInfo(std::unique_ptr childType, uint64_t numElements) : ListTypeInfo{std::move(childType)}, numElements{numElements} {} - inline uint64_t getNumElements() const { return numElements; } + + uint64_t getNumElements() const { return numElements; } + bool operator==(const ExtraTypeInfo& other) const override; + static std::unique_ptr deserialize(Deserializer& deserializer); + std::unique_ptr copy() const override; private: @@ -224,11 +235,14 @@ class StructField { StructField(std::string name, std::unique_ptr type) : name{std::move(name)}, type{std::move(type)} {}; - inline bool operator!=(const StructField& other) const { return !(*this == other); } - inline std::string getName() const { return name; } - inline LogicalType* getType() const { return type.get(); } + std::string getName() const { return name; } + + LogicalType* getType() const { return type.get(); } + + bool containsAny() const; bool operator==(const StructField& other) const; + bool operator!=(const StructField& other) const { return !(*this == other); } void serialize(Serializer& serializer) const; @@ -252,10 +266,14 @@ class StructTypeInfo : public ExtraTypeInfo { struct_field_idx_t getStructFieldIdx(std::string fieldName) const; const StructField* getStructField(struct_field_idx_t idx) const; const StructField* getStructField(const std::string& fieldName) const; + std::vector getStructFields() const; + LogicalType* getChildType(struct_field_idx_t idx) const; std::vector getChildrenTypes() const; std::vector getChildrenNames() const; - std::vector getStructFields() const; + + bool containsAny() const override; + bool operator==(const ExtraTypeInfo& other) const override; static std::unique_ptr deserialize(Deserializer& deserializer); @@ -292,6 +310,7 @@ class LogicalType { static LogicalType fromString(const std::string& str); KUZU_API LogicalTypeID getLogicalTypeID() const { return typeID; } + bool containsAny() const; PhysicalTypeID getPhysicalType() const { return physicalType; } static PhysicalTypeID getPhysicalType(LogicalTypeID logicalType); @@ -464,22 +483,6 @@ struct ArrayType { } }; -struct NodeType { - static inline void setExtraTypeInfo(LogicalType& type, - std::unique_ptr extraTypeInfo) { - KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::NODE); - type.setExtraTypeInfo(std::move(extraTypeInfo)); - } -}; - -struct RelType { - static inline void setExtraTypeInfo(LogicalType& type, - std::unique_ptr extraTypeInfo) { - KU_ASSERT(type.getLogicalTypeID() == LogicalTypeID::REL); - type.setExtraTypeInfo(std::move(extraTypeInfo)); - } -}; - struct StructType { static inline std::vector getFieldTypes(const LogicalType* type) { KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT); diff --git a/test/test_files/function/cast.test b/test/test_files/function/cast.test new file mode 100644 index 0000000000..7b633e28fc --- /dev/null +++ b/test/test_files/function/cast.test @@ -0,0 +1,33 @@ +-GROUP CastFunction +-DATASET CSV tinysnb + +-- + +-CASE NestTypeImplicitCast + +-LOG ListAny +-STATEMENT RETURN [], [[]], [{'b': NULL}]; +---- 1 +[]|[[]]|[{b: }] + +-LOG StructAny +-STATEMENT RETURN {'a': NULL}, {'c': []}, {'d': {'x': NULL}}; +---- 1 +{a: }|{c: []}|{d: {x: }} + +-LOG MapAny +-STATEMENT RETURN map([1,2], [NULL, NULL]), map([1,2], ['a', NULL]); +---- 1 +{1=, 2=}|{1=a, 2=} +-STATEMENT RETURN map([NULL, NULL], [1, 2]); +---- error +Runtime exception: Null value key is not allowed in map. + +-LOG ListImplicitCast +-STATEMENT CREATE NODE TABLE T0(a SERIAL, b INT64[], c DOUBLE[], d STRING[], PRIMARY KEY(a)); +---- ok +-STATEMENT CREATE (:T0 {b: [1.0, 2.0], c: [1, 2], d: [3, 5]}); +---- ok +-STATEMENT MATCH (a:T0) RETURN a.a, a.b, a.c, a.d; +---- 1 +0|[1,2]|[1.000000,2.000000]|[3,5] diff --git a/tools/python_api/test/test_df_pyarrow.py b/tools/python_api/test/test_df_pyarrow.py index 51cb05c58f..aef8348b1f 100644 --- a/tools/python_api/test/test_df_pyarrow.py +++ b/tools/python_api/test/test_df_pyarrow.py @@ -514,22 +514,20 @@ def test_pyarrow_struct_offset(conn_db_readonly: ConnDB) -> None: assert idx == len(index) -def test_pyarrow_union_sparse(conn_db_readonly : ConnDB) -> None: +def test_pyarrow_union_sparse(conn_db_readonly: ConnDB) -> None: conn, db = conn_db_readonly random.seed(100) datalength = 4096 index = pa.array(range(datalength)) type_codes = pa.array([random.randint(0, 2) for i in range(datalength)], type=pa.int8()) - arr1 = pa.array([generate_primitive('int32[pyarrow]') for i in range(datalength + 1)], type=pa.int32()) + arr1 = pa.array([generate_primitive("int32[pyarrow]") for i in range(datalength + 1)], type=pa.int32()) arr2 = pa.array([generate_string(random.randint(1, 10)) for i in range(datalength + 2)]) - arr3 = pa.array([generate_primitive('float32[pyarrow]') for j in range(datalength + 3)]) - col1 = pa.UnionArray.from_sparse(type_codes, [ - arr1.slice(1, datalength), arr2.slice(2, datalength), arr3.slice(3, datalength)]) - df = pd.DataFrame({ - 'index': arrowtopd(index), - 'col1': arrowtopd(col1) - }) - result = conn.execute('LOAD FROM df RETURN * ORDER BY index') + arr3 = pa.array([generate_primitive("float32[pyarrow]") for j in range(datalength + 3)]) + col1 = pa.UnionArray.from_sparse( + type_codes, [arr1.slice(1, datalength), arr2.slice(2, datalength), arr3.slice(3, datalength)] + ) + df = pd.DataFrame({"index": arrowtopd(index), "col1": arrowtopd(col1)}) + result = conn.execute("LOAD FROM df RETURN * ORDER BY index") idx = 0 while result.has_next(): assert idx < len(index) @@ -540,7 +538,8 @@ def test_pyarrow_union_sparse(conn_db_readonly : ConnDB) -> None: assert idx == len(index) -def test_pyarrow_union_dense(conn_db_readonly : ConnDB) -> None: + +def test_pyarrow_union_dense(conn_db_readonly: ConnDB) -> None: conn, db = conn_db_readonly random.seed(100) datalength = 4096 @@ -548,21 +547,19 @@ def test_pyarrow_union_dense(conn_db_readonly : ConnDB) -> None: _type_codes = [random.randint(0, 2) for i in range(datalength)] type_codes = pa.array(_type_codes, type=pa.int8()) _offsets = [0 for _ in range(datalength)] - _cnt = [0,0,0] + _cnt = [0, 0, 0] for i in range(len(_type_codes)): _offsets[i] = _cnt[_type_codes[i]] _cnt[_type_codes[i]] += 1 offsets = pa.array(_offsets, type=pa.int32()) - arr1 = pa.array([generate_primitive('int32[pyarrow]') for i in range(datalength + 1)], type=pa.int32()) + arr1 = pa.array([generate_primitive("int32[pyarrow]") for i in range(datalength + 1)], type=pa.int32()) arr2 = pa.array([generate_string(random.randint(1, 10)) for i in range(datalength + 2)]) - arr3 = pa.array([generate_primitive('float32[pyarrow]') for j in range(datalength + 3)]) - col1 = pa.UnionArray.from_dense(type_codes, offsets, [ - arr1.slice(1, datalength), arr2.slice(2, datalength), arr3.slice(3, datalength)]) - df = pd.DataFrame({ - 'index': arrowtopd(index), - 'col1': arrowtopd(col1) - }) - result = conn.execute('LOAD FROM df RETURN * ORDER BY index') + arr3 = pa.array([generate_primitive("float32[pyarrow]") for j in range(datalength + 3)]) + col1 = pa.UnionArray.from_dense( + type_codes, offsets, [arr1.slice(1, datalength), arr2.slice(2, datalength), arr3.slice(3, datalength)] + ) + df = pd.DataFrame({"index": arrowtopd(index), "col1": arrowtopd(col1)}) + result = conn.execute("LOAD FROM df RETURN * ORDER BY index") idx = 0 while result.has_next(): assert idx < len(index) @@ -573,7 +570,8 @@ def test_pyarrow_union_dense(conn_db_readonly : ConnDB) -> None: assert idx == len(index) -def test_pyarrow_map(conn_db_readonly : ConnDB) -> None: + +def test_pyarrow_map(conn_db_readonly: ConnDB) -> None: conn, db = conn_db_readonly random.seed(100) datalength = 4096 diff --git a/tools/python_api/test/test_issue.py b/tools/python_api/test/test_issue.py new file mode 100644 index 0000000000..60aa1283d6 --- /dev/null +++ b/tools/python_api/test/test_issue.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from tools.python_api.test.type_aliases import ConnDB + +# required by python-lint + + +def test_param_empty(conn_db_readwrite: ConnDB) -> None: + conn, db = conn_db_readwrite + lst = [[]] + conn.execute("CREATE NODE TABLE tab(id SERIAL, lst INT64[][], PRIMARY KEY(id))") + result = conn.execute("CREATE (t:tab {lst: $1}) RETURN t.*", {"1": lst}) + assert result.has_next() + assert result.get_next() == [0, lst] + assert not result.has_next() + result.close() + + +def test_issue_2874(conn_db_readwrite: ConnDB) -> None: + conn, db = conn_db_readwrite + result = conn.execute("UNWIND $idList as tid MATCH (t:person {ID: tid}) RETURN t.fName;", {"idList": [1, 2, 3]}) + assert result.has_next() + assert result.get_next() == ["Bob"] + assert result.has_next() + assert result.get_next() == ["Carol"] + assert not result.has_next() + result.close() + + +def test_issue_2906(conn_db_readwrite: ConnDB) -> None: + conn, db = conn_db_readwrite + result = conn.execute("MATCH (a:person) WHERE $1 > a.ID AND $1 < a.age / 5 RETURN a.fName;", {"1": 6}) + assert result.has_next() + assert result.get_next() == ["Alice"] + assert result.has_next() + assert result.get_next() == ["Carol"] + assert not result.has_next() + result.close() + + +def test_issue_3135(conn_db_readwrite: ConnDB) -> None: + conn, db = conn_db_readwrite + conn.execute("CREATE NODE TABLE t1(id SERIAL, number INT32, PRIMARY KEY(id));") + conn.execute("CREATE (:t1 {number: $1})", {"1": 2}) + result = conn.execute("MATCH (n:t1) RETURN n.number;") + assert result.has_next() + assert result.get_next() == [2] + assert not result.has_next() + result.close() + + +def test_empty_list2(conn_db_readwrite: ConnDB) -> None: + conn, db = conn_db_readwrite + conn.execute( + """ + CREATE NODE TABLE SnapArtifactScan ( + artifact_name STRING, + scan_columns STRING[], + scan_filter STRING, + scan_limit INT64, + scan_id STRING, + PRIMARY KEY(scan_id) + ) + """ + ) + result = conn.execute( + """ + MERGE (n:SnapArtifactScan { scan_id: $scan_id }) + SET n.artifact_name = $artifact_name, n.scan_columns = $scan_columns + RETURN n.scan_id + """, + { + "artifact_name": "taxi_zones", + "scan_columns": [], + "scan_id": "896de6b9c7b69fa2598def49e8c61de07949be374d229a82899c9c75994fad20", + }, + ) + assert result.has_next() + assert result.get_next() == ["896de6b9c7b69fa2598def49e8c61de07949be374d229a82899c9c75994fad20"] + assert not result.has_next() + result.close() + + +# TODO(Maxwell): check if we should change getCastCost() for the following test +# def test_issue_3248(conn_db_readwrite: ConnDB) -> None: +# conn, db = conn_db_readwrite +# # Define schema +# conn.execute("CREATE NODE TABLE Item(id UINT64, item STRING, price DOUBLE, vector DOUBLE[2], PRIMARY KEY (id))") +# +# # Add data +# conn.execute("MERGE (a:Item {id: 1, item: 'apple', price: 2.0, vector: cast([3.1, 4.1], 'DOUBLE[2]')})") +# conn.execute("MERGE (b:Item {id: 2, item: 'banana', price: 1.0, vector: cast([5.9, 26.5], 'DOUBLE[2]')})") +# +# # Run similarity search +# result = conn.execute("MATCH (a:Item) RETURN a.item, a.price, +# array_cosine_similarity(a.vector, [6.0, 25.0]) AS sim ORDER BY sim DESC") +# assert result.has_next() +# assert result.get_next() == [2] +# assert result.has_next() +# assert result.get_next() == [2] +# assert not result.has_next() +# result.close()