Skip to content

Commit

Permalink
Move var length field to function (#3328)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Apr 19, 2024
1 parent b6921a2 commit 48188ce
Show file tree
Hide file tree
Showing 15 changed files with 117 additions and 148 deletions.
17 changes: 4 additions & 13 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "function/aggregate_function.h"
#include "function/arithmetic/vector_arithmetic_functions.h"
#include "function/function_collection.h"
#include "function/scalar_function.h"

using namespace kuzu::common;
using namespace kuzu::catalog;
Expand Down Expand Up @@ -420,19 +419,11 @@ Function* BuiltInFunctionsUtils::getBestMatch(std::vector<Function*>& functionsT

uint32_t BuiltInFunctionsUtils::getFunctionCost(const std::vector<LogicalType>& inputTypes,
Function* function, bool isOverload) {
switch (function->type) {
case FunctionType::SCALAR: {
auto scalarFunction = ku_dynamic_cast<Function*, ScalarFunction*>(function);
if (scalarFunction->isVarLength) {
KU_ASSERT(function->parameterTypeIDs.size() == 1);
return matchVarLengthParameters(inputTypes, function->parameterTypeIDs[0], isOverload);
} else {
return matchParameters(inputTypes, function->parameterTypeIDs, isOverload);
}
}
default:
return matchParameters(inputTypes, function->parameterTypeIDs, isOverload);
if (function->isVarLength) {
KU_ASSERT(function->parameterTypeIDs.size() == 1);
return matchVarLengthParameters(inputTypes, function->parameterTypeIDs[0], isOverload);
}
return matchParameters(inputTypes, function->parameterTypeIDs, isOverload);
}

uint32_t BuiltInFunctionsUtils::getAggregateFunctionCost(const std::vector<LogicalType>& inputTypes,
Expand Down
2 changes: 1 addition & 1 deletion src/function/vector_arithmetic_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ function_set AddFunction::getFunctionSet() {
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::LIST,
ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, list_entry_t,
ListConcat>,
nullptr, ListConcatFunction::bindFunc, false /* isVarlength*/));
nullptr, ListConcatFunction::bindFunc));
// string + string -> string
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::STRING},
Expand Down
14 changes: 7 additions & 7 deletions src/function/vector_array_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ std::unique_ptr<FunctionBindData> ArrayValueBindFunc(const binder::expression_ve

function_set ArrayValueFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::ANY}, LogicalTypeID::ARRAY,
ListCreationFunction::execFunc, nullptr, ArrayValueBindFunc, true /* isVarLength */));
auto function =
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::ANY},
LogicalTypeID::ARRAY, ListCreationFunction::execFunc, nullptr, ArrayValueBindFunc);
function->isVarLength = true;
result.push_back(std::move(function));
return result;
}

Expand Down Expand Up @@ -85,8 +87,7 @@ function_set ArrayCrossProductFunction::getFunctionSet() {
LogicalTypeID::ARRAY,
LogicalTypeID::ARRAY,
},
LogicalTypeID::ARRAY, nullptr, nullptr, ArrayCrossProductBindFunc,
false /* isVarLength */));
LogicalTypeID::ARRAY, nullptr, nullptr, ArrayCrossProductBindFunc));
return result;
}

Expand Down Expand Up @@ -147,8 +148,7 @@ function_set templateGetFunctionSet(const std::string& functionName) {
},
LogicalTypeID::ANY, nullptr, nullptr,
std::bind(arrayTemplateBindFunc<OPERATION>, functionName, std::placeholders::_1,
std::placeholders::_2),
false /* isVarLength */));
std::placeholders::_2)));
return result;
}

Expand Down
8 changes: 3 additions & 5 deletions src/function/vector_blob_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,23 @@ function_set OctetLengthFunctions::getFunctionSet() {
definitions.push_back(
make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::BLOB},
LogicalTypeID::INT64, ScalarFunction::UnaryExecFunction<blob_t, int64_t, OctetLength>,
nullptr, nullptr, nullptr, false /* isVarLength */));
nullptr, nullptr, nullptr));
return definitions;
}

function_set EncodeFunctions::getFunctionSet() {
function_set definitions;
definitions.push_back(make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}, LogicalTypeID::BLOB,
ScalarFunction::UnaryStringExecFunction<ku_string_t, blob_t, Encode>, nullptr,
false /* isVarLength */));
ScalarFunction::UnaryStringExecFunction<ku_string_t, blob_t, Encode>, nullptr));
return definitions;
}

function_set DecodeFunctions::getFunctionSet() {
function_set definitions;
definitions.push_back(make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::BLOB}, LogicalTypeID::STRING,
ScalarFunction::UnaryStringExecFunction<blob_t, ku_string_t, Decode>, nullptr,
false /* isVarLength */));
ScalarFunction::UnaryStringExecFunction<blob_t, ku_string_t, Decode>, nullptr));
return definitions;
}

