diff --git a/src/binder/bind/bind_reading_clause.cpp b/src/binder/bind/bind_reading_clause.cpp index 8475501a3e..db2f77bceb 100644 --- a/src/binder/bind/bind_reading_clause.cpp +++ b/src/binder/bind/bind_reading_clause.cpp @@ -126,8 +126,8 @@ std::unique_ptr Binder::bindInQueryCall(const ReadingClause& for (auto& val : inputValues) { inputTypes.push_back(*val.getDataType()); } - auto func = catalog.getBuiltInFunctions(clientContext->getTx()) - ->matchFunction(functionExpr->getFunctionName(), inputTypes); + auto func = BuiltInFunctionsUtils::matchFunction( + functionExpr->getFunctionName(), inputTypes, catalog.getFunctions(clientContext->getTx())); auto tableFunc = ku_dynamic_cast(func); auto bindInput = std::make_unique(); bindInput->inputs = std::move(inputValues); @@ -157,9 +157,9 @@ std::unique_ptr Binder::bindLoadFrom(const ReadingClause& re auto objectExpr = expressionBinder.bindVariableExpression(objectName); auto literalExpr = ku_dynamic_cast(objectExpr.get()); - auto func = catalog.getBuiltInFunctions(clientContext->getTx()) - ->matchFunction(READ_PANDAS_FUNC_NAME, - std::vector{objectExpr->getDataType()}); + auto func = BuiltInFunctionsUtils::matchFunction(READ_PANDAS_FUNC_NAME, + std::vector{objectExpr->getDataType()}, + catalog.getFunctions(clientContext->getTx())); scanFunction = ku_dynamic_cast(func); bindInput = std::make_unique(); bindInput->inputs.push_back(*literalExpr->getValue()); diff --git a/src/binder/bind/copy/bind_copy_rdf_graph.cpp b/src/binder/bind/copy/bind_copy_rdf_graph.cpp index 74d2886a74..ee67c4d544 100644 --- a/src/binder/bind/copy/bind_copy_rdf_graph.cpp +++ b/src/binder/bind/copy/bind_copy_rdf_graph.cpp @@ -18,7 +18,7 @@ namespace binder { std::unique_ptr Binder::bindCopyRdfFrom(const parser::Statement&, std::unique_ptr config, RDFGraphCatalogEntry* rdfGraphEntry) { - auto functions = catalog.getBuiltInFunctions(clientContext->getTx()); + auto functions = catalog.getFunctions(clientContext->getTx()); auto offset = expressionBinder.createVariableExpression( *LogicalType::INT64(), InternalKeyword::ROW_OFFSET); auto r = expressionBinder.createVariableExpression(*LogicalType::STRING(), rdf::IRI); @@ -36,15 +36,16 @@ std::unique_ptr Binder::bindCopyRdfFrom(const parser::Statement& Function* func; // Bind file scan; auto inMemory = RdfReaderConfig::construct(config->options).inMemory; - func = functions->matchFunction(READ_RDF_ALL_TRIPLE_FUNC_NAME); + func = BuiltInFunctionsUtils::matchFunction(READ_RDF_ALL_TRIPLE_FUNC_NAME, functions); auto scanFunc = ku_dynamic_cast(func); auto bindData = scanFunc->bindFunc(clientContext, bindInput.get(), (Catalog*)&catalog, storageManager); auto scanInfo = std::make_unique( scanFunc, bindData->copy(), expression_vector{}, offset); // Bind copy resource. - func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_RESOURCE_FUNC_NAME) : - functions->matchFunction(READ_RDF_RESOURCE_FUNC_NAME); + func = inMemory ? + BuiltInFunctionsUtils::matchFunction(IN_MEM_READ_RDF_RESOURCE_FUNC_NAME, functions) : + BuiltInFunctionsUtils::matchFunction(READ_RDF_RESOURCE_FUNC_NAME, functions); auto rScanFunc = ku_dynamic_cast(func); auto rColumns = expression_vector{r}; auto rScanInfo = std::make_unique( @@ -53,8 +54,9 @@ std::unique_ptr Binder::bindCopyRdfFrom(const parser::Statement& auto rSchema = catalog.getTableCatalogEntry(clientContext->getTx(), rTableID); auto rCopyInfo = BoundCopyFromInfo(rSchema, std::move(rScanInfo), false, nullptr); // Bind copy literal. - func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_LITERAL_FUNC_NAME) : - functions->matchFunction(READ_RDF_LITERAL_FUNC_NAME); + func = inMemory ? + BuiltInFunctionsUtils::matchFunction(IN_MEM_READ_RDF_LITERAL_FUNC_NAME, functions) : + BuiltInFunctionsUtils::matchFunction(READ_RDF_LITERAL_FUNC_NAME, functions); auto lScanFunc = ku_dynamic_cast(func); auto lColumns = expression_vector{l, lang}; auto lScanInfo = std::make_unique( @@ -63,8 +65,10 @@ std::unique_ptr Binder::bindCopyRdfFrom(const parser::Statement& auto lSchema = catalog.getTableCatalogEntry(clientContext->getTx(), lTableID); auto lCopyInfo = BoundCopyFromInfo(lSchema, std::move(lScanInfo), true, nullptr); // Bind copy resource triples - func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_RESOURCE_TRIPLE_FUNC_NAME) : - functions->matchFunction(READ_RDF_RESOURCE_TRIPLE_FUNC_NAME); + func = inMemory ? + BuiltInFunctionsUtils::matchFunction( + IN_MEM_READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, functions) : + BuiltInFunctionsUtils::matchFunction(READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, functions); auto rrrScanFunc = ku_dynamic_cast(func); auto rrrColumns = expression_vector{s, p, o}; auto rrrScanInfo = @@ -83,8 +87,10 @@ std::unique_ptr Binder::bindCopyRdfFrom(const parser::Statement& auto rrrCopyInfo = BoundCopyFromInfo(rrrSchema, std::move(rrrScanInfo), false, std::move(rrrExtraInfo)); // Bind copy literal triples - func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_LITERAL_TRIPLE_FUNC_NAME) : - functions->matchFunction(READ_RDF_LITERAL_TRIPLE_FUNC_NAME); + func = inMemory ? + BuiltInFunctionsUtils::matchFunction( + IN_MEM_READ_RDF_LITERAL_TRIPLE_FUNC_NAME, functions) : + BuiltInFunctionsUtils::matchFunction(READ_RDF_LITERAL_TRIPLE_FUNC_NAME, functions); auto rrlScanFunc = ku_dynamic_cast(func); auto rrlColumns = expression_vector{s, p, oOffset}; auto rrlScanInfo = diff --git a/src/binder/bind_expression/bind_comparison_expression.cpp b/src/binder/bind_expression/bind_comparison_expression.cpp index 03ce542079..b5f4366795 100644 --- a/src/binder/bind_expression/bind_comparison_expression.cpp +++ b/src/binder/bind_expression/bind_comparison_expression.cpp @@ -21,14 +21,15 @@ std::shared_ptr ExpressionBinder::bindComparisonExpression( std::shared_ptr ExpressionBinder::bindComparisonExpression( ExpressionType expressionType, const expression_vector& children) { - auto builtInFunctions = binder->catalog.getBuiltInFunctions(binder->clientContext->getTx()); + auto builtInFunctions = binder->catalog.getFunctions(binder->clientContext->getTx()); auto functionName = expressionTypeToString(expressionType); std::vector childrenTypes; for (auto& child : children) { childrenTypes.push_back(child->dataType); } auto function = ku_dynamic_cast( - builtInFunctions->matchFunction(functionName, childrenTypes)); + function::BuiltInFunctionsUtils::matchFunction( + functionName, childrenTypes, builtInFunctions)); expression_vector childrenAfterCast; for (auto i = 0u; i < children.size(); ++i) { childrenAfterCast.push_back( diff --git a/src/binder/bind_expression/bind_function_expression.cpp b/src/binder/bind_expression/bind_function_expression.cpp index 2aa5f0a8f5..850aa10dc2 100644 --- a/src/binder/bind_expression/bind_function_expression.cpp +++ b/src/binder/bind_expression/bind_function_expression.cpp @@ -53,13 +53,13 @@ std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( const expression_vector& children, const std::string& functionName) { - auto builtInFunctions = binder->catalog.getBuiltInFunctions(binder->clientContext->getTx()); std::vector childrenTypes; for (auto& child : children) { childrenTypes.push_back(child->dataType); } auto function = ku_dynamic_cast( - builtInFunctions->matchFunction(functionName, childrenTypes)); + function::BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes, + binder->catalog.getFunctions(binder->clientContext->getTx()))); expression_vector childrenAfterCast; std::unique_ptr bindData; if (functionName == CAST_FUNC_NAME) { @@ -98,7 +98,6 @@ std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( std::shared_ptr ExpressionBinder::bindAggregateFunctionExpression( const ParsedExpression& parsedExpression, const std::string& functionName, bool isDistinct) { - auto builtInFunctions = binder->catalog.getBuiltInFunctions(binder->clientContext->getTx()); std::vector childrenTypes; expression_vector children; for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) { @@ -111,8 +110,9 @@ std::shared_ptr ExpressionBinder::bindAggregateFunctionExpression( childrenTypes.push_back(child->dataType); children.push_back(std::move(child)); } - auto function = - builtInFunctions->matchAggregateFunction(functionName, childrenTypes, isDistinct)->clone(); + auto function = function::BuiltInFunctionsUtils::matchAggregateFunction(functionName, + childrenTypes, isDistinct, binder->catalog.getFunctions(binder->clientContext->getTx())) + ->clone(); if (function->paramRewriteFunc) { function->paramRewriteFunc(children); } diff --git a/src/binder/bind_expression/bind_subquery_expression.cpp b/src/binder/bind_expression/bind_subquery_expression.cpp index 8e0b4e5b46..ec8fd9eaee 100644 --- a/src/binder/bind_expression/bind_subquery_expression.cpp +++ b/src/binder/bind_expression/bind_subquery_expression.cpp @@ -3,7 +3,6 @@ #include "binder/expression/subquery_expression.h" #include "binder/expression_binder.h" #include "common/types/value/value.h" -#include "main/client_context.h" #include "parser/expression/parsed_subquery_expression.h" using namespace kuzu::parser; @@ -32,9 +31,9 @@ std::shared_ptr ExpressionBinder::bindSubqueryExpression( std::move(boundGraphPattern.queryGraphCollection), uniqueName, std::move(rawName)); boundSubqueryExpr->setWhereExpression(boundGraphPattern.where); // Bind projection - auto function = - binder->catalog.getBuiltInFunctions(binder->clientContext->getTx()) - ->matchAggregateFunction(COUNT_STAR_FUNC_NAME, std::vector{}, false); + auto function = BuiltInFunctionsUtils::matchAggregateFunction(COUNT_STAR_FUNC_NAME, + std::vector{}, false, + binder->catalog.getFunctions(binder->clientContext->getTx())); auto bindData = std::make_unique(std::make_unique(function->returnTypeID)); auto countStarExpr = std::make_shared(COUNT_STAR_FUNC_NAME, diff --git a/src/binder/binder.cpp b/src/binder/binder.cpp index 36de035d21..eef7536f9e 100644 --- a/src/binder/binder.cpp +++ b/src/binder/binder.cpp @@ -209,19 +209,21 @@ function::TableFunction* Binder::getScanFunction(FileType fileType, const Reader auto stringType = LogicalType(LogicalTypeID::STRING); std::vector inputTypes; inputTypes.push_back(stringType); - auto functions = catalog.getBuiltInFunctions(clientContext->getTx()); + auto functions = catalog.getFunctions(clientContext->getTx()); switch (fileType) { case FileType::PARQUET: { - func = functions->matchFunction(READ_PARQUET_FUNC_NAME, inputTypes); + func = function::BuiltInFunctionsUtils::matchFunction( + READ_PARQUET_FUNC_NAME, inputTypes, functions); } break; case FileType::NPY: { - func = functions->matchFunction(READ_NPY_FUNC_NAME, inputTypes); + func = function::BuiltInFunctionsUtils::matchFunction( + READ_NPY_FUNC_NAME, inputTypes, functions); } break; case FileType::CSV: { auto csvConfig = CSVReaderConfig::construct(config.options); - func = functions->matchFunction( + func = function::BuiltInFunctionsUtils::matchFunction( csvConfig.parallel ? READ_CSV_PARALLEL_FUNC_NAME : READ_CSV_SERIAL_FUNC_NAME, - inputTypes); + inputTypes, functions); } break; default: KU_UNREACHABLE; diff --git a/src/catalog/catalog.cpp b/src/catalog/catalog.cpp index f96f304001..720adba49a 100644 --- a/src/catalog/catalog.cpp +++ b/src/catalog/catalog.cpp @@ -100,6 +100,10 @@ std::vector Catalog::getTableSchemas( return result; } +CatalogSet* Catalog::getFunctions(Transaction* tx) const { + return getVersion(tx)->functions.get(); +} + void Catalog::prepareCommitOrRollback(TransactionAction action) { if (hasUpdates()) { wal->logCatalogRecord(); @@ -265,8 +269,10 @@ void Catalog::addScalarMacroFunction( std::vector Catalog::getMacroNames(transaction::Transaction* tx) const { std::vector macroNames; - for (auto& macro : getVersion(tx)->macros) { - macroNames.push_back(macro.first); + for (auto& [_, function] : getVersion(tx)->functions->getEntries()) { + if (function->getType() == CatalogEntryType::SCALAR_MACRO_ENTRY) { + macroNames.push_back(function->getName()); + } } return macroNames; } diff --git a/src/catalog/catalog_content.cpp b/src/catalog/catalog_content.cpp index 1b32dc6d33..00f3bd8ec6 100644 --- a/src/catalog/catalog_content.cpp +++ b/src/catalog/catalog_content.cpp @@ -7,6 +7,8 @@ #include "catalog/catalog_entry/rdf_graph_catalog_entry.h" #include "catalog/catalog_entry/rel_group_catalog_entry.h" #include "catalog/catalog_entry/rel_table_catalog_entry.h" +#include "catalog/catalog_entry/scalar_function_catalog_entry.h" +#include "catalog/catalog_entry/scalar_macro_catalog_entry.h" #include "common/cast.h" #include "common/exception/catalog.h" #include "common/exception/runtime.h" @@ -15,7 +17,6 @@ #include "common/serializer/deserializer.h" #include "common/serializer/serializer.h" #include "common/string_format.h" -#include "common/string_utils.h" #include "storage/storage_info.h" #include "storage/storage_utils.h" @@ -27,8 +28,9 @@ namespace kuzu { namespace catalog { CatalogContent::CatalogContent(common::VirtualFileSystem* vfs) : nextTableID{0}, vfs{vfs} { - registerBuiltInFunctions(); tables = std::make_unique(); + functions = std::make_unique(); + registerBuiltInFunctions(); } CatalogContent::CatalogContent(const std::string& directory, VirtualFileSystem* vfs) : vfs{vfs} { @@ -180,7 +182,7 @@ void CatalogContent::saveToFile(const std::string& directory, FileVersionType db serializer.serializeValue(StorageVersionInfo::getStorageVersion()); tables->serialize(serializer); serializer.serializeValue(nextTableID); - serializer.serializeUnorderedMap(macros); + functions->serialize(serializer); } void CatalogContent::readFromFile(const std::string& directory, FileVersionType dbFileType) { @@ -197,19 +199,20 @@ void CatalogContent::readFromFile(const std::string& directory, FileVersionType ku_dynamic_cast(entry.get())->getTableID(); } deserializer.deserializeValue(nextTableID); - deserializer.deserializeUnorderedMap(macros); + functions = CatalogSet::deserialize(deserializer); } ExpressionType CatalogContent::getFunctionType(const std::string& name) const { - auto normalizedName = StringUtils::getUpper(name); - if (macros.contains(normalizedName)) { - return ExpressionType::MACRO; + if (!functions->containsEntry(name)) { + throw CatalogException{common::stringFormat("function {} does not exist.", name)}; } - auto functionType = builtInFunctions->getFunctionType(name); - switch (functionType) { - case function::FunctionType::SCALAR: + auto functionEntry = functions->getEntry(name); + switch (functionEntry->getType()) { + case CatalogEntryType::SCALAR_MACRO_ENTRY: + return ExpressionType::MACRO; + case CatalogEntryType::SCALAR_FUNCTION_ENTRY: return ExpressionType::FUNCTION; - case function::FunctionType::AGGREGATE: + case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY: return ExpressionType::AGGREGATE_FUNCTION; default: KU_UNREACHABLE; @@ -217,27 +220,33 @@ ExpressionType CatalogContent::getFunctionType(const std::string& name) const { } void CatalogContent::addFunction(std::string name, function::function_set definitions) { - StringUtils::toUpper(name); - builtInFunctions->addFunction(std::move(name), std::move(definitions)); + if (functions->containsEntry(name)) { + throw CatalogException{common::stringFormat("function {} already exists.", name)}; + } + functions->createEntry( + std::make_unique(std::move(name), std::move(definitions))); } void CatalogContent::addScalarMacroFunction( std::string name, std::unique_ptr macro) { - StringUtils::toUpper(name); - macros.emplace(std::move(name), std::move(macro)); + functions->createEntry( + std::make_unique(std::move(name), std::move(macro))); +} + +function::ScalarMacroFunction* CatalogContent::getScalarMacroFunction( + const std::string& name) const { + return ku_dynamic_cast(functions->getEntry(name)) + ->getMacroFunction(); } std::unique_ptr CatalogContent::copy() const { std::unordered_map> macrosToCopy; - for (auto& macro : macros) { - macrosToCopy.emplace(macro.first, macro.second->copy()); - } - return std::make_unique(tables->copy(), tableNameToIDMap, nextTableID, - builtInFunctions->copy(), std::move(macrosToCopy), vfs); + return std::make_unique( + tables->copy(), tableNameToIDMap, nextTableID, functions->copy(), vfs); } void CatalogContent::registerBuiltInFunctions() { - builtInFunctions = std::make_unique(); + function::BuiltInFunctionsUtils::createFunctions(functions.get()); } bool CatalogContent::containsTable(const std::string& tableName) const { @@ -264,7 +273,7 @@ CatalogEntry* CatalogContent::getTableCatalogEntry(table_id_t tableID) const { return table.get(); } } - KU_ASSERT(false); + KU_UNREACHABLE; } common::table_id_t CatalogContent::getTableID(const std::string& tableName) const { diff --git a/src/catalog/catalog_entry/CMakeLists.txt b/src/catalog/catalog_entry/CMakeLists.txt index 60564d1e39..52560c84ec 100644 --- a/src/catalog/catalog_entry/CMakeLists.txt +++ b/src/catalog/catalog_entry/CMakeLists.txt @@ -1,11 +1,16 @@ add_library(kuzu_catalog_entry OBJECT + aggregate_function_catalog_entry.cpp catalog_entry.cpp + function_catalog_entry.cpp table_catalog_entry.cpp node_table_catalog_entry.cpp rel_table_catalog_entry.cpp rel_group_catalog_entry.cpp - rdf_graph_catalog_entry.cpp) + rdf_graph_catalog_entry.cpp + scalar_macro_catalog_entry.cpp + scalar_function_catalog_entry.cpp + table_function_catalog_entry.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/catalog/catalog_entry/aggregate_function_catalog_entry.cpp b/src/catalog/catalog_entry/aggregate_function_catalog_entry.cpp new file mode 100644 index 0000000000..157e9d2e03 --- /dev/null +++ b/src/catalog/catalog_entry/aggregate_function_catalog_entry.cpp @@ -0,0 +1,19 @@ +#include "catalog/catalog_entry/aggregate_function_catalog_entry.h" + +#include "common/utils.h" + +namespace kuzu { +namespace catalog { + +AggregateFunctionCatalogEntry::AggregateFunctionCatalogEntry( + std::string name, function::function_set functionSet) + : FunctionCatalogEntry{ + CatalogEntryType::AGGREGATE_FUNCTION_ENTRY, std::move(name), std::move(functionSet)} {} + +std::unique_ptr AggregateFunctionCatalogEntry::copy() const { + return std::make_unique( + getName(), common::copyVector(functionSet)); +} + +} // namespace catalog +} // namespace kuzu diff --git a/src/catalog/catalog_entry/catalog_entry.cpp b/src/catalog/catalog_entry/catalog_entry.cpp index 37e62f148c..5b50a906fa 100644 --- a/src/catalog/catalog_entry/catalog_entry.cpp +++ b/src/catalog/catalog_entry/catalog_entry.cpp @@ -1,5 +1,6 @@ #include "catalog/catalog_entry/catalog_entry.h" +#include "catalog/catalog_entry/scalar_macro_catalog_entry.h" #include "catalog/catalog_entry/table_catalog_entry.h" namespace kuzu { @@ -20,9 +21,12 @@ std::unique_ptr CatalogEntry::deserialize(common::Deserializer& de case CatalogEntryType::NODE_TABLE_ENTRY: case CatalogEntryType::REL_TABLE_ENTRY: case CatalogEntryType::REL_GROUP_ENTRY: - case CatalogEntryType::RDF_GRAPH_ENTRY: + case CatalogEntryType::RDF_GRAPH_ENTRY: { entry = TableCatalogEntry::deserialize(deserializer, type); - break; + } break; + case CatalogEntryType::SCALAR_MACRO_ENTRY: { + entry = ScalarMacroCatalogEntry::deserialize(deserializer); + } break; default: KU_UNREACHABLE; } diff --git a/src/catalog/catalog_entry/function_catalog_entry.cpp b/src/catalog/catalog_entry/function_catalog_entry.cpp new file mode 100644 index 0000000000..edc8c74164 --- /dev/null +++ b/src/catalog/catalog_entry/function_catalog_entry.cpp @@ -0,0 +1,18 @@ +#include "catalog/catalog_entry/function_catalog_entry.h" + +#include "common/utils.h" + +namespace kuzu { +namespace catalog { + +FunctionCatalogEntry::FunctionCatalogEntry( + CatalogEntryType entryType, std::string name, function::function_set functionSet) + : CatalogEntry{entryType, std::move(name)}, functionSet{std::move(functionSet)} {} + +std::unique_ptr FunctionCatalogEntry::copy() const { + return std::make_unique( + getType(), getName(), common::copyVector>(functionSet)); +} + +} // namespace catalog +} // namespace kuzu diff --git a/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp b/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp new file mode 100644 index 0000000000..de921bda66 --- /dev/null +++ b/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp @@ -0,0 +1,18 @@ +#include "catalog/catalog_entry/scalar_function_catalog_entry.h" + +#include "common/utils.h" + +namespace kuzu { +namespace catalog { + +ScalarFunctionCatalogEntry::ScalarFunctionCatalogEntry( + std::string name, function::function_set functionSet) + : FunctionCatalogEntry{ + CatalogEntryType::SCALAR_FUNCTION_ENTRY, std::move(name), std::move(functionSet)} {} + +std::unique_ptr ScalarFunctionCatalogEntry::copy() const { + return std::make_unique(getName(), common::copyVector(functionSet)); +} + +} // namespace catalog +} // namespace kuzu diff --git a/src/catalog/catalog_entry/scalar_macro_catalog_entry.cpp b/src/catalog/catalog_entry/scalar_macro_catalog_entry.cpp new file mode 100644 index 0000000000..aea87f907f --- /dev/null +++ b/src/catalog/catalog_entry/scalar_macro_catalog_entry.cpp @@ -0,0 +1,33 @@ +#include "catalog/catalog_entry/scalar_macro_catalog_entry.h" + +namespace kuzu { +namespace catalog { + +ScalarMacroCatalogEntry::ScalarMacroCatalogEntry( + std::string name, std::unique_ptr macroFunction) + : CatalogEntry{CatalogEntryType::SCALAR_MACRO_ENTRY, std::move(name)}, macroFunction{std::move( + macroFunction)} {} + +void ScalarMacroCatalogEntry::serialize(common::Serializer& serializer) const { + CatalogEntry::serialize(serializer); + macroFunction->serialize(serializer); +} + +std::unique_ptr ScalarMacroCatalogEntry::deserialize( + common::Deserializer& deserializer) { + auto scalarMacroCatalogEntry = std::make_unique(); + scalarMacroCatalogEntry->macroFunction = + function::ScalarMacroFunction::deserialize(deserializer); + return scalarMacroCatalogEntry; +} + +std::unique_ptr ScalarMacroCatalogEntry::copy() const { + return std::make_unique(getName(), macroFunction->copy()); +} + +std::string ScalarMacroCatalogEntry::toCypher(main::ClientContext* /*clientContext*/) const { + return macroFunction->toCypher(getName()); +} + +} // namespace catalog +} // namespace kuzu diff --git a/src/catalog/catalog_entry/table_function_catalog_entry.cpp b/src/catalog/catalog_entry/table_function_catalog_entry.cpp new file mode 100644 index 0000000000..26b4f69dd5 --- /dev/null +++ b/src/catalog/catalog_entry/table_function_catalog_entry.cpp @@ -0,0 +1,18 @@ +#include "catalog/catalog_entry/table_function_catalog_entry.h" + +#include "common/utils.h" + +namespace kuzu { +namespace catalog { + +TableFunctionCatalogEntry::TableFunctionCatalogEntry( + std::string name, function::function_set functionSet) + : FunctionCatalogEntry{ + CatalogEntryType::TABLE_FUNCTION_ENTRY, std::move(name), std::move(functionSet)} {} + +std::unique_ptr TableFunctionCatalogEntry::copy() const { + return std::make_unique(getName(), common::copyVector(functionSet)); +} + +} // namespace catalog +} // namespace kuzu diff --git a/src/catalog/catalog_set.cpp b/src/catalog/catalog_set.cpp index cf1fe29630..7f0a512206 100644 --- a/src/catalog/catalog_set.cpp +++ b/src/catalog/catalog_set.cpp @@ -10,7 +10,6 @@ bool CatalogSet::containsEntry(const std::string& name) const { } CatalogEntry* CatalogSet::getEntry(const std::string& name) { - KU_ASSERT(containsEntry(name)); return entries.at(name).get(); } @@ -33,7 +32,18 @@ void CatalogSet::renameEntry(const std::string& oldName, const std::string& newN } void CatalogSet::serialize(common::Serializer serializer) const { - serializer.serializeValue(entries.size()); + uint64_t numEntries = 0; + for (auto& [name, entry] : entries) { + switch (entry->getType()) { + case CatalogEntryType::SCALAR_FUNCTION_ENTRY: + case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY: + case CatalogEntryType::TABLE_FUNCTION_ENTRY: + continue; + default: + numEntries++; + } + } + serializer.serializeValue(numEntries); for (auto& [name, entry] : entries) { entry->serialize(serializer); } @@ -45,7 +55,9 @@ std::unique_ptr CatalogSet::deserialize(common::Deserializer& deseri deserializer.deserializeValue(numEntries); for (uint64_t i = 0; i < numEntries; i++) { auto entry = CatalogEntry::deserialize(deserializer); - catalogSet->createEntry(std::move(entry)); + if (entry != nullptr) { + catalogSet->createEntry(std::move(entry)); + } } return catalogSet; } diff --git a/src/common/utils.cpp b/src/common/utils.cpp index e5ef0ac9c9..febf1e6114 100644 --- a/src/common/utils.cpp +++ b/src/common/utils.cpp @@ -59,5 +59,24 @@ std::string LoggerUtils::getLoggerName(LoggerConstants::LoggerEnum loggerEnum) { } } } + +uint64_t nextPowerOfTwo(uint64_t v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + v++; + return v; +} + +bool isLittleEndian() { + // Little endian arch stores the least significant value in the lower bytes. + int testNumber = 1; + return *(uint8_t*)&testNumber == 1; +} + } // namespace common } // namespace kuzu diff --git a/src/function/CMakeLists.txt b/src/function/CMakeLists.txt index 6ee3ea9457..9cc03f47c9 100644 --- a/src/function/CMakeLists.txt +++ b/src/function/CMakeLists.txt @@ -7,7 +7,7 @@ add_library(kuzu_function OBJECT aggregate_function.cpp base_lower_upper_operation.cpp - built_in_functions.cpp + built_in_function_utils.cpp cast_string_non_nested_functions.cpp cast_from_string_functions.cpp comparison_functions.cpp diff --git a/src/function/built_in_function_utils.cpp b/src/function/built_in_function_utils.cpp new file mode 100644 index 0000000000..af226da0c8 --- /dev/null +++ b/src/function/built_in_function_utils.cpp @@ -0,0 +1,1103 @@ +#include "function/built_in_function_utils.h" + +#include "catalog/catalog_entry/aggregate_function_catalog_entry.h" +#include "catalog/catalog_entry/scalar_function_catalog_entry.h" +#include "catalog/catalog_entry/table_function_catalog_entry.h" +#include "catalog/catalog_set.h" +#include "common/exception/binder.h" +#include "common/exception/catalog.h" +#include "function/aggregate/collect.h" +#include "function/aggregate/count.h" +#include "function/aggregate/count_star.h" +#include "function/aggregate_function.h" +#include "function/arithmetic/vector_arithmetic_functions.h" +#include "function/blob/vector_blob_functions.h" +#include "function/cast/vector_cast_functions.h" +#include "function/comparison/vector_comparison_functions.h" +#include "function/date/vector_date_functions.h" +#include "function/interval/vector_interval_functions.h" +#include "function/list/vector_list_functions.h" +#include "function/map/vector_map_functions.h" +#include "function/path/vector_path_functions.h" +#include "function/rdf/vector_rdf_functions.h" +#include "function/schema/vector_node_rel_functions.h" +#include "function/string/vector_string_functions.h" +#include "function/struct/vector_struct_functions.h" +#include "function/table_functions/call_functions.h" +#include "function/timestamp/vector_timestamp_functions.h" +#include "function/union/vector_union_functions.h" +#include "function/uuid/vector_uuid_functions.h" +#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h" +#include "processor/operator/persistent/reader/csv/serial_csv_reader.h" +#include "processor/operator/persistent/reader/npy/npy_reader.h" +#include "processor/operator/persistent/reader/parquet/parquet_reader.h" +#include "processor/operator/persistent/reader/rdf/rdf_scan.h" +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +using namespace kuzu::catalog; + +void BuiltInFunctionsUtils::createFunctions(catalog::CatalogSet* catalogSet) { + registerScalarFunctions(catalogSet); + registerAggregateFunctions(catalogSet); + registerTableFunctions(catalogSet); +} + +void BuiltInFunctionsUtils::registerScalarFunctions(catalog::CatalogSet* catalogSet) { + registerComparisonFunctions(catalogSet); + registerArithmeticFunctions(catalogSet); + registerDateFunctions(catalogSet); + registerTimestampFunctions(catalogSet); + registerIntervalFunctions(catalogSet); + registerStringFunctions(catalogSet); + registerCastFunctions(catalogSet); + registerListFunctions(catalogSet); + registerStructFunctions(catalogSet); + registerMapFunctions(catalogSet); + registerUnionFunctions(catalogSet); + registerNodeRelFunctions(catalogSet); + registerPathFunctions(catalogSet); + registerBlobFunctions(catalogSet); + registerUUIDFunctions(catalogSet); + registerRdfFunctions(catalogSet); +} + +void BuiltInFunctionsUtils::registerAggregateFunctions(catalog::CatalogSet* catalogSet) { + registerCountStar(catalogSet); + registerCount(catalogSet); + registerSum(catalogSet); + registerAvg(catalogSet); + registerMin(catalogSet); + registerMax(catalogSet); + registerCollect(catalogSet); +} + +Function* BuiltInFunctionsUtils::matchFunction( + const std::string& name, catalog::CatalogSet* catalogSet) { + return matchFunction(name, std::vector{}, catalogSet); +} + +Function* BuiltInFunctionsUtils::matchFunction(const std::string& name, + const std::vector& inputTypes, catalog::CatalogSet* catalogSet) { + if (!catalogSet->containsEntry(name)) { + throw CatalogException(stringFormat("{} function does not exist.", name)); + } + auto& functionSet = + reinterpret_cast(catalogSet->getEntry(name))->getFunctionSet(); + bool isOverload = functionSet.size() > 1; + std::vector candidateFunctions; + uint32_t minCost = UINT32_MAX; + for (auto& function : functionSet) { + auto func = reinterpret_cast(function.get()); + if (name == CAST_FUNC_NAME) { + return func; + } + auto cost = getFunctionCost(inputTypes, func, isOverload); + if (cost == UINT32_MAX) { + continue; + } + if (cost < minCost) { + candidateFunctions.clear(); + candidateFunctions.push_back(func); + minCost = cost; + } else if (cost == minCost) { + candidateFunctions.push_back(func); + } + } + validateNonEmptyCandidateFunctions(candidateFunctions, name, inputTypes, functionSet); + if (candidateFunctions.size() > 1) { + return getBestMatch(candidateFunctions); + } + validateSpecialCases(candidateFunctions, name, inputTypes, functionSet); + return candidateFunctions[0]; +} + +AggregateFunction* BuiltInFunctionsUtils::matchAggregateFunction(const std::string& name, + const std::vector& inputTypes, bool isDistinct, + catalog::CatalogSet* catalogSet) { + auto& functionSet = + reinterpret_cast(catalogSet->getEntry(name))->getFunctionSet(); + std::vector candidateFunctions; + for (auto& function : functionSet) { + auto aggregateFunction = ku_dynamic_cast(function.get()); + auto cost = getAggregateFunctionCost(inputTypes, isDistinct, aggregateFunction); + if (cost == UINT32_MAX) { + continue; + } + candidateFunctions.push_back(aggregateFunction); + } + validateNonEmptyCandidateFunctions( + candidateFunctions, name, inputTypes, isDistinct, functionSet); + KU_ASSERT(candidateFunctions.size() == 1); + return candidateFunctions[0]; +} + +uint32_t BuiltInFunctionsUtils::getCastCost(LogicalTypeID inputTypeID, LogicalTypeID targetTypeID) { + if (inputTypeID == targetTypeID) { + return 0; + } + // TODO(Jiamin): should check any type + if (inputTypeID == LogicalTypeID::ANY || targetTypeID == LogicalTypeID::ANY || + inputTypeID == LogicalTypeID::RDF_VARIANT) { + // anything can be cast to ANY type for (almost no) cost + return 1; + } + if (targetTypeID == LogicalTypeID::RDF_VARIANT) { + return castFromRDFVariant(inputTypeID); + } + if (targetTypeID == LogicalTypeID::STRING) { + return castFromString(inputTypeID); + } + switch (inputTypeID) { + case LogicalTypeID::INT64: + return castInt64(targetTypeID); + case LogicalTypeID::INT32: + return castInt32(targetTypeID); + case LogicalTypeID::INT16: + return castInt16(targetTypeID); + case LogicalTypeID::INT8: + return castInt8(targetTypeID); + case LogicalTypeID::UINT64: + return castUInt64(targetTypeID); + case LogicalTypeID::UINT32: + return castUInt32(targetTypeID); + case LogicalTypeID::UINT16: + return castUInt16(targetTypeID); + case LogicalTypeID::UINT8: + return castUInt8(targetTypeID); + case LogicalTypeID::INT128: + return castInt128(targetTypeID); + case LogicalTypeID::DOUBLE: + return castDouble(targetTypeID); + case LogicalTypeID::FLOAT: + return castFloat(targetTypeID); + case LogicalTypeID::DATE: + return castDate(targetTypeID); + case LogicalTypeID::UUID: + return castUUID(targetTypeID); + case LogicalTypeID::SERIAL: + return castSerial(targetTypeID); + case LogicalTypeID::TIMESTAMP_SEC: + case LogicalTypeID::TIMESTAMP_MS: + case LogicalTypeID::TIMESTAMP_NS: + case LogicalTypeID::TIMESTAMP_TZ: + // currently don't allow timestamp to other timestamp types + return castTimestamp(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::getTargetTypeCost(LogicalTypeID typeID) { + switch (typeID) { + case LogicalTypeID::INT64: + return 101; + case LogicalTypeID::INT32: + return 102; + case LogicalTypeID::INT128: + return 103; + case LogicalTypeID::DOUBLE: + return 104; + case LogicalTypeID::TIMESTAMP: + return 120; + case LogicalTypeID::STRING: + return 149; + case LogicalTypeID::STRUCT: + case LogicalTypeID::MAP: + case LogicalTypeID::FIXED_LIST: + case LogicalTypeID::VAR_LIST: + case LogicalTypeID::UNION: + return 160; + case LogicalTypeID::RDF_VARIANT: + return 170; + default: + return 110; + } +} + +uint32_t BuiltInFunctionsUtils::castInt64(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT128: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castInt32(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castInt16(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT32: + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castInt8(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT16: + case LogicalTypeID::INT32: + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castUInt64(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT128: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castUInt32(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::UINT64: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castUInt16(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT32: + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::UINT32: + case LogicalTypeID::UINT64: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castUInt8(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT16: + case LogicalTypeID::INT32: + case LogicalTypeID::INT64: + case LogicalTypeID::INT128: + case LogicalTypeID::UINT16: + case LogicalTypeID::UINT32: + case LogicalTypeID::UINT64: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castInt128(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::FLOAT: + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castUUID(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::STRING: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castDouble(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castFloat(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::DOUBLE: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castDate(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::TIMESTAMP: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castSerial(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::INT64: + return 0; + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castTimestamp(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::TIMESTAMP: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castFromString(LogicalTypeID inputTypeID) { + switch (inputTypeID) { + case LogicalTypeID::BLOB: + case LogicalTypeID::INTERNAL_ID: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + return UNDEFINED_CAST_COST; + default: // Any other inputTypeID can be cast to String, but this cast has a high cost + return getTargetTypeCost(LogicalTypeID::STRING); + } +} + +uint32_t BuiltInFunctionsUtils::castFromRDFVariant(LogicalTypeID inputTypeID) { + switch (inputTypeID) { + case LogicalTypeID::STRUCT: + case LogicalTypeID::VAR_LIST: + case LogicalTypeID::FIXED_LIST: + case LogicalTypeID::UNION: + case LogicalTypeID::MAP: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::RECURSIVE_REL: + case LogicalTypeID::RDF_VARIANT: + return UNDEFINED_CAST_COST; + default: // Any other inputTypeID can be cast to RDF_VARIANT, but this cast has a high cost + return getTargetTypeCost(LogicalTypeID::RDF_VARIANT); + } +} + +// When there is multiple candidates functions, e.g. double + int and double + double for input +// "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double. +// Additionally, we prefer function with string parameter because string is most permissive and can +// be cast to any type. +Function* BuiltInFunctionsUtils::getBestMatch(std::vector& functionsToMatch) { + KU_ASSERT(functionsToMatch.size() > 1); + Function* result = nullptr; + auto cost = UNDEFINED_CAST_COST; + for (auto& function : functionsToMatch) { + auto currentCost = 0u; + std::unordered_set distinctParameterTypes; + for (auto& parameterTypeID : function->parameterTypeIDs) { + if (parameterTypeID != LogicalTypeID::STRING) { + currentCost++; + } + if (!distinctParameterTypes.contains(parameterTypeID)) { + currentCost++; + distinctParameterTypes.insert(parameterTypeID); + } + } + if (currentCost < cost) { + cost = currentCost; + result = function; + } + } + KU_ASSERT(result != nullptr); + return result; +} + +uint32_t BuiltInFunctionsUtils::getFunctionCost( + const std::vector& inputTypes, Function* function, bool isOverload) { + switch (function->type) { + case FunctionType::SCALAR: { + auto scalarFunction = ku_dynamic_cast(function); + if (scalarFunction->isVarLength) { + KU_ASSERT(function->parameterTypeIDs.size() == 1); + return matchVarLengthParameters(inputTypes, function->parameterTypeIDs[0], isOverload); + } else { + return matchParameters(inputTypes, function->parameterTypeIDs, isOverload); + } + } + case FunctionType::TABLE: + return matchParameters(inputTypes, function->parameterTypeIDs, isOverload); + default: + KU_UNREACHABLE; + } +} + +uint32_t BuiltInFunctionsUtils::getAggregateFunctionCost( + const std::vector& inputTypes, bool isDistinct, AggregateFunction* function) { + if (inputTypes.size() != function->parameterTypeIDs.size() || + isDistinct != function->isDistinct) { + return UINT32_MAX; + } + for (auto i = 0u; i < inputTypes.size(); ++i) { + if (function->parameterTypeIDs[i] == LogicalTypeID::ANY) { + continue; + } else if (inputTypes[i].getLogicalTypeID() != function->parameterTypeIDs[i]) { + return UINT32_MAX; + } + } + return 0; +} + +uint32_t BuiltInFunctionsUtils::matchParameters(const std::vector& inputTypes, + const std::vector& targetTypeIDs, bool /*isOverload*/) { + if (inputTypes.size() != targetTypeIDs.size()) { + return UINT32_MAX; + } + auto cost = 0u; + for (auto i = 0u; i < inputTypes.size(); ++i) { + auto castCost = getCastCost(inputTypes[i].getLogicalTypeID(), targetTypeIDs[i]); + if (castCost == UNDEFINED_CAST_COST) { + return UINT32_MAX; + } + cost += castCost; + } + return cost; +} + +uint32_t BuiltInFunctionsUtils::matchVarLengthParameters( + const std::vector& inputTypes, LogicalTypeID targetTypeID, bool /*isOverload*/) { + auto cost = 0u; + for (auto inputType : inputTypes) { + auto castCost = getCastCost(inputType.getLogicalTypeID(), targetTypeID); + if (castCost == UNDEFINED_CAST_COST) { + return UINT32_MAX; + } + cost += castCost; + } + return cost; +} + +void BuiltInFunctionsUtils::validateSpecialCases(std::vector& candidateFunctions, + const std::string& name, const std::vector& inputTypes, + function::function_set& set) { + // special case for add func + if (name == ADD_FUNC_NAME) { + auto targetType0 = candidateFunctions[0]->parameterTypeIDs[0]; + auto targetType1 = candidateFunctions[0]->parameterTypeIDs[1]; + auto inputType0 = inputTypes[0].getLogicalTypeID(); + auto inputType1 = inputTypes[1].getLogicalTypeID(); + if ((inputType0 != LogicalTypeID::STRING || inputType1 != LogicalTypeID::STRING) && + targetType0 == LogicalTypeID::STRING && targetType1 == LogicalTypeID::STRING) { + if (inputType0 != inputType1 && (inputType0 == LogicalTypeID::RDF_VARIANT || + inputType1 == LogicalTypeID::RDF_VARIANT)) { + return; + } + std::string supportedInputsString; + for (auto& function : set) { + supportedInputsString += function->signatureToString() + "\n"; + } + throw BinderException("Cannot match a built-in function for given function " + name + + LogicalTypeUtils::toString(inputTypes) + + ". Supported inputs are\n" + supportedInputsString); + } + } +} + +void BuiltInFunctionsUtils::registerComparisonFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + EQUALS_FUNC_NAME, EqualsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + NOT_EQUALS_FUNC_NAME, NotEqualsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + GREATER_THAN_FUNC_NAME, GreaterThanFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + GREATER_THAN_EQUALS_FUNC_NAME, GreaterThanEqualsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LESS_THAN_FUNC_NAME, LessThanFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LESS_THAN_EQUALS_FUNC_NAME, LessThanEqualsFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerArithmeticFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + ADD_FUNC_NAME, AddFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + SUBTRACT_FUNC_NAME, SubtractFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + MULTIPLY_FUNC_NAME, MultiplyFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + DIVIDE_FUNC_NAME, DivideFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + MODULO_FUNC_NAME, ModuloFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + POWER_FUNC_NAME, PowerFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ABS_FUNC_NAME, AbsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ACOS_FUNC_NAME, AcosFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ASIN_FUNC_NAME, AsinFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ATAN_FUNC_NAME, AtanFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ATAN2_FUNC_NAME, Atan2Function::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + BITWISE_XOR_FUNC_NAME, BitwiseXorFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + BITWISE_AND_FUNC_NAME, BitwiseAndFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + BITWISE_OR_FUNC_NAME, BitwiseOrFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + BITSHIFT_LEFT_FUNC_NAME, BitShiftLeftFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + BITSHIFT_RIGHT_FUNC_NAME, BitShiftRightFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CBRT_FUNC_NAME, CbrtFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CEIL_FUNC_NAME, CeilFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CEILING_FUNC_NAME, CeilFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + COS_FUNC_NAME, CosFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + COT_FUNC_NAME, CotFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + DEGREES_FUNC_NAME, DegreesFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + EVEN_FUNC_NAME, EvenFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + FACTORIAL_FUNC_NAME, FactorialFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + FLOOR_FUNC_NAME, FloorFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + GAMMA_FUNC_NAME, GammaFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LGAMMA_FUNC_NAME, LgammaFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LN_FUNC_NAME, LnFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LOG_FUNC_NAME, LogFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LOG2_FUNC_NAME, Log2Function::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LOG10_FUNC_NAME, LogFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + NEGATE_FUNC_NAME, NegateFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + PI_FUNC_NAME, PiFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + POW_FUNC_NAME, PowerFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + RADIANS_FUNC_NAME, RadiansFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ROUND_FUNC_NAME, RoundFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + SIN_FUNC_NAME, SinFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + SIGN_FUNC_NAME, SignFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + SQRT_FUNC_NAME, SqrtFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + TAN_FUNC_NAME, TanFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerDateFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + DATE_PART_FUNC_NAME, DatePartFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + DATEPART_FUNC_NAME, DatePartFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + DATE_TRUNC_FUNC_NAME, DateTruncFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + DATETRUNC_FUNC_NAME, DateTruncFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + DAYNAME_FUNC_NAME, DayNameFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + GREATEST_FUNC_NAME, GreatestFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LAST_DAY_FUNC_NAME, LastDayFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LEAST_FUNC_NAME, LeastFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + MAKE_DATE_FUNC_NAME, MakeDateFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + MONTHNAME_FUNC_NAME, MonthNameFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerTimestampFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + CENTURY_FUNC_NAME, CenturyFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + EPOCH_MS_FUNC_NAME, EpochMsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + TO_TIMESTAMP_FUNC_NAME, ToTimestampFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerIntervalFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + TO_YEARS_FUNC_NAME, ToYearsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + TO_MONTHS_FUNC_NAME, ToMonthsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + TO_DAYS_FUNC_NAME, ToDaysFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + TO_HOURS_FUNC_NAME, ToHoursFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + TO_MINUTES_FUNC_NAME, ToMinutesFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + TO_SECONDS_FUNC_NAME, ToSecondsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + TO_MILLISECONDS_FUNC_NAME, ToMillisecondsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + TO_MICROSECONDS_FUNC_NAME, ToMicrosecondsFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerBlobFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + OCTET_LENGTH_FUNC_NAME, OctetLengthFunctions::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ENCODE_FUNC_NAME, EncodeFunctions::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + DECODE_FUNC_NAME, DecodeFunctions::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerUUIDFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + GEN_RANDOM_UUID_FUNC_NAME, GenRandomUUIDFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerStringFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + ARRAY_EXTRACT_FUNC_NAME, ArrayExtractFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CONCAT_FUNC_NAME, ConcatFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CONTAINS_FUNC_NAME, ContainsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ENDS_WITH_FUNC_NAME, EndsWithFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LCASE_FUNC_NAME, LowerFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LEFT_FUNC_NAME, LeftFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LOWER_FUNC_NAME, LowerFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LPAD_FUNC_NAME, LpadFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LTRIM_FUNC_NAME, LtrimFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + PREFIX_FUNC_NAME, StartsWithFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + REPEAT_FUNC_NAME, RepeatFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + REVERSE_FUNC_NAME, ReverseFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + RIGHT_FUNC_NAME, RightFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + RPAD_FUNC_NAME, RpadFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + RTRIM_FUNC_NAME, RtrimFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + STARTS_WITH_FUNC_NAME, StartsWithFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + SUBSTR_FUNC_NAME, SubStrFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + SUBSTRING_FUNC_NAME, SubStrFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + SUFFIX_FUNC_NAME, EndsWithFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + TRIM_FUNC_NAME, TrimFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + UCASE_FUNC_NAME, UpperFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + UPPER_FUNC_NAME, UpperFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + REGEXP_FULL_MATCH_FUNC_NAME, RegexpFullMatchFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + REGEXP_MATCHES_FUNC_NAME, RegexpMatchesFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + REGEXP_REPLACE_FUNC_NAME, RegexpReplaceFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + REGEXP_EXTRACT_FUNC_NAME, RegexpExtractFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + REGEXP_EXTRACT_ALL_FUNC_NAME, RegexpExtractAllFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerCastFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + CAST_DATE_FUNC_NAME, CastToDateFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_DATE_FUNC_NAME, CastToDateFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_TIMESTAMP_FUNC_NAME, CastToTimestampFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_INTERVAL_FUNC_NAME, CastToIntervalFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_INTERVAL_FUNC_NAME, CastToIntervalFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_STRING_FUNC_NAME, CastToStringFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_STRING_FUNC_NAME, CastToStringFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_BLOB_FUNC_NAME, CastToBlobFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_BLOB_FUNC_NAME, CastToBlobFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_UUID_FUNC_NAME, CastToUUIDFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_UUID_FUNC_NAME, CastToUUIDFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_DOUBLE_FUNC_NAME, CastToDoubleFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_FLOAT_FUNC_NAME, CastToFloatFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_SERIAL_FUNC_NAME, CastToSerialFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_INT64_FUNC_NAME, CastToInt64Function::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_INT32_FUNC_NAME, CastToInt32Function::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_INT16_FUNC_NAME, CastToInt16Function::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_INT8_FUNC_NAME, CastToInt8Function::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_UINT64_FUNC_NAME, CastToUInt64Function::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_UINT32_FUNC_NAME, CastToUInt32Function::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_UINT16_FUNC_NAME, CastToUInt16Function::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_UINT8_FUNC_NAME, CastToUInt8Function::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_INT128_FUNC_NAME, CastToInt128Function::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_TO_BOOL_FUNC_NAME, CastToBoolFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CAST_FUNC_NAME, CastAnyFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerListFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + LIST_CREATION_FUNC_NAME, ListCreationFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_RANGE_FUNC_NAME, ListRangeFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + SIZE_FUNC_NAME, SizeFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_EXTRACT_FUNC_NAME, ListExtractFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_ELEMENT_FUNC_NAME, ListExtractFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_CONCAT_FUNC_NAME, ListConcatFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_CAT_FUNC_NAME, ListConcatFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ARRAY_CONCAT_FUNC_NAME, ListConcatFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ARRAY_CAT_FUNC_NAME, ListConcatFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_APPEND_FUNC_NAME, ListAppendFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ARRAY_APPEND_FUNC_NAME, ListAppendFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ARRAY_PUSH_BACK_FUNC_NAME, ListAppendFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_PREPEND_FUNC_NAME, ListPrependFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ARRAY_PREPEND_FUNC_NAME, ListPrependFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ARRAY_PUSH_FRONT_FUNC_NAME, ListPrependFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_POSITION_FUNC_NAME, ListPositionFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ARRAY_POSITION_FUNC_NAME, ListPositionFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_INDEXOF_FUNC_NAME, ListPositionFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ARRAY_INDEXOF_FUNC_NAME, ListPositionFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_CONTAINS_FUNC_NAME, ListContainsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_HAS_FUNC_NAME, ListContainsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ARRAY_CONTAINS_FUNC_NAME, ListContainsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ARRAY_HAS_FUNC_NAME, ListContainsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_SLICE_FUNC_NAME, ListSliceFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ARRAY_SLICE_FUNC_NAME, ListSliceFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_SORT_FUNC_NAME, ListSortFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_REVERSE_SORT_FUNC_NAME, ListReverseSortFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_SUM_FUNC_NAME, ListSumFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_PRODUCT_FUNC_NAME, ListProductFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_DISTINCT_FUNC_NAME, ListDistinctFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_UNIQUE_FUNC_NAME, ListUniqueFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + LIST_ANY_VALUE_FUNC_NAME, ListAnyValueFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerStructFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + STRUCT_PACK_FUNC_NAME, StructPackFunctions::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + STRUCT_EXTRACT_FUNC_NAME, StructExtractFunctions::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerMapFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + MAP_CREATION_FUNC_NAME, MapCreationFunctions::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + MAP_EXTRACT_FUNC_NAME, MapExtractFunctions::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + ELEMENT_AT_FUNC_NAME, MapExtractFunctions::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + CARDINALITY_FUNC_NAME, SizeFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + MAP_KEYS_FUNC_NAME, MapKeysFunctions::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + MAP_VALUES_FUNC_NAME, MapValuesFunctions::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerUnionFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + UNION_VALUE_FUNC_NAME, UnionValueFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + UNION_TAG_FUNC_NAME, UnionTagFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + UNION_EXTRACT_FUNC_NAME, UnionExtractFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerNodeRelFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + OFFSET_FUNC_NAME, OffsetFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerPathFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + NODES_FUNC_NAME, NodesFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + RELS_FUNC_NAME, RelsFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + PROPERTIES_FUNC_NAME, PropertiesFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + IS_TRAIL_FUNC_NAME, IsTrailFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + IS_ACYCLIC_FUNC_NAME, IsACyclicFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerRdfFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + TYPE_FUNC_NAME, RDFTypeFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + VALIDATE_PREDICATE_FUNC_NAME, ValidatePredicateFunction::getFunctionSet())); +} + +void BuiltInFunctionsUtils::registerCountStar(catalog::CatalogSet* catalogSet) { + function_set functionSet; + functionSet.push_back(std::make_unique(COUNT_STAR_FUNC_NAME, + std::vector{}, LogicalTypeID::INT64, CountStarFunction::initialize, + CountStarFunction::updateAll, CountStarFunction::updatePos, CountStarFunction::combine, + CountStarFunction::finalize, false)); + catalogSet->createEntry(std::make_unique( + COUNT_STAR_FUNC_NAME, std::move(functionSet))); +} + +void BuiltInFunctionsUtils::registerCount(catalog::CatalogSet* catalogSet) { + function_set functionSet; + for (auto& type : LogicalTypeUtils::getAllValidLogicTypes()) { + for (auto isDistinct : std::vector{true, false}) { + functionSet.push_back(AggregateFunctionUtil::getAggFunc(COUNT_FUNC_NAME, + type, LogicalTypeID::INT64, isDistinct, CountFunction::paramRewriteFunc)); + } + } + catalogSet->createEntry(std::make_unique( + COUNT_FUNC_NAME, std::move(functionSet))); +} + +void BuiltInFunctionsUtils::registerSum(catalog::CatalogSet* catalogSet) { + function_set functionSet; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + for (auto isDistinct : std::vector{true, false}) { + functionSet.push_back( + AggregateFunctionUtil::getSumFunc(SUM_FUNC_NAME, typeID, typeID, isDistinct)); + } + } + catalogSet->createEntry(std::make_unique( + SUM_FUNC_NAME, std::move(functionSet))); +} + +void BuiltInFunctionsUtils::registerAvg(catalog::CatalogSet* catalogSet) { + function_set functionSet; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + for (auto isDistinct : std::vector{true, false}) { + functionSet.push_back(AggregateFunctionUtil::getAvgFunc( + AVG_FUNC_NAME, typeID, LogicalTypeID::DOUBLE, isDistinct)); + } + } + catalogSet->createEntry(std::make_unique( + AVG_FUNC_NAME, std::move(functionSet))); +} + +void BuiltInFunctionsUtils::registerMin(catalog::CatalogSet* catalogSet) { + function_set functionSet; + for (auto& type : LogicalTypeUtils::getAllValidComparableLogicalTypes()) { + for (auto isDistinct : std::vector{true, false}) { + functionSet.push_back(AggregateFunctionUtil::getMinFunc(type, isDistinct)); + } + } + catalogSet->createEntry(std::make_unique( + MIN_FUNC_NAME, std::move(functionSet))); +} + +void BuiltInFunctionsUtils::registerMax(catalog::CatalogSet* catalogSet) { + function_set functionSet; + for (auto& type : LogicalTypeUtils::getAllValidComparableLogicalTypes()) { + for (auto isDistinct : std::vector{true, false}) { + functionSet.push_back(AggregateFunctionUtil::getMaxFunc(type, isDistinct)); + } + } + catalogSet->createEntry(std::make_unique( + MAX_FUNC_NAME, std::move(functionSet))); +} + +void BuiltInFunctionsUtils::registerCollect(catalog::CatalogSet* catalogSet) { + function_set functionSet; + for (auto isDistinct : std::vector{true, false}) { + functionSet.push_back(std::make_unique(COLLECT_FUNC_NAME, + std::vector{common::LogicalTypeID::ANY}, LogicalTypeID::VAR_LIST, + CollectFunction::initialize, CollectFunction::updateAll, CollectFunction::updatePos, + CollectFunction::combine, CollectFunction::finalize, isDistinct, + CollectFunction::bindFunc)); + } + catalogSet->createEntry(std::make_unique( + COLLECT_FUNC_NAME, std::move(functionSet))); +} + +void BuiltInFunctionsUtils::registerTableFunctions(catalog::CatalogSet* catalogSet) { + catalogSet->createEntry(std::make_unique( + CURRENT_SETTING_FUNC_NAME, CurrentSettingFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + DB_VERSION_FUNC_NAME, DBVersionFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + SHOW_TABLES_FUNC_NAME, ShowTablesFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + TABLE_INFO_FUNC_NAME, TableInfoFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + SHOW_CONNECTION_FUNC_NAME, ShowConnectionFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + READ_PARQUET_FUNC_NAME, processor::ParquetScanFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + READ_NPY_FUNC_NAME, processor::NpyScanFunction::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + READ_CSV_SERIAL_FUNC_NAME, processor::SerialCSVScan::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + READ_CSV_PARALLEL_FUNC_NAME, processor::ParallelCSVScan::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + READ_RDF_RESOURCE_FUNC_NAME, processor::RdfResourceScan::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + READ_RDF_LITERAL_FUNC_NAME, processor::RdfLiteralScan::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, processor::RdfResourceTripleScan::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + READ_RDF_LITERAL_TRIPLE_FUNC_NAME, processor::RdfLiteralTripleScan::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + READ_RDF_ALL_TRIPLE_FUNC_NAME, processor::RdfAllTripleScan::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + IN_MEM_READ_RDF_RESOURCE_FUNC_NAME, processor::RdfResourceInMemScan::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + IN_MEM_READ_RDF_LITERAL_FUNC_NAME, processor::RdfLiteralInMemScan::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + IN_MEM_READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, + processor::RdfResourceTripleInMemScan::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + IN_MEM_READ_RDF_LITERAL_TRIPLE_FUNC_NAME, + processor::RdfLiteralTripleInMemScan::getFunctionSet())); + catalogSet->createEntry(std::make_unique( + STORAGE_INFO_FUNC_NAME, StorageInfoFunction::getFunctionSet())); +} + +static std::string getFunctionMatchFailureMsg(const std::string name, + const std::vector& inputTypes, const std::string& supportedInputs, + bool isDistinct = false) { + auto result = stringFormat("Cannot match a built-in function for given function {}{}{}.", name, + isDistinct ? "DISTINCT " : "", LogicalTypeUtils::toString(inputTypes)); + if (supportedInputs.empty()) { + result += " Expect empty inputs."; + } else { + result += " Supported inputs are\n" + supportedInputs; + } + return result; +} + +void BuiltInFunctionsUtils::validateNonEmptyCandidateFunctions( + std::vector& candidateFunctions, const std::string& name, + const std::vector& inputTypes, bool isDistinct, function::function_set& set) { + if (candidateFunctions.empty()) { + std::string supportedInputsString; + for (auto& function : set) { + auto aggregateFunction = ku_dynamic_cast(function.get()); + if (aggregateFunction->isDistinct) { + supportedInputsString += "DISTINCT "; + } + supportedInputsString += aggregateFunction->signatureToString() + "\n"; + } + throw BinderException( + getFunctionMatchFailureMsg(name, inputTypes, supportedInputsString, isDistinct)); + } +} + +void BuiltInFunctionsUtils::validateNonEmptyCandidateFunctions( + std::vector& candidateFunctions, const std::string& name, + const std::vector& inputTypes, function::function_set& set) { + if (candidateFunctions.empty()) { + std::string supportedInputsString; + for (auto& function : set) { + if (function->parameterTypeIDs.empty()) { + continue; + } + supportedInputsString += function->signatureToString() + "\n"; + } + throw BinderException(getFunctionMatchFailureMsg(name, inputTypes, supportedInputsString)); + } +} + +} // namespace function +} // namespace kuzu diff --git a/src/function/built_in_functions.cpp b/src/function/built_in_functions.cpp deleted file mode 100644 index 2dad84e12f..0000000000 --- a/src/function/built_in_functions.cpp +++ /dev/null @@ -1,931 +0,0 @@ -#include "common/exception/binder.h" -#include "common/exception/catalog.h" -#include "common/string_format.h" -#include "common/string_utils.h" -#include "function/aggregate/collect.h" -#include "function/aggregate/count.h" -#include "function/aggregate/count_star.h" -#include "function/aggregate_function.h" -#include "function/arithmetic/vector_arithmetic_functions.h" -#include "function/blob/vector_blob_functions.h" -#include "function/built_in_function.h" -#include "function/cast/vector_cast_functions.h" -#include "function/comparison/vector_comparison_functions.h" -#include "function/date/vector_date_functions.h" -#include "function/interval/vector_interval_functions.h" -#include "function/list/vector_list_functions.h" -#include "function/map/vector_map_functions.h" -#include "function/path/vector_path_functions.h" -#include "function/rdf/vector_rdf_functions.h" -#include "function/schema/vector_node_rel_functions.h" -#include "function/string/vector_string_functions.h" -#include "function/struct/vector_struct_functions.h" -#include "function/table_functions/call_functions.h" -#include "function/timestamp/vector_timestamp_functions.h" -#include "function/union/vector_union_functions.h" -#include "function/uuid/vector_uuid_functions.h" -#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h" -#include "processor/operator/persistent/reader/csv/serial_csv_reader.h" -#include "processor/operator/persistent/reader/npy/npy_reader.h" -#include "processor/operator/persistent/reader/parquet/parquet_reader.h" -#include "processor/operator/persistent/reader/rdf/rdf_scan.h" - -using namespace kuzu::common; - -namespace kuzu { -namespace function { - -BuiltInFunctions::BuiltInFunctions() { - registerScalarFunctions(); - registerAggregateFunctions(); - registerTableFunctions(); -} - -void BuiltInFunctions::registerScalarFunctions() { - registerComparisonFunctions(); - registerArithmeticFunctions(); - registerDateFunctions(); - registerTimestampFunctions(); - registerIntervalFunctions(); - registerStringFunctions(); - registerCastFunctions(); - registerListFunctions(); - registerStructFunctions(); - registerMapFunctions(); - registerUnionFunctions(); - registerNodeRelFunctions(); - registerPathFunctions(); - registerBlobFunctions(); - registerUUIDFunctions(); - registerRdfFunctions(); -} - -void BuiltInFunctions::registerAggregateFunctions() { - registerCountStar(); - registerCount(); - registerSum(); - registerAvg(); - registerMin(); - registerMax(); - registerCollect(); -} - -FunctionType BuiltInFunctions::getFunctionType(const std::string& name) { - auto normalizedName = StringUtils::getUpper(name); - validateFunctionExists(normalizedName); - auto& functionSet = functions.at(normalizedName); - KU_ASSERT(!functionSet.empty()); - return functionSet[0]->type; -} - -Function* BuiltInFunctions::matchFunction(const std::string& name) { - return matchFunction(name, std::vector{}); -} - -Function* BuiltInFunctions::matchFunction( - const std::string& name, const std::vector& inputTypes) { - auto normalizedName = StringUtils::getUpper(name); - validateFunctionExists(normalizedName); - auto& functionSet = functions.at(normalizedName); - bool isOverload = functionSet.size() > 1; - std::vector candidateFunctions; - uint32_t minCost = UINT32_MAX; - for (auto& function : functionSet) { - auto func = reinterpret_cast(function.get()); - if (normalizedName == CAST_FUNC_NAME) { - return func; - } - auto cost = getFunctionCost(inputTypes, func, isOverload); - if (cost == UINT32_MAX) { - continue; - } - if (cost < minCost) { - candidateFunctions.clear(); - candidateFunctions.push_back(func); - minCost = cost; - } else if (cost == minCost) { - candidateFunctions.push_back(func); - } - } - validateNonEmptyCandidateFunctions(candidateFunctions, normalizedName, inputTypes); - if (candidateFunctions.size() > 1) { - return getBestMatch(candidateFunctions); - } - validateSpecialCases(candidateFunctions, normalizedName, inputTypes); - return candidateFunctions[0]; -} - -AggregateFunction* BuiltInFunctions::matchAggregateFunction( - const std::string& name, const std::vector& inputTypes, bool isDistinct) { - auto& functionSet = functions.at(name); - std::vector candidateFunctions; - for (auto& function : functionSet) { - auto aggregateFunction = ku_dynamic_cast(function.get()); - auto cost = getAggregateFunctionCost(inputTypes, isDistinct, aggregateFunction); - if (cost == UINT32_MAX) { - continue; - } - candidateFunctions.push_back(aggregateFunction); - } - validateNonEmptyCandidateFunctions(candidateFunctions, name, inputTypes, isDistinct); - KU_ASSERT(candidateFunctions.size() == 1); - return candidateFunctions[0]; -} - -uint32_t BuiltInFunctions::getCastCost(LogicalTypeID inputTypeID, LogicalTypeID targetTypeID) { - if (inputTypeID == targetTypeID) { - return 0; - } - // TODO(Jiamin): should check any type - if (inputTypeID == LogicalTypeID::ANY || targetTypeID == LogicalTypeID::ANY || - inputTypeID == LogicalTypeID::RDF_VARIANT) { - // anything can be cast to ANY type for (almost no) cost - return 1; - } - if (targetTypeID == LogicalTypeID::RDF_VARIANT) { - return castFromRDFVariant(inputTypeID); - } - if (targetTypeID == LogicalTypeID::STRING) { - return castFromString(inputTypeID); - } - switch (inputTypeID) { - case LogicalTypeID::INT64: - return castInt64(targetTypeID); - case LogicalTypeID::INT32: - return castInt32(targetTypeID); - case LogicalTypeID::INT16: - return castInt16(targetTypeID); - case LogicalTypeID::INT8: - return castInt8(targetTypeID); - case LogicalTypeID::UINT64: - return castUInt64(targetTypeID); - case LogicalTypeID::UINT32: - return castUInt32(targetTypeID); - case LogicalTypeID::UINT16: - return castUInt16(targetTypeID); - case LogicalTypeID::UINT8: - return castUInt8(targetTypeID); - case LogicalTypeID::INT128: - return castInt128(targetTypeID); - case LogicalTypeID::DOUBLE: - return castDouble(targetTypeID); - case LogicalTypeID::FLOAT: - return castFloat(targetTypeID); - case LogicalTypeID::DATE: - return castDate(targetTypeID); - case LogicalTypeID::UUID: - return castUUID(targetTypeID); - case LogicalTypeID::SERIAL: - return castSerial(targetTypeID); - case LogicalTypeID::TIMESTAMP_SEC: - case LogicalTypeID::TIMESTAMP_MS: - case LogicalTypeID::TIMESTAMP_NS: - case LogicalTypeID::TIMESTAMP_TZ: - // currently don't allow timestamp to other timestamp types - return castTimestamp(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -std::unique_ptr BuiltInFunctions::copy() { - auto result = std::make_unique(); - for (auto& [name, functionSet] : functions) { - std::vector> functionSetToCopy; - for (auto& function : functionSet) { - functionSetToCopy.push_back(function->copy()); - } - result->functions.emplace(name, std::move(functionSetToCopy)); - } - return result; -} - -uint32_t BuiltInFunctions::getTargetTypeCost(LogicalTypeID typeID) { - switch (typeID) { - case LogicalTypeID::INT64: - return 101; - case LogicalTypeID::INT32: - return 102; - case LogicalTypeID::INT128: - return 103; - case LogicalTypeID::DOUBLE: - return 104; - case LogicalTypeID::TIMESTAMP: - return 120; - case LogicalTypeID::STRING: - return 149; - case LogicalTypeID::STRUCT: - case LogicalTypeID::MAP: - case LogicalTypeID::FIXED_LIST: - case LogicalTypeID::VAR_LIST: - case LogicalTypeID::UNION: - return 160; - case LogicalTypeID::RDF_VARIANT: - return 170; - default: - return 110; - } -} - -uint32_t BuiltInFunctions::castInt64(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::INT128: - case LogicalTypeID::FLOAT: - case LogicalTypeID::DOUBLE: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castInt32(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::INT64: - case LogicalTypeID::INT128: - case LogicalTypeID::FLOAT: - case LogicalTypeID::DOUBLE: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castInt16(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::INT32: - case LogicalTypeID::INT64: - case LogicalTypeID::INT128: - case LogicalTypeID::FLOAT: - case LogicalTypeID::DOUBLE: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castInt8(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::INT16: - case LogicalTypeID::INT32: - case LogicalTypeID::INT64: - case LogicalTypeID::INT128: - case LogicalTypeID::FLOAT: - case LogicalTypeID::DOUBLE: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castUInt64(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::INT128: - case LogicalTypeID::FLOAT: - case LogicalTypeID::DOUBLE: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castUInt32(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::INT64: - case LogicalTypeID::INT128: - case LogicalTypeID::UINT64: - case LogicalTypeID::FLOAT: - case LogicalTypeID::DOUBLE: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castUInt16(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::INT32: - case LogicalTypeID::INT64: - case LogicalTypeID::INT128: - case LogicalTypeID::UINT32: - case LogicalTypeID::UINT64: - case LogicalTypeID::FLOAT: - case LogicalTypeID::DOUBLE: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castUInt8(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::INT16: - case LogicalTypeID::INT32: - case LogicalTypeID::INT64: - case LogicalTypeID::INT128: - case LogicalTypeID::UINT16: - case LogicalTypeID::UINT32: - case LogicalTypeID::UINT64: - case LogicalTypeID::FLOAT: - case LogicalTypeID::DOUBLE: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castInt128(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::FLOAT: - case LogicalTypeID::DOUBLE: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castUUID(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::STRING: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castDouble(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castFloat(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::DOUBLE: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castDate(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::TIMESTAMP: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castSerial(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::INT64: - return 0; - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castTimestamp(LogicalTypeID targetTypeID) { - switch (targetTypeID) { - case LogicalTypeID::TIMESTAMP: - return getTargetTypeCost(targetTypeID); - default: - return UNDEFINED_CAST_COST; - } -} - -uint32_t BuiltInFunctions::castFromString(LogicalTypeID inputTypeID) { - switch (inputTypeID) { - case LogicalTypeID::BLOB: - case LogicalTypeID::INTERNAL_ID: - case LogicalTypeID::NODE: - case LogicalTypeID::REL: - case LogicalTypeID::RECURSIVE_REL: - return UNDEFINED_CAST_COST; - default: // Any other inputTypeID can be cast to String, but this cast has a high cost - return getTargetTypeCost(LogicalTypeID::STRING); - } -} - -uint32_t BuiltInFunctions::castFromRDFVariant(LogicalTypeID inputTypeID) { - switch (inputTypeID) { - case LogicalTypeID::STRUCT: - case LogicalTypeID::VAR_LIST: - case LogicalTypeID::FIXED_LIST: - case LogicalTypeID::UNION: - case LogicalTypeID::MAP: - case LogicalTypeID::NODE: - case LogicalTypeID::REL: - case LogicalTypeID::RECURSIVE_REL: - case LogicalTypeID::RDF_VARIANT: - return UNDEFINED_CAST_COST; - default: // Any other inputTypeID can be cast to RDF_VARIANT, but this cast has a high cost - return getTargetTypeCost(LogicalTypeID::RDF_VARIANT); - } -} - -// When there is multiple candidates functions, e.g. double + int and double + double for input -// "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double. -// Additionally, we prefer function with string parameter because string is most permissive and can -// be cast to any type. -Function* BuiltInFunctions::getBestMatch(std::vector& functionsToMatch) { - KU_ASSERT(functionsToMatch.size() > 1); - Function* result = nullptr; - auto cost = UNDEFINED_CAST_COST; - for (auto& function : functionsToMatch) { - auto currentCost = 0u; - std::unordered_set distinctParameterTypes; - for (auto& parameterTypeID : function->parameterTypeIDs) { - if (parameterTypeID != LogicalTypeID::STRING) { - currentCost++; - } - if (!distinctParameterTypes.contains(parameterTypeID)) { - currentCost++; - distinctParameterTypes.insert(parameterTypeID); - } - } - if (currentCost < cost) { - cost = currentCost; - result = function; - } - } - KU_ASSERT(result != nullptr); - return result; -} - -uint32_t BuiltInFunctions::getFunctionCost( - const std::vector& inputTypes, Function* function, bool isOverload) { - switch (function->type) { - case FunctionType::SCALAR: { - auto scalarFunction = ku_dynamic_cast(function); - if (scalarFunction->isVarLength) { - KU_ASSERT(function->parameterTypeIDs.size() == 1); - return matchVarLengthParameters(inputTypes, function->parameterTypeIDs[0], isOverload); - } else { - return matchParameters(inputTypes, function->parameterTypeIDs, isOverload); - } - } - case FunctionType::TABLE: - return matchParameters(inputTypes, function->parameterTypeIDs, isOverload); - default: - KU_UNREACHABLE; - } -} - -uint32_t BuiltInFunctions::getAggregateFunctionCost( - const std::vector& inputTypes, bool isDistinct, AggregateFunction* function) { - if (inputTypes.size() != function->parameterTypeIDs.size() || - isDistinct != function->isDistinct) { - return UINT32_MAX; - } - for (auto i = 0u; i < inputTypes.size(); ++i) { - if (function->parameterTypeIDs[i] == LogicalTypeID::ANY) { - continue; - } else if (inputTypes[i].getLogicalTypeID() != function->parameterTypeIDs[i]) { - return UINT32_MAX; - } - } - return 0; -} - -uint32_t BuiltInFunctions::matchParameters(const std::vector& inputTypes, - const std::vector& targetTypeIDs, bool /*isOverload*/) { - if (inputTypes.size() != targetTypeIDs.size()) { - return UINT32_MAX; - } - auto cost = 0u; - for (auto i = 0u; i < inputTypes.size(); ++i) { - auto castCost = getCastCost(inputTypes[i].getLogicalTypeID(), targetTypeIDs[i]); - if (castCost == UNDEFINED_CAST_COST) { - return UINT32_MAX; - } - cost += castCost; - } - return cost; -} - -uint32_t BuiltInFunctions::matchVarLengthParameters( - const std::vector& inputTypes, LogicalTypeID targetTypeID, bool /*isOverload*/) { - auto cost = 0u; - for (auto inputType : inputTypes) { - auto castCost = getCastCost(inputType.getLogicalTypeID(), targetTypeID); - if (castCost == UNDEFINED_CAST_COST) { - return UINT32_MAX; - } - cost += castCost; - } - return cost; -} - -void BuiltInFunctions::validateSpecialCases(std::vector& candidateFunctions, - const std::string& name, const std::vector& inputTypes) { - // special case for add func - if (name == ADD_FUNC_NAME) { - auto targetType0 = candidateFunctions[0]->parameterTypeIDs[0]; - auto targetType1 = candidateFunctions[0]->parameterTypeIDs[1]; - auto inputType0 = inputTypes[0].getLogicalTypeID(); - auto inputType1 = inputTypes[1].getLogicalTypeID(); - if ((inputType0 != LogicalTypeID::STRING || inputType1 != LogicalTypeID::STRING) && - targetType0 == LogicalTypeID::STRING && targetType1 == LogicalTypeID::STRING) { - if (inputType0 != inputType1 && (inputType0 == LogicalTypeID::RDF_VARIANT || - inputType1 == LogicalTypeID::RDF_VARIANT)) { - return; - } - std::string supportedInputsString; - for (auto& function : functions.at(name)) { - supportedInputsString += function->signatureToString() + "\n"; - } - throw BinderException("Cannot match a built-in function for given function " + name + - LogicalTypeUtils::toString(inputTypes) + - ". Supported inputs are\n" + supportedInputsString); - } - } -} - -void BuiltInFunctions::registerComparisonFunctions() { - functions.insert({EQUALS_FUNC_NAME, EqualsFunction::getFunctionSet()}); - functions.insert({NOT_EQUALS_FUNC_NAME, NotEqualsFunction::getFunctionSet()}); - functions.insert({GREATER_THAN_FUNC_NAME, GreaterThanFunction::getFunctionSet()}); - functions.insert({GREATER_THAN_EQUALS_FUNC_NAME, GreaterThanEqualsFunction::getFunctionSet()}); - functions.insert({LESS_THAN_FUNC_NAME, LessThanFunction::getFunctionSet()}); - functions.insert({LESS_THAN_EQUALS_FUNC_NAME, LessThanEqualsFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerArithmeticFunctions() { - functions.insert({ADD_FUNC_NAME, AddFunction::getFunctionSet()}); - functions.insert({SUBTRACT_FUNC_NAME, SubtractFunction::getFunctionSet()}); - functions.insert({MULTIPLY_FUNC_NAME, MultiplyFunction::getFunctionSet()}); - functions.insert({DIVIDE_FUNC_NAME, DivideFunction::getFunctionSet()}); - functions.insert({MODULO_FUNC_NAME, ModuloFunction::getFunctionSet()}); - functions.insert({POWER_FUNC_NAME, PowerFunction::getFunctionSet()}); - - functions.insert({ABS_FUNC_NAME, AbsFunction::getFunctionSet()}); - functions.insert({ACOS_FUNC_NAME, AcosFunction::getFunctionSet()}); - functions.insert({ASIN_FUNC_NAME, AsinFunction::getFunctionSet()}); - functions.insert({ATAN_FUNC_NAME, AtanFunction::getFunctionSet()}); - functions.insert({ATAN2_FUNC_NAME, Atan2Function::getFunctionSet()}); - functions.insert({BITWISE_XOR_FUNC_NAME, BitwiseXorFunction::getFunctionSet()}); - functions.insert({BITWISE_AND_FUNC_NAME, BitwiseAndFunction::getFunctionSet()}); - functions.insert({BITWISE_OR_FUNC_NAME, BitwiseOrFunction::getFunctionSet()}); - functions.insert({BITSHIFT_LEFT_FUNC_NAME, BitShiftLeftFunction::getFunctionSet()}); - functions.insert({BITSHIFT_RIGHT_FUNC_NAME, BitShiftRightFunction::getFunctionSet()}); - functions.insert({CBRT_FUNC_NAME, CbrtFunction::getFunctionSet()}); - functions.insert({CEIL_FUNC_NAME, CeilFunction::getFunctionSet()}); - functions.insert({CEILING_FUNC_NAME, CeilFunction::getFunctionSet()}); - functions.insert({COS_FUNC_NAME, CosFunction::getFunctionSet()}); - functions.insert({COT_FUNC_NAME, CotFunction::getFunctionSet()}); - functions.insert({DEGREES_FUNC_NAME, DegreesFunction::getFunctionSet()}); - functions.insert({EVEN_FUNC_NAME, EvenFunction::getFunctionSet()}); - functions.insert({FACTORIAL_FUNC_NAME, FactorialFunction::getFunctionSet()}); - functions.insert({FLOOR_FUNC_NAME, FloorFunction::getFunctionSet()}); - functions.insert({GAMMA_FUNC_NAME, GammaFunction::getFunctionSet()}); - functions.insert({LGAMMA_FUNC_NAME, LgammaFunction::getFunctionSet()}); - functions.insert({LN_FUNC_NAME, LnFunction::getFunctionSet()}); - functions.insert({LOG_FUNC_NAME, LogFunction::getFunctionSet()}); - functions.insert({LOG2_FUNC_NAME, Log2Function::getFunctionSet()}); - functions.insert({LOG10_FUNC_NAME, LogFunction::getFunctionSet()}); - functions.insert({NEGATE_FUNC_NAME, NegateFunction::getFunctionSet()}); - functions.insert({PI_FUNC_NAME, PiFunction::getFunctionSet()}); - functions.insert({POW_FUNC_NAME, PowerFunction::getFunctionSet()}); - functions.insert({RADIANS_FUNC_NAME, RadiansFunction::getFunctionSet()}); - functions.insert({ROUND_FUNC_NAME, RoundFunction::getFunctionSet()}); - functions.insert({SIN_FUNC_NAME, SinFunction::getFunctionSet()}); - functions.insert({SIGN_FUNC_NAME, SignFunction::getFunctionSet()}); - functions.insert({SQRT_FUNC_NAME, SqrtFunction::getFunctionSet()}); - functions.insert({TAN_FUNC_NAME, TanFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerDateFunctions() { - functions.insert({DATE_PART_FUNC_NAME, DatePartFunction::getFunctionSet()}); - functions.insert({DATEPART_FUNC_NAME, DatePartFunction::getFunctionSet()}); - functions.insert({DATE_TRUNC_FUNC_NAME, DateTruncFunction::getFunctionSet()}); - functions.insert({DATETRUNC_FUNC_NAME, DateTruncFunction::getFunctionSet()}); - functions.insert({DAYNAME_FUNC_NAME, DayNameFunction::getFunctionSet()}); - functions.insert({GREATEST_FUNC_NAME, GreatestFunction::getFunctionSet()}); - functions.insert({LAST_DAY_FUNC_NAME, LastDayFunction::getFunctionSet()}); - functions.insert({LEAST_FUNC_NAME, LeastFunction::getFunctionSet()}); - functions.insert({MAKE_DATE_FUNC_NAME, MakeDateFunction::getFunctionSet()}); - functions.insert({MONTHNAME_FUNC_NAME, MonthNameFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerTimestampFunctions() { - functions.insert({CENTURY_FUNC_NAME, CenturyFunction::getFunctionSet()}); - functions.insert({EPOCH_MS_FUNC_NAME, EpochMsFunction::getFunctionSet()}); - functions.insert({TO_TIMESTAMP_FUNC_NAME, ToTimestampFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerIntervalFunctions() { - functions.insert({TO_YEARS_FUNC_NAME, ToYearsFunction::getFunctionSet()}); - functions.insert({TO_MONTHS_FUNC_NAME, ToMonthsFunction::getFunctionSet()}); - functions.insert({TO_DAYS_FUNC_NAME, ToDaysFunction::getFunctionSet()}); - functions.insert({TO_HOURS_FUNC_NAME, ToHoursFunction::getFunctionSet()}); - functions.insert({TO_MINUTES_FUNC_NAME, ToMinutesFunction::getFunctionSet()}); - functions.insert({TO_SECONDS_FUNC_NAME, ToSecondsFunction::getFunctionSet()}); - functions.insert({TO_MILLISECONDS_FUNC_NAME, ToMillisecondsFunction::getFunctionSet()}); - functions.insert({TO_MICROSECONDS_FUNC_NAME, ToMicrosecondsFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerBlobFunctions() { - functions.insert({OCTET_LENGTH_FUNC_NAME, OctetLengthFunctions::getFunctionSet()}); - functions.insert({ENCODE_FUNC_NAME, EncodeFunctions::getFunctionSet()}); - functions.insert({DECODE_FUNC_NAME, DecodeFunctions::getFunctionSet()}); -} - -void BuiltInFunctions::registerUUIDFunctions() { - functions.insert({GEN_RANDOM_UUID_FUNC_NAME, GenRandomUUIDFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerStringFunctions() { - functions.insert({ARRAY_EXTRACT_FUNC_NAME, ArrayExtractFunction::getFunctionSet()}); - functions.insert({CONCAT_FUNC_NAME, ConcatFunction::getFunctionSet()}); - functions.insert({CONTAINS_FUNC_NAME, ContainsFunction::getFunctionSet()}); - functions.insert({ENDS_WITH_FUNC_NAME, EndsWithFunction::getFunctionSet()}); - functions.insert({LCASE_FUNC_NAME, LowerFunction::getFunctionSet()}); - functions.insert({LEFT_FUNC_NAME, LeftFunction::getFunctionSet()}); - functions.insert({LOWER_FUNC_NAME, LowerFunction::getFunctionSet()}); - functions.insert({LPAD_FUNC_NAME, LpadFunction::getFunctionSet()}); - functions.insert({LTRIM_FUNC_NAME, LtrimFunction::getFunctionSet()}); - functions.insert({PREFIX_FUNC_NAME, StartsWithFunction::getFunctionSet()}); - functions.insert({REPEAT_FUNC_NAME, RepeatFunction::getFunctionSet()}); - functions.insert({REVERSE_FUNC_NAME, ReverseFunction::getFunctionSet()}); - functions.insert({RIGHT_FUNC_NAME, RightFunction::getFunctionSet()}); - functions.insert({RPAD_FUNC_NAME, RpadFunction::getFunctionSet()}); - functions.insert({RTRIM_FUNC_NAME, RtrimFunction::getFunctionSet()}); - functions.insert({STARTS_WITH_FUNC_NAME, StartsWithFunction::getFunctionSet()}); - functions.insert({SUBSTR_FUNC_NAME, SubStrFunction::getFunctionSet()}); - functions.insert({SUBSTRING_FUNC_NAME, SubStrFunction::getFunctionSet()}); - functions.insert({SUFFIX_FUNC_NAME, EndsWithFunction::getFunctionSet()}); - functions.insert({TRIM_FUNC_NAME, TrimFunction::getFunctionSet()}); - functions.insert({UCASE_FUNC_NAME, UpperFunction::getFunctionSet()}); - functions.insert({UPPER_FUNC_NAME, UpperFunction::getFunctionSet()}); - functions.insert({REGEXP_FULL_MATCH_FUNC_NAME, RegexpFullMatchFunction::getFunctionSet()}); - functions.insert({REGEXP_MATCHES_FUNC_NAME, RegexpMatchesFunction::getFunctionSet()}); - functions.insert({REGEXP_REPLACE_FUNC_NAME, RegexpReplaceFunction::getFunctionSet()}); - functions.insert({REGEXP_EXTRACT_FUNC_NAME, RegexpExtractFunction::getFunctionSet()}); - functions.insert({REGEXP_EXTRACT_ALL_FUNC_NAME, RegexpExtractAllFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerCastFunctions() { - functions.insert({CAST_DATE_FUNC_NAME, CastToDateFunction::getFunctionSet()}); - functions.insert({CAST_TO_DATE_FUNC_NAME, CastToDateFunction::getFunctionSet()}); - functions.insert({CAST_TO_TIMESTAMP_FUNC_NAME, CastToTimestampFunction::getFunctionSet()}); - functions.insert({CAST_INTERVAL_FUNC_NAME, CastToIntervalFunction::getFunctionSet()}); - functions.insert({CAST_TO_INTERVAL_FUNC_NAME, CastToIntervalFunction::getFunctionSet()}); - functions.insert({CAST_STRING_FUNC_NAME, CastToStringFunction::getFunctionSet()}); - functions.insert({CAST_TO_STRING_FUNC_NAME, CastToStringFunction::getFunctionSet()}); - functions.insert({CAST_BLOB_FUNC_NAME, CastToBlobFunction::getFunctionSet()}); - functions.insert({CAST_TO_BLOB_FUNC_NAME, CastToBlobFunction::getFunctionSet()}); - functions.insert({CAST_UUID_FUNC_NAME, CastToUUIDFunction::getFunctionSet()}); - functions.insert({CAST_TO_UUID_FUNC_NAME, CastToUUIDFunction::getFunctionSet()}); - functions.insert({CAST_TO_DOUBLE_FUNC_NAME, CastToDoubleFunction::getFunctionSet()}); - functions.insert({CAST_TO_FLOAT_FUNC_NAME, CastToFloatFunction::getFunctionSet()}); - functions.insert({CAST_TO_SERIAL_FUNC_NAME, CastToSerialFunction::getFunctionSet()}); - functions.insert({CAST_TO_INT64_FUNC_NAME, CastToInt64Function::getFunctionSet()}); - functions.insert({CAST_TO_INT32_FUNC_NAME, CastToInt32Function::getFunctionSet()}); - functions.insert({CAST_TO_INT16_FUNC_NAME, CastToInt16Function::getFunctionSet()}); - functions.insert({CAST_TO_INT8_FUNC_NAME, CastToInt8Function::getFunctionSet()}); - functions.insert({CAST_TO_UINT64_FUNC_NAME, CastToUInt64Function::getFunctionSet()}); - functions.insert({CAST_TO_UINT32_FUNC_NAME, CastToUInt32Function::getFunctionSet()}); - functions.insert({CAST_TO_UINT16_FUNC_NAME, CastToUInt16Function::getFunctionSet()}); - functions.insert({CAST_TO_UINT8_FUNC_NAME, CastToUInt8Function::getFunctionSet()}); - functions.insert({CAST_TO_INT128_FUNC_NAME, CastToInt128Function::getFunctionSet()}); - functions.insert({CAST_TO_BOOL_FUNC_NAME, CastToBoolFunction::getFunctionSet()}); - functions.insert({CAST_FUNC_NAME, CastAnyFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerListFunctions() { - functions.insert({LIST_CREATION_FUNC_NAME, ListCreationFunction::getFunctionSet()}); - functions.insert({LIST_RANGE_FUNC_NAME, ListRangeFunction::getFunctionSet()}); - functions.insert({SIZE_FUNC_NAME, SizeFunction::getFunctionSet()}); - functions.insert({LIST_EXTRACT_FUNC_NAME, ListExtractFunction::getFunctionSet()}); - functions.insert({LIST_ELEMENT_FUNC_NAME, ListExtractFunction::getFunctionSet()}); - functions.insert({LIST_CONCAT_FUNC_NAME, ListConcatFunction::getFunctionSet()}); - functions.insert({LIST_CAT_FUNC_NAME, ListConcatFunction::getFunctionSet()}); - functions.insert({ARRAY_CONCAT_FUNC_NAME, ListConcatFunction::getFunctionSet()}); - functions.insert({ARRAY_CAT_FUNC_NAME, ListConcatFunction::getFunctionSet()}); - functions.insert({LIST_APPEND_FUNC_NAME, ListAppendFunction::getFunctionSet()}); - functions.insert({ARRAY_APPEND_FUNC_NAME, ListAppendFunction::getFunctionSet()}); - functions.insert({ARRAY_PUSH_BACK_FUNC_NAME, ListAppendFunction::getFunctionSet()}); - functions.insert({LIST_PREPEND_FUNC_NAME, ListPrependFunction::getFunctionSet()}); - functions.insert({ARRAY_PREPEND_FUNC_NAME, ListPrependFunction::getFunctionSet()}); - functions.insert({ARRAY_PUSH_FRONT_FUNC_NAME, ListPrependFunction::getFunctionSet()}); - functions.insert({LIST_POSITION_FUNC_NAME, ListPositionFunction::getFunctionSet()}); - functions.insert({ARRAY_POSITION_FUNC_NAME, ListPositionFunction::getFunctionSet()}); - functions.insert({LIST_INDEXOF_FUNC_NAME, ListPositionFunction::getFunctionSet()}); - functions.insert({ARRAY_INDEXOF_FUNC_NAME, ListPositionFunction::getFunctionSet()}); - functions.insert({LIST_CONTAINS_FUNC_NAME, ListContainsFunction::getFunctionSet()}); - functions.insert({LIST_HAS_FUNC_NAME, ListContainsFunction::getFunctionSet()}); - functions.insert({ARRAY_CONTAINS_FUNC_NAME, ListContainsFunction::getFunctionSet()}); - functions.insert({ARRAY_HAS_FUNC_NAME, ListContainsFunction::getFunctionSet()}); - functions.insert({LIST_SLICE_FUNC_NAME, ListSliceFunction::getFunctionSet()}); - functions.insert({ARRAY_SLICE_FUNC_NAME, ListSliceFunction::getFunctionSet()}); - functions.insert({LIST_SORT_FUNC_NAME, ListSortFunction::getFunctionSet()}); - functions.insert({LIST_REVERSE_SORT_FUNC_NAME, ListReverseSortFunction::getFunctionSet()}); - functions.insert({LIST_SUM_FUNC_NAME, ListSumFunction::getFunctionSet()}); - functions.insert({LIST_PRODUCT_FUNC_NAME, ListProductFunction::getFunctionSet()}); - functions.insert({LIST_DISTINCT_FUNC_NAME, ListDistinctFunction::getFunctionSet()}); - functions.insert({LIST_UNIQUE_FUNC_NAME, ListUniqueFunction::getFunctionSet()}); - functions.insert({LIST_ANY_VALUE_FUNC_NAME, ListAnyValueFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerStructFunctions() { - functions.insert({STRUCT_PACK_FUNC_NAME, StructPackFunctions::getFunctionSet()}); - functions.insert({STRUCT_EXTRACT_FUNC_NAME, StructExtractFunctions::getFunctionSet()}); -} - -void BuiltInFunctions::registerMapFunctions() { - functions.insert({MAP_CREATION_FUNC_NAME, MapCreationFunctions::getFunctionSet()}); - functions.insert({MAP_EXTRACT_FUNC_NAME, MapExtractFunctions::getFunctionSet()}); - functions.insert({ELEMENT_AT_FUNC_NAME, MapExtractFunctions::getFunctionSet()}); - functions.insert({CARDINALITY_FUNC_NAME, SizeFunction::getFunctionSet()}); - functions.insert({MAP_KEYS_FUNC_NAME, MapKeysFunctions::getFunctionSet()}); - functions.insert({MAP_VALUES_FUNC_NAME, MapValuesFunctions::getFunctionSet()}); -} - -void BuiltInFunctions::registerUnionFunctions() { - functions.insert({UNION_VALUE_FUNC_NAME, UnionValueFunction::getFunctionSet()}); - functions.insert({UNION_TAG_FUNC_NAME, UnionTagFunction::getFunctionSet()}); - functions.insert({UNION_EXTRACT_FUNC_NAME, UnionExtractFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerNodeRelFunctions() { - functions.insert({OFFSET_FUNC_NAME, OffsetFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerPathFunctions() { - functions.insert({NODES_FUNC_NAME, NodesFunction::getFunctionSet()}); - functions.insert({RELS_FUNC_NAME, RelsFunction::getFunctionSet()}); - functions.insert({PROPERTIES_FUNC_NAME, PropertiesFunction::getFunctionSet()}); - functions.insert({IS_TRAIL_FUNC_NAME, IsTrailFunction::getFunctionSet()}); - functions.insert({IS_ACYCLIC_FUNC_NAME, IsACyclicFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerRdfFunctions() { - functions.insert({TYPE_FUNC_NAME, RDFTypeFunction::getFunctionSet()}); - functions.insert({VALIDATE_PREDICATE_FUNC_NAME, ValidatePredicateFunction::getFunctionSet()}); -} - -void BuiltInFunctions::registerCountStar() { - function_set functionSet; - functionSet.push_back(std::make_unique(COUNT_STAR_FUNC_NAME, - std::vector{}, LogicalTypeID::INT64, CountStarFunction::initialize, - CountStarFunction::updateAll, CountStarFunction::updatePos, CountStarFunction::combine, - CountStarFunction::finalize, false)); - functions.insert({COUNT_STAR_FUNC_NAME, std::move(functionSet)}); -} - -void BuiltInFunctions::registerCount() { - function_set functionSet; - for (auto& type : LogicalTypeUtils::getAllValidLogicTypes()) { - for (auto isDistinct : std::vector{true, false}) { - functionSet.push_back(AggregateFunctionUtil::getAggFunc(COUNT_FUNC_NAME, - type, LogicalTypeID::INT64, isDistinct, CountFunction::paramRewriteFunc)); - } - } - functions.insert({COUNT_FUNC_NAME, std::move(functionSet)}); -} - -void BuiltInFunctions::registerSum() { - function_set functionSet; - for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - for (auto isDistinct : std::vector{true, false}) { - functionSet.push_back( - AggregateFunctionUtil::getSumFunc(SUM_FUNC_NAME, typeID, typeID, isDistinct)); - } - } - functions.insert({SUM_FUNC_NAME, std::move(functionSet)}); -} - -void BuiltInFunctions::registerAvg() { - function_set functionSet; - for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - for (auto isDistinct : std::vector{true, false}) { - functionSet.push_back(AggregateFunctionUtil::getAvgFunc( - AVG_FUNC_NAME, typeID, LogicalTypeID::DOUBLE, isDistinct)); - } - } - functions.insert({AVG_FUNC_NAME, std::move(functionSet)}); -} - -void BuiltInFunctions::registerMin() { - function_set functionSet; - for (auto& type : LogicalTypeUtils::getAllValidComparableLogicalTypes()) { - for (auto isDistinct : std::vector{true, false}) { - functionSet.push_back(AggregateFunctionUtil::getMinFunc(type, isDistinct)); - } - } - functions.insert({MIN_FUNC_NAME, std::move(functionSet)}); -} - -void BuiltInFunctions::registerMax() { - function_set functionSet; - for (auto& type : LogicalTypeUtils::getAllValidComparableLogicalTypes()) { - for (auto isDistinct : std::vector{true, false}) { - functionSet.push_back(AggregateFunctionUtil::getMaxFunc(type, isDistinct)); - } - } - functions.insert({MAX_FUNC_NAME, std::move(functionSet)}); -} - -void BuiltInFunctions::registerCollect() { - function_set functionSet; - for (auto isDistinct : std::vector{true, false}) { - functionSet.push_back(std::make_unique(COLLECT_FUNC_NAME, - std::vector{common::LogicalTypeID::ANY}, LogicalTypeID::VAR_LIST, - CollectFunction::initialize, CollectFunction::updateAll, CollectFunction::updatePos, - CollectFunction::combine, CollectFunction::finalize, isDistinct, - CollectFunction::bindFunc)); - } - functions.insert({COLLECT_FUNC_NAME, std::move(functionSet)}); -} - -void BuiltInFunctions::registerTableFunctions() { - functions.insert({CURRENT_SETTING_FUNC_NAME, CurrentSettingFunction::getFunctionSet()}); - functions.insert({DB_VERSION_FUNC_NAME, DBVersionFunction::getFunctionSet()}); - functions.insert({SHOW_TABLES_FUNC_NAME, ShowTablesFunction::getFunctionSet()}); - functions.insert({TABLE_INFO_FUNC_NAME, TableInfoFunction::getFunctionSet()}); - functions.insert({SHOW_CONNECTION_FUNC_NAME, ShowConnectionFunction::getFunctionSet()}); - functions.insert({READ_PARQUET_FUNC_NAME, processor::ParquetScanFunction::getFunctionSet()}); - functions.insert({READ_NPY_FUNC_NAME, processor::NpyScanFunction::getFunctionSet()}); - functions.insert({READ_CSV_SERIAL_FUNC_NAME, processor::SerialCSVScan::getFunctionSet()}); - functions.insert({READ_CSV_PARALLEL_FUNC_NAME, processor::ParallelCSVScan::getFunctionSet()}); - functions.insert({READ_RDF_RESOURCE_FUNC_NAME, processor::RdfResourceScan::getFunctionSet()}); - functions.insert({READ_RDF_LITERAL_FUNC_NAME, processor::RdfLiteralScan::getFunctionSet()}); - functions.insert( - {READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, processor::RdfResourceTripleScan::getFunctionSet()}); - functions.insert( - {READ_RDF_LITERAL_TRIPLE_FUNC_NAME, processor::RdfLiteralTripleScan::getFunctionSet()}); - functions.insert( - {READ_RDF_ALL_TRIPLE_FUNC_NAME, processor::RdfAllTripleScan::getFunctionSet()}); - functions.insert( - {IN_MEM_READ_RDF_RESOURCE_FUNC_NAME, processor::RdfResourceInMemScan::getFunctionSet()}); - functions.insert( - {IN_MEM_READ_RDF_LITERAL_FUNC_NAME, processor::RdfLiteralInMemScan::getFunctionSet()}); - functions.insert({IN_MEM_READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, - processor::RdfResourceTripleInMemScan::getFunctionSet()}); - functions.insert({IN_MEM_READ_RDF_LITERAL_TRIPLE_FUNC_NAME, - processor::RdfLiteralTripleInMemScan::getFunctionSet()}); - functions.insert({STORAGE_INFO_FUNC_NAME, StorageInfoFunction::getFunctionSet()}); -} - -void BuiltInFunctions::validateFunctionExists(const std::string& name) { - if (!functions.contains(name)) { - throw CatalogException(stringFormat("{} function does not exist.", name)); - } -} - -static std::string getFunctionMatchFailureMsg(const std::string name, - const std::vector& inputTypes, const std::string& supportedInputs, - bool isDistinct = false) { - auto result = stringFormat("Cannot match a built-in function for given function {}{}{}.", name, - isDistinct ? "DISTINCT " : "", LogicalTypeUtils::toString(inputTypes)); - if (supportedInputs.empty()) { - result += " Expect empty inputs."; - } else { - result += " Supported inputs are\n" + supportedInputs; - } - return result; -} - -void BuiltInFunctions::validateNonEmptyCandidateFunctions( - std::vector& candidateFunctions, const std::string& name, - const std::vector& inputTypes, bool isDistinct) { - if (candidateFunctions.empty()) { - std::string supportedInputsString; - for (auto& function : functions.at(name)) { - auto aggregateFunction = ku_dynamic_cast(function.get()); - if (aggregateFunction->isDistinct) { - supportedInputsString += "DISTINCT "; - } - supportedInputsString += aggregateFunction->signatureToString() + "\n"; - } - throw BinderException( - getFunctionMatchFailureMsg(name, inputTypes, supportedInputsString, isDistinct)); - } -} - -void BuiltInFunctions::validateNonEmptyCandidateFunctions( - std::vector& candidateFunctions, const std::string& name, - const std::vector& inputTypes) { - if (candidateFunctions.empty()) { - std::string supportedInputsString; - for (auto& function : functions.at(name)) { - if (function->parameterTypeIDs.empty()) { - continue; - } - supportedInputsString += function->signatureToString() + "\n"; - } - throw BinderException(getFunctionMatchFailureMsg(name, inputTypes, supportedInputsString)); - } -} - -void BuiltInFunctions::addFunction(std::string name, function::function_set definitions) { - if (functions.contains(name)) { - throw CatalogException{stringFormat("function {} already exists.", name)}; - } - functions.emplace(std::move(name), std::move(definitions)); -} - -} // namespace function -} // namespace kuzu diff --git a/src/function/vector_cast_functions.cpp b/src/function/vector_cast_functions.cpp index 72a8247ca3..ebfcbba068 100644 --- a/src/function/vector_cast_functions.cpp +++ b/src/function/vector_cast_functions.cpp @@ -158,8 +158,8 @@ bool CastFunction::hasImplicitCast(const LogicalType& srcType, const LogicalType dstType.getLogicalTypeID() == LogicalTypeID::VAR_LIST) { return false; } - if (BuiltInFunctions::getCastCost(srcType.getLogicalTypeID(), dstType.getLogicalTypeID()) != - UNDEFINED_CAST_COST) { + if (BuiltInFunctionsUtils::getCastCost( + srcType.getLogicalTypeID(), dstType.getLogicalTypeID()) != UNDEFINED_CAST_COST) { return true; } // TODO(Jiamin): there are still other special cases diff --git a/src/include/catalog/catalog.h b/src/include/catalog/catalog.h index 3263a8c5e3..cc980e33e9 100644 --- a/src/include/catalog/catalog.h +++ b/src/include/catalog/catalog.h @@ -50,6 +50,7 @@ class Catalog { std::vector getTableEntries(transaction::Transaction* tx) const; std::vector getTableSchemas( transaction::Transaction* tx, const common::table_id_vector_t& tableIDs) const; + CatalogSet* getFunctions(transaction::Transaction* tx) const; common::table_id_t addNodeTableSchema(const binder::BoundCreateTableInfo& info); common::table_id_t addRelTableSchema(const binder::BoundCreateTableInfo& info); @@ -71,9 +72,6 @@ class Catalog { void setTableComment(common::table_id_t tableID, const std::string& comment); // ----------------------------- Functions ---------------------------- - inline function::BuiltInFunctions* getBuiltInFunctions(transaction::Transaction* tx) const { - return getVersion(tx)->builtInFunctions.get(); - } common::ExpressionType getFunctionType( transaction::Transaction* tx, const std::string& name) const; void addFunction(std::string name, function::function_set functionSet); @@ -82,8 +80,8 @@ class Catalog { void addScalarMacroFunction( std::string name, std::unique_ptr macro); // TODO(Ziyi): pass transaction pointer here. - inline function::ScalarMacroFunction* getScalarMacroFunction(const std::string& name) const { - return readOnlyVersion->macros.at(name).get(); + function::ScalarMacroFunction* getScalarMacroFunction(const std::string& name) const { + return readOnlyVersion->getScalarMacroFunction(name); } std::vector getMacroNames(transaction::Transaction* tx) const; diff --git a/src/include/catalog/catalog_content.h b/src/include/catalog/catalog_content.h index 7df89c6876..2281911185 100644 --- a/src/include/catalog/catalog_content.h +++ b/src/include/catalog/catalog_content.h @@ -3,7 +3,7 @@ #include "binder/ddl/bound_create_table_info.h" #include "catalog_set.h" #include "common/cast.h" -#include "function/built_in_function.h" +#include "function/built_in_function_utils.h" #include "function/scalar_macro_function.h" namespace kuzu { @@ -24,13 +24,10 @@ class CatalogContent { CatalogContent(std::unique_ptr tables, std::unordered_map tableNameToIDMap, - common::table_id_t nextTableID, - std::unique_ptr builtInFunctions, - std::unordered_map> macros, + common::table_id_t nextTableID, std::unique_ptr functions, common::VirtualFileSystem* vfs) - : tableNameToIDMap{std::move(tableNameToIDMap)}, nextTableID{nextTableID}, - builtInFunctions{std::move(builtInFunctions)}, macros{std::move(macros)}, vfs{vfs}, - tables{std::move(tables)} {} + : tableNameToIDMap{std::move(tableNameToIDMap)}, nextTableID{nextTableID}, vfs{vfs}, + tables{std::move(tables)}, functions{std::move(functions)} {} void saveToFile(const std::string& directory, common::FileVersionType dbFileType); void readFromFile(const std::string& directory, common::FileVersionType dbFileType); @@ -43,11 +40,15 @@ class CatalogContent { void registerBuiltInFunctions(); - bool containMacro(const std::string& macroName) const { return macros.contains(macroName); } + bool containMacro(const std::string& macroName) const { + return functions->containsEntry(macroName); + } void addFunction(std::string name, function::function_set definitions); void addScalarMacroFunction( std::string name, std::unique_ptr macro); + function::ScalarMacroFunction* getScalarMacroFunction(const std::string& name) const; + // ----------------------------- Table entries ---------------------------- common::table_id_t assignNextTableID() { return nextTableID++; } uint64_t getNumTables() const { return tables->getEntries().size(); } @@ -85,10 +86,9 @@ class CatalogContent { // is re-constructed when reading from the catalog file. std::unordered_map tableNameToIDMap; common::table_id_t nextTableID; - std::unique_ptr builtInFunctions; - std::unordered_map> macros; common::VirtualFileSystem* vfs; std::unique_ptr tables; + std::unique_ptr functions; }; } // namespace catalog diff --git a/src/include/catalog/catalog_entry/aggregate_function_catalog_entry.h b/src/include/catalog/catalog_entry/aggregate_function_catalog_entry.h new file mode 100644 index 0000000000..16a29abb67 --- /dev/null +++ b/src/include/catalog/catalog_entry/aggregate_function_catalog_entry.h @@ -0,0 +1,24 @@ +#pragma once + +#include "catalog_entry.h" +#include "function_catalog_entry.h" + +namespace kuzu { +namespace catalog { + +class AggregateFunctionCatalogEntry : public FunctionCatalogEntry { +public: + //===--------------------------------------------------------------------===// + // constructors + //===--------------------------------------------------------------------===// + AggregateFunctionCatalogEntry() = default; + AggregateFunctionCatalogEntry(std::string name, function::function_set functionSet); + + //===--------------------------------------------------------------------===// + // serialization & deserialization + //===--------------------------------------------------------------------===// + std::unique_ptr copy() const override; +}; + +} // namespace catalog +} // namespace kuzu diff --git a/src/include/catalog/catalog_entry/catalog_entry.h b/src/include/catalog/catalog_entry/catalog_entry.h index 2a4b590562..e2b0ab3e5b 100644 --- a/src/include/catalog/catalog_entry/catalog_entry.h +++ b/src/include/catalog/catalog_entry/catalog_entry.h @@ -4,6 +4,7 @@ #include "catalog_entry_type.h" #include "common/serializer/deserializer.h" #include "common/serializer/serializer.h" +#include "main/client_context.h" namespace kuzu { namespace catalog { @@ -30,6 +31,7 @@ class CatalogEntry { virtual void serialize(common::Serializer& serializer) const; static std::unique_ptr deserialize(common::Deserializer& deserializer); virtual std::unique_ptr copy() const = 0; + virtual std::string toCypher(main::ClientContext* /*clientContext*/) const { KU_UNREACHABLE; } private: CatalogEntryType type; diff --git a/src/include/catalog/catalog_entry/catalog_entry_type.h b/src/include/catalog/catalog_entry/catalog_entry_type.h index 127e24df7e..f0dcf30f8f 100644 --- a/src/include/catalog/catalog_entry/catalog_entry_type.h +++ b/src/include/catalog/catalog_entry/catalog_entry_type.h @@ -10,6 +10,10 @@ enum class CatalogEntryType : uint8_t { REL_TABLE_ENTRY = 1, REL_GROUP_ENTRY = 2, RDF_GRAPH_ENTRY = 3, + SCALAR_MACRO_ENTRY = 4, + AGGREGATE_FUNCTION_ENTRY = 5, + SCALAR_FUNCTION_ENTRY = 6, + TABLE_FUNCTION_ENTRY = 7, }; } // namespace catalog diff --git a/src/include/catalog/catalog_entry/function_catalog_entry.h b/src/include/catalog/catalog_entry/function_catalog_entry.h new file mode 100644 index 0000000000..8c05b914ba --- /dev/null +++ b/src/include/catalog/catalog_entry/function_catalog_entry.h @@ -0,0 +1,36 @@ +#pragma once + +#include "catalog_entry.h" +#include "function/scalar_function.h" + +namespace kuzu { +namespace catalog { + +class FunctionCatalogEntry : public CatalogEntry { +public: + //===--------------------------------------------------------------------===// + // constructors + //===--------------------------------------------------------------------===// + FunctionCatalogEntry() = default; + FunctionCatalogEntry( + CatalogEntryType entryType, std::string name, function::function_set functionSet); + + //===--------------------------------------------------------------------===// + // getters & setters + //===--------------------------------------------------------------------===// + function::function_set& getFunctionSet() { return functionSet; } + + //===--------------------------------------------------------------------===// + // serialization & deserialization + //===--------------------------------------------------------------------===// + // We always register functions while initializing the catalog, so we don't have to + // serialize functions. + void serialize(common::Serializer& /*serializer*/) const override { return; } + std::unique_ptr copy() const override; + +protected: + function::function_set functionSet; +}; + +} // namespace catalog +} // namespace kuzu diff --git a/src/include/catalog/catalog_entry/scalar_function_catalog_entry.h b/src/include/catalog/catalog_entry/scalar_function_catalog_entry.h new file mode 100644 index 0000000000..90470208ff --- /dev/null +++ b/src/include/catalog/catalog_entry/scalar_function_catalog_entry.h @@ -0,0 +1,24 @@ +#pragma once + +#include "catalog_entry.h" +#include "function_catalog_entry.h" + +namespace kuzu { +namespace catalog { + +class ScalarFunctionCatalogEntry final : public FunctionCatalogEntry { +public: + //===--------------------------------------------------------------------===// + // constructors + //===--------------------------------------------------------------------===// + ScalarFunctionCatalogEntry() = default; + ScalarFunctionCatalogEntry(std::string name, function::function_set functionSet); + + //===--------------------------------------------------------------------===// + // serialization & deserialization + //===--------------------------------------------------------------------===// + std::unique_ptr copy() const override; +}; + +} // namespace catalog +} // namespace kuzu diff --git a/src/include/catalog/catalog_entry/scalar_macro_catalog_entry.h b/src/include/catalog/catalog_entry/scalar_macro_catalog_entry.h new file mode 100644 index 0000000000..c211164d95 --- /dev/null +++ b/src/include/catalog/catalog_entry/scalar_macro_catalog_entry.h @@ -0,0 +1,36 @@ +#pragma once + +#include "catalog_entry.h" +#include "function/scalar_macro_function.h" + +namespace kuzu { +namespace catalog { + +class ScalarMacroCatalogEntry final : public CatalogEntry { +public: + //===--------------------------------------------------------------------===// + // constructors + //===--------------------------------------------------------------------===// + ScalarMacroCatalogEntry() = default; + ScalarMacroCatalogEntry( + std::string name, std::unique_ptr macroFunction); + + //===--------------------------------------------------------------------===// + // getter & setter + //===--------------------------------------------------------------------===// + function::ScalarMacroFunction* getMacroFunction() const { return macroFunction.get(); } + + //===--------------------------------------------------------------------===// + // serialization & deserialization + //===--------------------------------------------------------------------===// + void serialize(common::Serializer& serializer) const override; + static std::unique_ptr deserialize(common::Deserializer& deserializer); + std::unique_ptr copy() const override; + std::string toCypher(main::ClientContext* clientContext) const override; + +private: + std::unique_ptr macroFunction; +}; + +} // namespace catalog +} // namespace kuzu diff --git a/src/include/catalog/catalog_entry/table_catalog_entry.h b/src/include/catalog/catalog_entry/table_catalog_entry.h index 96a1ac47b2..eaf505bd36 100644 --- a/src/include/catalog/catalog_entry/table_catalog_entry.h +++ b/src/include/catalog/catalog_entry/table_catalog_entry.h @@ -5,7 +5,6 @@ #include "catalog/property.h" #include "catalog_entry.h" #include "common/enums/table_type.h" -#include "main/client_context.h" namespace kuzu { namespace catalog { @@ -51,7 +50,6 @@ class TableCatalogEntry : public CatalogEntry { void serialize(common::Serializer& serializer) const override; static std::unique_ptr deserialize( common::Deserializer& deserializer, CatalogEntryType type); - virtual std::string toCypher(main::ClientContext* /*clientContext*/) const { KU_UNREACHABLE; } private: common::table_id_t tableID; diff --git a/src/include/catalog/catalog_entry/table_function_catalog_entry.h b/src/include/catalog/catalog_entry/table_function_catalog_entry.h new file mode 100644 index 0000000000..ba5b8e9d5d --- /dev/null +++ b/src/include/catalog/catalog_entry/table_function_catalog_entry.h @@ -0,0 +1,23 @@ +#pragma once + +#include "function_catalog_entry.h" + +namespace kuzu { +namespace catalog { + +class TableFunctionCatalogEntry : public FunctionCatalogEntry { +public: + //===--------------------------------------------------------------------===// + // constructors + //===--------------------------------------------------------------------===// + TableFunctionCatalogEntry() = default; + TableFunctionCatalogEntry(std::string name, function::function_set functionSet); + + //===--------------------------------------------------------------------===// + // serialization & deserialization + //===--------------------------------------------------------------------===// + std::unique_ptr copy() const override; +}; + +} // namespace catalog +} // namespace kuzu diff --git a/src/include/common/utils.h b/src/include/common/utils.h index 222908bb0b..9b287542b3 100644 --- a/src/include/common/utils.h +++ b/src/include/common/utils.h @@ -2,6 +2,7 @@ #include #include +#include #include "common/assert.h" #include "common/constants.h" @@ -27,28 +28,24 @@ class LoggerUtils { class BitmaskUtils { public: - static inline uint64_t all1sMaskForLeastSignificantBits(uint64_t numBits) { + static uint64_t all1sMaskForLeastSignificantBits(uint64_t numBits) { KU_ASSERT(numBits <= 64); return numBits == 64 ? UINT64_MAX : ((uint64_t)1 << numBits) - 1; } }; -inline uint64_t nextPowerOfTwo(uint64_t v) { - v--; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - v |= v >> 32; - v++; - return v; -} +uint64_t nextPowerOfTwo(uint64_t v); + +bool isLittleEndian(); -inline bool isLittleEndian() { - // Little endian arch stores the least significant value in the lower bytes. - int testNumber = 1; - return *(uint8_t*)&testNumber == 1; +template +std::vector copyVector(const std::vector& objects) { + std::vector result; + result.reserve(objects.size()); + for (auto& object : objects) { + result.push_back(object->copy()); + } + return result; } } // namespace common diff --git a/src/include/function/built_in_function.h b/src/include/function/built_in_function.h deleted file mode 100644 index 2b5bb53526..0000000000 --- a/src/include/function/built_in_function.h +++ /dev/null @@ -1,127 +0,0 @@ -#pragma once - -#include "aggregate_function.h" -#include "scalar_function.h" - -namespace kuzu { -namespace function { - -class BuiltInFunctions { -public: - BuiltInFunctions(); - - FunctionType getFunctionType(const std::string& name); - - Function* matchFunction(const std::string& name); - // TODO(Ziyi): We should have a unified interface for matching table, aggregate and scalar - // functions. - Function* matchFunction( - const std::string& name, const std::vector& inputTypes); - - AggregateFunction* matchAggregateFunction(const std::string& name, - const std::vector& inputTypes, bool isDistinct); - - static uint32_t getCastCost( - common::LogicalTypeID inputTypeID, common::LogicalTypeID targetTypeID); - - void addFunction(std::string name, function::function_set definitions); - - std::unique_ptr copy(); - -private: - static uint32_t getTargetTypeCost(common::LogicalTypeID typeID); - - static uint32_t castInt64(common::LogicalTypeID targetTypeID); - - static uint32_t castInt32(common::LogicalTypeID targetTypeID); - - static uint32_t castInt16(common::LogicalTypeID targetTypeID); - - static uint32_t castInt8(common::LogicalTypeID targetTypeID); - - static uint32_t castUInt64(common::LogicalTypeID targetTypeID); - - static uint32_t castUInt32(common::LogicalTypeID targetTypeID); - - static uint32_t castUInt16(common::LogicalTypeID targetTypeID); - - static uint32_t castUInt8(common::LogicalTypeID targetTypeID); - - static uint32_t castInt128(common::LogicalTypeID targetTypeID); - - static uint32_t castDouble(common::LogicalTypeID targetTypeID); - - static uint32_t castFloat(common::LogicalTypeID targetTypeID); - - static uint32_t castDate(common::LogicalTypeID targetTypeID); - - static uint32_t castSerial(common::LogicalTypeID targetTypeID); - - static uint32_t castTimestamp(common::LogicalTypeID targetTypeID); - - static uint32_t castFromString(common::LogicalTypeID inputTypeID); - - static uint32_t castFromRDFVariant(common::LogicalTypeID inputTypeID); - - static uint32_t castUUID(common::LogicalTypeID targetTypeID); - - Function* getBestMatch(std::vector& functions); - - uint32_t getFunctionCost( - const std::vector& inputTypes, Function* function, bool isOverload); - uint32_t matchParameters(const std::vector& inputTypes, - const std::vector& targetTypeIDs, bool isOverload); - uint32_t matchVarLengthParameters(const std::vector& inputTypes, - common::LogicalTypeID targetTypeID, bool isOverload); - uint32_t getAggregateFunctionCost(const std::vector& inputTypes, - bool isDistinct, AggregateFunction* function); - - void validateSpecialCases(std::vector& candidateFunctions, const std::string& name, - const std::vector& inputTypes); - - // Scalar functions. - void registerScalarFunctions(); - void registerComparisonFunctions(); - void registerArithmeticFunctions(); - void registerDateFunctions(); - void registerTimestampFunctions(); - void registerIntervalFunctions(); - void registerBlobFunctions(); - void registerUUIDFunctions(); - void registerStringFunctions(); - void registerCastFunctions(); - void registerListFunctions(); - void registerStructFunctions(); - void registerMapFunctions(); - void registerUnionFunctions(); - void registerNodeRelFunctions(); - void registerPathFunctions(); - void registerRdfFunctions(); - - // Aggregate functions. - void registerAggregateFunctions(); - void registerCountStar(); - void registerCount(); - void registerSum(); - void registerAvg(); - void registerMin(); - void registerMax(); - void registerCollect(); - - // Table functions. - void registerTableFunctions(); - - // Validations - void validateFunctionExists(const std::string& name); - void validateNonEmptyCandidateFunctions(std::vector& candidateFunctions, - const std::string& name, const std::vector& inputTypes, - bool isDistinct); - void validateNonEmptyCandidateFunctions(std::vector& candidateFunctions, - const std::string& name, const std::vector& inputTypes); - -private: - std::unordered_map functions; -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/built_in_function_utils.h b/src/include/function/built_in_function_utils.h new file mode 100644 index 0000000000..e93577f975 --- /dev/null +++ b/src/include/function/built_in_function_utils.h @@ -0,0 +1,126 @@ +#pragma once + +#include "aggregate_function.h" +#include "scalar_function.h" + +namespace kuzu { +namespace catalog { +class CatalogSet; +} // namespace catalog + +namespace function { + +class BuiltInFunctionsUtils { +public: + static void createFunctions(catalog::CatalogSet* catalogSet); + + static Function* matchFunction(const std::string& name, catalog::CatalogSet* catalogSet); + // TODO(Ziyi): We should have a unified interface for matching table, aggregate and scalar + // functions. + static Function* matchFunction(const std::string& name, + const std::vector& inputTypes, catalog::CatalogSet* catalogSet); + + static AggregateFunction* matchAggregateFunction(const std::string& name, + const std::vector& inputTypes, bool isDistinct, + catalog::CatalogSet* catalogSet); + + static uint32_t getCastCost( + common::LogicalTypeID inputTypeID, common::LogicalTypeID targetTypeID); + +private: + // TODO(Xiyang): move casting cost related functions to binder. + static uint32_t getTargetTypeCost(common::LogicalTypeID typeID); + + static uint32_t castInt64(common::LogicalTypeID targetTypeID); + + static uint32_t castInt32(common::LogicalTypeID targetTypeID); + + static uint32_t castInt16(common::LogicalTypeID targetTypeID); + + static uint32_t castInt8(common::LogicalTypeID targetTypeID); + + static uint32_t castUInt64(common::LogicalTypeID targetTypeID); + + static uint32_t castUInt32(common::LogicalTypeID targetTypeID); + + static uint32_t castUInt16(common::LogicalTypeID targetTypeID); + + static uint32_t castUInt8(common::LogicalTypeID targetTypeID); + + static uint32_t castInt128(common::LogicalTypeID targetTypeID); + + static uint32_t castDouble(common::LogicalTypeID targetTypeID); + + static uint32_t castFloat(common::LogicalTypeID targetTypeID); + + static uint32_t castDate(common::LogicalTypeID targetTypeID); + + static uint32_t castSerial(common::LogicalTypeID targetTypeID); + + static uint32_t castTimestamp(common::LogicalTypeID targetTypeID); + + static uint32_t castFromString(common::LogicalTypeID inputTypeID); + + static uint32_t castFromRDFVariant(common::LogicalTypeID inputTypeID); + + static uint32_t castUUID(common::LogicalTypeID targetTypeID); + + static Function* getBestMatch(std::vector& functions); + + static uint32_t getFunctionCost( + const std::vector& inputTypes, Function* function, bool isOverload); + static uint32_t matchParameters(const std::vector& inputTypes, + const std::vector& targetTypeIDs, bool isOverload); + static uint32_t matchVarLengthParameters(const std::vector& inputTypes, + common::LogicalTypeID targetTypeID, bool isOverload); + static uint32_t getAggregateFunctionCost(const std::vector& inputTypes, + bool isDistinct, AggregateFunction* function); + + static void validateSpecialCases(std::vector& candidateFunctions, + const std::string& name, const std::vector& inputTypes, + function::function_set& set); + + // Scalar functions. + static void registerScalarFunctions(catalog::CatalogSet* catalogSet); + static void registerComparisonFunctions(catalog::CatalogSet* catalogSet); + static void registerArithmeticFunctions(catalog::CatalogSet* catalogSet); + static void registerDateFunctions(catalog::CatalogSet* catalogSet); + static void registerTimestampFunctions(catalog::CatalogSet* catalogSet); + static void registerIntervalFunctions(catalog::CatalogSet* catalogSet); + static void registerBlobFunctions(catalog::CatalogSet* catalogSet); + static void registerUUIDFunctions(catalog::CatalogSet* catalogSet); + static void registerStringFunctions(catalog::CatalogSet* catalogSet); + static void registerCastFunctions(catalog::CatalogSet* catalogSet); + static void registerListFunctions(catalog::CatalogSet* catalogSet); + static void registerStructFunctions(catalog::CatalogSet* catalogSet); + static void registerMapFunctions(catalog::CatalogSet* catalogSet); + static void registerUnionFunctions(catalog::CatalogSet* catalogSet); + static void registerNodeRelFunctions(catalog::CatalogSet* catalogSet); + static void registerPathFunctions(catalog::CatalogSet* catalogSet); + static void registerRdfFunctions(catalog::CatalogSet* catalogSet); + + // Aggregate functions. + static void registerAggregateFunctions(catalog::CatalogSet* catalogSet); + static void registerCountStar(catalog::CatalogSet* catalogSet); + static void registerCount(catalog::CatalogSet* catalogSet); + static void registerSum(catalog::CatalogSet* catalogSet); + static void registerAvg(catalog::CatalogSet* catalogSet); + static void registerMin(catalog::CatalogSet* catalogSet); + static void registerMax(catalog::CatalogSet* catalogSet); + static void registerCollect(catalog::CatalogSet* catalogSet); + + // Table functions. + static void registerTableFunctions(catalog::CatalogSet* catalogSet); + + // Validations + static void validateNonEmptyCandidateFunctions( + std::vector& candidateFunctions, const std::string& name, + const std::vector& inputTypes, bool isDistinct, + function::function_set& set); + static void validateNonEmptyCandidateFunctions(std::vector& candidateFunctions, + const std::string& name, const std::vector& inputTypes, + function::function_set& set); +}; + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/scalar_macro_function.h b/src/include/function/scalar_macro_function.h index 6e60ac8014..62ea51081e 100644 --- a/src/include/function/scalar_macro_function.h +++ b/src/include/function/scalar_macro_function.h @@ -14,6 +14,8 @@ struct ScalarMacroFunction { std::vector positionalArgs; parser::default_macro_args defaultArgs; + ScalarMacroFunction() = default; + ScalarMacroFunction(std::unique_ptr expression, std::vector positionalArgs, parser::default_macro_args defaultArgs) : expression{std::move(expression)}, positionalArgs{std::move(positionalArgs)}, diff --git a/src/include/function/table_functions/call_functions.h b/src/include/function/table_functions/call_functions.h index b23e464475..d393f4a1ec 100644 --- a/src/include/function/table_functions/call_functions.h +++ b/src/include/function/table_functions/call_functions.h @@ -1,9 +1,9 @@ #pragma once -#include "catalog/catalog_content.h" #include "catalog/catalog_entry/table_catalog_entry.h" #include "common/data_chunk/data_chunk_collection.h" #include "common/vector/value_vector.h" +#include "function/scalar_function.h" #include "function/table_functions.h" #include "function/table_functions/bind_data.h" #include "function/table_functions/bind_input.h" diff --git a/test/main/udf_test.cpp b/test/main/udf_test.cpp index 3e4b55515f..b29d15391f 100644 --- a/test/main/udf_test.cpp +++ b/test/main/udf_test.cpp @@ -202,7 +202,7 @@ static void validateUDFError(std::function createFunc, std::string errMs TEST_F(ApiTest, UDFError) { conn->createScalarFunction("add5", &add5); validateUDFError([&]() { conn->createScalarFunction("add5", &add5); }, - "Catalog exception: function ADD5 already exists."); + "Catalog exception: function add5 already exists."); } TEST_F(ApiTest, UDFTypeError) { @@ -336,7 +336,7 @@ TEST_F(ApiTest, UDFTrxTest) { conn->createScalarFunction("times2", ×2); ASSERT_TRUE(conn->query("ROLLBACK;")->isSuccess()); ASSERT_EQ(conn->query("return times2(5)")->getErrorMessage(), - "Catalog exception: TIMES2 function does not exist."); + "Catalog exception: function TIMES2 does not exist."); } } // namespace testing diff --git a/test/test_files/exceptions/catalog/catalog.test b/test/test_files/exceptions/catalog/catalog.test index af1454b417..3a1db3e840 100644 --- a/test/test_files/exceptions/catalog/catalog.test +++ b/test/test_files/exceptions/catalog/catalog.test @@ -6,13 +6,12 @@ -CASE CatalogExeception -STATEMENT MATCH (a:person) RETURN dummy(n) ---- error -Catalog exception: DUMMY function does not exist. +Catalog exception: function DUMMY does not exist. -STATEMENT MATCH (a:person) WHERE dummy() < 2 RETURN COUNT(*) ---- error -Catalog exception: DUMMY function does not exist. +Catalog exception: function DUMMY does not exist. -STATEMENT CREATE REL TABLE knows_post ( FROM person TO person, MANY_LOT) ---- error Binder exception: Cannot bind MANY_LOT as relationship multiplicity. - diff --git a/test/test_files/rdf/patch.test b/test/test_files/rdf/patch.test index ab9f68856f..cd9782296f 100644 --- a/test/test_files/rdf/patch.test +++ b/test/test_files/rdf/patch.test @@ -34,10 +34,10 @@ X|RDFGraph| -STATEMENT RETURN dummy(); ---- error -Catalog exception: DUMMY function does not exist. +Catalog exception: function DUMMY does not exist. -STATEMENT CALL show_tables('x') RETURN *; ---- error -Binder exception: Cannot match a built-in function for given function SHOW_TABLES(STRING). Expect empty inputs. +Binder exception: Cannot match a built-in function for given function show_tables(STRING). Expect empty inputs. -STATEMENT MATCH (a:X_lt) RETURN *; ---- error Binder exception: Cannot bind X_lt as a node pattern label. diff --git a/test/test_files/tck/return/return2.test b/test/test_files/tck/return/return2.test index b873e16724..7b2a082575 100644 --- a/test/test_files/tck/return/return2.test +++ b/test/test_files/tck/return/return2.test @@ -200,4 +200,4 @@ EntityNotFound: DeletedEntityAccess ---- ok -STATEMENT MATCH (a) RETURN foo(a); ---- error -Catalog exception: FOO function does not exist. +Catalog exception: function FOO does not exist. diff --git a/test/test_files/tinysnb/call/call.test b/test/test_files/tinysnb/call/call.test index 075001f59e..34516febc0 100644 --- a/test/test_files/tinysnb/call/call.test +++ b/test/test_files/tinysnb/call/call.test @@ -172,7 +172,7 @@ Binder exception: Show connection can only be called on a rel table! -LOG WrongNumParameter -STATEMENT CALL table_info('person', 'knows') RETURN * ---- error -Binder exception: Cannot match a built-in function for given function TABLE_INFO(STRING,STRING). Supported inputs are +Binder exception: Cannot match a built-in function for given function table_info(STRING,STRING). Supported inputs are (STRING) -LOG WrongParameterType diff --git a/tools/python_api/src_cpp/pandas/pandas_analyzer.cpp b/tools/python_api/src_cpp/pandas/pandas_analyzer.cpp index 856e30ad45..5fd2af3d43 100644 --- a/tools/python_api/src_cpp/pandas/pandas_analyzer.cpp +++ b/tools/python_api/src_cpp/pandas/pandas_analyzer.cpp @@ -1,6 +1,6 @@ #include "pandas/pandas_analyzer.h" -#include "function/built_in_function.h" +#include "function/built_in_function_utils.h" #include "py_conversion.h" namespace kuzu { @@ -22,7 +22,7 @@ static bool upgradeType(common::LogicalType& left, const common::LogicalType& ri return true; } auto leftToRightCost = - function::BuiltInFunctions::getCastCost(left.getLogicalTypeID(), right.getLogicalTypeID()); + function::BuiltInFunctionsUtils::getCastCost(left.getLogicalTypeID(), right.getLogicalTypeID()); if (leftToRightCost != common::UNDEFINED_CAST_COST) { left = right; } else {