diff --git a/src/common/types/types.cpp b/src/common/types/types.cpp index fa5d964f4b..7ef73cb2a2 100644 --- a/src/common/types/types.cpp +++ b/src/common/types/types.cpp @@ -98,8 +98,13 @@ uint32_t PhysicalTypeUtils::getFixedTypeSize(PhysicalTypeID physicalType) { } } -bool ListTypeInfo::operator==(const ListTypeInfo& other) const { - return *childType == *other.childType; +bool ListTypeInfo::operator==(const ExtraTypeInfo& other) const { + const ListTypeInfo* otherListTypeInfo = + ku_dynamic_cast(&other); + if (otherListTypeInfo) { + return *childType == *otherListTypeInfo->childType; + } + return false; } std::unique_ptr ListTypeInfo::copy() const { @@ -114,8 +119,14 @@ void ListTypeInfo::serializeInternal(Serializer& serializer) const { childType->serialize(serializer); } -bool ArrayTypeInfo::operator==(const ArrayTypeInfo& other) const { - return *childType == *other.childType && numElements == other.numElements; +bool ArrayTypeInfo::operator==(const ExtraTypeInfo& other) const { + const ArrayTypeInfo* otherArrayTypeInfo = + ku_dynamic_cast(&other); + if (otherArrayTypeInfo) { + return *childType == *otherArrayTypeInfo->childType && + numElements == otherArrayTypeInfo->numElements; + } + return false; } std::unique_ptr ArrayTypeInfo::deserialize(Deserializer& deserializer) { @@ -225,16 +236,21 @@ std::vector StructTypeInfo::getStructFields() const { return structFields; } -bool StructTypeInfo::operator==(const StructTypeInfo& other) const { - if (fields.size() != other.fields.size()) { - return false; - } - for (auto i = 0u; i < fields.size(); ++i) { - if (fields[i] != other.fields[i]) { +bool StructTypeInfo::operator==(const ExtraTypeInfo& other) const { + const StructTypeInfo* otherStructTypeInfo = + ku_dynamic_cast(&other); + if (otherStructTypeInfo) { + if (fields.size() != otherStructTypeInfo->fields.size()) { return false; } + for (auto i = 0u; i < fields.size(); ++i) { + if (fields[i] != otherStructTypeInfo->fields[i]) { + return false; + } + } + return true; } - return true; + return false; } std::unique_ptr StructTypeInfo::deserialize(Deserializer& deserializer) { @@ -288,19 +304,10 @@ bool LogicalType::operator==(const LogicalType& other) const { if (typeID != other.typeID) { return false; } - switch (other.getPhysicalType()) { - case PhysicalTypeID::LIST: - return *ku_dynamic_cast(extraTypeInfo.get()) == - *ku_dynamic_cast(other.extraTypeInfo.get()); - case PhysicalTypeID::ARRAY: - return *ku_dynamic_cast(extraTypeInfo.get()) == - *ku_dynamic_cast(other.extraTypeInfo.get()); - case PhysicalTypeID::STRUCT: - return *ku_dynamic_cast(extraTypeInfo.get()) == - *ku_dynamic_cast(other.extraTypeInfo.get()); - default: - return true; + if (extraTypeInfo) { + return *extraTypeInfo == *other.extraTypeInfo; } + return true; } bool LogicalType::operator!=(const LogicalType& other) const { diff --git a/src/include/common/types/types.h b/src/include/common/types/types.h index 84b1ca0ec3..017f35f576 100644 --- a/src/include/common/types/types.h +++ b/src/include/common/types/types.h @@ -175,6 +175,8 @@ class ExtraTypeInfo { inline void serialize(Serializer& serializer) const { serializeInternal(serializer); } + virtual bool operator==(const ExtraTypeInfo& other) const = 0; + virtual std::unique_ptr copy() const = 0; protected: @@ -187,7 +189,7 @@ class ListTypeInfo : public ExtraTypeInfo { explicit ListTypeInfo(std::unique_ptr childType) : childType{std::move(childType)} {} inline LogicalType* getChildType() const { return childType.get(); } - bool operator==(const ListTypeInfo& other) const; + bool operator==(const ExtraTypeInfo& other) const override; std::unique_ptr copy() const override; static std::unique_ptr deserialize(Deserializer& deserializer); @@ -205,7 +207,7 @@ class ArrayTypeInfo : public ListTypeInfo { explicit ArrayTypeInfo(std::unique_ptr childType, uint64_t numElements) : ListTypeInfo{std::move(childType)}, numElements{numElements} {} inline uint64_t getNumElements() const { return numElements; } - bool operator==(const ArrayTypeInfo& other) const; + bool operator==(const ExtraTypeInfo& other) const override; static std::unique_ptr deserialize(Deserializer& deserializer); std::unique_ptr copy() const override; @@ -254,7 +256,7 @@ class StructTypeInfo : public ExtraTypeInfo { std::vector getChildrenTypes() const; std::vector getChildrenNames() const; std::vector getStructFields() const; - bool operator==(const kuzu::common::StructTypeInfo& other) const; + bool operator==(const ExtraTypeInfo& other) const override; static std::unique_ptr deserialize(Deserializer& deserializer); std::unique_ptr copy() const override; diff --git a/src/include/common/vector/value_vector.h b/src/include/common/vector/value_vector.h index a3c5c71ae6..07a451ce9e 100644 --- a/src/include/common/vector/value_vector.h +++ b/src/include/common/vector/value_vector.h @@ -146,6 +146,7 @@ struct KUZU_API BlobVector { } }; +// Currently, ListVector is used for both VAR_LIST and ARRAY physical type class KUZU_API ListVector { public: static void setDataVector(const ValueVector* vector, std::shared_ptr dataVector) {