Skip to content

Commit

Permalink
Merge pull request #1817 from kuzudb/udf-with-param
Browse files Browse the repository at this point in the history
Add vectorized UDF support
  • Loading branch information
acquamarin committed Jul 14, 2023
2 parents 3901e66 + d57ee9c commit 1f0e674
Show file tree
Hide file tree
Showing 3 changed files with 425 additions and 136 deletions.
332 changes: 212 additions & 120 deletions src/include/function/udf_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,126 +34,218 @@ struct TernaryUDFExecutor {
}
};

template<typename RESULT_TYPE, typename... Args>
static inline function::scalar_exec_func createUnaryExecFunc(RESULT_TYPE (*udfFunc)(Args...)) {
throw common::NotImplementedException{"function::createUnaryExecFunc()"};
}

template<typename RESULT_TYPE, typename OPERAND_TYPE>
static inline function::scalar_exec_func createUnaryExecFunc(RESULT_TYPE (*udfFunc)(OPERAND_TYPE)) {
function::scalar_exec_func execFunc =
[=](const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result) -> void {
assert(params.size() == 1);
UnaryFunctionExecutor::executeUDF<OPERAND_TYPE, RESULT_TYPE, UnaryUDFExecutor>(
*params[0], result, (void*)udfFunc);
};
return execFunc;
}

template<typename RESULT_TYPE, typename... Args>
static inline function::scalar_exec_func createBinaryExecFunc(RESULT_TYPE (*udfFunc)(Args...)) {
throw common::NotImplementedException{"function::createBinaryExecFunc()"};
}

template<typename RESULT_TYPE, typename LEFT_TYPE, typename RIGHT_TYPE>
static inline function::scalar_exec_func createBinaryExecFunc(
RESULT_TYPE (*udfFunc)(LEFT_TYPE, RIGHT_TYPE)) {
function::scalar_exec_func execFunc =
[=](const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result) -> void {
assert(params.size() == 2);
BinaryFunctionExecutor::executeUDF<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, BinaryUDFExecutor>(
*params[0], *params[1], result, (void*)udfFunc);
};
return execFunc;
}

template<typename RESULT_TYPE, typename... Args>
static inline function::scalar_exec_func createTernaryExecFunc(RESULT_TYPE (*udfFunc)(Args...)) {
throw common::NotImplementedException{"function::createTernaryExecFunc()"};
}

template<typename RESULT_TYPE, typename A_TYPE, typename B_TYPE, typename C_TYPE>
static inline function::scalar_exec_func createTernaryExecFunc(
RESULT_TYPE (*udfFunc)(A_TYPE, B_TYPE, C_TYPE)) {
function::scalar_exec_func execFunc =
[=](const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result) -> void {
assert(params.size() == 3);
TernaryFunctionExecutor::executeUDF<A_TYPE, B_TYPE, C_TYPE, RESULT_TYPE,
TernaryUDFExecutor>(*params[0], *params[1], *params[2], result, (void*)udfFunc);
};
return execFunc;
}

template<typename TR, typename... Args>
inline static scalar_exec_func getScalarExecFunc(TR (*udfFunc)(Args...)) {
constexpr auto numArgs = sizeof...(Args);
switch (numArgs) {
case 1:
return createUnaryExecFunc<TR, Args...>(udfFunc);
case 2:
return createBinaryExecFunc<TR, Args...>(udfFunc);
case 3:
return createTernaryExecFunc<TR, Args...>(udfFunc);
default:
throw common::BinderException("UDF function only supported until ternary!");
}
}

template<typename T>
inline static common::LogicalTypeID getParameterType() {
if (std::is_same<T, bool>()) {
return common::LogicalTypeID::BOOL;
} else if (std::is_same<T, int16_t>()) {
return common::LogicalTypeID::INT16;
} else if (std::is_same<T, int32_t>()) {
return common::LogicalTypeID::INT32;
} else if (std::is_same<T, int64_t>()) {
return common::LogicalTypeID::INT64;
} else if (std::is_same<T, float_t>()) {
return common::LogicalTypeID::FLOAT;
} else if (std::is_same<T, double_t>()) {
return common::LogicalTypeID::DOUBLE;
} else if (std::is_same<T, common::ku_string_t>()) {
return common::LogicalTypeID::STRING;
} else {
throw common::NotImplementedException{"function::getParameterType"};
}
}

