Skip to content

Commit

Permalink
Merge pull request #1210 from kuzudb/expression-binder-split
Browse files Browse the repository at this point in the history
Split expression binding to multiple cpp files
  • Loading branch information
andyfengHKU committed Jan 29, 2023
2 parents f2b3dba + b7bea76 commit 0a8d48f
Show file tree
Hide file tree
Showing 13 changed files with 519 additions and 417 deletions.
1 change: 1 addition & 0 deletions src/binder/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(bind)
add_subdirectory(bind_expression)
add_subdirectory(expression)
add_subdirectory(query)

Expand Down
17 changes: 17 additions & 0 deletions src/binder/bind_expression/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_library(
kuzu_binder_bind_expression
OBJECT
bind_boolean_expression.cpp
bind_case_expression.cpp
bind_comparison_expression.cpp
bind_function_expression.cpp
bind_literal_expression.cpp
bind_null_operator_expression.cpp
bind_parameter_expression.cpp
bind_property_expression.cpp
bind_subquery_expression.cpp
bind_variable_expression.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_binder_bind_expression>
PARENT_SCOPE)
35 changes: 35 additions & 0 deletions src/binder/bind_expression/bind_boolean_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "binder/expression/function_expression.h"
#include "binder/expression_binder.h"
#include "function/boolean/vector_boolean_operations.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(
const ParsedExpression& parsedExpression) {
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
children.push_back(bindExpression(*parsedExpression.getChild(i)));
}
return bindBooleanExpression(parsedExpression.getExpressionType(), children);
}

shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(
ExpressionType expressionType, const expression_vector& children) {
expression_vector childrenAfterCast;
for (auto& child : children) {
childrenAfterCast.push_back(implicitCastIfNecessary(child, BOOL));
}
auto functionName = expressionTypeToString(expressionType);
auto execFunc = VectorBooleanOperations::bindExecFunction(expressionType, childrenAfterCast);
auto selectFunc =
VectorBooleanOperations::bindSelectFunction(expressionType, childrenAfterCast);
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(functionName, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(expressionType, DataType(BOOL),
std::move(childrenAfterCast), std::move(execFunc), std::move(selectFunc),
uniqueExpressionName);
}

} // namespace binder
} // namespace kuzu
53 changes: 53 additions & 0 deletions src/binder/bind_expression/bind_case_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "binder/binder.h"
#include "binder/expression/case_expression.h"
#include "binder/expression_binder.h"
#include "parser/expression/parsed_case_expression.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindCaseExpression(
const ParsedExpression& parsedExpression) {
auto& parsedCaseExpression = (ParsedCaseExpression&)parsedExpression;
auto anchorCaseAlternative = parsedCaseExpression.getCaseAlternative(0);
auto outDataType = bindExpression(*anchorCaseAlternative->thenExpression)->dataType;
auto name = binder->getUniqueExpressionName(parsedExpression.getRawName());
// bind ELSE ...
shared_ptr<Expression> elseExpression;
if (parsedCaseExpression.hasElseExpression()) {
elseExpression = bindExpression(*parsedCaseExpression.getElseExpression());
} else {
elseExpression = bindNullLiteralExpression();
}
elseExpression = implicitCastIfNecessary(elseExpression, outDataType);
auto boundCaseExpression =
make_shared<CaseExpression>(outDataType, std::move(elseExpression), name);
// bind WHEN ... THEN ...
if (parsedCaseExpression.hasCaseExpression()) {
auto boundCase = bindExpression(*parsedCaseExpression.getCaseExpression());
for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) {
auto caseAlternative = parsedCaseExpression.getCaseAlternative(i);
auto boundWhen = bindExpression(*caseAlternative->whenExpression);
boundWhen = implicitCastIfNecessary(boundWhen, boundCase->dataType);
// rewrite "CASE a.age WHEN 1" as "CASE WHEN a.age = 1"
boundWhen = bindComparisonExpression(
EQUALS, vector<shared_ptr<Expression>>{boundCase, boundWhen});
auto boundThen = bindExpression(*caseAlternative->thenExpression);
boundThen = implicitCastIfNecessary(boundThen, outDataType);
boundCaseExpression->addCaseAlternative(boundWhen, boundThen);
}
} else {
for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) {
auto caseAlternative = parsedCaseExpression.getCaseAlternative(i);
auto boundWhen = bindExpression(*caseAlternative->whenExpression);
boundWhen = implicitCastIfNecessary(boundWhen, BOOL);
auto boundThen = bindExpression(*caseAlternative->thenExpression);
boundThen = implicitCastIfNecessary(boundThen, outDataType);
boundCaseExpression->addCaseAlternative(boundWhen, boundThen);
}
}
return boundCaseExpression;
}

} // namespace binder
} // namespace kuzu
40 changes: 40 additions & 0 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "binder/binder.h"
#include "binder/expression/function_expression.h"
#include "binder/expression_binder.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
const ParsedExpression& parsedExpression) {
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
children.push_back(std::move(child));
}
return bindComparisonExpression(parsedExpression.getExpressionType(), std::move(children));
}

shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
ExpressionType expressionType, const expression_vector& children) {
auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions();
auto functionName = expressionTypeToString(expressionType);
vector<DataType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes);
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
childrenAfterCast.push_back(
implicitCastIfNecessary(children[i], function->parameterTypeIDs[i]));
}
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(expressionType, DataType(function->returnTypeID),
std::move(childrenAfterCast), function->execFunc, function->selectFunc,
uniqueExpressionName);
}

} // namespace binder
} // namespace kuzu
175 changes: 175 additions & 0 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#include "binder/binder.h"
#include "binder/expression/function_expression.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression_binder.h"
#include "function/node/vector_node_operations.h"
#include "parser/expression/parsed_function_expression.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(
const ParsedExpression& parsedExpression) {
auto& parsedFunctionExpression = (ParsedFunctionExpression&)parsedExpression;
auto functionName = parsedFunctionExpression.getFunctionName();
StringUtils::toUpper(functionName);
// check for special function binding
if (functionName == ID_FUNC_NAME) {
return bindInternalIDExpression(parsedExpression);
} else if (functionName == LABEL_FUNC_NAME) {
return bindLabelFunction(parsedExpression);
}
auto functionType = binder->catalog.getFunctionType(functionName);
if (functionType == FUNCTION) {
return bindScalarFunctionExpression(parsedExpression, functionName);
} else {
assert(functionType == AGGREGATE_FUNCTION);
return bindAggregateFunctionExpression(
parsedExpression, functionName, parsedFunctionExpression.getIsDistinct());
}
}

shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const ParsedExpression& parsedExpression, const string& functionName) {
auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions();
vector<DataType> childrenTypes;
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
childrenTypes.push_back(child->dataType);
children.push_back(std::move(child));
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes);
if (builtInFunctions->canApplyStaticEvaluation(functionName, children)) {
return staticEvaluate(functionName, parsedExpression, children);
}
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
auto targetType =
function->isVarLength ? function->parameterTypeIDs[0] : function->parameterTypeIDs[i];
childrenAfterCast.push_back(implicitCastIfNecessary(children[i], targetType));
}
DataType returnType;
if (function->bindFunc) {
function->bindFunc(childrenTypes, function, returnType);
} else {
returnType = DataType(function->returnTypeID);
}
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(FUNCTION, returnType, std::move(childrenAfterCast),
function->execFunc, function->selectFunc, uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
const ParsedExpression& parsedExpression, const string& functionName, bool isDistinct) {
auto builtInFunctions = binder->catalog.getBuiltInAggregateFunction();
vector<DataType> childrenTypes;
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
// rewrite aggregate on node or rel as aggregate on their internal IDs.
// e.g. COUNT(a) -> COUNT(a._id)
if (child->dataType.typeID == NODE || child->dataType.typeID == REL) {
child = bindInternalIDExpression(*child);
}
childrenTypes.push_back(child->dataType);
children.push_back(std::move(child));
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes, isDistinct);
auto uniqueExpressionName =
AggregateFunctionExpression::getUniqueName(function->name, children, function->isDistinct);
if (children.empty()) {
uniqueExpressionName = binder->getUniqueExpressionName(uniqueExpressionName);
}
return make_shared<AggregateFunctionExpression>(DataType(function->returnTypeID),
std::move(children), function->aggregateFunction->clone(), uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::staticEvaluate(const string& functionName,
const ParsedExpression& parsedExpression, const expression_vector& children) {
assert(children[0]->expressionType == common::LITERAL);
auto strVal = ((LiteralExpression*)children[0].get())->getValue()->getValue<string>();
if (functionName == CAST_TO_DATE_FUNC_NAME) {
return make_shared<LiteralExpression>(
make_unique<Value>(Date::FromCString(strVal.c_str(), strVal.length())));
} else if (functionName == CAST_TO_TIMESTAMP_FUNC_NAME) {
return make_shared<LiteralExpression>(
make_unique<Value>(Timestamp::FromCString(strVal.c_str(), strVal.length())));
} else {
assert(functionName == CAST_TO_INTERVAL_FUNC_NAME);
return make_shared<LiteralExpression>(
make_unique<Value>(Interval::FromCString(strVal.c_str(), strVal.length())));
}
}

shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
const ParsedExpression& parsedExpression) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(*child, unordered_set<DataTypeID>{NODE, REL});
return bindInternalIDExpression(*child);
}

shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(const Expression& expression) {
if (expression.dataType.typeID == NODE) {
auto& node = (NodeExpression&)expression;
return node.getInternalIDProperty();
} else {
assert(expression.dataType.typeID == REL);
return bindRelPropertyExpression(expression, INTERNAL_ID_SUFFIX);
}
}

unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
const Expression& expression) {
auto& node = (NodeExpression&)expression;
unordered_map<table_id_t, property_id_t> propertyIDPerTable;
for (auto tableID : node.getTableIDs()) {
propertyIDPerTable.insert({tableID, INVALID_PROPERTY_ID});
}
auto result = make_unique<PropertyExpression>(
DataType(NODE_ID), INTERNAL_ID_SUFFIX, node, std::move(propertyIDPerTable));
return result;
}

shared_ptr<Expression> ExpressionBinder::bindLabelFunction(
const ParsedExpression& parsedExpression) {
// bind child node
auto child = bindExpression(*parsedExpression.getChild(0));
assert(child->dataType.typeID == common::NODE);
return bindNodeLabelFunction(*child);
}

shared_ptr<Expression> ExpressionBinder::bindNodeLabelFunction(const Expression& expression) {
auto catalogContent = binder->catalog.getReadOnlyVersion();
auto& node = (NodeExpression&)expression;
if (!node.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(node.getSingleTableID());
return make_shared<LiteralExpression>(make_unique<Value>(labelName));
}
// bind string node labels as list literal
auto nodeTableIDs = catalogContent->getNodeTableIDs();
table_id_t maxNodeTableID = *std::max_element(nodeTableIDs.begin(), nodeTableIDs.end());
vector<unique_ptr<Value>> nodeLabels;
nodeLabels.resize(maxNodeTableID + 1);
for (auto i = 0; i < nodeLabels.size(); ++i) {
if (catalogContent->containNodeTable(i)) {
nodeLabels[i] = make_unique<Value>(catalogContent->getTableName(i));
} else {
// TODO(Xiyang/Guodong): change to null literal once we support null in LIST type.
nodeLabels[i] = make_unique<Value>(string(""));
}
}
auto literalDataType = DataType(LIST, make_unique<DataType>(STRING));
expression_vector children;
children.push_back(node.getInternalIDProperty());
children.push_back(
make_shared<LiteralExpression>(make_unique<Value>(literalDataType, std::move(nodeLabels))));
auto execFunc = NodeLabelVectorOperation::execFunction;
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(
FUNCTION, DataType(STRING), std::move(children), execFunc, nullptr, uniqueExpressionName);
}

} // namespace binder
} // namespace kuzu
25 changes: 25 additions & 0 deletions src/binder/bind_expression/bind_literal_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "binder/binder.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression_binder.h"
#include "parser/expression/parsed_literal_expression.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindLiteralExpression(
const ParsedExpression& parsedExpression) {
auto& literalExpression = (ParsedLiteralExpression&)parsedExpression;
auto value = literalExpression.getValue();
if (value->isNull()) {
return bindNullLiteralExpression();
}
return make_shared<LiteralExpression>(value->copy());
}

shared_ptr<Expression> ExpressionBinder::bindNullLiteralExpression() {
return make_shared<LiteralExpression>(
make_unique<Value>(Value::createNullValue()), binder->getUniqueExpressionName("NULL"));
}

} // namespace binder
} // namespace kuzu
24 changes: 24 additions & 0 deletions src/binder/bind_expression/bind_null_operator_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "binder/expression/function_expression.h"
#include "binder/expression_binder.h"
#include "function/null/vector_null_operations.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindNullOperatorExpression(
const ParsedExpression& parsedExpression) {
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
children.push_back(bindExpression(*parsedExpression.getChild(i)));
}
auto expressionType = parsedExpression.getExpressionType();
auto functionName = expressionTypeToString(expressionType);
auto execFunc = VectorNullOperations::bindExecFunction(expressionType, children);
auto selectFunc = VectorNullOperations::bindSelectFunction(expressionType, children);
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(functionName, children);
return make_shared<ScalarFunctionExpression>(expressionType, DataType(BOOL),
std::move(children), std::move(execFunc), std::move(selectFunc), uniqueExpressionName);
}

} // namespace binder
} // namespace kuzu
Loading

0 comments on commit 0a8d48f

Please sign in to comment.