Skip to content

Commit

Permalink
Merge pull request #1125 from kuzudb/case-expression
Browse files Browse the repository at this point in the history
Case expression
  • Loading branch information
andyfengHKU committed Dec 23, 2022
2 parents c49cb1c + e62505b commit 059ec91
Show file tree
Hide file tree
Showing 42 changed files with 4,060 additions and 3,095 deletions.
17 changes: 17 additions & 0 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ oC_PropertyOrLabelsExpression
oC_Atom
: oC_Literal
| oC_Parameter
| oC_CaseExpression
| oC_ParenthesizedExpression
| oC_FunctionInvocation
| oC_ExistentialSubquery
Expand Down Expand Up @@ -428,6 +429,22 @@ EXISTS : ( 'E' | 'e' ) ( 'X' | 'x' ) ( 'I' | 'i' ) ( 'S' | 's' ) ( 'T' | 't' ) (
oC_PropertyLookup
: '.' SP? ( oC_PropertyKeyName ) ;

oC_CaseExpression
: ( ( CASE ( SP? oC_CaseAlternative )+ ) | ( CASE SP? oC_Expression ( SP? oC_CaseAlternative )+ ) ) ( SP? ELSE SP? oC_Expression )? SP? END ;

CASE : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'S' | 's' ) ( 'E' | 'e' ) ;

ELSE : ( 'E' | 'e' ) ( 'L' | 'l' ) ( 'S' | 's' ) ( 'E' | 'e' ) ;

END : ( 'E' | 'e' ) ( 'N' | 'n' ) ( 'D' | 'd' ) ;

oC_CaseAlternative
: WHEN SP? oC_Expression SP? THEN SP? oC_Expression ;

WHEN : ( 'W' | 'w' ) ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'N' | 'n' ) ;

THEN : ( 'T' | 't' ) ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'N' | 'n' ) ;

oC_Variable
: oC_SymbolicName ;

