Skip to content

Commit

Permalink
Add UDF support to c++ API
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Jul 12, 2023
1 parent 7ec4007 commit e34f0cd
Show file tree
Hide file tree
Showing 12 changed files with 633 additions and 282 deletions.
6 changes: 6 additions & 0 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,5 +463,11 @@ std::unordered_set<RelTableSchema*> Catalog::getAllRelTableSchemasContainBoundTa
return relTableSchemas;
}

void Catalog::addVectorFunction(
std::string name, function::vector_function_definitions definitions) {
common::StringUtils::toUpper(name);
builtInVectorFunctions->addFunction(std::move(name), std::move(definitions));
}

} // namespace catalog
} // namespace kuzu
314 changes: 157 additions & 157 deletions src/function/built_in_vector_functions.cpp

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/include/catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ class Catalog {
std::unordered_set<RelTableSchema*> getAllRelTableSchemasContainBoundTable(
common::table_id_t boundTableID) const;

void addVectorFunction(std::string name, function::vector_function_definitions definitions);

private:
inline bool hasUpdates() { return catalogContentForWriteTrx != nullptr; }

Expand Down
92 changes: 55 additions & 37 deletions src/include/function/binary_function_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct BinaryFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector) {
common::ValueVector* resultValueVector, void* dataPtr) {
OP::operation(left, right, result);
}
};
Expand All @@ -27,7 +27,7 @@ struct BinaryListStructFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector) {
common::ValueVector* resultValueVector, void* dataPtr) {
OP::operation(left, right, result, *leftValueVector, *rightValueVector, *resultValueVector);
}
};
Expand All @@ -36,7 +36,7 @@ struct BinaryStringFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector) {
common::ValueVector* resultValueVector, void* dataPtr) {
OP::operation(left, right, result, *resultValueVector);
}
};
Expand All @@ -45,53 +45,64 @@ struct BinaryComparisonFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector) {
common::ValueVector* resultValueVector, void* dataPtr) {
OP::operation(left, right, result, leftValueVector, rightValueVector);
}
};

struct BinaryUDFFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector, void* dataPtr) {
OP::operation(left, right, result, dataPtr);
}
};

