Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Case expression #1125

Merged
merged 1 commit into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ oC_PropertyOrLabelsExpression
oC_Atom
: oC_Literal
| oC_Parameter
| oC_CaseExpression
| oC_ParenthesizedExpression
| oC_FunctionInvocation
| oC_ExistentialSubquery
Expand Down Expand Up @@ -428,6 +429,22 @@ EXISTS : ( 'E' | 'e' ) ( 'X' | 'x' ) ( 'I' | 'i' ) ( 'S' | 's' ) ( 'T' | 't' ) (
oC_PropertyLookup
: '.' SP? ( oC_PropertyKeyName ) ;

oC_CaseExpression
: ( ( CASE ( SP? oC_CaseAlternative )+ ) | ( CASE SP? oC_Expression ( SP? oC_CaseAlternative )+ ) ) ( SP? ELSE SP? oC_Expression )? SP? END ;

CASE : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'S' | 's' ) ( 'E' | 'e' ) ;

ELSE : ( 'E' | 'e' ) ( 'L' | 'l' ) ( 'S' | 's' ) ( 'E' | 'e' ) ;

END : ( 'E' | 'e' ) ( 'N' | 'n' ) ( 'D' | 'd' ) ;

oC_CaseAlternative
: WHEN SP? oC_Expression SP? THEN SP? oC_Expression ;

WHEN : ( 'W' | 'w' ) ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'N' | 'n' ) ;

THEN : ( 'T' | 't' ) ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'N' | 'n' ) ;

oC_Variable
: oC_SymbolicName ;

