Skip to content

Commit

Permalink
Merge pull request #2309 from kuzudb/table-function
Browse files Browse the repository at this point in the history
Fix table-function parameter match
  • Loading branch information
acquamarin committed Oct 31, 2023
2 parents 43128fc + 6e33ca6 commit e7b0c2f
Show file tree
Hide file tree
Showing 21 changed files with 1,473 additions and 1,518 deletions.
2 changes: 1 addition & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ LOAD : ( 'L' | 'l' ) ( 'O' | 'o' ) ( 'A' | 'a' ) ( 'D' | 'd' ) ;
HEADERS : ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'A' | 'a' ) ( 'D' | 'd' ) ( 'E' | 'e' ) ( 'R' | 'r' ) ( 'S' | 's' ) ;

kU_InQueryCall
: CALL SP oC_FunctionName SP? '(' oC_Literal* ')' ;
: CALL SP oC_FunctionInvocation ;

oC_Match
: ( OPTIONAL SP )? MATCH SP? oC_Pattern (SP? oC_Where)? ;
Expand Down
23 changes: 17 additions & 6 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#include "binder/binder.h"
#include "binder/expression/literal_expression.h"
#include "binder/query/reading_clause/bound_in_query_call.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "binder/query/reading_clause/bound_match_clause.h"
#include "binder/query/reading_clause/bound_unwind_clause.h"
#include "common/exception/binder.h"
#include "common/string_format.h"
#include "function/table_functions/bind_input.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/expression/parsed_literal_expression.h"
#include "parser/query/reading_clause/in_query_call_clause.h"
#include "parser/query/reading_clause/load_from.h"
#include "parser/query/reading_clause/unwind_clause.h"
Expand Down Expand Up @@ -98,12 +99,22 @@ std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause

std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause& readingClause) {
auto& call = reinterpret_cast<const InQueryCallClause&>(readingClause);
auto tableFunction = catalog.getBuiltInFunctions()->mathTableFunction(call.getFuncName());
auto inputValues = std::vector<Value>{};
for (auto& parameter : call.getParameters()) {
auto boundExpr = expressionBinder.bindLiteralExpression(*parameter);
inputValues.push_back(*reinterpret_cast<LiteralExpression*>(boundExpr.get())->getValue());
auto funcExpr = reinterpret_cast<ParsedFunctionExpression*>(call.getFunctionExpression());
std::vector<std::unique_ptr<Value>> inputValues;
std::vector<LogicalType*> inputTypes;
for (auto i = 0u; i < funcExpr->getNumChildren(); i++) {
auto parameter = funcExpr->getChild(i);
if (parameter->getExpressionType() != ExpressionType::LITERAL) {
throw BinderException{"Parameters in table function must be a literal expression."};
}
auto expressionValue = reinterpret_cast<ParsedLiteralExpression*>(parameter)->getValue();
inputTypes.push_back(expressionValue->getDataType());
inputValues.push_back(expressionValue->copy());
}
auto funcNameToMatch = funcExpr->getFunctionName();
StringUtils::toUpper(funcNameToMatch);
auto tableFunction = reinterpret_cast<function::TableFunction*>(
catalog.getBuiltInFunctions()->matchScalarFunction(std::move(funcNameToMatch), inputTypes));
auto bindData = tableFunction->bindFunc(clientContext,
function::TableFuncBindInput{std::move(inputValues)}, catalog.getReadOnlyVersion());
expression_vector outputExpressions;
Expand Down
7 changes: 4 additions & 3 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
ExpressionType expressionType, const expression_vector& children) {
auto builtInFunctions = binder->catalog.getBuiltInFunctions();
auto functionName = expressionTypeToString(expressionType);
std::vector<LogicalType> childrenTypes;
std::vector<LogicalType*> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
childrenTypes.push_back(&child->dataType);
}
auto function = builtInFunctions->matchScalarFunction(functionName, childrenTypes);
auto function = reinterpret_cast<function::ScalarFunction*>(
builtInFunctions->matchScalarFunction(functionName, childrenTypes));
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
childrenAfterCast.push_back(
Expand Down
11 changes: 6 additions & 5 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName) {
auto builtInFunctions = binder->catalog.getBuiltInFunctions();
std::vector<LogicalType> childrenTypes;
std::vector<LogicalType*> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
childrenTypes.push_back(&child->dataType);
}
auto function = builtInFunctions->matchScalarFunction(functionName, childrenTypes);
auto function = reinterpret_cast<function::ScalarFunction*>(
builtInFunctions->matchScalarFunction(functionName, childrenTypes));
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
auto targetType =
Expand All @@ -80,7 +81,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
const ParsedExpression& parsedExpression, const std::string& functionName, bool isDistinct) {
auto builtInFunctions = binder->catalog.getBuiltInFunctions();
std::vector<LogicalType> childrenTypes;
std::vector<LogicalType*> childrenTypes;
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
Expand All @@ -89,7 +90,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
(childTypeID == LogicalTypeID::NODE || childTypeID == LogicalTypeID::REL)) {
throw BinderException{"DISTINCT is not supported for NODE or REL type."};
}
childrenTypes.push_back(child->dataType);
childrenTypes.push_back(&child->dataType);
children.push_back(std::move(child));
}
auto function =
Expand Down
4 changes: 2 additions & 2 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -640,10 +640,10 @@ std::string LogicalTypeUtils::dataTypeToString(LogicalTypeID dataTypeID) {
}
}

