Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UDF support to c++ API #1808

Merged
merged 1 commit into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,5 +463,11 @@ std::unordered_set<RelTableSchema*> Catalog::getAllRelTableSchemasContainBoundTa
return relTableSchemas;
}

void Catalog::addVectorFunction(
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
std::string name, function::vector_function_definitions definitions) {
common::StringUtils::toUpper(name);
builtInVectorFunctions->addFunction(std::move(name), std::move(definitions));
}

} // namespace catalog
} // namespace kuzu
4 changes: 4 additions & 0 deletions src/catalog/catalog_structs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ std::string getRelMultiplicityAsString(RelMultiplicity relMultiplicity) {
}
}

bool TableSchema::isReservedPropertyName(const std::string& propertyName) {
return common::StringUtils::getUpper(propertyName) == common::InternalKeyword::ID;
}

std::string TableSchema::getPropertyName(property_id_t propertyID) const {
for (auto& property : properties) {
if (property.propertyID == propertyID) {
Expand Down
1 change: 1 addition & 0 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "common/types/value.h"

#include "common/null_buffer.h"
#include "common/string_utils.h"
#include "storage/storage_utils.h"

namespace kuzu {
Expand Down
1 change: 1 addition & 0 deletions src/common/vector/auxiliary_buffer.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "common/vector/auxiliary_buffer.h"

#include "arrow/array.h"
#include "common/vector/value_vector.h"

namespace kuzu {
Expand Down
1 change: 1 addition & 0 deletions src/expression_evaluator/path_evaluator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "expression_evaluator/path_evaluator.h"

#include "binder/expression/path_expression.h"
#include "common/string_utils.h"

using namespace kuzu::common;
using namespace kuzu::binder;
Expand Down
324 changes: 167 additions & 157 deletions src/function/built_in_vector_functions.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/function/vector_path_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "function/path/vector_path_functions.h"

#include "binder/expression/literal_expression.h"
#include "common/string_utils.h"
#include "function/struct/vector_struct_functions.h"

namespace kuzu {
Expand Down
2 changes: 2 additions & 0 deletions src/include/catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ class Catalog {
std::unordered_set<RelTableSchema*> getAllRelTableSchemasContainBoundTable(
common::table_id_t boundTableID) const;

void addVectorFunction(std::string name, function::vector_function_definitions definitions);

private:
inline bool hasUpdates() { return catalogContentForWriteTrx != nullptr; }

Expand Down
5 changes: 1 addition & 4 deletions src/include/catalog/catalog_structs.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "common/constants.h"
#include "common/exception.h"
#include "common/rel_direction.h"
#include "common/string_utils.h"
#include "common/types/types_include.h"

namespace kuzu {
Expand Down Expand Up @@ -48,9 +47,7 @@ struct TableSchema {

virtual ~TableSchema() = default;

static inline bool isReservedPropertyName(const std::string& propertyName) {
return common::StringUtils::getUpper(propertyName) == common::InternalKeyword::ID;
}
static bool isReservedPropertyName(const std::string& propertyName);

inline uint32_t getNumProperties() const { return properties.size(); }

Expand Down
5 changes: 4 additions & 1 deletion src/include/common/vector/auxiliary_buffer.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#pragma once

#include "arrow/array.h"
#include "common/in_mem_overflow_buffer.h"

namespace arrow {
class Array;
}

namespace kuzu {
namespace common {

Expand Down
92 changes: 55 additions & 37 deletions src/include/function/binary_function_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct BinaryFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector) {
common::ValueVector* resultValueVector, void* dataPtr) {
OP::operation(left, right, result);
}
};
Expand All @@ -27,7 +27,7 @@ struct BinaryListStructFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector) {
common::ValueVector* resultValueVector, void* dataPtr) {
OP::operation(left, right, result, *leftValueVector, *rightValueVector, *resultValueVector);
}
};
Expand All @@ -36,7 +36,7 @@ struct BinaryStringFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector) {
common::ValueVector* resultValueVector, void* dataPtr) {
OP::operation(left, right, result, *resultValueVector);
}
};
Expand All @@ -45,53 +45,64 @@ struct BinaryComparisonFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector) {
common::ValueVector* resultValueVector, void* dataPtr) {
OP::operation(left, right, result, leftValueVector, rightValueVector);
}
};

struct BinaryUDFFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector, void* dataPtr) {
OP::operation(left, right, result, dataPtr);
}
};

struct BinaryFunctionExecutor {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static inline void executeOnValue(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& resultValueVector, uint64_t lPos, uint64_t rPos, uint64_t resPos) {
common::ValueVector& resultValueVector, uint64_t lPos, uint64_t rPos, uint64_t resPos,
void* dataPtr) {
OP_WRAPPER::template operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC>(
((LEFT_TYPE*)left.getData())[lPos], ((RIGHT_TYPE*)right.getData())[rPos],
((RESULT_TYPE*)resultValueVector.getData())[resPos], &left, &right, &resultValueVector);
((RESULT_TYPE*)resultValueVector.getData())[resPos], &left, &right, &resultValueVector,
dataPtr);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static void executeBothFlat(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
static void executeBothFlat(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
auto lPos = left.state->selVector->selectedPositions[0];
auto rPos = right.state->selVector->selectedPositions[0];
auto resPos = result.state->selVector->selectedPositions[0];
result.setNull(resPos, left.isNull(lPos) || right.isNull(rPos));
if (!result.isNull(resPos)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, rPos, resPos);
left, right, result, lPos, rPos, resPos, dataPtr);
}
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static void executeFlatUnFlat(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
static void executeFlatUnFlat(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
auto lPos = left.state->selVector->selectedPositions[0];
if (left.isNull(lPos)) {
result.setAllNull();
} else if (right.hasNoNullsGuarantee()) {
if (right.state->selVector->isUnfiltered()) {
for (auto i = 0u; i < right.state->selVector->selectedSize; ++i) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, i, i);
left, right, result, lPos, i, i, dataPtr);
}
} else {
for (auto i = 0u; i < right.state->selVector->selectedSize; ++i) {
auto rPos = right.state->selVector->selectedPositions[i];
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, rPos, rPos);
left, right, result, lPos, rPos, rPos, dataPtr);
}
}
} else {
Expand All @@ -100,7 +111,7 @@ struct BinaryFunctionExecutor {
result.setNull(i, right.isNull(i)); // left is always not null
if (!result.isNull(i)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, i, i);
left, right, result, lPos, i, i, dataPtr);
}
}
} else {
Expand All @@ -109,7 +120,7 @@ struct BinaryFunctionExecutor {
result.setNull(rPos, right.isNull(rPos)); // left is always not null
if (!result.isNull(rPos)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, rPos, rPos);
left, right, result, lPos, rPos, rPos, dataPtr);
}
}
}
Expand All @@ -118,22 +129,22 @@ struct BinaryFunctionExecutor {

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static void executeUnFlatFlat(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
static void executeUnFlatFlat(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
auto rPos = right.state->selVector->selectedPositions[0];
if (right.isNull(rPos)) {
result.setAllNull();
} else if (left.hasNoNullsGuarantee()) {
if (left.state->selVector->isUnfiltered()) {
for (auto i = 0u; i < left.state->selVector->selectedSize; ++i) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, i, rPos, i);
left, right, result, i, rPos, i, dataPtr);
}
} else {
for (auto i = 0u; i < left.state->selVector->selectedSize; ++i) {
auto lPos = left.state->selVector->selectedPositions[i];
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, rPos, lPos);
left, right, result, lPos, rPos, lPos, dataPtr);
}
}
} else {
Expand All @@ -142,7 +153,7 @@ struct BinaryFunctionExecutor {
result.setNull(i, left.isNull(i)); // right is always not null
if (!result.isNull(i)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, i, rPos, i);
left, right, result, i, rPos, i, dataPtr);
}
}
} else {
Expand All @@ -151,7 +162,7 @@ struct BinaryFunctionExecutor {
result.setNull(lPos, left.isNull(lPos)); // right is always not null
if (!result.isNull(lPos)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, lPos, rPos, lPos);
left, right, result, lPos, rPos, lPos, dataPtr);
}
}
}
Expand All @@ -160,20 +171,20 @@ struct BinaryFunctionExecutor {

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static void executeBothUnFlat(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
static void executeBothUnFlat(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
assert(left.state == right.state);
if (left.hasNoNullsGuarantee() && right.hasNoNullsGuarantee()) {
if (result.state->selVector->isUnfiltered()) {
for (uint64_t i = 0; i < result.state->selVector->selectedSize; i++) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, i, i, i);
left, right, result, i, i, i, dataPtr);
}
} else {
for (uint64_t i = 0; i < result.state->selVector->selectedSize; i++) {
auto pos = result.state->selVector->selectedPositions[i];
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, pos, pos, pos);
left, right, result, pos, pos, pos, dataPtr);
}
}
} else {
Expand All @@ -182,7 +193,7 @@ struct BinaryFunctionExecutor {
result.setNull(i, left.isNull(i) || right.isNull(i));
if (!result.isNull(i)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, i, i, i);
left, right, result, i, i, i, dataPtr);
}
}
} else {
Expand All @@ -191,7 +202,7 @@ struct BinaryFunctionExecutor {
result.setNull(pos, left.isNull(pos) || right.isNull(pos));
if (!result.isNull(pos)) {
executeOnValue<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result, pos, pos, pos);
left, right, result, pos, pos, pos, dataPtr);
}
}
}
Expand All @@ -200,21 +211,21 @@ struct BinaryFunctionExecutor {

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
typename OP_WRAPPER>
static void executeSwitch(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
static void executeSwitch(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
result.resetAuxiliaryBuffer();
if (left.state->isFlat() && right.state->isFlat()) {
executeBothFlat<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result);
left, right, result, dataPtr);
} else if (left.state->isFlat() && !right.state->isFlat()) {
executeFlatUnFlat<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result);
left, right, result, dataPtr);
} else if (!left.state->isFlat() && right.state->isFlat()) {
executeUnFlatFlat<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result);
left, right, result, dataPtr);
} else if (!left.state->isFlat() && !right.state->isFlat()) {
executeBothUnFlat<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, OP_WRAPPER>(
left, right, result);
left, right, result, dataPtr);
} else {
assert(false);
}
Expand All @@ -224,28 +235,35 @@ struct BinaryFunctionExecutor {
static void execute(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
executeSwitch<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, BinaryFunctionWrapper>(
left, right, result);
left, right, result, nullptr /* dataPtr */);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeString(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
executeSwitch<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, BinaryStringFunctionWrapper>(
left, right, result);
left, right, result, nullptr /* dataPtr */);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeListStruct(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
executeSwitch<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, BinaryListStructFunctionWrapper>(
left, right, result);
left, right, result, nullptr /* dataPtr */);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeComparison(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
executeSwitch<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, BinaryComparisonFunctionWrapper>(
left, right, result);
left, right, result, nullptr /* dataPtr */);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeUDF(common::ValueVector& left, common::ValueVector& right,
common::ValueVector& result, void* dataPtr) {
executeSwitch<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, BinaryUDFFunctionWrapper>(
left, right, result, dataPtr);
}

struct BinarySelectWrapper {
Expand Down
6 changes: 4 additions & 2 deletions src/include/function/built_in_vector_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BuiltInVectorFunctions {
BuiltInVectorFunctions() { registerVectorFunctions(); }

inline bool containsFunction(const std::string& functionName) {
return VectorFunctions.contains(functionName);
return vectorFunctions.contains(functionName);
}

/**
Expand All @@ -27,6 +27,8 @@ class BuiltInVectorFunctions {
static uint32_t getCastCost(
common::LogicalTypeID inputTypeID, common::LogicalTypeID targetTypeID);

void addFunction(std::string name, function::vector_function_definitions definitions);

private:
static uint32_t getTargetTypeCost(common::LogicalTypeID typeID);

Expand Down Expand Up @@ -75,7 +77,7 @@ class BuiltInVectorFunctions {

private:
// TODO(Ziyi): Refactor VectorFunction/tableOperation to inherit from the same base class.
std::unordered_map<std::string, vector_function_definitions> VectorFunctions;
std::unordered_map<std::string, vector_function_definitions> vectorFunctions;
};

} // namespace function
Expand Down
Loading
Loading