Expand Down
3 changes: 1 addition & 2 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ unique_ptr<BoundCreateRel> Binder::bindCreateRel(
} else {
auto propertyExpression =
expressionBinder.bindRelPropertyExpression(rel, property.name);
shared_ptr<Expression> nullExpression =
LiteralExpression::createNullLiteralExpression(getUniqueExpressionName("NULL"));
auto nullExpression = expressionBinder.bindNullLiteralExpression();
nullExpression = ExpressionBinder::implicitCastIfNecessary(
nullExpression, propertyExpression->dataType);
setItems.emplace_back(std::move(propertyExpression), std::move(nullExpression));
Expand Down
1 change: 1 addition & 0 deletions src/binder/expression/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(
kuzu_binder_expression
OBJECT
case_expression.cpp
existential_subquery_expression.cpp
expression.cpp)

Expand Down
17 changes: 17 additions & 0 deletions src/binder/expression/case_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "binder/expression/case_expression.h"

namespace kuzu {
namespace binder {

expression_vector CaseExpression::getChildren() const {
expression_vector result;
for (auto& caseAlternative : caseAlternatives) {
result.push_back(caseAlternative->whenExpression);
result.push_back(caseAlternative->thenExpression);
}
result.push_back(elseExpression);
return result;
}

} // namespace binder
} // namespace kuzu
10 changes: 0 additions & 10 deletions src/binder/expression/existential_subquery_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
namespace kuzu {
namespace binder {

unordered_set<string> ExistentialSubqueryExpression::getDependentVariableNames() {
unordered_set<string> result;
for (auto& expression : getChildren()) {
for (auto& variableName : expression->getDependentVariableNames()) {
result.insert(variableName);
}
}
return result;
}

// The children of subquery expressions is defined as all expressions in the subquery, i.e.
// expressions from predicates and return clause. Plus nodeID expressions from query graph.
expression_vector ExistentialSubqueryExpression::getChildren() const {
Expand Down
2 changes: 1 addition & 1 deletion src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ unordered_set<string> Expression::getDependentVariableNames() {
result.insert(getUniqueName());
return result;
}
for (auto& child : children) {
for (auto& child : getChildren()) {
for (auto& variableName : child->getDependentVariableNames()) {
result.insert(variableName);
}
Expand Down
55 changes: 53 additions & 2 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "binder/expression_binder.h"

#include "binder/binder.h"
#include "binder/expression/case_expression.h"
#include "binder/expression/existential_subquery_expression.h"
#include "binder/expression/function_expression.h"
#include "binder/expression/literal_expression.h"
Expand All @@ -9,6 +10,7 @@
#include "common/type_utils.h"
#include "function/boolean/vector_boolean_operations.h"
#include "function/null/vector_null_operations.h"
#include "parser/expression/parsed_case_expression.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/expression/parsed_literal_expression.h"
#include "parser/expression/parsed_parameter_expression.h"
Expand Down Expand Up @@ -42,6 +44,8 @@ shared_ptr<Expression> ExpressionBinder::bindExpression(const ParsedExpression&
expression = bindVariableExpression(parsedExpression);
} else if (EXISTENTIAL_SUBQUERY == expressionType) {
expression = bindExistentialSubqueryExpression(parsedExpression);
} else if (CASE_ELSE == expressionType) {
expression = bindCaseExpression(parsedExpression);
} else {
throw NotImplementedException(
"bindExpression(" + expressionTypeToString(expressionType) + ").");
Expand Down Expand Up @@ -376,12 +380,16 @@ shared_ptr<Expression> ExpressionBinder::bindLiteralExpression(
auto& literalExpression = (ParsedLiteralExpression&)parsedExpression;
auto literal = literalExpression.getLiteral();
if (literal->isNull()) {
return LiteralExpression::createNullLiteralExpression(
binder->getUniqueExpressionName("NULL"));
return bindNullLiteralExpression();
}
return make_shared<LiteralExpression>(literal->dataType, make_unique<Literal>(*literal));
}

shared_ptr<Expression> ExpressionBinder::bindNullLiteralExpression() {
return make_shared<LiteralExpression>(
DataType(ANY), make_unique<Literal>(), binder->getUniqueExpressionName("NULL"));
}

shared_ptr<Expression> ExpressionBinder::bindVariableExpression(
const ParsedExpression& parsedExpression) {
auto& variableExpression = (ParsedVariableExpression&)parsedExpression;
Expand All @@ -408,6 +416,49 @@ shared_ptr<Expression> ExpressionBinder::bindExistentialSubqueryExpression(
return boundSubqueryExpression;
}

shared_ptr<Expression> ExpressionBinder::bindCaseExpression(
const ParsedExpression& parsedExpression) {
auto& parsedCaseExpression = (ParsedCaseExpression&)parsedExpression;
auto anchorCaseAlternative = parsedCaseExpression.getCaseAlternative(0);
auto outDataType = bindExpression(*anchorCaseAlternative->thenExpression)->dataType;
auto name = binder->getUniqueExpressionName(parsedExpression.getRawName());
// bind ELSE ...
shared_ptr<Expression> elseExpression;
if (parsedCaseExpression.hasElseExpression()) {
elseExpression = bindExpression(*parsedCaseExpression.getElseExpression());
} else {
elseExpression = bindNullLiteralExpression();
}
elseExpression = implicitCastIfNecessary(elseExpression, outDataType);
auto boundCaseExpression =
make_shared<CaseExpression>(outDataType, std::move(elseExpression), name);
// bind WHEN ... THEN ...
if (parsedCaseExpression.hasCaseExpression()) {
auto boundCase = bindExpression(*parsedCaseExpression.getCaseExpression());
for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) {
auto caseAlternative = parsedCaseExpression.getCaseAlternative(i);
auto boundWhen = bindExpression(*caseAlternative->whenExpression);
boundWhen = implicitCastIfNecessary(boundWhen, boundCase->dataType);
// rewrite "CASE a.age WHEN 1" as "CASE WHEN a.age = 1"
boundWhen = bindComparisonExpression(
EQUALS, vector<shared_ptr<Expression>>{boundCase, boundWhen});
auto boundThen = bindExpression(*caseAlternative->thenExpression);
boundThen = implicitCastIfNecessary(boundThen, outDataType);
boundCaseExpression->addCaseAlternative(boundWhen, boundThen);
}
} else {
for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) {
auto caseAlternative = parsedCaseExpression.getCaseAlternative(i);
auto boundWhen = bindExpression(*caseAlternative->whenExpression);
boundWhen = implicitCastIfNecessary(boundWhen, BOOL);
auto boundThen = bindExpression(*caseAlternative->thenExpression);
boundThen = implicitCastIfNecessary(boundThen, outDataType);
boundCaseExpression->addCaseAlternative(boundWhen, boundThen);
}
}
return boundCaseExpression;
}

shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
const shared_ptr<Expression>& expression, DataType targetType) {
if (targetType.typeID == ANY || expression->dataType == targetType) {
Expand Down
5 changes: 2 additions & 3 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ ValueVector::ValueVector(DataType dataType, MemoryManager* memoryManager)

void ValueVector::addString(uint32_t pos, char* value, uint64_t len) const {
assert(dataType.typeID == STRING);
auto vectorData = (ku_string_t*)valueBuffer.get();
auto& result = vectorData[pos];
InMemOverflowBufferUtils::copyString(value, len, result, *inMemOverflowBuffer);
auto& entry = ((ku_string_t*)getData())[pos];
InMemOverflowBufferUtils::copyString(value, len, entry, *inMemOverflowBuffer);
}

bool NodeIDVector::discardNull(ValueVector& vector) {
Expand Down
1 change: 1 addition & 0 deletions src/expression_evaluator/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(kuzu_expression_evaluator
OBJECT
base_evaluator.cpp
case_evaluator.cpp
function_evaluator.cpp
literal_evaluator.cpp
reference_evaluator.cpp)
Expand Down
25 changes: 15 additions & 10 deletions src/expression_evaluator/base_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,25 @@
namespace kuzu {
namespace evaluator {

BaseExpressionEvaluator::BaseExpressionEvaluator(unique_ptr<BaseExpressionEvaluator> child) {
children.push_back(move(child));
}

BaseExpressionEvaluator::BaseExpressionEvaluator(
unique_ptr<BaseExpressionEvaluator> left, unique_ptr<BaseExpressionEvaluator> right) {
children.push_back(move(left));
children.push_back(move(right));
}

void BaseExpressionEvaluator::init(const ResultSet& resultSet, MemoryManager* memoryManager) {
for (auto& child : children) {
child->init(resultSet, memoryManager);
}
resolveResultVector(resultSet, memoryManager);
}

void BaseExpressionEvaluator::resolveResultStateFromChildren(
const vector<BaseExpressionEvaluator*>& inputEvaluators) {
for (auto& input : inputEvaluators) {
if (!input->isResultFlat()) {
isResultFlat_ = false;
resultVector->state = input->resultVector->state;
return;
}
}
// All children are flat.
isResultFlat_ = true;
resultVector->state = DataChunkState::getSingleValueDataChunkState();
}

} // namespace evaluator
Expand Down
Loading