Skip to content

Commit

Permalink
Merge pull request #3069 from kuzudb/replace-func
Browse files Browse the repository at this point in the history
Add scalar_func_rewrite_t
  • Loading branch information
andyfengHKU committed Mar 18, 2024
2 parents a612c0f + 1d7b9f3 commit a0ee10e
Show file tree
Hide file tree
Showing 38 changed files with 311 additions and 201 deletions.
4 changes: 2 additions & 2 deletions src/binder/bind_expression/bind_boolean_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(
childrenAfterCast.push_back(implicitCastIfNecessary(child, LogicalTypeID::BOOL));
}
auto functionName = expressionTypeToString(expressionType);
function::scalar_exec_func execFunc;
function::scalar_func_exec_t execFunc;
function::VectorBooleanFunction::bindExecFunction(expressionType, childrenAfterCast, execFunc);
function::scalar_select_func selectFunc;
function::scalar_func_select_t selectFunc;
function::VectorBooleanFunction::bindSelectFunction(
expressionType, childrenAfterCast, selectFunc);
auto bindData = std::make_unique<function::FunctionBindData>(LogicalType::BOOL());
Expand Down
88 changes: 44 additions & 44 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "binder/expression/property_expression.h"
#include "binder/expression_binder.h"
#include "common/exception/binder.h"
#include "common/string_utils.h"
#include "function/rewrite_function.h"
#include "function/schema/vector_label_functions.h"
#include "main/client_context.h"
#include "parser/expression/parsed_function_expression.h"
Expand All @@ -13,29 +13,29 @@
using namespace kuzu::common;
using namespace kuzu::parser;
using namespace kuzu::function;
using namespace kuzu::catalog;

namespace kuzu {
namespace binder {

std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(
const ParsedExpression& parsedExpression) {
auto& parsedFunctionExpression = (ParsedFunctionExpression&)parsedExpression;
auto functionName = parsedFunctionExpression.getFunctionName();
StringUtils::toUpper(functionName);
auto result = rewriteFunctionExpression(parsedExpression, functionName);
std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(const ParsedExpression& expr) {
auto& funcExpr =
ku_dynamic_cast<const ParsedExpression&, const ParsedFunctionExpression&>(expr);
auto functionName = funcExpr.getNormalizedFunctionName();
auto result = rewriteFunctionExpression(expr, functionName);
if (result != nullptr) {
return result;
}
auto functionType =
context->getCatalog()->getFunctionType(binder->clientContext->getTx(), functionName);
switch (functionType) {
case ExpressionType::FUNCTION:
return bindScalarFunctionExpression(parsedExpression, functionName);
case ExpressionType::AGGREGATE_FUNCTION:
return bindAggregateFunctionExpression(
parsedExpression, functionName, parsedFunctionExpression.getIsDistinct());
case ExpressionType::MACRO:
return bindMacroExpression(parsedExpression, functionName);
auto entry = context->getCatalog()->getFunctionEntry(context->getTx(), functionName);
switch (entry->getType()) {
case CatalogEntryType::SCALAR_FUNCTION_ENTRY:
return bindScalarFunctionExpression(expr, functionName);
case CatalogEntryType::REWRITE_FUNCTION_ENTRY:
return bindRewriteFunctionExpression(expr);
case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY:
return bindAggregateFunctionExpression(expr, functionName, funcExpr.getIsDistinct());
case CatalogEntryType::SCALAR_MACRO_ENTRY:
return bindMacroExpression(expr, functionName);
default:
KU_UNREACHABLE;
}
Expand All @@ -51,12 +51,17 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
return bindScalarFunctionExpression(children, functionName);
}

static std::vector<LogicalType> getTypes(const expression_vector& exprs) {
std::vector<LogicalType> result;
for (auto& expr : exprs) {
result.push_back(expr->getDataType());
}
return result;
}

std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName) {
std::vector<LogicalType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto childrenTypes = getTypes(children);
auto functions = context->getCatalog()->getFunctions(context->getTx());
auto function = ku_dynamic_cast<Function*, function::ScalarFunction*>(
function::BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes, functions));
Expand Down Expand Up @@ -93,6 +98,23 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
function->compileFunc, uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::bindRewriteFunctionExpression(
const parser::ParsedExpression& expr) {
auto& funcExpr =
ku_dynamic_cast<const ParsedExpression&, const ParsedFunctionExpression&>(expr);
expression_vector children;
for (auto i = 0u; i < expr.getNumChildren(); ++i) {
children.push_back(bindExpression(*expr.getChild(i)));
}
auto childrenTypes = getTypes(children);
auto functions = context->getCatalog()->getFunctions(context->getTx());
auto match = BuiltInFunctionsUtils::matchFunction(
funcExpr.getNormalizedFunctionName(), childrenTypes, functions);
auto function = ku_dynamic_cast<Function*, RewriteFunction*>(match);
KU_ASSERT(function->rewriteFunc != nullptr);
return function->rewriteFunc(children, this);
}

std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
const ParsedExpression& parsedExpression, const std::string& functionName, bool isDistinct) {
std::vector<LogicalType> childrenTypes;
Expand Down Expand Up @@ -159,19 +181,13 @@ std::shared_ptr<Expression> ExpressionBinder::bindMacroExpression(
// Function rewriting happens when we need to expose internal property access through function so
// that it becomes read-only or the function involves catalog information. Currently we write
// Before | After
// ID(a) | a._id
// LABEL(a) | LIST_EXTRACT(offset(a), [table names from catalog])
// LENGTH(e) | e._length
// STARTNODE(a) | a._src
// ENDNODE(a) | a._dst
std::shared_ptr<Expression> ExpressionBinder::rewriteFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName) {
if (functionName == ID_FUNC_NAME) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(*child, std::vector<LogicalTypeID>{LogicalTypeID::NODE,
LogicalTypeID::REL, LogicalTypeID::STRUCT});
return bindInternalIDExpression(child);
} else if (functionName == LABEL_FUNC_NAME) {
if (functionName == LABEL_FUNC_NAME) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(
*child, std::vector<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL});
Expand Down Expand Up @@ -203,22 +219,6 @@ std::unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
std::move(propertyIDPerTable), false /* isPrimaryKey */);
}

std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
const std::shared_ptr<Expression>& expression) {
if (ExpressionUtil::isNodePattern(*expression)) {
auto& node = (NodeExpression&)*expression;
return node.getInternalID();
}
if (ExpressionUtil::isRelPattern(*expression)) {
return bindNodeOrRelPropertyExpression(*expression, InternalKeyword::ID);
}
KU_ASSERT(expression->dataType.getPhysicalType() == PhysicalTypeID::STRUCT);
auto stringValue = std::make_unique<Value>(LogicalType::STRING(), InternalKeyword::ID);
return bindScalarFunctionExpression(
expression_vector{expression, createLiteralExpression(std::move(stringValue))},
STRUCT_EXTRACT_FUNC_NAME);
}

