Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split expression binding to multiple cpp files #1210

Merged
merged 1 commit into from
Jan 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/binder/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(bind)
add_subdirectory(bind_expression)
add_subdirectory(expression)
add_subdirectory(query)

Expand Down
17 changes: 17 additions & 0 deletions src/binder/bind_expression/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_library(
kuzu_binder_bind_expression
OBJECT
bind_boolean_expression.cpp
bind_case_expression.cpp
bind_comparison_expression.cpp
bind_function_expression.cpp
bind_literal_expression.cpp
bind_null_operator_expression.cpp
bind_parameter_expression.cpp
bind_property_expression.cpp
bind_subquery_expression.cpp
bind_variable_expression.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_binder_bind_expression>
PARENT_SCOPE)
35 changes: 35 additions & 0 deletions src/binder/bind_expression/bind_boolean_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "binder/expression/function_expression.h"
#include "binder/expression_binder.h"
#include "function/boolean/vector_boolean_operations.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(
const ParsedExpression& parsedExpression) {
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
children.push_back(bindExpression(*parsedExpression.getChild(i)));
}
return bindBooleanExpression(parsedExpression.getExpressionType(), children);
}

shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(
ExpressionType expressionType, const expression_vector& children) {
expression_vector childrenAfterCast;
for (auto& child : children) {
childrenAfterCast.push_back(implicitCastIfNecessary(child, BOOL));
}
auto functionName = expressionTypeToString(expressionType);
auto execFunc = VectorBooleanOperations::bindExecFunction(expressionType, childrenAfterCast);
auto selectFunc =
VectorBooleanOperations::bindSelectFunction(expressionType, childrenAfterCast);
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(functionName, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(expressionType, DataType(BOOL),
std::move(childrenAfterCast), std::move(execFunc), std::move(selectFunc),
uniqueExpressionName);
}

} // namespace binder
} // namespace kuzu
53 changes: 53 additions & 0 deletions src/binder/bind_expression/bind_case_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "binder/binder.h"
#include "binder/expression/case_expression.h"
#include "binder/expression_binder.h"
#include "parser/expression/parsed_case_expression.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindCaseExpression(
const ParsedExpression& parsedExpression) {
auto& parsedCaseExpression = (ParsedCaseExpression&)parsedExpression;
auto anchorCaseAlternative = parsedCaseExpression.getCaseAlternative(0);
auto outDataType = bindExpression(*anchorCaseAlternative->thenExpression)->dataType;
auto name = binder->getUniqueExpressionName(parsedExpression.getRawName());
// bind ELSE ...
shared_ptr<Expression> elseExpression;
if (parsedCaseExpression.hasElseExpression()) {
elseExpression = bindExpression(*parsedCaseExpression.getElseExpression());
} else {
elseExpression = bindNullLiteralExpression();
}
elseExpression = implicitCastIfNecessary(elseExpression, outDataType);
auto boundCaseExpression =
make_shared<CaseExpression>(outDataType, std::move(elseExpression), name);
// bind WHEN ... THEN ...
if (parsedCaseExpression.hasCaseExpression()) {
auto boundCase = bindExpression(*parsedCaseExpression.getCaseExpression());
for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) {
auto caseAlternative = parsedCaseExpression.getCaseAlternative(i);
auto boundWhen = bindExpression(*caseAlternative->whenExpression);
boundWhen = implicitCastIfNecessary(boundWhen, boundCase->dataType);
// rewrite "CASE a.age WHEN 1" as "CASE WHEN a.age = 1"
boundWhen = bindComparisonExpression(
EQUALS, vector<shared_ptr<Expression>>{boundCase, boundWhen});
auto boundThen = bindExpression(*caseAlternative->thenExpression);
boundThen = implicitCastIfNecessary(boundThen, outDataType);
boundCaseExpression->addCaseAlternative(boundWhen, boundThen);
}
} else {
for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) {
auto caseAlternative = parsedCaseExpression.getCaseAlternative(i);
auto boundWhen = bindExpression(*caseAlternative->whenExpression);
boundWhen = implicitCastIfNecessary(boundWhen, BOOL);
auto boundThen = bindExpression(*caseAlternative->thenExpression);
boundThen = implicitCastIfNecessary(boundThen, outDataType);
boundCaseExpression->addCaseAlternative(boundWhen, boundThen);
}
}
return boundCaseExpression;
}

} // namespace binder
} // namespace kuzu
40 changes: 40 additions & 0 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "binder/binder.h"
#include "binder/expression/function_expression.h"
#include "binder/expression_binder.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
const ParsedExpression& parsedExpression) {
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
children.push_back(std::move(child));
}
return bindComparisonExpression(parsedExpression.getExpressionType(), std::move(children));
}

shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
ExpressionType expressionType, const expression_vector& children) {
auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions();
auto functionName = expressionTypeToString(expressionType);
vector<DataType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes);
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
childrenAfterCast.push_back(
implicitCastIfNecessary(children[i], function->parameterTypeIDs[i]));
}
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(expressionType, DataType(function->returnTypeID),
std::move(childrenAfterCast), function->execFunc, function->selectFunc,
uniqueExpressionName);
}

} // namespace binder
} // namespace kuzu
175 changes: 175 additions & 0 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#include "binder/binder.h"
#include "binder/expression/function_expression.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression_binder.h"
#include "function/node/vector_node_operations.h"
#include "parser/expression/parsed_function_expression.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(
const ParsedExpression& parsedExpression) {
auto& parsedFunctionExpression = (ParsedFunctionExpression&)parsedExpression;
auto functionName = parsedFunctionExpression.getFunctionName();
StringUtils::toUpper(functionName);
// check for special function binding
if (functionName == ID_FUNC_NAME) {
return bindInternalIDExpression(parsedExpression);
} else if (functionName == LABEL_FUNC_NAME) {
return bindLabelFunction(parsedExpression);
}
auto functionType = binder->catalog.getFunctionType(functionName);
if (functionType == FUNCTION) {
return bindScalarFunctionExpression(parsedExpression, functionName);
} else {
assert(functionType == AGGREGATE_FUNCTION);
return bindAggregateFunctionExpression(
parsedExpression, functionName, parsedFunctionExpression.getIsDistinct());
}
}

shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const ParsedExpression& parsedExpression, const string& functionName) {
auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions();
vector<DataType> childrenTypes;
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
childrenTypes.push_back(child->dataType);
children.push_back(std::move(child));
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes);
if (builtInFunctions->canApplyStaticEvaluation(functionName, children)) {
return staticEvaluate(functionName, parsedExpression, children);
}
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
auto targetType =
function->isVarLength ? function->parameterTypeIDs[0] : function->parameterTypeIDs[i];
childrenAfterCast.push_back(implicitCastIfNecessary(children[i], targetType));
}
DataType returnType;
if (function->bindFunc) {
function->bindFunc(childrenTypes, function, returnType);
} else {
returnType = DataType(function->returnTypeID);
}
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(FUNCTION, returnType, std::move(childrenAfterCast),
function->execFunc, function->selectFunc, uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
const ParsedExpression& parsedExpression, const string& functionName, bool isDistinct) {
auto builtInFunctions = binder->catalog.getBuiltInAggregateFunction();
vector<DataType> childrenTypes;
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
// rewrite aggregate on node or rel as aggregate on their internal IDs.
// e.g. COUNT(a) -> COUNT(a._id)
if (child->dataType.typeID == NODE || child->dataType.typeID == REL) {
child = bindInternalIDExpression(*child);
}
childrenTypes.push_back(child->dataType);
children.push_back(std::move(child));
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes, isDistinct);
auto uniqueExpressionName =
AggregateFunctionExpression::getUniqueName(function->name, children, function->isDistinct);
if (children.empty()) {
uniqueExpressionName = binder->getUniqueExpressionName(uniqueExpressionName);
}
return make_shared<AggregateFunctionExpression>(DataType(function->returnTypeID),
std::move(children), function->aggregateFunction->clone(), uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::staticEvaluate(const string& functionName,
const ParsedExpression& parsedExpression, const expression_vector& children) {
assert(children[0]->expressionType == common::LITERAL);
auto strVal = ((LiteralExpression*)children[0].get())->getValue()->getValue<string>();
if (functionName == CAST_TO_DATE_FUNC_NAME) {
return make_shared<LiteralExpression>(
make_unique<Value>(Date::FromCString(strVal.c_str(), strVal.length())));
} else if (functionName == CAST_TO_TIMESTAMP_FUNC_NAME) {
return make_shared<LiteralExpression>(
make_unique<Value>(Timestamp::FromCString(strVal.c_str(), strVal.length())));
} else {
assert(functionName == CAST_TO_INTERVAL_FUNC_NAME);
return make_shared<LiteralExpression>(
make_unique<Value>(Interval::FromCString(strVal.c_str(), strVal.length())));
}
}

shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
const ParsedExpression& parsedExpression) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(*child, unordered_set<DataTypeID>{NODE, REL});
return bindInternalIDExpression(*child);
}

shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(const Expression& expression) {
if (expression.dataType.typeID == NODE) {
auto& node = (NodeExpression&)expression;
return node.getInternalIDProperty();
} else {
assert(expression.dataType.typeID == REL);
return bindRelPropertyExpression(expression, INTERNAL_ID_SUFFIX);
}
}

unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
const Expression& expression) {
auto& node = (NodeExpression&)expression;
unordered_map<table_id_t, property_id_t> propertyIDPerTable;
for (auto tableID : node.getTableIDs()) {
propertyIDPerTable.insert({tableID, INVALID_PROPERTY_ID});
}
auto result = make_unique<PropertyExpression>(
DataType(NODE_ID), INTERNAL_ID_SUFFIX, node, std::move(propertyIDPerTable));
return result;
}

shared_ptr<Expression> ExpressionBinder::bindLabelFunction(
const ParsedExpression& parsedExpression) {
// bind child node
auto child = bindExpression(*parsedExpression.getChild(0));
assert(child->dataType.typeID == common::NODE);
return bindNodeLabelFunction(*child);
}

shared_ptr<Expression> ExpressionBinder::bindNodeLabelFunction(const Expression& expression) {
auto catalogContent = binder->catalog.getReadOnlyVersion();
auto& node = (NodeExpression&)expression;
if (!node.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(node.getSingleTableID());
return make_shared<LiteralExpression>(make_unique<Value>(labelName));
}
// bind string node labels as list literal
auto nodeTableIDs = catalogContent->getNodeTableIDs();
table_id_t maxNodeTableID = *std::max_element(nodeTableIDs.begin(), nodeTableIDs.end());
vector<unique_ptr<Value>> nodeLabels;
nodeLabels.resize(maxNodeTableID + 1);
for (auto i = 0; i < nodeLabels.size(); ++i) {
if (catalogContent->containNodeTable(i)) {
nodeLabels[i] = make_unique<Value>(catalogContent->getTableName(i));
} else {
// TODO(Xiyang/Guodong): change to null literal once we support null in LIST type.
nodeLabels[i] = make_unique<Value>(string(""));
}
}
auto literalDataType = DataType(LIST, make_unique<DataType>(STRING));
expression_vector children;
children.push_back(node.getInternalIDProperty());
children.push_back(
make_shared<LiteralExpression>(make_unique<Value>(literalDataType, std::move(nodeLabels))));
auto execFunc = NodeLabelVectorOperation::execFunction;
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(
FUNCTION, DataType(STRING), std::move(children), execFunc, nullptr, uniqueExpressionName);
}

} // namespace binder
} // namespace kuzu
25 changes: 25 additions & 0 deletions src/binder/bind_expression/bind_literal_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "binder/binder.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression_binder.h"
#include "parser/expression/parsed_literal_expression.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindLiteralExpression(
const ParsedExpression& parsedExpression) {
auto& literalExpression = (ParsedLiteralExpression&)parsedExpression;
auto value = literalExpression.getValue();
if (value->isNull()) {
return bindNullLiteralExpression();
}
return make_shared<LiteralExpression>(value->copy());
}

shared_ptr<Expression> ExpressionBinder::bindNullLiteralExpression() {
return make_shared<LiteralExpression>(
make_unique<Value>(Value::createNullValue()), binder->getUniqueExpressionName("NULL"));
}

} // namespace binder
} // namespace kuzu
24 changes: 24 additions & 0 deletions src/binder/bind_expression/bind_null_operator_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "binder/expression/function_expression.h"
#include "binder/expression_binder.h"
#include "function/null/vector_null_operations.h"

namespace kuzu {
namespace binder {

shared_ptr<Expression> ExpressionBinder::bindNullOperatorExpression(
const ParsedExpression& parsedExpression) {
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
children.push_back(bindExpression(*parsedExpression.getChild(i)));
}
auto expressionType = parsedExpression.getExpressionType();
auto functionName = expressionTypeToString(expressionType);
auto execFunc = VectorNullOperations::bindExecFunction(expressionType, children);
auto selectFunc = VectorNullOperations::bindSelectFunction(expressionType, children);
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(functionName, children);
return make_shared<ScalarFunctionExpression>(expressionType, DataType(BOOL),
std::move(children), std::move(execFunc), std::move(selectFunc), uniqueExpressionName);
}

} // namespace binder
} // namespace kuzu
Loading