diff --git a/src/binder/bind/read/bind_in_query_call.cpp b/src/binder/bind/read/bind_in_query_call.cpp index fd6a1572517..9fc987a951a 100644 --- a/src/binder/bind/read/bind_in_query_call.cpp +++ b/src/binder/bind/read/bind_in_query_call.cpp @@ -3,6 +3,7 @@ #include "binder/expression/literal_expression.h" #include "binder/query/reading_clause/bound_in_query_call.h" #include "catalog/catalog.h" +#include "common/exception/binder.h" #include "function/built_in_function_utils.h" #include "parser/expression/parsed_function_expression.h" #include "parser/query/reading_clause/in_query_call_clause.h" @@ -10,6 +11,7 @@ using namespace kuzu::common; using namespace kuzu::parser; using namespace kuzu::function; +using namespace kuzu::catalog; namespace kuzu { namespace binder { @@ -18,6 +20,7 @@ std::unique_ptr Binder::bindInQueryCall(const ReadingClause& auto& call = readingClause.constCast(); auto expr = call.getFunctionExpression(); auto functionExpr = expr->constPtrCast(); + auto functionName = functionExpr->getFunctionName(); expression_vector params; for (auto i = 0u; i < functionExpr->getNumChildren(); i++) { auto child = functionExpr->getChild(i); @@ -32,10 +35,14 @@ std::unique_ptr Binder::bindInQueryCall(const ReadingClause& inputTypes.push_back(literalExpr->getDataType()); inputValues.push_back(*literalExpr->getValue()); } - auto functions = clientContext->getCatalog()->getFunctions(clientContext->getTx()); + auto catalogSet = clientContext->getCatalog()->getFunctions(clientContext->getTx()); + auto functionEntry = BuiltInFunctionsUtils::getFunctionCatalogEntry(functionName, catalogSet); + if (functionEntry->getType() != CatalogEntryType::TABLE_FUNCTION_ENTRY) { + throw BinderException(stringFormat("{} is not a table function.", functionName)); + } auto func = BuiltInFunctionsUtils::matchFunction(functionExpr->getFunctionName(), inputTypes, - functions); - tableFunction = *ku_dynamic_cast(func); + functionEntry); + tableFunction = *func->constPtrCast(); auto bindInput = function::TableFuncBindInput(); bindInput.inputs = std::move(inputValues); auto bindData = tableFunction.bindFunc(clientContext, &bindInput); diff --git a/src/binder/bind_expression/bind_function_expression.cpp b/src/binder/bind_expression/bind_function_expression.cpp index 26c124377e3..972ff0eda4c 100644 --- a/src/binder/bind_expression/bind_function_expression.cpp +++ b/src/binder/bind_expression/bind_function_expression.cpp @@ -27,9 +27,8 @@ namespace kuzu { namespace binder { std::shared_ptr ExpressionBinder::bindFunctionExpression(const ParsedExpression& expr) { - auto& funcExpr = - ku_dynamic_cast(expr); - auto functionName = funcExpr.getNormalizedFunctionName(); + auto funcExpr = expr.constPtrCast(); + auto functionName = funcExpr->getNormalizedFunctionName(); auto result = rewriteFunctionExpression(expr, functionName); if (result != nullptr) { return result; @@ -41,11 +40,13 @@ std::shared_ptr ExpressionBinder::bindFunctionExpression(const Parse case CatalogEntryType::REWRITE_FUNCTION_ENTRY: return bindRewriteFunctionExpression(expr); case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY: - return bindAggregateFunctionExpression(expr, functionName, funcExpr.getIsDistinct()); + return bindAggregateFunctionExpression(expr, functionName, funcExpr->getIsDistinct()); case CatalogEntryType::SCALAR_MACRO_ENTRY: return bindMacroExpression(expr, functionName); default: - KU_UNREACHABLE; + throw BinderException( + stringFormat("{} is a {}. Scalar function, aggregate function or macro was expected. ", + functionName, CatalogEntryTypeUtils::toString(entry->getType()))); } } diff --git a/src/binder/binder.cpp b/src/binder/binder.cpp index d416829746e..cadf737cc25 100644 --- a/src/binder/binder.cpp +++ b/src/binder/binder.cpp @@ -118,7 +118,7 @@ std::shared_ptr Binder::createVariable(const std::string& name, } std::unique_ptr Binder::bindDataType(const std::string& dataType) { - auto boundType = LogicalTypeUtils::dataTypeFromString(dataType); + auto boundType = LogicalType::fromString(dataType); if (boundType.getLogicalTypeID() == LogicalTypeID::ARRAY) { auto numElementsInArray = ArrayType::getNumElements(&boundType); if (numElementsInArray == 0) { diff --git a/src/catalog/catalog_entry/CMakeLists.txt b/src/catalog/catalog_entry/CMakeLists.txt index 648fc3041b6..0b2f9dbde46 100644 --- a/src/catalog/catalog_entry/CMakeLists.txt +++ b/src/catalog/catalog_entry/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(kuzu_catalog_entry OBJECT catalog_entry.cpp + catalog_entry_type.cpp function_catalog_entry.cpp table_catalog_entry.cpp node_table_catalog_entry.cpp diff --git a/src/catalog/catalog_entry/catalog_entry_type.cpp b/src/catalog/catalog_entry/catalog_entry_type.cpp new file mode 100644 index 00000000000..6420ad96585 --- /dev/null +++ b/src/catalog/catalog_entry/catalog_entry_type.cpp @@ -0,0 +1,36 @@ +#include "catalog/catalog_entry/catalog_entry_type.h" + +#include "common/assert.h" + +namespace kuzu { +namespace catalog { + +std::string CatalogEntryTypeUtils::toString(CatalogEntryType type) { + switch (type) { + case CatalogEntryType::NODE_TABLE_ENTRY: + return "NODE_TABLE_ENTRY"; + case CatalogEntryType::REL_TABLE_ENTRY: + return "REL_TABLE_ENTRY"; + case CatalogEntryType::REL_GROUP_ENTRY: + return "REL_GROUP_ENTRY"; + case CatalogEntryType::RDF_GRAPH_ENTRY: + return "RDF_GRAPH_ENTRY"; + case CatalogEntryType::SCALAR_MACRO_ENTRY: + return "SCALAR_MACRO_ENTRY"; + case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY: + return "AGGREGATE_FUNCTION_ENTRY"; + case CatalogEntryType::SCALAR_FUNCTION_ENTRY: + return "SCALAR_FUNCTION_ENTRY"; + case CatalogEntryType::REWRITE_FUNCTION_ENTRY: + return "REWRITE_FUNCTION_ENTRY"; + case CatalogEntryType::TABLE_FUNCTION_ENTRY: + return "TABLE_FUNCTION_ENTRY"; + case CatalogEntryType::FOREIGN_TABLE_ENTRY: + return "FOREIGN_TABLE_ENTRY"; + default: + KU_UNREACHABLE; + } +} + +} // namespace catalog +} // namespace kuzu diff --git a/src/common/types/types.cpp b/src/common/types/types.cpp index 69244ede627..26e923142a2 100644 --- a/src/common/types/types.cpp +++ b/src/common/types/types.cpp @@ -389,6 +389,36 @@ std::string LogicalType::toString() const { } } +static LogicalTypeID strToLogicalTypeID(const std::string& trimmedStr); +static std::vector parseStructFields(const std::string& structTypeStr); +static std::unique_ptr parseListType(const std::string& trimmedStr); +static std::unique_ptr parseArrayType(const std::string& trimmedStr); +static std::vector parseStructTypeInfo(const std::string& structTypeStr); +static std::unique_ptr parseStructType(const std::string& trimmedStr); +static std::unique_ptr parseMapType(const std::string& trimmedStr); +static std::unique_ptr parseUnionType(const std::string& trimmedStr); + +LogicalType LogicalType::fromString(const std::string& str) { + LogicalType dataType; + auto trimmedStr = StringUtils::ltrim(StringUtils::rtrim(str)); + auto upperDataTypeString = StringUtils::getUpper(trimmedStr); + if (upperDataTypeString.ends_with("[]")) { + dataType = *parseListType(trimmedStr); + } else if (upperDataTypeString.ends_with("]")) { + dataType = *parseArrayType(trimmedStr); + } else if (upperDataTypeString.starts_with("STRUCT")) { + dataType = *parseStructType(trimmedStr); + } else if (upperDataTypeString.starts_with("MAP")) { + dataType = *parseMapType(trimmedStr); + } else if (upperDataTypeString.starts_with("UNION")) { + dataType = *parseUnionType(trimmedStr); + } else { + dataType.typeID = strToLogicalTypeID(upperDataTypeString); + } + dataType.physicalType = LogicalType::getPhysicalType(dataType.typeID); + return dataType; +} + void LogicalType::serialize(Serializer& serializer) const { serializer.serializeValue(typeID); serializer.serializeValue(physicalType); @@ -544,79 +574,58 @@ PhysicalTypeID LogicalType::getPhysicalType(LogicalTypeID typeID) { } } -LogicalType LogicalTypeUtils::dataTypeFromString(const std::string& dataTypeString) { - LogicalType dataType; - auto trimmedStr = StringUtils::ltrim(StringUtils::rtrim(dataTypeString)); - auto upperDataTypeString = StringUtils::getUpper(trimmedStr); - if (upperDataTypeString.ends_with("[]")) { - dataType = *parseListType(trimmedStr); - } else if (upperDataTypeString.ends_with("]")) { - dataType = *parseArrayType(trimmedStr); - } else if (upperDataTypeString.starts_with("STRUCT")) { - dataType = *parseStructType(trimmedStr); - } else if (upperDataTypeString.starts_with("MAP")) { - dataType = *parseMapType(trimmedStr); - } else if (upperDataTypeString.starts_with("UNION")) { - dataType = *parseUnionType(trimmedStr); - } else { - dataType.typeID = dataTypeIDFromString(upperDataTypeString); - } - dataType.physicalType = LogicalType::getPhysicalType(dataType.typeID); - return dataType; -} - -LogicalTypeID LogicalTypeUtils::dataTypeIDFromString(const std::string& dataTypeIDString) { - auto upperDataTypeIDString = StringUtils::getUpper(dataTypeIDString); - if ("INTERNAL_ID" == upperDataTypeIDString) { +LogicalTypeID strToLogicalTypeID(const std::string& str) { + auto upperStr = StringUtils::getUpper(str); + if ("INTERNAL_ID" == upperStr) { return LogicalTypeID::INTERNAL_ID; - } else if ("INT64" == upperDataTypeIDString) { + } else if ("INT64" == upperStr) { return LogicalTypeID::INT64; - } else if ("INT32" == upperDataTypeIDString || "INT" == upperDataTypeIDString) { + } else if ("INT32" == upperStr || "INT" == upperStr) { return LogicalTypeID::INT32; - } else if ("INT16" == upperDataTypeIDString) { + } else if ("INT16" == upperStr) { return LogicalTypeID::INT16; - } else if ("INT8" == upperDataTypeIDString) { + } else if ("INT8" == upperStr) { return LogicalTypeID::INT8; - } else if ("UINT64" == upperDataTypeIDString) { + } else if ("UINT64" == upperStr) { return LogicalTypeID::UINT64; - } else if ("UINT32" == upperDataTypeIDString) { + } else if ("UINT32" == upperStr) { return LogicalTypeID::UINT32; - } else if ("UINT16" == upperDataTypeIDString) { + } else if ("UINT16" == upperStr) { return LogicalTypeID::UINT16; - } else if ("UINT8" == upperDataTypeIDString) { + } else if ("UINT8" == upperStr) { return LogicalTypeID::UINT8; - } else if ("INT128" == upperDataTypeIDString) { + } else if ("INT128" == upperStr) { return LogicalTypeID::INT128; - } else if ("DOUBLE" == upperDataTypeIDString) { + } else if ("DOUBLE" == upperStr || "FLOAT8" == upperStr) { return LogicalTypeID::DOUBLE; - } else if ("FLOAT" == upperDataTypeIDString) { + } else if ("FLOAT" == upperStr || "FLOAT4" == upperStr || "REAL" == upperStr) { return LogicalTypeID::FLOAT; - } else if ("BOOLEAN" == upperDataTypeIDString || "BOOL" == upperDataTypeIDString) { + } else if ("BOOLEAN" == upperStr || "BOOL" == upperStr) { return LogicalTypeID::BOOL; - } else if ("BYTEA" == upperDataTypeIDString || "BLOB" == upperDataTypeIDString) { + } else if ("BYTEA" == upperStr || "BLOB" == upperStr) { return LogicalTypeID::BLOB; - } else if ("UUID" == upperDataTypeIDString) { + } else if ("UUID" == upperStr) { return LogicalTypeID::UUID; - } else if ("STRING" == upperDataTypeIDString) { + } else if ("STRING" == upperStr) { return LogicalTypeID::STRING; - } else if ("DATE" == upperDataTypeIDString) { + } else if ("DATE" == upperStr) { return LogicalTypeID::DATE; - } else if ("TIMESTAMP" == upperDataTypeIDString) { + } else if ("TIMESTAMP" == upperStr) { return LogicalTypeID::TIMESTAMP; - } else if ("TIMESTAMP_NS" == upperDataTypeIDString) { + } else if ("TIMESTAMP_NS" == upperStr) { return LogicalTypeID::TIMESTAMP_NS; - } else if ("TIMESTAMP_MS" == upperDataTypeIDString) { + } else if ("TIMESTAMP_MS" == upperStr) { return LogicalTypeID::TIMESTAMP_MS; - } else if ("TIMESTAMP_SEC" == upperDataTypeIDString || "TIMESTAMP_S" == upperDataTypeIDString) { + } else if ("TIMESTAMP_SEC" == upperStr || "TIMESTAMP_S" == upperStr) { return LogicalTypeID::TIMESTAMP_SEC; - } else if ("TIMESTAMP_TZ" == upperDataTypeIDString) { + } else if ("TIMESTAMP_TZ" == upperStr) { return LogicalTypeID::TIMESTAMP_TZ; - } else if ("INTERVAL" == upperDataTypeIDString) { + } else if ("INTERVAL" == upperStr || "DURATION" == upperStr) { return LogicalTypeID::INTERVAL; - } else if ("SERIAL" == upperDataTypeIDString) { + } else if ("SERIAL" == upperStr) { return LogicalTypeID::SERIAL; } else { - throw NotImplementedException("Cannot parse dataTypeID: " + dataTypeIDString); + throw NotImplementedException("Cannot parse dataTypeID: " + str); } } @@ -822,7 +831,7 @@ std::vector LogicalTypeUtils::getAllValidLogicTypes() { LogicalTypeID::RDF_VARIANT}; } -std::vector LogicalTypeUtils::parseStructFields(const std::string& structTypeStr) { +std::vector parseStructFields(const std::string& structTypeStr) { std::vector structFieldsStr; auto startPos = 0u; auto curPos = 0u; @@ -853,22 +862,22 @@ std::vector LogicalTypeUtils::parseStructFields(const std::string& return structFieldsStr; } -std::unique_ptr LogicalTypeUtils::parseListType(const std::string& trimmedStr) { - return LogicalType::LIST(dataTypeFromString(trimmedStr.substr(0, trimmedStr.size() - 2))); +std::unique_ptr parseListType(const std::string& trimmedStr) { + return LogicalType::LIST(LogicalType::fromString(trimmedStr.substr(0, trimmedStr.size() - 2))); } -std::unique_ptr LogicalTypeUtils::parseArrayType(const std::string& trimmedStr) { +std::unique_ptr parseArrayType(const std::string& trimmedStr) { auto leftBracketPos = trimmedStr.find_last_of('['); auto rightBracketPos = trimmedStr.find_last_of(']'); - auto childType = - std::make_unique(dataTypeFromString(trimmedStr.substr(0, leftBracketPos))); + auto childType = std::make_unique( + LogicalType::fromString(trimmedStr.substr(0, leftBracketPos))); auto numElements = std::strtoll( trimmedStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1).c_str(), nullptr, 0 /* base */); return LogicalType::ARRAY(std::move(childType), numElements); } -std::vector LogicalTypeUtils::parseStructTypeInfo(const std::string& structTypeStr) { +std::vector parseStructTypeInfo(const std::string& structTypeStr) { auto leftBracketPos = structTypeStr.find('('); auto rightBracketPos = structTypeStr.find_last_of(')'); if (leftBracketPos == std::string::npos || rightBracketPos == std::string::npos) { @@ -884,16 +893,16 @@ std::vector LogicalTypeUtils::parseStructTypeInfo(const std::string auto fieldName = structFieldStr.substr(0, pos); auto fieldTypeString = structFieldStr.substr(pos + 1); structFields.emplace_back(fieldName, - std::make_unique(dataTypeFromString(fieldTypeString))); + std::make_unique(LogicalType::fromString(fieldTypeString))); } return structFields; } -std::unique_ptr LogicalTypeUtils::parseStructType(const std::string& trimmedStr) { +std::unique_ptr parseStructType(const std::string& trimmedStr) { return LogicalType::STRUCT(parseStructTypeInfo(trimmedStr)); } -std::unique_ptr LogicalTypeUtils::parseMapType(const std::string& trimmedStr) { +std::unique_ptr parseMapType(const std::string& trimmedStr) { auto leftBracketPos = trimmedStr.find('('); auto rightBracketPos = trimmedStr.find_last_of(')'); if (leftBracketPos == std::string::npos || rightBracketPos == std::string::npos) { @@ -901,11 +910,11 @@ std::unique_ptr LogicalTypeUtils::parseMapType(const std::string& t } auto mapTypeStr = trimmedStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1); auto keyValueTypes = StringUtils::splitComma(mapTypeStr); - return LogicalType::MAP(dataTypeFromString(keyValueTypes[0]), - dataTypeFromString(keyValueTypes[1])); + return LogicalType::MAP(LogicalType::fromString(keyValueTypes[0]), + LogicalType::fromString(keyValueTypes[1])); } -std::unique_ptr LogicalTypeUtils::parseUnionType(const std::string& trimmedStr) { +std::unique_ptr parseUnionType(const std::string& trimmedStr) { auto unionFields = parseStructTypeInfo(trimmedStr); auto unionTagField = StructField(UnionType::TAG_FIELD_NAME, std::make_unique(UnionType::TAG_FIELD_TYPE)); diff --git a/src/function/built_in_function_utils.cpp b/src/function/built_in_function_utils.cpp index a43ef52d365..3c860b3d2de 100644 --- a/src/function/built_in_function_utils.cpp +++ b/src/function/built_in_function_utils.cpp @@ -17,11 +17,11 @@ namespace kuzu { namespace function { static void validateNonEmptyCandidateFunctions(std::vector& candidateFunctions, - const std::string& name, const std::vector& inputTypes, bool isDistinct, - function::function_set& set); + const std::string& name, const std::vector& inputTypes, bool isDistinct, + const function::function_set& set); static void validateNonEmptyCandidateFunctions(std::vector& candidateFunctions, - const std::string& name, const std::vector& inputTypes, - function::function_set& set); + const std::string& name, const std::vector& inputTypes, + const function::function_set& set); void BuiltInFunctionsUtils::createFunctions(CatalogSet* catalogSet) { auto functions = FunctionCollection::getFunctions(); @@ -32,17 +32,28 @@ void BuiltInFunctionsUtils::createFunctions(CatalogSet* catalogSet) { } } +catalog::CatalogEntry* BuiltInFunctionsUtils::getFunctionCatalogEntry(const std::string& name, + CatalogSet* catalogSet) { + if (!catalogSet->containsEntry(name)) { + throw CatalogException(stringFormat("{} function does not exist.", name)); + } + return catalogSet->getEntry(name); +} + Function* BuiltInFunctionsUtils::matchFunction(const std::string& name, CatalogSet* catalogSet) { - return matchFunction(name, std::vector{}, catalogSet); + return matchFunction(name, std::vector{}, catalogSet); } Function* BuiltInFunctionsUtils::matchFunction(const std::string& name, - const std::vector& inputTypes, CatalogSet* catalogSet) { - if (!catalogSet->containsEntry(name)) { - throw CatalogException(stringFormat("{} function does not exist.", name)); - } - auto& functionSet = - reinterpret_cast(catalogSet->getEntry(name))->getFunctionSet(); + const std::vector& inputTypes, CatalogSet* catalogSet) { + auto entry = getFunctionCatalogEntry(name, catalogSet); + return matchFunction(name, inputTypes, entry); +} + +Function* BuiltInFunctionsUtils::matchFunction(const std::string& name, + const std::vector& inputTypes, const catalog::CatalogEntry* catalogEntry) { + auto functionEntry = catalogEntry->constPtrCast(); + auto& functionSet = functionEntry->getFunctionSet(); bool isOverload = functionSet.size() > 1; std::vector candidateFunctions; uint32_t minCost = UINT32_MAX; @@ -69,7 +80,7 @@ Function* BuiltInFunctionsUtils::matchFunction(const std::string& name, } AggregateFunction* BuiltInFunctionsUtils::matchAggregateFunction(const std::string& name, - const std::vector& inputTypes, bool isDistinct, CatalogSet* catalogSet) { + const std::vector& inputTypes, bool isDistinct, CatalogSet* catalogSet) { auto& functionSet = reinterpret_cast(catalogSet->getEntry(name))->getFunctionSet(); std::vector candidateFunctions; @@ -461,7 +472,7 @@ uint32_t BuiltInFunctionsUtils::matchVarLengthParameters(const std::vector& candidateFunctions, const std::string& name, const std::vector& inputTypes, - function::function_set& set) { + const function::function_set& set) { // special case for add func if (name == AddFunction::name) { auto targetType0 = candidateFunctions[0]->parameterTypeIDs[0]; @@ -500,7 +511,7 @@ static std::string getFunctionMatchFailureMsg(const std::string name, void validateNonEmptyCandidateFunctions(std::vector& candidateFunctions, const std::string& name, const std::vector& inputTypes, bool isDistinct, - function::function_set& set) { + const function::function_set& set) { if (candidateFunctions.empty()) { std::string supportedInputsString; for (auto& function : set) { @@ -517,7 +528,7 @@ void validateNonEmptyCandidateFunctions(std::vector& candida void validateNonEmptyCandidateFunctions(std::vector& candidateFunctions, const std::string& name, const std::vector& inputTypes, - function::function_set& set) { + const function::function_set& set) { if (candidateFunctions.empty()) { std::string supportedInputsString; for (auto& function : set) { diff --git a/src/include/catalog/catalog_entry/catalog_entry.h b/src/include/catalog/catalog_entry/catalog_entry.h index 65e4fedade2..bec303e19d5 100644 --- a/src/include/catalog/catalog_entry/catalog_entry.h +++ b/src/include/catalog/catalog_entry/catalog_entry.h @@ -33,6 +33,11 @@ class KUZU_API CatalogEntry { virtual std::unique_ptr copy() const = 0; virtual std::string toCypher(main::ClientContext* /*clientContext*/) const { KU_UNREACHABLE; } + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + private: CatalogEntryType type; std::string name; diff --git a/src/include/catalog/catalog_entry/catalog_entry_type.h b/src/include/catalog/catalog_entry/catalog_entry_type.h index d44db816aca..2d8dad4311f 100644 --- a/src/include/catalog/catalog_entry/catalog_entry_type.h +++ b/src/include/catalog/catalog_entry/catalog_entry_type.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace kuzu { namespace catalog { @@ -18,5 +19,9 @@ enum class CatalogEntryType : uint8_t { FOREIGN_TABLE_ENTRY = 9, }; +struct CatalogEntryTypeUtils { + static std::string toString(CatalogEntryType type); +}; + } // namespace catalog } // namespace kuzu diff --git a/src/include/catalog/catalog_entry/function_catalog_entry.h b/src/include/catalog/catalog_entry/function_catalog_entry.h index 969f0b077cb..2d4db8492ec 100644 --- a/src/include/catalog/catalog_entry/function_catalog_entry.h +++ b/src/include/catalog/catalog_entry/function_catalog_entry.h @@ -18,7 +18,7 @@ class FunctionCatalogEntry : public CatalogEntry { //===--------------------------------------------------------------------===// // getters & setters //===--------------------------------------------------------------------===// - function::function_set& getFunctionSet() { return functionSet; } + const function::function_set& getFunctionSet() const { return functionSet; } //===--------------------------------------------------------------------===// // serialization & deserialization diff --git a/src/include/common/types/types.h b/src/include/common/types/types.h index d44a926338d..baa8e82ba34 100644 --- a/src/include/common/types/types.h +++ b/src/include/common/types/types.h @@ -270,7 +270,7 @@ class StructTypeInfo : public ExtraTypeInfo { }; class LogicalType { - friend class LogicalTypeUtils; + friend struct LogicalTypeUtils; friend struct StructType; friend struct ListType; friend struct ArrayType; @@ -284,22 +284,19 @@ class LogicalType { KUZU_API LogicalType(LogicalType&& other) = default; KUZU_API LogicalType& operator=(const LogicalType& other); - KUZU_API bool operator==(const LogicalType& other) const; - KUZU_API bool operator!=(const LogicalType& other) const; - KUZU_API LogicalType& operator=(LogicalType&& other) = default; KUZU_API std::string toString() const; + static LogicalType fromString(const std::string& str); - KUZU_API inline LogicalTypeID getLogicalTypeID() const { return typeID; } + KUZU_API LogicalTypeID getLogicalTypeID() const { return typeID; } - inline PhysicalTypeID getPhysicalType() const { return physicalType; } + PhysicalTypeID getPhysicalType() const { return physicalType; } static PhysicalTypeID getPhysicalType(LogicalTypeID logicalType); - inline bool hasExtraTypeInfo() const { return extraTypeInfo != nullptr; } - inline void setExtraTypeInfo(std::unique_ptr typeInfo) { + void setExtraTypeInfo(std::unique_ptr typeInfo) { extraTypeInfo = std::move(typeInfo); } @@ -442,8 +439,6 @@ class LogicalType { std::unique_ptr extraTypeInfo; }; -// TODO: Should remove `logical_types_t`. -using logical_types_t = std::vector>; using logical_type_vec_t = std::vector; struct ListType { @@ -576,12 +571,10 @@ struct PhysicalTypeUtils { static uint32_t getFixedTypeSize(PhysicalTypeID physicalType); }; -class LogicalTypeUtils { -public: +struct LogicalTypeUtils { KUZU_API static std::string toString(LogicalTypeID dataTypeID); KUZU_API static std::string toString(const std::vector& dataTypes); KUZU_API static std::string toString(const std::vector& dataTypeIDs); - KUZU_API static LogicalType dataTypeFromString(const std::string& dataTypeString); static uint32_t getRowLayoutSize(const LogicalType& logicalType); static bool isNumerical(const LogicalType& dataType); static bool isNested(const LogicalType& dataType); @@ -590,16 +583,6 @@ class LogicalTypeUtils { static std::vector getNumericalLogicalTypeIDs(); static std::vector getIntegerLogicalTypeIDs(); static std::vector getAllValidLogicTypes(); - -private: - static LogicalTypeID dataTypeIDFromString(const std::string& trimmedStr); - static std::vector parseStructFields(const std::string& structTypeStr); - static std::unique_ptr parseListType(const std::string& trimmedStr); - static std::unique_ptr parseArrayType(const std::string& trimmedStr); - static std::vector parseStructTypeInfo(const std::string& structTypeStr); - static std::unique_ptr parseStructType(const std::string& trimmedStr); - static std::unique_ptr parseMapType(const std::string& trimmedStr); - static std::unique_ptr parseUnionType(const std::string& trimmedStr); }; enum class FileVersionType : uint8_t { ORIGINAL = 0, WAL_VERSION = 1 }; diff --git a/src/include/function/built_in_function_utils.h b/src/include/function/built_in_function_utils.h index ae8d19d0a78..0f748d3fbee 100644 --- a/src/include/function/built_in_function_utils.h +++ b/src/include/function/built_in_function_utils.h @@ -6,6 +6,7 @@ namespace kuzu { namespace catalog { class CatalogSet; +class CatalogEntry; } // namespace catalog namespace function { @@ -14,11 +15,16 @@ class BuiltInFunctionsUtils { public: static void createFunctions(catalog::CatalogSet* catalogSet); + static catalog::CatalogEntry* getFunctionCatalogEntry(const std::string& name, + catalog::CatalogSet* catalogSet); static Function* matchFunction(const std::string& name, catalog::CatalogSet* catalogSet); // TODO(Ziyi): We should have a unified interface for matching table, aggregate and scalar // functions. static Function* matchFunction(const std::string& name, const std::vector& inputTypes, catalog::CatalogSet* catalogSet); + static Function* matchFunction(const std::string& name, + const std::vector& inputTypes, + const catalog::CatalogEntry* catalogEntry); static AggregateFunction* matchAggregateFunction(const std::string& name, const std::vector& inputTypes, bool isDistinct, @@ -78,7 +84,7 @@ class BuiltInFunctionsUtils { static void validateSpecialCases(std::vector& candidateFunctions, const std::string& name, const std::vector& inputTypes, - function::function_set& set); + const function::function_set& set); }; } // namespace function diff --git a/src/include/function/function.h b/src/include/function/function.h index ea3f028d0e3..23906c920c1 100644 --- a/src/include/function/function.h +++ b/src/include/function/function.h @@ -45,6 +45,11 @@ struct Function { FunctionType type; std::string name; std::vector parameterTypeIDs; + + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } }; struct BaseScalarFunction : public Function { diff --git a/src/processor/operator/persistent/reader/csv/driver.cpp b/src/processor/operator/persistent/reader/csv/driver.cpp index fbf03c29861..7d9d762a08d 100644 --- a/src/processor/operator/persistent/reader/csv/driver.cpp +++ b/src/processor/operator/persistent/reader/csv/driver.cpp @@ -94,7 +94,7 @@ void SniffCSVNameAndTypeDriver::addValue(uint64_t, common::column_id_t, std::str auto it = value.rfind(':'); if (it != std::string_view::npos) { try { - columnType = LogicalTypeUtils::dataTypeFromString(std::string(value.substr(it + 1))); + columnType = LogicalType::fromString(std::string(value.substr(it + 1))); columnName = std::string(value.substr(0, it)); } catch (const Exception&) { // NOLINT(bugprone-empty-catch): // This is how we check for a suitable diff --git a/test/test_files/common/types/type_alias.test b/test/test_files/common/types/type_alias.test new file mode 100644 index 00000000000..1780a4a6933 --- /dev/null +++ b/test/test_files/common/types/type_alias.test @@ -0,0 +1,23 @@ +-GROUP Types +-DATASET CSV EMPTY + +-- + +-CASE TypeAlias + +-STATEMENT CREATE NODE TABLE durationTest (ID INT, a FLOAT8, b FLOAT4, c REAL, d BOOLEAN, e BYTEA, f TIMESTAMP_S, g DURATION, PRIMARY KEY (ID)); +---- ok +-STATEMENT CREATE (a:durationTest { + ID: 1, + a: cast(2.1, "INT"), + b: cast(1.1, "FLOAT8"), + c: cast(1.1, "FLOAT4"), + d: cast(true, "BOOLEAN"), + e: encode('ΓΌ'), + f: cast("2020-01-01T00:00:00Z", "TIMESTAMP_S"), + g: cast("1s", "DURATION") + }); +---- ok +-STATEMENT MATCH (a:durationTest) RETURN a.*; +---- 1 +1|2.000000|1.100000|1.100000|True|\xC3\xBC|2020-01-01 00:00:00|00:00:01 diff --git a/test/test_files/exceptions/catalog/catalog.test b/test/test_files/exceptions/catalog/catalog.test index 3a1db3e840d..8ed27ca8c74 100644 --- a/test/test_files/exceptions/catalog/catalog.test +++ b/test/test_files/exceptions/catalog/catalog.test @@ -15,3 +15,12 @@ Catalog exception: function DUMMY does not exist. -STATEMENT CREATE REL TABLE knows_post ( FROM person TO person, MANY_LOT) ---- error Binder exception: Cannot bind MANY_LOT as relationship multiplicity. + +-LOG InvalidFunctionCall +-STATEMENT RETURN db_version(); +---- error +Binder exception: DB_VERSION is a TABLE_FUNCTION_ENTRY. Scalar function, aggregate function or macro was expected. +-STATEMENT CALL lower('AAA') RETURN *; +---- error +Binder exception: lower is not a table function. +