struct BinaryFunctionExecutor {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static inline void executeOnValue(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& resultValueVector, uint64_t lPos, uint64_t rPos, uint64_t resPos) {
common::ValueVector& resultValueVector, uint64_t lPos, uint64_t rPos, uint64_t resPos,
void* dataPtr) {
OP_WRAPPER::template operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC>(
((LEFT_TYPE*)left.getData())[lPos], ((RIGHT_TYPE*)right.getData())[rPos],
((RESULT_TYPE*)resultValueVector.getData())[resPos], &left, &right, &resultValueVector);
((RESULT_TYPE*)resultValueVector.getData())[resPos], &left, &right, &resultValueVector,
dataPtr);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static void executeBothFlat(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
static void executeBothFlat(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
auto lPos = left.state->selVector->selectedPositions[0];
auto rPos = right.state->selVector->selectedPositions[0];
auto resPos = result.state->selVector->selectedPositions[0];
result.setNull(resPos, left.isNull(lPos) || right.isNull(rPos));
if (!result.isNull(resPos)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, rPos, resPos);
left, right, result, lPos, rPos, resPos, dataPtr);
}
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static void executeFlatUnFlat(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
static void executeFlatUnFlat(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
auto lPos = left.state->selVector->selectedPositions[0];
if (left.isNull(lPos)) {
result.setAllNull();
} else if (right.hasNoNullsGuarantee()) {
if (right.state->selVector->isUnfiltered()) {
for (auto i = 0u; i < right.state->selVector->selectedSize; ++i) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, i, i);
left, right, result, lPos, i, i, dataPtr);
}
} else {
for (auto i = 0u; i < right.state->selVector->selectedSize; ++i) {
auto rPos = right.state->selVector->selectedPositions[i];
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, rPos, rPos);
left, right, result, lPos, rPos, rPos, dataPtr);
}
}
} else {
Expand All @@ -100,7 +111,7 @@ struct BinaryFunctionExecutor {
result.setNull(i, right.isNull(i)); // left is always not null
if (!result.isNull(i)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, i, i);
left, right, result, lPos, i, i, dataPtr);
}
}
} else {
Expand All @@ -109,7 +120,7 @@ struct BinaryFunctionExecutor {
result.setNull(rPos, right.isNull(rPos)); // left is always not null
if (!result.isNull(rPos)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, rPos, rPos);
left, right, result, lPos, rPos, rPos, dataPtr);
}
}
}
Expand All @@ -118,22 +129,22 @@ struct BinaryFunctionExecutor {

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static void executeUnFlatFlat(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
static void executeUnFlatFlat(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
auto rPos = right.state->selVector->selectedPositions[0];
if (right.isNull(rPos)) {
result.setAllNull();
} else if (left.hasNoNullsGuarantee()) {
if (left.state->selVector->isUnfiltered()) {
for (auto i = 0u; i < left.state->selVector->selectedSize; ++i) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, i, rPos, i);
left, right, result, i, rPos, i, dataPtr);
}
} else {
for (auto i = 0u; i < left.state->selVector->selectedSize; ++i) {
auto lPos = left.state->selVector->selectedPositions[i];
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, rPos, lPos);
left, right, result, lPos, rPos, lPos, dataPtr);
}
}
} else {
Expand All @@ -142,7 +153,7 @@ struct BinaryFunctionExecutor {
result.setNull(i, left.isNull(i)); // right is always not null
if (!result.isNull(i)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, i, rPos, i);
left, right, result, i, rPos, i, dataPtr);
}
}
} else {
Expand All @@ -151,7 +162,7 @@ struct BinaryFunctionExecutor {
result.setNull(lPos, left.isNull(lPos)); // right is always not null
if (!result.isNull(lPos)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, rPos, lPos);
left, right, result, lPos, rPos, lPos, dataPtr);
}
}
}
Expand All @@ -160,20 +171,20 @@ struct BinaryFunctionExecutor {

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static void executeBothUnFlat(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
static void executeBothUnFlat(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
assert(left.state == right.state);
if (left.hasNoNullsGuarantee() && right.hasNoNullsGuarantee()) {
if (result.state->selVector->isUnfiltered()) {
for (uint64_t i = 0; i < result.state->selVector->selectedSize; i++) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, i, i, i);
left, right, result, i, i, i, dataPtr);
}
} else {
for (uint64_t i = 0; i < result.state->selVector->selectedSize; i++) {
auto pos = result.state->selVector->selectedPositions[i];
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, pos, pos, pos);
left, right, result, pos, pos, pos, dataPtr);
}
}
} else {
Expand All @@ -182,7 +193,7 @@ struct BinaryFunctionExecutor {
result.setNull(i, left.isNull(i) || right.isNull(i));
if (!result.isNull(i)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, i, i, i);
left, right, result, i, i, i, dataPtr);
}
}
} else {
Expand All @@ -191,7 +202,7 @@ struct BinaryFunctionExecutor {
result.setNull(pos, left.isNull(pos) || right.isNull(pos));
if (!result.isNull(pos)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, pos, pos, pos);
left, right, result, pos, pos, pos, dataPtr);
}
}
}
Expand All @@ -200,21 +211,21 @@ struct BinaryFunctionExecutor {

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static void executeSwitch(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
static void executeSwitch(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
result.resetAuxiliaryBuffer();
if (left.state->isFlat() && right.state->isFlat()) {
executeBothFlat<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result);
left, right, result, dataPtr);
} else if (left.state->isFlat() && !right.state->isFlat()) {
executeFlatUnFlat<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result);
left, right, result, dataPtr);
} else if (!left.state->isFlat() && right.state->isFlat()) {
executeUnFlatFlat<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result);
left, right, result, dataPtr);
} else if (!left.state->isFlat() && !right.state->isFlat()) {
executeBothUnFlat<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result);
left, right, result, dataPtr);
} else {
assert(false);
}
Expand All @@ -224,28 +235,35 @@ struct BinaryFunctionExecutor {
static void execute(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
executeSwitch<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, BinaryFunctionWrapper>(
left, right, result);
left, right, result, nullptr /* dataPtr */);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeString(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
executeSwitch<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, BinaryStringFunctionWrapper>(
left, right, result);
left, right, result, nullptr /* dataPtr */);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeListStruct(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
executeSwitch<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, BinaryListStructFunctionWrapper>(
left, right, result);
left, right, result, nullptr /* dataPtr */);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeComparison(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
executeSwitch<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, BinaryComparisonFunctionWrapper>(
left, right, result);
left, right, result, nullptr /* dataPtr */);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeUDF(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
executeSwitch<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, BinaryUDFFunctionWrapper>(
left, right, result, dataPtr);
}

struct BinarySelectWrapper {
Expand Down
12 changes: 10 additions & 2 deletions src/include/function/built_in_vector_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BuiltInVectorFunctions {
BuiltInVectorFunctions() { registerVectorFunctions(); }

inline bool containsFunction(const std::string& functionName) {
return VectorFunctions.contains(functionName);
return vectorFunctions.contains(functionName);
}

/**
Expand All @@ -27,6 +27,14 @@ class BuiltInVectorFunctions {
static uint32_t getCastCost(
common::LogicalTypeID inputTypeID, common::LogicalTypeID targetTypeID);

inline void addFunction(std::string name, function::vector_function_definitions definitions) {
if (vectorFunctions.contains(name)) {
throw common::CatalogException{
common::StringUtils::string_format("function {} already exists.", name)};
}
vectorFunctions.emplace(std::move(name), std::move(definitions));
}

private:
static uint32_t getTargetTypeCost(common::LogicalTypeID typeID);

Expand Down Expand Up @@ -75,7 +83,7 @@ class BuiltInVectorFunctions {

private:
// TODO(Ziyi): Refactor VectorFunction/tableOperation to inherit from the same base class.
std::unordered_map<std::string, vector_function_definitions> VectorFunctions;
std::unordered_map<std::string, vector_function_definitions> vectorFunctions;
};

} // namespace function
Expand Down
Loading

0 comments on commit e34f0cd

Please sign in to comment.