std::shared_ptr<Expression> ExpressionBinder::bindStartNodeExpression(
const Expression& expression) {
auto& rel = (RelExpression&)expression;
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind_expression/bind_null_operator_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindNullOperatorExpression(
}
auto expressionType = parsedExpression.getExpressionType();
auto functionName = expressionTypeToString(expressionType);
function::scalar_exec_func execFunc;
function::scalar_func_exec_t execFunc;
function::VectorNullFunction::bindExecFunction(expressionType, children, execFunc);
function::scalar_select_func selectFunc;
function::scalar_func_select_t selectFunc;
function::VectorNullFunction::bindSelectFunction(expressionType, children, selectFunc);
auto bindData = std::make_unique<function::FunctionBindData>(LogicalType::BOOL());
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(functionName, children);
Expand Down
21 changes: 13 additions & 8 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
#include "catalog/catalog_entry/scalar_macro_catalog_entry.h"
#include "common/exception/catalog.h"
#include "storage/wal/wal.h"
#include "transaction/transaction.h"
#include "transaction/transaction_action.h"
Expand Down Expand Up @@ -101,10 +102,6 @@ std::vector<TableCatalogEntry*> Catalog::getTableSchemas(
return result;
}

CatalogSet* Catalog::getFunctions(Transaction* tx) const {
return getVersion(tx)->functions.get();
}

void Catalog::prepareCommitOrRollback(TransactionAction action) {
if (hasUpdates()) {
wal->logCatalogRecord();
Expand All @@ -121,10 +118,6 @@ void Catalog::checkpointInMemory() {
}
}

ExpressionType Catalog::getFunctionType(Transaction* tx, const std::string& name) const {
return getVersion(tx)->getFunctionType(name);
}

table_id_t Catalog::addNodeTableSchema(const binder::BoundCreateTableInfo& info) {
KU_ASSERT(readWriteVersion != nullptr);
setToUpdated();
Expand Down Expand Up @@ -257,6 +250,18 @@ void Catalog::addBuiltInFunction(std::string name, function::function_set functi
readOnlyVersion->addFunction(std::move(name), std::move(functionSet));
}

CatalogSet* Catalog::getFunctions(Transaction* tx) const {
return getVersion(tx)->functions.get();
}

CatalogEntry* Catalog::getFunctionEntry(transaction::Transaction* tx, const std::string& name) {
auto catalogSet = getVersion(tx)->functions.get();
if (!catalogSet->containsEntry(name)) {
throw CatalogException(stringFormat("function {} does not exist.", name));
}
return catalogSet->getEntry(name);
}

bool Catalog::containsMacro(Transaction* tx, const std::string& macroName) const {
return getVersion(tx)->containMacro(macroName);
}
Expand Down
17 changes: 0 additions & 17 deletions src/catalog/catalog_content.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,23 +202,6 @@ void CatalogContent::readFromFile(const std::string& directory, FileVersionType
functions = CatalogSet::deserialize(deserializer);
}

ExpressionType CatalogContent::getFunctionType(const std::string& name) const {
if (!functions->containsEntry(name)) {
throw CatalogException{common::stringFormat("function {} does not exist.", name)};
}
auto functionEntry = functions->getEntry(name);
switch (functionEntry->getType()) {
case CatalogEntryType::SCALAR_MACRO_ENTRY:
return ExpressionType::MACRO;
case CatalogEntryType::SCALAR_FUNCTION_ENTRY:
return ExpressionType::FUNCTION;
case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY:
return ExpressionType::AGGREGATE_FUNCTION;
default:
KU_UNREACHABLE;
}
}

void CatalogContent::addFunction(std::string name, function::function_set definitions) {
if (functions->containsEntry(name)) {
throw CatalogException{common::stringFormat("function {} already exists.", name)};
Expand Down
12 changes: 12 additions & 0 deletions src/catalog/catalog_set.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include "catalog/catalog_set.h"

#include "common/assert.h"
#include "common/exception/catalog.h"
#include "common/string_format.h"

using namespace kuzu::common;

namespace kuzu {
namespace catalog {
Expand All @@ -10,6 +14,13 @@ bool CatalogSet::containsEntry(const std::string& name) const {
}

CatalogEntry* CatalogSet::getEntry(const std::string& name) {
// LCOV_EXCL_START
// We should not trigger the following check. If so, we should throw more informative error
// message at catalog level.
if (!containsEntry(name)) {
throw CatalogException(stringFormat("Cannot find catalog entry with name {}.", name));
}
// LCOV_EXCL_STOP
return entries.at(name).get();
}

Expand All @@ -36,6 +47,7 @@ void CatalogSet::serialize(common::Serializer serializer) const {
for (auto& [name, entry] : entries) {
switch (entry->getType()) {
case CatalogEntryType::SCALAR_FUNCTION_ENTRY:
case CatalogEntryType::REWRITE_FUNCTION_ENTRY:
case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY:
case CatalogEntryType::TABLE_FUNCTION_ENTRY:
continue;
Expand Down
1 change: 1 addition & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_subdirectory(aggregate)
add_subdirectory(arithmetic)
add_subdirectory(cast)
add_subdirectory(pattern)
add_subdirectory(table)

add_library(kuzu_function
Expand Down
26 changes: 17 additions & 9 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ using namespace kuzu::processor;
namespace kuzu {
namespace function {

static void validateNonEmptyCandidateFunctions(std::vector<AggregateFunction*>& candidateFunctions,
const std::string& name, const std::vector<common::LogicalType>& inputTypes, bool isDistinct,
function::function_set& set);
static void validateNonEmptyCandidateFunctions(std::vector<Function*>& candidateFunctions,
const std::string& name, const std::vector<common::LogicalType>& inputTypes,
function::function_set& set);

void BuiltInFunctionsUtils::createFunctions(CatalogSet* catalogSet) {
registerScalarFunctions(catalogSet);
registerAggregateFunctions(catalogSet);
Expand Down Expand Up @@ -453,10 +460,8 @@ uint32_t BuiltInFunctionsUtils::getFunctionCost(
return matchParameters(inputTypes, function->parameterTypeIDs, isOverload);
}
}
case FunctionType::TABLE:
return matchParameters(inputTypes, function->parameterTypeIDs, isOverload);
default:
KU_UNREACHABLE;
return matchParameters(inputTypes, function->parameterTypeIDs, isOverload);
}
}

Expand Down Expand Up @@ -909,6 +914,9 @@ void BuiltInFunctionsUtils::registerUnionFunctions(CatalogSet* catalogSet) {
void BuiltInFunctionsUtils::registerNodeRelFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
OFFSET_FUNC_NAME, OffsetFunction::getFunctionSet()));
catalogSet->createEntry(
std::make_unique<FunctionCatalogEntry>(catalog::CatalogEntryType::REWRITE_FUNCTION_ENTRY,
ID_FUNC_NAME, IDFunction::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerPathFunctions(CatalogSet* catalogSet) {
Expand Down Expand Up @@ -1069,9 +1077,9 @@ static std::string getFunctionMatchFailureMsg(const std::string name,
return result;
}

void BuiltInFunctionsUtils::validateNonEmptyCandidateFunctions(
std::vector<AggregateFunction*>& candidateFunctions, const std::string& name,
const std::vector<LogicalType>& inputTypes, bool isDistinct, function::function_set& set) {
void validateNonEmptyCandidateFunctions(std::vector<AggregateFunction*>& candidateFunctions,
const std::string& name, const std::vector<LogicalType>& inputTypes, bool isDistinct,
function::function_set& set) {
if (candidateFunctions.empty()) {
std::string supportedInputsString;
for (auto& function : set) {
Expand All @@ -1086,9 +1094,9 @@ void BuiltInFunctionsUtils::validateNonEmptyCandidateFunctions(
}
}

void BuiltInFunctionsUtils::validateNonEmptyCandidateFunctions(
std::vector<Function*>& candidateFunctions, const std::string& name,
const std::vector<LogicalType>& inputTypes, function::function_set& set) {
void validateNonEmptyCandidateFunctions(std::vector<Function*>& candidateFunctions,
const std::string& name, const std::vector<LogicalType>& inputTypes,
function::function_set& set) {
if (candidateFunctions.empty()) {
std::string supportedInputsString;
for (auto& function : set) {
Expand Down
7 changes: 7 additions & 0 deletions src/function/pattern/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
add_library(kuzu_function_pattern
OBJECT
id_function.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_function_pattern>
PARENT_SCOPE)
46 changes: 46 additions & 0 deletions src/function/pattern/id_function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "binder/expression/expression_util.h"
#include "binder/expression/node_expression.h"
#include "binder/expression/rel_expression.h"
#include "binder/expression_binder.h"
#include "function/rewrite_function.h"
#include "function/schema/vector_node_rel_functions.h"

using namespace kuzu::common;
using namespace kuzu::binder;

namespace kuzu {
namespace function {

static std::shared_ptr<binder::Expression> rewriteFunc(
const expression_vector& params, ExpressionBinder* binder) {
KU_ASSERT(params.size() == 1);
auto param = params[0].get();
if (ExpressionUtil::isNodePattern(*param)) {
auto node = ku_dynamic_cast<Expression*, NodeExpression*>(param);
return node->getInternalID();
}
if (ExpressionUtil::isRelPattern(*param)) {
auto rel = ku_dynamic_cast<Expression*, RelExpression*>(param);
return rel->getPropertyExpression(InternalKeyword::ID);
}
// Bind as struct_extract(param, "_id")
auto key = Value(LogicalType::STRING(), InternalKeyword::ID);
auto keyExpr = binder->createLiteralExpression(key.copy());
auto newParams = expression_vector{params[0], keyExpr};
return binder->bindScalarFunctionExpression(newParams, STRUCT_EXTRACT_FUNC_NAME);
}

function_set IDFunction::getFunctionSet() {
function_set functionSet;
auto inputTypes =
std::vector<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL, LogicalTypeID::STRUCT};
for (auto& inputType : inputTypes) {
auto function = std::make_unique<RewriteFunction>(
InternalKeyword::ID, std::vector<LogicalTypeID>{inputType}, rewriteFunc);
functionSet.push_back(std::move(function));
}
return functionSet;
}

} // namespace function
} // namespace kuzu
Loading

0 comments on commit a0ee10e

Please sign in to comment.