Skip to content

Commit

Permalink
Add scalar_compile_func and nodes/rels function
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jun 14, 2023
1 parent abbb617 commit 896ff12
Show file tree
Hide file tree
Showing 20 changed files with 266 additions and 150 deletions.
2 changes: 1 addition & 1 deletion src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(functionName, FUNCTION, std::move(bindData),
std::move(childrenAfterCast), function->execFunc, function->selectFunc,
uniqueExpressionName);
function->compileFunc, uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
Expand Down
39 changes: 6 additions & 33 deletions src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "expression_evaluator/function_evaluator.h"

#include "binder/expression/function_expression.h"
#include "function/struct/vector_struct_operations.h"

using namespace kuzu::common;
using namespace kuzu::processor;
Expand All @@ -16,17 +15,13 @@ void FunctionExpressionEvaluator::init(const ResultSet& resultSet, MemoryManager
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::BOOL) {
selectFunc = ((binder::ScalarFunctionExpression&)*expression).selectFunc;
}
for (auto& child : children) {
parameters.push_back(child->resultVector);
}
}

void FunctionExpressionEvaluator::evaluate() {
for (auto& child : children) {
child->evaluate();
}
if (execFunc != nullptr) {
// Some functions are evaluated at compile time (e.g. struct_extract).
execFunc(parameters, *resultVector);
}
}
Expand Down Expand Up @@ -63,41 +58,19 @@ std::unique_ptr<BaseExpressionEvaluator> FunctionExpressionEvaluator::clone() {

void FunctionExpressionEvaluator::resolveResultVector(
const ResultSet& resultSet, MemoryManager* memoryManager) {
auto& functionExpression = (binder::ScalarFunctionExpression&)*expression;
auto functionName = functionExpression.getFunctionName();
if (functionName == STRUCT_EXTRACT_FUNC_NAME || functionName == UNION_EXTRACT_FUNC_NAME) {
auto& bindData = (function::StructExtractBindData&)*functionExpression.getBindData();
resultVector =
StructVector::getFieldVector(children[0]->resultVector.get(), bindData.childIdx);
} else {
resultVector = std::make_shared<ValueVector>(expression->dataType, memoryManager);
for (auto& child : children) {
parameters.push_back(child->resultVector);
}
resultVector = std::make_shared<ValueVector>(expression->dataType, memoryManager);
std::vector<BaseExpressionEvaluator*> inputEvaluators;
inputEvaluators.reserve(children.size());
for (auto& child : children) {
inputEvaluators.push_back(child.get());
}
resolveResultStateFromChildren(inputEvaluators);
// TODO(Ziyi): We should move result valueVector state resolution to each function.
if (functionExpression.getFunctionName() == STRUCT_PACK_FUNC_NAME) {
// Our goal is to make the state of the resultVector consistent with its children vectors.
// If the resultVector and inputVector are in different dataChunks, we should create a new
// child valueVector, which shares the state with the resultVector, instead of reusing the
// inputVector.
for (auto i = 0u; i < inputEvaluators.size(); i++) {
auto inputEvaluator = inputEvaluators[i];
if (inputEvaluator->resultVector->state == resultVector->state) {
common::StructVector::referenceVector(
resultVector.get(), i, inputEvaluator->resultVector);
}
}
} else if (functionExpression.getFunctionName() == UNION_VALUE_FUNC_NAME) {
assert(inputEvaluators.size() == 1);
resultVector->setState(inputEvaluators[0]->resultVector->state);
common::UnionVector::getTagVector(resultVector.get())
->setState(inputEvaluators[0]->resultVector->state);
common::UnionVector::referenceVector(
resultVector.get(), 0 /* fieldIdx */, inputEvaluators[0]->resultVector);
auto& functionExpression = (binder::ScalarFunctionExpression&)*expression;
if (functionExpression.compileFunc != nullptr) {
functionExpression.compileFunc(functionExpression.getBindData(), parameters, resultVector);
}
}

Expand Down
1 change: 1 addition & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_library(kuzu_function
vector_hash_operations.cpp
vector_list_operation.cpp
vector_null_operations.cpp
vector_node_rel_operations.cpp
vector_string_operations.cpp
vector_timestamp_operations.cpp
vector_struct_operations.cpp
Expand Down
11 changes: 8 additions & 3 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "function/interval/vector_interval_operations.h"
#include "function/list/vector_list_operations.h"
#include "function/map/vector_map_operations.h"
#include "function/schema/vector_offset_operations.h"
#include "function/schema/vector_node_rel_operations.h"
#include "function/string/vector_string_operations.h"
#include "function/struct/vector_struct_operations.h"
#include "function/timestamp/vector_timestamp_operations.h"
Expand All @@ -30,8 +30,7 @@ void BuiltInVectorOperations::registerVectorOperations() {
registerStructOperations();
registerMapOperations();
registerUnionOperations();
// register internal offset operation
vectorOperations.insert({OFFSET_FUNC_NAME, OffsetVectorOperation::getDefinitions()});
registerNodeRelOperations();
}

bool BuiltInVectorOperations::canApplyStaticEvaluation(
Expand Down Expand Up @@ -505,5 +504,11 @@ void BuiltInVectorOperations::registerUnionOperations() {
{UNION_EXTRACT_FUNC_NAME, UnionExtractVectorOperations::getDefinitions()});
}

void BuiltInVectorOperations::registerNodeRelOperations() {
vectorOperations.insert({OFFSET_FUNC_NAME, OffsetVectorOperation::getDefinitions()});
vectorOperations.insert({NODES_FUNC_NAME, NodesVectorOperation::getDefinitions()});
vectorOperations.insert({RELS_FUNC_NAME, RelsVectorOperation::getDefinitions()});
}

} // namespace function
} // namespace kuzu
117 changes: 58 additions & 59 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,7 @@ template<typename T>
void ListSortVectorOperation::getExecFunction(
const binder::expression_vector& arguments, scalar_exec_func& func) {
if (arguments.size() == 1) {
func = UnaryExecListStructFunctionWithVectors<list_entry_t, list_entry_t,
operation::ListSort<T>>;
func = UnaryExecListStructFunction<list_entry_t, list_entry_t, operation::ListSort<T>>;
return;
} else if (arguments.size() == 2) {
func = BinaryExecListStructFunction<list_entry_t, ku_string_t, list_entry_t,
Expand Down Expand Up @@ -472,8 +471,8 @@ template<typename T>
void ListReverseSortVectorOperation::getExecFunction(
const binder::expression_vector& arguments, scalar_exec_func& func) {
if (arguments.size() == 1) {
func = UnaryExecListStructFunctionWithVectors<list_entry_t, list_entry_t,
operation::ListReverseSort<T>>;
func =
UnaryExecListStructFunction<list_entry_t, list_entry_t, operation::ListReverseSort<T>>;
return;
} else if (arguments.size() == 2) {
func = BinaryExecListStructFunction<list_entry_t, ku_string_t, list_entry_t,
Expand All @@ -500,23 +499,23 @@ std::unique_ptr<FunctionBindData> ListSumVectorOperation::bindFunc(
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, int64_t, operation::ListSum>;
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListSum>;
} break;
case LogicalTypeID::INT32: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, int32_t, operation::ListSum>;
UnaryExecListStructFunction<list_entry_t, int32_t, operation::ListSum>;

Check warning on line 506 in src/function/vector_list_operation.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_list_operation.cpp#L506

Added line #L506 was not covered by tests
} break;
case LogicalTypeID::INT16: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, int16_t, operation::ListSum>;
UnaryExecListStructFunction<list_entry_t, int16_t, operation::ListSum>;

Check warning on line 510 in src/function/vector_list_operation.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_list_operation.cpp#L510

Added line #L510 was not covered by tests
} break;
case LogicalTypeID::DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, double_t, operation::ListSum>;
UnaryExecListStructFunction<list_entry_t, double_t, operation::ListSum>;
} break;
case LogicalTypeID::FLOAT: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, float_t, operation::ListSum>;
UnaryExecListStructFunction<list_entry_t, float_t, operation::ListSum>;

Check warning on line 518 in src/function/vector_list_operation.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_list_operation.cpp#L518

Added line #L518 was not covered by tests
} break;
default: {
throw common::NotImplementedException("ListSumVectorOperation::bindFunc");
Expand All @@ -539,47 +538,47 @@ std::unique_ptr<FunctionBindData> ListDistinctVectorOperation::bindFunc(
switch (VarListType::getChildType(&arguments[0]->dataType)->getLogicalTypeID()) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t,
list_entry_t, operation::ListDistinct<int64_t>>;
} break;
case LogicalTypeID::INT32: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t,
list_entry_t, operation::ListDistinct<int32_t>>;
} break;
case LogicalTypeID::INT16: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t,
list_entry_t, operation::ListDistinct<int16_t>>;
} break;
case LogicalTypeID::DOUBLE: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t,
list_entry_t, operation::ListDistinct<double_t>>;
} break;
case LogicalTypeID::FLOAT: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t,
list_entry_t, operation::ListDistinct<float_t>>;
} break;
case LogicalTypeID::BOOL: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t,
list_entry_t, operation::ListDistinct<uint8_t>>;
} break;
case LogicalTypeID::STRING: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t,
list_entry_t, operation::ListDistinct<ku_string_t>>;
} break;
case LogicalTypeID::DATE: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t,
list_entry_t, operation::ListDistinct<date_t>>;
} break;
case LogicalTypeID::TIMESTAMP: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t,
list_entry_t, operation::ListDistinct<timestamp_t>>;
} break;
case LogicalTypeID::INTERVAL: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t,
list_entry_t, operation::ListDistinct<interval_t>>;
} break;
case LogicalTypeID::INTERNAL_ID: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t,
list_entry_t, operation::ListDistinct<internalID_t>>;
} break;
default: {
Expand All @@ -603,48 +602,48 @@ std::unique_ptr<FunctionBindData> ListUniqueVectorOperation::bindFunc(
switch (common::VarListType::getChildType(&arguments[0]->dataType)->getLogicalTypeID()) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
int64_t, operation::ListUnique<int64_t>>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListUnique<int64_t>>;
} break;
case LogicalTypeID::INT32: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
int64_t, operation::ListUnique<int32_t>>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListUnique<int32_t>>;
} break;
case LogicalTypeID::INT16: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
int64_t, operation::ListUnique<int16_t>>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListUnique<int16_t>>;
} break;
case LogicalTypeID::DOUBLE: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
int64_t, operation::ListUnique<double_t>>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListUnique<double_t>>;
} break;
case LogicalTypeID::FLOAT: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
int64_t, operation::ListUnique<float_t>>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListUnique<float_t>>;
} break;
case LogicalTypeID::BOOL: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
int64_t, operation::ListUnique<uint8_t>>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListUnique<uint8_t>>;
} break;
case LogicalTypeID::STRING: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
int64_t, operation::ListUnique<ku_string_t>>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListUnique<ku_string_t>>;
} break;
case LogicalTypeID::DATE: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
int64_t, operation::ListUnique<date_t>>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListUnique<date_t>>;
} break;
case LogicalTypeID::TIMESTAMP: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
int64_t, operation::ListUnique<timestamp_t>>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListUnique<timestamp_t>>;
} break;
case LogicalTypeID::INTERVAL: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
int64_t, operation::ListUnique<interval_t>>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListUnique<interval_t>>;
} break;
case LogicalTypeID::INTERNAL_ID: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
int64_t, operation::ListUnique<common::internalID_t>>;
vectorOperationDefinition->execFunc = UnaryExecListStructFunction<list_entry_t, int64_t,
operation::ListUnique<common::internalID_t>>;
} break;
default: {
throw common::NotImplementedException("ListUniqueVectorOperation::bindFunc");
Expand All @@ -669,51 +668,51 @@ std::unique_ptr<FunctionBindData> ListAnyValueVectorOperation::bindFunc(
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, int64_t, operation::ListAnyValue>;
UnaryExecListStructFunction<list_entry_t, int64_t, operation::ListAnyValue>;
} break;
case LogicalTypeID::INT32: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, int32_t, operation::ListAnyValue>;
UnaryExecListStructFunction<list_entry_t, int32_t, operation::ListAnyValue>;
} break;
case LogicalTypeID::INT16: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, int16_t, operation::ListAnyValue>;
UnaryExecListStructFunction<list_entry_t, int16_t, operation::ListAnyValue>;
} break;
case LogicalTypeID::DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, double_t, operation::ListAnyValue>;
UnaryExecListStructFunction<list_entry_t, double_t, operation::ListAnyValue>;
} break;
case LogicalTypeID::FLOAT: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, float_t, operation::ListAnyValue>;
UnaryExecListStructFunction<list_entry_t, float_t, operation::ListAnyValue>;
} break;
case LogicalTypeID::BOOL: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, uint8_t, operation::ListAnyValue>;
UnaryExecListStructFunction<list_entry_t, uint8_t, operation::ListAnyValue>;
} break;
case LogicalTypeID::STRING: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
ku_string_t, operation::ListAnyValue>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, ku_string_t, operation::ListAnyValue>;
} break;
case LogicalTypeID::DATE: {
vectorOperationDefinition->execFunc =
UnaryExecListStructFunctionWithVectors<list_entry_t, date_t, operation::ListAnyValue>;
UnaryExecListStructFunction<list_entry_t, date_t, operation::ListAnyValue>;
} break;
case LogicalTypeID::TIMESTAMP: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
timestamp_t, operation::ListAnyValue>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, timestamp_t, operation::ListAnyValue>;
} break;
case LogicalTypeID::INTERVAL: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
interval_t, operation::ListAnyValue>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, interval_t, operation::ListAnyValue>;
} break;
case LogicalTypeID::VAR_LIST: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
list_entry_t, operation::ListAnyValue>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, list_entry_t, operation::ListAnyValue>;
} break;
case LogicalTypeID::INTERNAL_ID: {
vectorOperationDefinition->execFunc = UnaryExecListStructFunctionWithVectors<list_entry_t,
internalID_t, operation::ListAnyValue>;
vectorOperationDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, internalID_t, operation::ListAnyValue>;
} break;
default: {
throw common::NotImplementedException("ListAnyValueVectorOperation::bindFunc");
Expand Down
Loading

0 comments on commit 896ff12

Please sign in to comment.