Skip to content

Commit

Permalink
Merge pull request #1479 from kuzudb/struct-planning
Browse files Browse the repository at this point in the history
Struct planning
  • Loading branch information
andyfengHKU committed Apr 22, 2023
2 parents 3e04497 + 4ce7aa9 commit f0d768e
Show file tree
Hide file tree
Showing 23 changed files with 274 additions and 177 deletions.
3 changes: 2 additions & 1 deletion src/binder/bind_expression/bind_boolean_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ std::shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(
function::VectorBooleanOperations::bindExecFunction(expressionType, childrenAfterCast);
auto selectFunc =
function::VectorBooleanOperations::bindSelectFunction(expressionType, childrenAfterCast);
auto bindData = std::make_unique<function::FunctionBindData>(DataType(BOOL));
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(functionName, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(functionName, expressionType, DataType(BOOL),
return make_shared<ScalarFunctionExpression>(functionName, expressionType, std::move(bindData),
std::move(childrenAfterCast), std::move(execFunc), std::move(selectFunc),
uniqueExpressionName);
}
Expand Down
8 changes: 5 additions & 3 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
childrenAfterCast.push_back(
implicitCastIfNecessary(children[i], function->parameterTypeIDs[i]));
}
auto bindData =
std::make_unique<function::FunctionBindData>(common::DataType(function->returnTypeID));
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(functionName, expressionType,
common::DataType(function->returnTypeID), std::move(childrenAfterCast), function->execFunc,
function->selectFunc, uniqueExpressionName);
return make_shared<ScalarFunctionExpression>(functionName, expressionType, std::move(bindData),
std::move(childrenAfterCast), function->execFunc, function->selectFunc,
uniqueExpressionName);
}

} // namespace binder
Expand Down
43 changes: 26 additions & 17 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,40 @@ std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(

std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const ParsedExpression& parsedExpression, const std::string& functionName) {
auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions();
std::vector<DataType> childrenTypes;
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
childrenTypes.push_back(child->dataType);
children.push_back(std::move(child));
}
return bindScalarFunctionExpression(children, functionName);
}