Expand Down
3 changes: 1 addition & 2 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ unique_ptr<BoundCreateRel> Binder::bindCreateRel(
} else {
auto propertyExpression =
expressionBinder.bindRelPropertyExpression(rel, property.name);
shared_ptr<Expression> nullExpression =
LiteralExpression::createNullLiteralExpression(getUniqueExpressionName("NULL"));
auto nullExpression = expressionBinder.bindNullLiteralExpression();
nullExpression = ExpressionBinder::implicitCastIfNecessary(
nullExpression, propertyExpression->dataType);
setItems.emplace_back(std::move(propertyExpression), std::move(nullExpression));
Expand Down
1 change: 1 addition & 0 deletions src/binder/expression/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(
kuzu_binder_expression
OBJECT
case_expression.cpp
existential_subquery_expression.cpp
expression.cpp)

Expand Down
17 changes: 17 additions & 0 deletions src/binder/expression/case_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "binder/expression/case_expression.h"

namespace kuzu {
namespace binder {

expression_vector CaseExpression::getChildren() const {
expression_vector result;
for (auto& caseAlternative : caseAlternatives) {
result.push_back(caseAlternative->whenExpression);
result.push_back(caseAlternative->thenExpression);
}
result.push_back(elseExpression);
return result;
}

} // namespace binder
} // namespace kuzu
10 changes: 0 additions & 10 deletions src/binder/expression/existential_subquery_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
namespace kuzu {
namespace binder {

unordered_set<string> ExistentialSubqueryExpression::getDependentVariableNames() {
unordered_set<string> result;
for (auto& expression : getChildren()) {
for (auto& variableName : expression->getDependentVariableNames()) {
result.insert(variableName);
}
}
return result;
}

// The children of subquery expressions is defined as all expressions in the subquery, i.e.
// expressions from predicates and return clause. Plus nodeID expressions from query graph.
expression_vector ExistentialSubqueryExpression::getChildren() const {
Expand Down
2 changes: 1 addition & 1 deletion src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ unordered_set<string> Expression::getDependentVariableNames() {
result.insert(getUniqueName());
return result;
}
for (auto& child : children) {
for (auto& child : getChildren()) {
for (auto& variableName : child->getDependentVariableNames()) {
result.insert(variableName);
}
Expand Down
55 changes: 53 additions & 2 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "binder/expression_binder.h"

#include "binder/binder.h"
#include "binder/expression/case_expression.h"
#include "binder/expression/existential_subquery_expression.h"
#include "binder/expression/function_expression.h"
#include "binder/expression/literal_expression.h"
Expand All @@ -9,6 +10,7 @@
#include "common/type_utils.h"
#include "function/boolean/vector_boolean_operations.h"
#include "function/null/vector_null_operations.h"
#include "parser/expression/parsed_case_expression.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/expression/parsed_literal_expression.h"
#include "parser/expression/parsed_parameter_expression.h"
Expand Down Expand Up @@ -42,6 +44,8 @@ shared_ptr<Expression> ExpressionBinder::bindExpression(const ParsedExpression&
expression = bindVariableExpression(parsedExpression);
} else if (EXISTENTIAL_SUBQUERY == expressionType) {
expression = bindExistentialSubqueryExpression(parsedExpression);
} else if (CASE_ELSE == expressionType) {
expression = bindCaseExpression(parsedExpression);
} else {
throw NotImplementedException(
"bindExpression(" + expressionTypeToString(expressionType) + ").");
Expand Down Expand Up @@ -376,12 +380,16 @@ shared_ptr<Expression> ExpressionBinder::bindLiteralExpression(
auto& literalExpression = (ParsedLiteralExpression&)parsedExpression;
auto literal = literalExpression.getLiteral();
if (literal->isNull()) {
return LiteralExpression::createNullLiteralExpression(
binder->getUniqueExpressionName("NULL"));
return bindNullLiteralExpression();
}
return make_shared<LiteralExpression>(literal->dataType, make_unique<Literal>(*literal));
}

shared_ptr<Expression> ExpressionBinder::bindNullLiteralExpression() {
return make_shared<LiteralExpression>(
DataType(ANY), make_unique<Literal>(), binder->getUniqueExpressionName("NULL"));
}

shared_ptr<Expression> ExpressionBinder::bindVariableExpression(
const ParsedExpression& parsedExpression) {
auto& variableExpression = (ParsedVariableExpression&)parsedExpression;
Expand All @@ -408,6 +416,49 @@ shared_ptr<Expression> ExpressionBinder::bindExistentialSubqueryExpression(
return boundSubqueryExpression;
}

shared_ptr<Expression> ExpressionBinder::bindCaseExpression(
const ParsedExpression& parsedExpression) {
auto& parsedCaseExpression = (ParsedCaseExpression&)parsedExpression;
auto anchorCaseAlternative = parsedCaseExpression.getCaseAlternative(0);
auto outDataType = bindExpression(*anchorCaseAlternative->thenExpression)->dataType;
auto name = binder->getUniqueExpressionName(parsedExpression.getRawName());
// bind ELSE ...
shared_ptr<Expression> elseExpression;
if (parsedCaseExpression.hasElseExpression()) {
elseExpression = bindExpression(*parsedCaseExpression.getElseExpression());
} else {
elseExpression = bindNullLiteralExpression();
}
elseExpression = implicitCastIfNecessary(elseExpression, outDataType);
auto boundCaseExpression =
make_shared<CaseExpression>(outDataType, std::move(elseExpression), name);
// bind WHEN ... THEN ...
if (parsedCaseExpression.hasCaseExpression()) {
auto boundCase = bindExpression(*parsedCaseExpression.getCaseExpression());
for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) {
auto caseAlternative = parsedCaseExpression.getCaseAlternative(i);
auto boundWhen = bindExpression(*caseAlternative->whenExpression);
boundWhen = implicitCastIfNecessary(boundWhen, boundCase->dataType);
// rewrite "CASE a.age WHEN 1" as "CASE WHEN a.age = 1"
boundWhen = bindComparisonExpression(
EQUALS, vector<shared_ptr<Expression>>{boundCase, boundWhen});
auto boundThen = bindExpression(*caseAlternative->thenExpression);
boundThen = implicitCastIfNecessary(boundThen, outDataType);
boundCaseExpression->addCaseAlternative(boundWhen, boundThen);
}
} else {
for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) {
auto caseAlternative = parsedCaseExpression.getCaseAlternative(i);
auto boundWhen = bindExpression(*caseAlternative->whenExpression);
boundWhen = implicitCastIfNecessary(boundWhen, BOOL);
auto boundThen = bindExpression(*caseAlternative->thenExpression);
boundThen = implicitCastIfNecessary(boundThen, outDataType);
boundCaseExpression->addCaseAlternative(boundWhen, boundThen);
}
}
return boundCaseExpression;
}

shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
const shared_ptr<Expression>& expression, DataType targetType) {
if (targetType.typeID == ANY || expression->dataType == targetType) {
Expand Down
5 changes: 2 additions & 3 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ ValueVector::ValueVector(DataType dataType, MemoryManager* memoryManager)

void ValueVector::addString(uint32_t pos, char* value, uint64_t len) const {
assert(dataType.typeID == STRING);
auto vectorData = (ku_string_t*)valueBuffer.get();
auto& result = vectorData[pos];
InMemOverflowBufferUtils::copyString(value, len, result, *inMemOverflowBuffer);
auto& entry = ((ku_string_t*)getData())[pos];
InMemOverflowBufferUtils::copyString(value, len, entry, *inMemOverflowBuffer);
}

bool NodeIDVector::discardNull(ValueVector& vector) {
Expand Down
1 change: 1 addition & 0 deletions src/expression_evaluator/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(kuzu_expression_evaluator
OBJECT
base_evaluator.cpp
case_evaluator.cpp
function_evaluator.cpp
literal_evaluator.cpp
reference_evaluator.cpp)
Expand Down
25 changes: 15 additions & 10 deletions src/expression_evaluator/base_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,25 @@
namespace kuzu {
namespace evaluator {

BaseExpressionEvaluator::BaseExpressionEvaluator(unique_ptr<BaseExpressionEvaluator> child) {
children.push_back(move(child));
}

BaseExpressionEvaluator::BaseExpressionEvaluator(
unique_ptr<BaseExpressionEvaluator> left, unique_ptr<BaseExpressionEvaluator> right) {
children.push_back(move(left));
children.push_back(move(right));
}

void BaseExpressionEvaluator::init(const ResultSet& resultSet, MemoryManager* memoryManager) {
for (auto& child : children) {
child->init(resultSet, memoryManager);
}
resolveResultVector(resultSet, memoryManager);
}

void BaseExpressionEvaluator::resolveResultStateFromChildren(
const vector<BaseExpressionEvaluator*>& inputEvaluators) {
for (auto& input : inputEvaluators) {
if (!input->isResultFlat()) {
isResultFlat_ = false;
resultVector->state = input->resultVector->state;
return;
}
}
// All children are flat.
isResultFlat_ = true;
resultVector->state = DataChunkState::getSingleValueDataChunkState();
}

} // namespace evaluator
Expand Down
Loading

0 comments on commit 059ec91

Please sign in to comment.