template<typename TA>
inline static void getParameterTypesRecursive(std::vector<common::LogicalTypeID>& arguments) {
arguments.push_back(getParameterType<TA>());
}

template<typename TA, typename TB, typename... Args>
inline static void getParameterTypesRecursive(std::vector<common::LogicalTypeID>& arguments) {
arguments.push_back(getParameterType<TA>());
getParameterTypesRecursive<TB, Args...>(arguments);
}

template<typename TR, typename... Args>
inline static std::unique_ptr<VectorFunctionDefinition> getFunctionDefinition(
const std::string& name, TR (*udfFunc)(Args...),
std::vector<common::LogicalTypeID> parameterTypes, common::LogicalTypeID returnType) {
function::scalar_exec_func scalarExecFunc = function::getScalarExecFunc<TR, Args...>(udfFunc);
return std::make_unique<function::VectorFunctionDefinition>(
name, std::move(parameterTypes), returnType, std::move(scalarExecFunc));
}

template<typename TR, typename... Args>
inline static std::unique_ptr<VectorFunctionDefinition> getFunctionDefinition(
const std::string& name, TR (*udfFunc)(Args...)) {
std::vector<common::LogicalTypeID> parameterTypes;
getParameterTypesRecursive<Args...>(parameterTypes);
common::LogicalTypeID returnType = getParameterType<TR>();
if (returnType == common::LogicalTypeID::STRING) {
throw common::NotImplementedException{"function::getFunctionDefinition"};
}
return getFunctionDefinition<TR, Args...>(name, udfFunc, std::move(parameterTypes), returnType);
}
struct UDF {
template<typename T>
static bool templateValidateType(const common::LogicalTypeID& type) {
switch (type) {
case common::LogicalTypeID::BOOL:
return std::is_same<T, bool>();
case common::LogicalTypeID::INT16:
return std::is_same<T, int16_t>();
case common::LogicalTypeID::INT32:
return std::is_same<T, int32_t>();
case common::LogicalTypeID::INT64:
return std::is_same<T, int64_t>();
case common::LogicalTypeID::FLOAT:
return std::is_same<T, float>();
case common::LogicalTypeID::DOUBLE:
return std::is_same<T, double>();
case common::LogicalTypeID::DATE:
return std::is_same<T, int32_t>();
case common::LogicalTypeID::TIMESTAMP:
return std::is_same<T, int64_t>();
case common::LogicalTypeID::STRING:
return std::is_same<T, common::ku_string_t>();
case common::LogicalTypeID::BLOB:
return std::is_same<T, common::blob_t>();
default:
throw common::NotImplementedException{"function::validateType"};
}
}

template<typename T>
static void validateType(const common::LogicalTypeID& type) {
if (!templateValidateType<T>(type)) {
throw common::CatalogException{
"Incompatible udf parameter/return type and templated type."};
}
}

template<typename RESULT_TYPE, typename... Args>
static function::scalar_exec_func createUnaryExecFunc(
RESULT_TYPE (*udfFunc)(Args...), std::vector<common::LogicalTypeID> parameterTypes) {
throw common::NotImplementedException{"function::createUnaryExecFunc()"};
}

template<typename RESULT_TYPE, typename OPERAND_TYPE>
static function::scalar_exec_func createUnaryExecFunc(
RESULT_TYPE (*udfFunc)(OPERAND_TYPE), std::vector<common::LogicalTypeID> parameterTypes) {
if (parameterTypes.size() != 1) {
throw common::CatalogException{
"Expected exactly one parameter type for unary udf. Got: " +
std::to_string(parameterTypes.size()) + "."};
}
validateType<OPERAND_TYPE>(parameterTypes[0]);
function::scalar_exec_func execFunc =
[=](const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result) -> void {
assert(params.size() == 1);
UnaryFunctionExecutor::executeUDF<OPERAND_TYPE, RESULT_TYPE, UnaryUDFExecutor>(
*params[0], result, (void*)udfFunc);
};
return execFunc;
}

template<typename RESULT_TYPE, typename... Args>
static function::scalar_exec_func createBinaryExecFunc(
RESULT_TYPE (*udfFunc)(Args...), std::vector<common::LogicalTypeID> parameterTypes) {
throw common::NotImplementedException{"function::createBinaryExecFunc()"};
}

template<typename RESULT_TYPE, typename LEFT_TYPE, typename RIGHT_TYPE>
static function::scalar_exec_func createBinaryExecFunc(
RESULT_TYPE (*udfFunc)(LEFT_TYPE, RIGHT_TYPE),
std::vector<common::LogicalTypeID> parameterTypes) {
if (parameterTypes.size() != 2) {
throw common::CatalogException{
"Expected exactly two parameter types for binary udf. Got: " +
std::to_string(parameterTypes.size()) + "."};
}
validateType<LEFT_TYPE>(parameterTypes[0]);
validateType<RIGHT_TYPE>(parameterTypes[1]);
function::scalar_exec_func execFunc =
[=](const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result) -> void {
assert(params.size() == 2);
BinaryFunctionExecutor::executeUDF<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE,
BinaryUDFExecutor>(*params[0], *params[1], result, (void*)udfFunc);
};
return execFunc;
}

template<typename RESULT_TYPE, typename... Args>
static function::scalar_exec_func createTernaryExecFunc(
RESULT_TYPE (*udfFunc)(Args...), std::vector<common::LogicalTypeID> parameterTypes) {
throw common::NotImplementedException{"function::createTernaryExecFunc()"};
}

template<typename RESULT_TYPE, typename A_TYPE, typename B_TYPE, typename C_TYPE>
static function::scalar_exec_func createTernaryExecFunc(
RESULT_TYPE (*udfFunc)(A_TYPE, B_TYPE, C_TYPE),
std::vector<common::LogicalTypeID> parameterTypes) {
if (parameterTypes.size() != 3) {
throw common::CatalogException{
"Expected exactly three parameter types for ternary udf. Got: " +
std::to_string(parameterTypes.size()) + "."};
}
validateType<A_TYPE>(parameterTypes[0]);
validateType<B_TYPE>(parameterTypes[1]);
validateType<C_TYPE>(parameterTypes[2]);
function::scalar_exec_func execFunc =
[=](const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result) -> void {
assert(params.size() == 3);
TernaryFunctionExecutor::executeUDF<A_TYPE, B_TYPE, C_TYPE, RESULT_TYPE,
TernaryUDFExecutor>(*params[0], *params[1], *params[2], result, (void*)udfFunc);
};
return execFunc;
}

template<typename TR, typename... Args>
static scalar_exec_func getScalarExecFunc(
TR (*udfFunc)(Args...), std::vector<common::LogicalTypeID> parameterTypes) {
constexpr auto numArgs = sizeof...(Args);
switch (numArgs) {
case 1:
return createUnaryExecFunc<TR, Args...>(udfFunc, std::move(parameterTypes));
case 2:
return createBinaryExecFunc<TR, Args...>(udfFunc, std::move(parameterTypes));
case 3:
return createTernaryExecFunc<TR, Args...>(udfFunc, std::move(parameterTypes));
default:
throw common::BinderException("UDF function only supported until ternary!");
}
}

template<typename T>
static common::LogicalTypeID getParameterType() {
if (std::is_same<T, bool>()) {
return common::LogicalTypeID::BOOL;
} else if (std::is_same<T, int16_t>()) {
return common::LogicalTypeID::INT16;
} else if (std::is_same<T, int32_t>()) {
return common::LogicalTypeID::INT32;
} else if (std::is_same<T, int64_t>()) {
return common::LogicalTypeID::INT64;
} else if (std::is_same<T, float_t>()) {
return common::LogicalTypeID::FLOAT;
} else if (std::is_same<T, double_t>()) {
return common::LogicalTypeID::DOUBLE;
} else if (std::is_same<T, common::ku_string_t>()) {
return common::LogicalTypeID::STRING;
} else {
throw common::NotImplementedException{"function::getParameterType"};
}
}

template<typename TA>
static void getParameterTypesRecursive(std::vector<common::LogicalTypeID>& arguments) {
arguments.push_back(getParameterType<TA>());
}

template<typename TA, typename TB, typename... Args>
static void getParameterTypesRecursive(std::vector<common::LogicalTypeID>& arguments) {
arguments.push_back(getParameterType<TA>());
getParameterTypesRecursive<TB, Args...>(arguments);
}

template<typename... Args>
static std::vector<common::LogicalTypeID> getParameterTypes() {
std::vector<common::LogicalTypeID> parameterTypes;
getParameterTypesRecursive<Args...>(parameterTypes);
return parameterTypes;
}

template<typename TR, typename... Args>
static vector_function_definitions getFunctionDefinition(const std::string& name,
TR (*udfFunc)(Args...), std::vector<common::LogicalTypeID> parameterTypes,
common::LogicalTypeID returnType) {
vector_function_definitions definitions;
if (returnType == common::LogicalTypeID::STRING) {
throw common::NotImplementedException{"function::getFunctionDefinition"};
}
validateType<TR>(returnType);
scalar_exec_func scalarExecFunc = getScalarExecFunc<TR, Args...>(udfFunc, parameterTypes);
definitions.push_back(std::make_unique<function::VectorFunctionDefinition>(
name, std::move(parameterTypes), returnType, std::move(scalarExecFunc)));
return definitions;
}

template<typename TR, typename... Args>
static vector_function_definitions getFunctionDefinition(
const std::string& name, TR (*udfFunc)(Args...)) {
return getFunctionDefinition<TR, Args...>(
name, udfFunc, getParameterTypes<Args...>(), getParameterType<TR>());
}

template<typename TR, typename... Args>
static vector_function_definitions getVectorizedFunctionDefinition(
const std::string& name, scalar_exec_func execFunc) {
vector_function_definitions definitions;
definitions.push_back(std::make_unique<function::VectorFunctionDefinition>(
name, getParameterTypes<Args...>(), getParameterType<TR>(), std::move(execFunc)));
return definitions;
}

static vector_function_definitions getVectorizedFunctionDefinition(const std::string& name,
scalar_exec_func execFunc, std::vector<common::LogicalTypeID> parameterTypes,
common::LogicalTypeID returnType) {
vector_function_definitions definitions;
definitions.push_back(std::make_unique<function::VectorFunctionDefinition>(
name, std::move(parameterTypes), returnType, std::move(execFunc)));
return definitions;
}
};

} // namespace function
} // namespace kuzu
28 changes: 25 additions & 3 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,31 @@ class Connection {

template<typename TR, typename... Args>
void createScalarFunction(const std::string& name, TR (*udfFunc)(Args...)) {
function::vector_function_definitions definitions;
auto definition = function::getFunctionDefinition<TR, Args...>(name, udfFunc);
definitions.push_back(std::move(definition));
auto definitions = function::UDF::getFunctionDefinition<TR, Args...>(name, udfFunc);
addScalarFunction(name, std::move(definitions));
}

template<typename TR, typename... Args>
void createScalarFunction(const std::string& name,
std::vector<common::LogicalTypeID> parameterTypes, common::LogicalTypeID returnType,
TR (*udfFunc)(Args...)) {
auto definitions = function::UDF::getFunctionDefinition<TR, Args...>(
name, udfFunc, std::move(parameterTypes), returnType);
addScalarFunction(name, std::move(definitions));
}

template<typename TR, typename... Args>
void createVectorizedFunction(const std::string& name, function::scalar_exec_func scalarFunc) {
auto definitions = function::UDF::getVectorizedFunctionDefinition<TR, Args...>(
name, std::move(scalarFunc));
addScalarFunction(name, std::move(definitions));
}

void createVectorizedFunction(const std::string& name,
std::vector<common::LogicalTypeID> parameterTypes, common::LogicalTypeID returnType,
function::scalar_exec_func scalarFunc) {
auto definitions = function::UDF::getVectorizedFunctionDefinition(
name, std::move(scalarFunc), std::move(parameterTypes), returnType);
addScalarFunction(name, std::move(definitions));
}

Expand Down
Loading

0 comments on commit 1f0e674

Please sign in to comment.