std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName) {
auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions();
std::vector<DataType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes);
if (builtInFunctions->canApplyStaticEvaluation(functionName, children)) {
return staticEvaluate(functionName, parsedExpression, children);
return staticEvaluate(functionName, children);
}
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
auto targetType =
function->isVarLength ? function->parameterTypeIDs[0] : function->parameterTypeIDs[i];
childrenAfterCast.push_back(implicitCastIfNecessary(children[i], targetType));
}
DataType returnType;
std::unique_ptr<function::FunctionBindData> bindData;
if (function->bindFunc) {
function->bindFunc(childrenAfterCast, function, returnType);
bindData = function->bindFunc(childrenAfterCast, function);
} else {
returnType = DataType(function->returnTypeID);
bindData = std::make_unique<function::FunctionBindData>(DataType(function->returnTypeID));
}
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(functionName, FUNCTION, returnType,
return make_shared<ScalarFunctionExpression>(functionName, FUNCTION, std::move(bindData),
std::move(childrenAfterCast), function->execFunc, function->selectFunc,
uniqueExpressionName);
}
Expand All @@ -87,18 +94,18 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
if (children.empty()) {
uniqueExpressionName = binder->getUniqueExpressionName(uniqueExpressionName);
}
DataType returnType;
std::unique_ptr<function::FunctionBindData> bindData;
if (function->bindFunc) {
function->bindFunc(children, function, returnType);
bindData = function->bindFunc(children, function);
} else {
returnType = DataType(function->returnTypeID);
bindData = std::make_unique<function::FunctionBindData>(DataType(function->returnTypeID));
}
return make_shared<AggregateFunctionExpression>(functionName, returnType, std::move(children),
function->aggregateFunction->clone(), uniqueExpressionName);
return make_shared<AggregateFunctionExpression>(functionName, std::move(bindData),
std::move(children), function->aggregateFunction->clone(), uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::staticEvaluate(const std::string& functionName,
const ParsedExpression& parsedExpression, const expression_vector& children) {
std::shared_ptr<Expression> ExpressionBinder::staticEvaluate(
const std::string& functionName, const expression_vector& children) {
assert(children[0]->expressionType == common::LITERAL);
auto strVal = ((LiteralExpression*)children[0].get())->getValue()->getValue<std::string>();
std::unique_ptr<Value> value;
Expand Down Expand Up @@ -187,8 +194,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindNodeLabelFunction(const Expres
populateLabelValues(nodeTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
auto execFunc = function::LabelVectorOperation::execFunction;
auto bindData = std::make_unique<function::FunctionBindData>(DataType(STRING));
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(LABEL_FUNC_NAME, FUNCTION, DataType(STRING),
return make_shared<ScalarFunctionExpression>(LABEL_FUNC_NAME, FUNCTION, std::move(bindData),
std::move(children), execFunc, nullptr, uniqueExpressionName);
}

Expand All @@ -207,8 +215,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindRelLabelFunction(const Express
populateLabelValues(relTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
auto execFunc = function::LabelVectorOperation::execFunction;
auto bindData = std::make_unique<function::FunctionBindData>(DataType(STRING));
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(LABEL_FUNC_NAME, FUNCTION, DataType(STRING),
return make_shared<ScalarFunctionExpression>(LABEL_FUNC_NAME, FUNCTION, std::move(bindData),
std::move(children), execFunc, nullptr, uniqueExpressionName);
}

Expand Down
6 changes: 3 additions & 3 deletions src/binder/bind_expression/bind_null_operator_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ std::shared_ptr<Expression> ExpressionBinder::bindNullOperatorExpression(
auto functionName = expressionTypeToString(expressionType);
auto execFunc = function::VectorNullOperations::bindExecFunction(expressionType, children);
auto selectFunc = function::VectorNullOperations::bindSelectFunction(expressionType, children);
auto bindData = std::make_unique<function::FunctionBindData>(common::DataType(common::BOOL));
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(functionName, children);
return make_shared<ScalarFunctionExpression>(functionName, expressionType,
common::DataType(common::BOOL), std::move(children), std::move(execFunc),
std::move(selectFunc), uniqueExpressionName);
return make_shared<ScalarFunctionExpression>(functionName, expressionType, std::move(bindData),
std::move(children), std::move(execFunc), std::move(selectFunc), uniqueExpressionName);
}

} // namespace binder
Expand Down
12 changes: 9 additions & 3 deletions src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "binder/expression/literal_expression.h"
#include "binder/expression/rel_expression.h"
#include "binder/expression_binder.h"
#include "parser/expression/parsed_property_expression.h"
Expand All @@ -20,12 +21,17 @@ std::shared_ptr<Expression> ExpressionBinder::bindPropertyExpression(
propertyName + " is reserved for system usage. External access is not allowed.");
}
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(*child, std::unordered_set<DataTypeID>{NODE, REL});
validateExpectedDataType(*child, std::unordered_set<DataTypeID>{NODE, REL, STRUCT});
if (NODE == child->dataType.typeID) {
return bindNodePropertyExpression(*child, propertyName);
} else {
assert(REL == child->dataType.typeID);
} else if (common::REL == child->dataType.typeID) {
return bindRelPropertyExpression(*child, propertyName);
} else {
assert(common::STRUCT == child->dataType.typeID);
auto stringValue = std::make_unique<Value>(propertyName);
return bindScalarFunctionExpression(
expression_vector{child, createLiteralExpression(std::move(stringValue))},
STRUCT_EXTRACT_FUNC_NAME);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCast(
if (VectorCastOperations::hasImplicitCast(expression->dataType, targetType)) {
auto functionName = VectorCastOperations::bindImplicitCastFuncName(targetType);
auto children = expression_vector{expression};
auto bindData = std::make_unique<FunctionBindData>(targetType);
auto uniqueName = ScalarFunctionExpression::getUniqueName(functionName, children);
return std::make_shared<ScalarFunctionExpression>(functionName, FUNCTION,
DataType{targetType.typeID}, std::move(children),
std::move(bindData), std::move(children),
VectorCastOperations::bindImplicitCastFunc(
expression->dataType.typeID, targetType.typeID),
nullptr /* selectFunc */, std::move(uniqueName));
Expand Down
14 changes: 13 additions & 1 deletion src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "expression_evaluator/function_evaluator.h"

#include "binder/expression/function_expression.h"
#include "function/struct/vector_struct_operations.h"

using namespace kuzu::common;
using namespace kuzu::processor;
Expand Down Expand Up @@ -59,7 +60,18 @@ std::unique_ptr<BaseExpressionEvaluator> FunctionExpressionEvaluator::clone() {

void FunctionExpressionEvaluator::resolveResultVector(
const ResultSet& resultSet, MemoryManager* memoryManager) {
resultVector = std::make_shared<ValueVector>(expression->dataType, memoryManager);
auto& functionExpression = (binder::ScalarFunctionExpression&)*expression;
if (functionExpression.getFunctionName() == STRUCT_PACK_FUNC_NAME) {
resultVector = std::make_shared<ValueVector>(expression->dataType, memoryManager);
for (auto& child : children) {
resultVector->addChildVector(child->resultVector);
}
} else if (functionExpression.getFunctionName() == STRUCT_EXTRACT_FUNC_NAME) {
auto& bindData = (function::StructExtractBindData&)*functionExpression.getBindData();
resultVector = children[0]->resultVector->getChildVector(bindData.childIdx);
} else {
resultVector = std::make_shared<ValueVector>(expression->dataType, memoryManager);
}
std::vector<BaseExpressionEvaluator*> inputEvaluators;
for (auto& child : children) {
inputEvaluators.push_back(child.get());
Expand Down
8 changes: 3 additions & 5 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,9 @@ void BuiltInVectorOperations::registerInternalIDOperation() {
}

void BuiltInVectorOperations::registerStructOperation() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.push_back(make_unique<VectorOperationDefinition>(STRUCT_PACK_FUNC_NAME,
std::vector<DataTypeID>{ANY}, STRUCT, VectorStructOperations::StructPack, nullptr,
StructPackVectorOperations::structPackBindFunc, true /* isVarLength */));
vectorOperations.insert({STRUCT_PACK_FUNC_NAME, std::move(definitions)});
vectorOperations.insert({STRUCT_PACK_FUNC_NAME, StructPackVectorOperations::getDefinitions()});
vectorOperations.insert(
{STRUCT_EXTRACT_FUNC_NAME, StructExtractVectorOperations::getDefinitions()});
}

} // namespace function
Expand Down
Loading

0 comments on commit f0d768e

Please sign in to comment.