Expand Down
2 changes: 1 addition & 1 deletion src/function/vector_cast_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ function_set CastAnyFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::ANY, LogicalTypeID::STRING}, LogicalTypeID::ANY,
nullptr, nullptr, castBindFunc, false));
nullptr, nullptr, castBindFunc));
return result;
}

Expand Down
11 changes: 4 additions & 7 deletions src/function/vector_hash_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,15 @@ function_set MD5Function::getFunctionSet() {
function_set functionSet;
functionSet.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}, LogicalTypeID::STRING,
ScalarFunction::UnaryStringExecFunction<ku_string_t, ku_string_t, MD5Operator>,
false /* isVarLength */));
ScalarFunction::UnaryStringExecFunction<ku_string_t, ku_string_t, MD5Operator>));
return functionSet;
}

function_set SHA256Function::getFunctionSet() {
function_set functionSet;
functionSet.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}, LogicalTypeID::STRING,
ScalarFunction::UnaryStringExecFunction<ku_string_t, ku_string_t, SHA256Operator>,
false /* isVarLength */));
ScalarFunction::UnaryStringExecFunction<ku_string_t, ku_string_t, SHA256Operator>));
return functionSet;
}

Expand All @@ -226,9 +224,8 @@ static void HashExecFunc(const std::vector<std::shared_ptr<common::ValueVector>>

function_set HashFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::ANY},
LogicalTypeID::INT64, HashExecFunc, false /* isVarLength */));
functionSet.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::ANY}, LogicalTypeID::INT64, HashExecFunc));
return functionSet;
}

Expand Down
75 changes: 37 additions & 38 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,11 @@ LogicalType ListCreationFunction::getChildType(const binder::expression_vector&

function_set ListCreationFunction::getFunctionSet() {
function_set result;
result.push_back(
auto function =
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::ANY},
LogicalTypeID::LIST, execFunc, nullptr, ListCreationBindFunc, true /* isVarLength */));
LogicalTypeID::LIST, execFunc, nullptr, ListCreationBindFunc);
function->isVarLength = true;
result.push_back(std::move(function));
return result;
}

Expand All @@ -309,12 +311,12 @@ function_set ListRangeFunction::getFunctionSet() {
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{typeID, typeID}, LogicalTypeID::LIST,
getBinaryListExecFuncSwitchAll<Range, list_entry_t>(LogicalType{typeID}), nullptr,
ListRangeBindFunc, false));
ListRangeBindFunc));
// start, end, step
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{typeID, typeID, typeID}, LogicalTypeID::LIST,
getTernaryListExecFuncSwitchAll<Range, list_entry_t>(LogicalType{typeID}), nullptr,
ListRangeBindFunc, false));
ListRangeBindFunc));
}
return result;
}
Expand All @@ -323,13 +325,13 @@ function_set SizeFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST}, LogicalTypeID::INT64,
ScalarFunction::UnaryExecFunction<list_entry_t, int64_t, ListLen>, true /* isVarlength*/));
ScalarFunction::UnaryExecFunction<list_entry_t, int64_t, ListLen>));
result.push_back(std::make_unique<ScalarFunction>(alias,
std::vector<LogicalTypeID>{LogicalTypeID::MAP}, LogicalTypeID::INT64,
ScalarFunction::UnaryExecFunction<list_entry_t, int64_t, ListLen>, true /* isVarlength*/));
ScalarFunction::UnaryExecFunction<list_entry_t, int64_t, ListLen>));
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}, LogicalTypeID::INT64,
ScalarFunction::UnaryExecFunction<ku_string_t, int64_t, ListLen>, true /* isVarlength*/));
ScalarFunction::UnaryExecFunction<ku_string_t, int64_t, ListLen>));
return result;
}

Expand Down Expand Up @@ -426,15 +428,14 @@ function_set ListExtractFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::INT64}, LogicalTypeID::ANY,
nullptr, nullptr, ListExtractBindFunc, false /* isVarlength*/));
nullptr, nullptr, ListExtractBindFunc));
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::INT64},
LogicalTypeID::STRING,
ScalarFunction::BinaryExecFunction<ku_string_t, int64_t, ku_string_t, ListExtract>,
false /* isVarlength */));
ScalarFunction::BinaryExecFunction<ku_string_t, int64_t, ku_string_t, ListExtract>));
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::ARRAY, LogicalTypeID::INT64}, LogicalTypeID::ANY,
nullptr, nullptr, ListExtractBindFunc, false /* isVarlength*/));
nullptr, nullptr, ListExtractBindFunc));
return result;
}

Expand All @@ -453,7 +454,7 @@ function_set ListConcatFunction::getFunctionSet() {
list_entry_t, ListConcat>;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::LIST,
execFunc, nullptr, bindFunc, false /* isVarlength*/));
execFunc, nullptr, bindFunc));
return result;
}

Expand All @@ -474,7 +475,7 @@ function_set ListAppendFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST,
nullptr, nullptr, ListAppendBindFunc, false /* isVarlength*/));
nullptr, nullptr, ListAppendBindFunc));
return result;
}