std::string LogicalTypeUtils::dataTypesToString(const std::vector<LogicalType>& dataTypes) {
std::string LogicalTypeUtils::dataTypesToString(const std::vector<LogicalType*>& dataTypes) {
std::vector<LogicalTypeID> dataTypeIDs;
for (auto& dataType : dataTypes) {
dataTypeIDs.push_back(dataType.typeID);
dataTypeIDs.push_back(dataType->typeID);
}
return dataTypesToString(dataTypeIDs);
}
Expand Down
68 changes: 36 additions & 32 deletions src/function/built_in_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,24 @@ void BuiltInFunctions::registerAggregateFunctions() {
registerCollect();
}

ScalarFunction* BuiltInFunctions::matchScalarFunction(
const std::string& name, const std::vector<LogicalType>& inputTypes) {
Function* BuiltInFunctions::matchScalarFunction(
const std::string& name, const std::vector<LogicalType*>& inputTypes) {
auto& functionSet = functions.at(name);
bool isOverload = functionSet.size() > 1;
std::vector<ScalarFunction*> candidateFunctions;
std::vector<Function*> candidateFunctions;
uint32_t minCost = UINT32_MAX;
for (auto& function : functionSet) {
auto scalarFunc = reinterpret_cast<ScalarFunction*>(function.get());
auto cost = getFunctionCost(inputTypes, scalarFunc, isOverload);
auto func = reinterpret_cast<Function*>(function.get());
auto cost = getFunctionCost(inputTypes, func, isOverload);
if (cost == UINT32_MAX) {
continue;
}
if (cost < minCost) {
candidateFunctions.clear();
candidateFunctions.push_back(scalarFunc);
candidateFunctions.push_back(func);
minCost = cost;
} else if (cost == minCost) {
candidateFunctions.push_back(scalarFunc);
candidateFunctions.push_back(func);
}
}
validateNonEmptyCandidateFunctions(candidateFunctions, name, inputTypes);
Expand All @@ -88,7 +88,7 @@ ScalarFunction* BuiltInFunctions::matchScalarFunction(
}

AggregateFunction* BuiltInFunctions::matchAggregateFunction(
const std::string& name, const std::vector<common::LogicalType>& inputTypes, bool isDistinct) {
const std::string& name, const std::vector<common::LogicalType*>& inputTypes, bool isDistinct) {
auto& functionSet = functions.at(name);
std::vector<AggregateFunction*> candidateFunctions;
for (auto& function : functionSet) {
Expand All @@ -104,13 +104,6 @@ AggregateFunction* BuiltInFunctions::matchAggregateFunction(
return candidateFunctions[0];
}

TableFunction* BuiltInFunctions::mathTableFunction(const std::string& name) {
auto upperName = name;
StringUtils::toUpper(upperName);
containsFunction(upperName);
return reinterpret_cast<TableFunction*>(functions.at(upperName)[0].get());
}

uint32_t BuiltInFunctions::getCastCost(LogicalTypeID inputTypeID, LogicalTypeID targetTypeID) {
if (inputTypeID == targetTypeID) {
return 0;
Expand Down Expand Up @@ -156,15 +149,15 @@ uint32_t BuiltInFunctions::getCastCost(LogicalTypeID inputTypeID, LogicalTypeID
}

uint32_t BuiltInFunctions::getAggregateFunctionCost(
const std::vector<LogicalType>& inputTypes, bool isDistinct, AggregateFunction* function) {
const std::vector<LogicalType*>& inputTypes, bool isDistinct, AggregateFunction* function) {
if (inputTypes.size() != function->parameterTypeIDs.size() ||
isDistinct != function->isDistinct) {
return UINT32_MAX;
}
for (auto i = 0u; i < inputTypes.size(); ++i) {
if (function->parameterTypeIDs[i] == LogicalTypeID::ANY) {
continue;
} else if (inputTypes[i].getLogicalTypeID() != function->parameterTypeIDs[i]) {
} else if (inputTypes[i]->getLogicalTypeID() != function->parameterTypeIDs[i]) {
return UINT32_MAX;
}
}
Expand All @@ -173,7 +166,7 @@ uint32_t BuiltInFunctions::getAggregateFunctionCost(

void BuiltInFunctions::validateNonEmptyCandidateFunctions(
std::vector<AggregateFunction*>& candidateFunctions, const std::string& name,
const std::vector<LogicalType>& inputTypes, bool isDistinct) {
const std::vector<LogicalType*>& inputTypes, bool isDistinct) {
if (candidateFunctions.empty()) {
std::string supportedInputsString;
for (auto& function : functions.at(name)) {
Expand Down Expand Up @@ -372,9 +365,9 @@ uint32_t BuiltInFunctions::castSerial(LogicalTypeID targetTypeID) {
// "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double.
// Additionally, we prefer function with string parameter because string is most permissive and can
// be cast to any type.
ScalarFunction* BuiltInFunctions::getBestMatch(std::vector<ScalarFunction*>& functionsToMatch) {
Function* BuiltInFunctions::getBestMatch(std::vector<Function*>& functionsToMatch) {
assert(functionsToMatch.size() > 1);
ScalarFunction* result = nullptr;
Function* result = nullptr;
auto cost = UNDEFINED_CAST_COST;
for (auto& function : functionsToMatch) {
auto currentCost = 0;
Expand All @@ -398,23 +391,34 @@ ScalarFunction* BuiltInFunctions::getBestMatch(std::vector<ScalarFunction*>& fun
}

uint32_t BuiltInFunctions::getFunctionCost(
const std::vector<LogicalType>& inputTypes, ScalarFunction* function, bool isOverload) {
if (function->isVarLength) {
assert(function->parameterTypeIDs.size() == 1);
return matchVarLengthParameters(inputTypes, function->parameterTypeIDs[0], isOverload);
} else {
const std::vector<LogicalType*>& inputTypes, Function* function, bool isOverload) {
switch (function->type) {
case FunctionType::SCALAR: {
auto scalarFunction = reinterpret_cast<ScalarFunction*>(function);
if (scalarFunction->isVarLength) {
assert(function->parameterTypeIDs.size() == 1);
return matchVarLengthParameters(inputTypes, function->parameterTypeIDs[0], isOverload);
} else {
return matchParameters(inputTypes, function->parameterTypeIDs, isOverload);
}
}
case FunctionType::TABLE:
return matchParameters(inputTypes, function->parameterTypeIDs, isOverload);
// LCOV_EXCL_START
default:
throw NotImplementedException{"BuiltInFunctions::getFunctionCost"};
// LCOC_EXCL_STOP
}
}

uint32_t BuiltInFunctions::matchParameters(const std::vector<LogicalType>& inputTypes,
uint32_t BuiltInFunctions::matchParameters(const std::vector<LogicalType*>& inputTypes,
const std::vector<LogicalTypeID>& targetTypeIDs, bool /*isOverload*/) {
if (inputTypes.size() != targetTypeIDs.size()) {
return UINT32_MAX;
}
auto cost = 0u;
for (auto i = 0u; i < inputTypes.size(); ++i) {
auto castCost = getCastCost(inputTypes[i].getLogicalTypeID(), targetTypeIDs[i]);
auto castCost = getCastCost(inputTypes[i]->getLogicalTypeID(), targetTypeIDs[i]);
if (castCost == UNDEFINED_CAST_COST) {
return UINT32_MAX;
}
Expand All @@ -424,10 +428,10 @@ uint32_t BuiltInFunctions::matchParameters(const std::vector<LogicalType>& input
}

uint32_t BuiltInFunctions::matchVarLengthParameters(
const std::vector<LogicalType>& inputTypes, LogicalTypeID targetTypeID, bool /*isOverload*/) {
const std::vector<LogicalType*>& inputTypes, LogicalTypeID targetTypeID, bool /*isOverload*/) {
auto cost = 0u;
for (auto& inputType : inputTypes) {
auto castCost = getCastCost(inputType.getLogicalTypeID(), targetTypeID);
for (auto inputType : inputTypes) {
auto castCost = getCastCost(inputType->getLogicalTypeID(), targetTypeID);
if (castCost == UNDEFINED_CAST_COST) {
return UINT32_MAX;
}
Expand All @@ -437,8 +441,8 @@ uint32_t BuiltInFunctions::matchVarLengthParameters(
}

void BuiltInFunctions::validateNonEmptyCandidateFunctions(
std::vector<ScalarFunction*>& candidateFunctions, const std::string& name,
const std::vector<LogicalType>& inputTypes) {
std::vector<Function*>& candidateFunctions, const std::string& name,
const std::vector<LogicalType*>& inputTypes) {
if (candidateFunctions.empty()) {
std::string supportedInputsString;
for (auto& function : functions.at(name)) {
Expand Down
26 changes: 13 additions & 13 deletions src/function/table_functions/call_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ using namespace kuzu::main;

function_set CurrentSettingFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>("current_setting", tableFunc, bindFunc, initSharedState));
functionSet.push_back(std::make_unique<TableFunction>("current_setting", tableFunc, bindFunc,
initSharedState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down Expand Up @@ -52,7 +52,7 @@ void CurrentSettingFunction::tableFunc(TableFunctionInput& data, std::vector<Val

std::unique_ptr<TableFuncBindData> CurrentSettingFunction::bindFunc(
ClientContext* context, TableFuncBindInput input, CatalogContent* /*catalog*/) {
auto optionName = input.inputs[0].getValue<std::string>();
auto optionName = input.inputs[0]->getValue<std::string>();
std::vector<std::string> returnColumnNames;
std::vector<LogicalType> returnTypes;
returnColumnNames.emplace_back(optionName);
Expand All @@ -63,8 +63,8 @@ std::unique_ptr<TableFuncBindData> CurrentSettingFunction::bindFunc(

function_set DBVersionFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>("db_version", tableFunc, bindFunc, initSharedState));
functionSet.push_back(std::make_unique<TableFunction>(
"db_version", tableFunc, bindFunc, initSharedState, std::vector<LogicalTypeID>{}));
return functionSet;
}

Expand Down Expand Up @@ -94,8 +94,8 @@ std::unique_ptr<TableFuncBindData> DBVersionFunction::bindFunc(

function_set ShowTablesFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>("show_tables", tableFunc, bindFunc, initSharedState));
functionSet.push_back(std::make_unique<TableFunction>(
"show_tables", tableFunc, bindFunc, initSharedState, std::vector<LogicalTypeID>{}));
return functionSet;
}

Expand Down Expand Up @@ -136,8 +136,8 @@ std::unique_ptr<TableFuncBindData> ShowTablesFunction::bindFunc(

function_set TableInfoFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>("table_info", tableFunc, bindFunc, initSharedState));
functionSet.push_back(std::make_unique<TableFunction>("table_info", tableFunc, bindFunc,
initSharedState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down Expand Up @@ -177,7 +177,7 @@ std::unique_ptr<TableFuncBindData> TableInfoFunction::bindFunc(
ClientContext* /*context*/, TableFuncBindInput input, CatalogContent* catalog) {
std::vector<std::string> returnColumnNames;
std::vector<LogicalType> returnTypes;
auto tableName = input.inputs[0].getValue<std::string>();
auto tableName = input.inputs[0]->getValue<std::string>();
auto tableID = catalog->getTableID(tableName);
auto schema = catalog->getTableSchema(tableID);
returnColumnNames.emplace_back("property id");
Expand All @@ -196,8 +196,8 @@ std::unique_ptr<TableFuncBindData> TableInfoFunction::bindFunc(

function_set ShowConnectionFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>("db_version", tableFunc, bindFunc, initSharedState));
functionSet.push_back(std::make_unique<TableFunction>("db_version", tableFunc, bindFunc,
initSharedState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down Expand Up @@ -253,7 +253,7 @@ std::unique_ptr<TableFuncBindData> ShowConnectionFunction::bindFunc(
ClientContext* /*context*/, TableFuncBindInput input, CatalogContent* catalog) {
std::vector<std::string> returnColumnNames;
std::vector<LogicalType> returnTypes;
auto tableName = input.inputs[0].getValue<std::string>();
auto tableName = input.inputs[0]->getValue<std::string>();
auto tableID = catalog->getTableID(tableName);
auto schema = catalog->getTableSchema(tableID);
auto tableType = schema->getTableType();
Expand Down
4 changes: 2 additions & 2 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,8 @@ class LogicalTypeUtils {
public:
KUZU_API static std::string dataTypeToString(const LogicalType& dataType);
KUZU_API static std::string dataTypeToString(LogicalTypeID dataTypeID);
static std::string dataTypesToString(const std::vector<LogicalType>& dataTypes);
static std::string dataTypesToString(const std::vector<LogicalTypeID>& dataTypeIDs);
static std::string dataTypesToString(const std::vector<LogicalType*>& dataTypes);
KUZU_API static std::string dataTypesToString(const std::vector<LogicalTypeID>& dataTypeIDs);
KUZU_API static LogicalType dataTypeFromString(const std::string& dataTypeString);
static uint32_t getRowLayoutSize(const LogicalType& logicalType);
static bool isNumerical(const LogicalType& dataType);
Expand Down
Loading

0 comments on commit e7b0c2f

Please sign in to comment.