Expand All @@ -496,7 +497,7 @@ function_set ListPrependFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST,
nullptr, nullptr, ListPrependBindFunc, false /* isVarlength */));
nullptr, nullptr, ListPrependBindFunc));
return result;
}

Expand All @@ -512,7 +513,7 @@ function_set ListPositionFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::INT64,
nullptr, nullptr, ListPositionBindFunc, false /* isVarlength */));
nullptr, nullptr, ListPositionBindFunc));
return result;
}

Expand All @@ -528,12 +529,12 @@ function_set ListContainsFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::BOOL,
nullptr, nullptr, ListContainsBindFunc, false /* isVarlength */));
nullptr, nullptr, ListContainsBindFunc));
return result;
}

static std::unique_ptr<FunctionBindData> ListSliceBindFunc(
const binder::expression_vector& arguments, Function* /*function*/) {
const binder::expression_vector& arguments, Function*) {
return std::make_unique<FunctionBindData>(arguments[0]->getDataType().copy());
}

Expand All @@ -544,14 +545,13 @@ function_set ListSliceFunction::getFunctionSet() {
LogicalTypeID::LIST,
ScalarFunction::TernaryExecListStructFunction<list_entry_t, int64_t, int64_t, list_entry_t,
ListSlice>,
nullptr, ListSliceBindFunc, false /* isVarlength*/));
nullptr, ListSliceBindFunc));
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::INT64,
LogicalTypeID::INT64},
LogicalTypeID::STRING,
ScalarFunction::TernaryExecListStructFunction<ku_string_t, int64_t, int64_t, ku_string_t,
ListSlice>,
false /* isVarlength */));
ListSlice>));
return result;
}

Expand Down Expand Up @@ -650,14 +650,14 @@ function_set ListSortFunction::getFunctionSet() {
function_set result;
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::LIST},
LogicalTypeID::LIST, nullptr, nullptr, ListSortBindFunc, false /* isVarlength*/));
LogicalTypeID::LIST, nullptr, nullptr, ListSortBindFunc));
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::STRING}, LogicalTypeID::LIST,
nullptr, nullptr, ListSortBindFunc, false /* isVarlength*/));
nullptr, nullptr, ListSortBindFunc));
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::STRING,
LogicalTypeID::STRING},
LogicalTypeID::LIST, nullptr, nullptr, ListSortBindFunc, false /* isVarlength*/));
LogicalTypeID::LIST, nullptr, nullptr, ListSortBindFunc));
return result;
}

Expand Down Expand Up @@ -751,28 +751,28 @@ static std::unique_ptr<FunctionBindData> ListReverseSortBindFunc(

function_set ListReverseSortFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST}, LogicalTypeID::LIST, nullptr, nullptr,
ListReverseSortBindFunc, false /* isVarlength*/));
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::LIST},
LogicalTypeID::LIST, nullptr, nullptr, ListReverseSortBindFunc));
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::STRING}, LogicalTypeID::LIST,
nullptr, nullptr, ListReverseSortBindFunc, false /* isVarlength*/));
nullptr, nullptr, ListReverseSortBindFunc));
return result;
}

function_set ListSumFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST}, LogicalTypeID::INT64, nullptr, nullptr,
bindFuncListAggr<ListSum>, false /* isVarlength*/));
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::LIST},
LogicalTypeID::INT64, nullptr, nullptr, bindFuncListAggr<ListSum>));
return result;
}

function_set ListProductFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST}, LogicalTypeID::INT64, nullptr, nullptr,
bindFuncListAggr<ListProduct>, false /* isVarlength*/));
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::LIST},
LogicalTypeID::INT64, nullptr, nullptr, bindFuncListAggr<ListProduct>));
return result;
}

Expand All @@ -786,16 +786,15 @@ function_set ListDistinctFunction::getFunctionSet() {
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST}, LogicalTypeID::LIST,
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, list_entry_t, ListDistinct>,
nullptr, ListDistinctBindFunc, false /* isVarlength*/));
nullptr, ListDistinctBindFunc));
return result;
}

function_set ListUniqueFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST}, LogicalTypeID::INT64,
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique>,
false /* isVarlength*/));
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique>));
return result;
}

Expand Down Expand Up @@ -904,7 +903,7 @@ function_set ListAnyValueFunction::getFunctionSet() {
function_set result;
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::LIST},
LogicalTypeID::ANY, nullptr, nullptr, ListAnyValueBindFunc, false /* isVarlength*/));
LogicalTypeID::ANY, nullptr, nullptr, ListAnyValueBindFunc));
return result;
}

Expand All @@ -921,7 +920,7 @@ function_set ListReverseFunction::getFunctionSet() {
function_set result;
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::LIST},
LogicalTypeID::ANY, nullptr, nullptr, ListReverseBindFunc, false /* isVarlength*/));
LogicalTypeID::ANY, nullptr, nullptr, ListReverseBindFunc));
return result;
}

Expand Down
Loading

0 comments on commit 48188ce

Please sign in to comment.