Skip to content

Commit

Permalink
rework group by aggregate binding
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jul 2, 2023
1 parent bd0f58d commit 424664d
Show file tree
Hide file tree
Showing 16 changed files with 197 additions and 182 deletions.
165 changes: 113 additions & 52 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/expression_visitor.h"
#include "binder/expression/literal_expression.h"

using namespace kuzu::common;
Expand All @@ -9,12 +10,10 @@ namespace binder {

std::unique_ptr<BoundWithClause> Binder::bindWithClause(const WithClause& withClause) {
auto projectionBody = withClause.getProjectionBody();
auto boundProjectionExpressions = bindProjectionExpressions(
auto projectionExpressions = bindProjectionExpressions(
projectionBody->getProjectionExpressions(), projectionBody->containsStar());
validateProjectionColumnsInWithClauseAreAliased(boundProjectionExpressions);
auto boundProjectionBody = std::make_unique<BoundProjectionBody>(
projectionBody->getIsDistinct(), std::move(boundProjectionExpressions));
bindOrderBySkipLimitIfNecessary(*boundProjectionBody, *projectionBody);
validateProjectionColumnsInWithClauseAreAliased(projectionExpressions);
auto boundProjectionBody = bindProjectionBody(*projectionBody, projectionExpressions);
validateOrderByFollowedBySkipOrLimitInWithClause(*boundProjectionBody);
variableScope->clear();
addExpressionsToScope(boundProjectionBody->getProjectionExpressions());
Expand All @@ -39,32 +38,129 @@ std::unique_ptr<BoundReturnClause> Binder::bindReturnClause(const ReturnClause&
statementResult->addColumn(expression, expression_vector{expression});
}
}
auto boundProjectionBody = std::make_unique<BoundProjectionBody>(
projectionBody->getIsDistinct(), statementResult->getExpressionsToCollect());
bindOrderBySkipLimitIfNecessary(*boundProjectionBody, *projectionBody);
auto boundProjectionBody =
bindProjectionBody(*projectionBody, statementResult->getExpressionsToCollect());
return std::make_unique<BoundReturnClause>(
std::move(boundProjectionBody), std::move(statementResult));
}

static bool isAggregateExpression(
const std::shared_ptr<Expression>& expression, VariableScope* scope) {
if (expression->hasAlias() && scope->contains(expression->getAlias())) {
return false;
}
if (expression->expressionType == common::AGGREGATE_FUNCTION) {
return true;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(*expression)) {
if (isAggregateExpression(child, scope)) {
return true;
}
}
return false;
}

static expression_vector getAggregateExpressions(
const std::shared_ptr<Expression>& expression, VariableScope* scope) {
expression_vector result;
if (expression->hasAlias() && scope->contains(expression->getAlias())) {
return result;
}
if (expression->expressionType == common::AGGREGATE_FUNCTION) {
result.push_back(expression);
return result;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(*expression)) {
for (auto& expr : getAggregateExpressions(child, scope)) {
result.push_back(expr);
}
}
return result;
}

std::unique_ptr<BoundProjectionBody> Binder::bindProjectionBody(
const parser::ProjectionBody& projectionBody, const expression_vector& projectionExpressions) {
auto boundProjectionBody = std::make_unique<BoundProjectionBody>(
projectionBody.getIsDistinct(), projectionExpressions);
// Bind group by & aggregate.
expression_vector groupByExpressions;
expression_vector aggregateExpressions;
for (auto& expression : projectionExpressions) {
if (isAggregateExpression(expression, variableScope.get())) {
for (auto& agg : getAggregateExpressions(expression, variableScope.get())) {
aggregateExpressions.push_back(agg);
}
} else {
groupByExpressions.push_back(expression);
}
}
if (!groupByExpressions.empty()) {
expression_vector augmentedGroupByExpressions = groupByExpressions;
for (auto& expression : groupByExpressions) {
if (ExpressionUtil::isNodeVariable(*expression)) {
auto node = (NodeExpression*)expression.get();
augmentedGroupByExpressions.push_back(node->getInternalIDProperty());
}
}
boundProjectionBody->setGroupByExpressions(std::move(augmentedGroupByExpressions));
}
if (!aggregateExpressions.empty()) {
boundProjectionBody->setAggregateExpressions(std::move(aggregateExpressions));
}
// Bind order by
if (projectionBody.hasOrderByExpressions()) {
addExpressionsToScope(projectionExpressions);
auto orderByExpressions = bindOrderByExpressions(projectionBody.getOrderByExpressions());
// Cypher rule of ORDER BY expression scope: if projection contains aggregation, only
// expressions in projection are available. Otherwise, expressions before projection are
// also available
if (boundProjectionBody->hasAggregateExpressions()) {
// TODO(Xiyang): abstract return/with clause as a temporary table and introduce
// reference expression to solve this. Our property expression should also be changed to
// reference expression.
auto projectionExpressionSet =
expression_set{projectionExpressions.begin(), projectionExpressions.end()};
for (auto& orderByExpression : orderByExpressions) {
if (!projectionExpressionSet.contains(orderByExpression)) {
throw BinderException("Order by expression " + orderByExpression->toString() +
" is not in RETURN or WITH clause.");
}
}
}
boundProjectionBody->setOrderByExpressions(
std::move(orderByExpressions), projectionBody.getSortOrders());
}
// Bind skip
if (projectionBody.hasSkipExpression()) {
boundProjectionBody->setSkipNumber(
bindSkipLimitExpression(*projectionBody.getSkipExpression()));
}
// Bind limit
if (projectionBody.hasLimitExpression()) {
boundProjectionBody->setLimitNumber(
bindSkipLimitExpression(*projectionBody.getLimitExpression()));
}
return boundProjectionBody;
}

expression_vector Binder::bindProjectionExpressions(
const std::vector<std::unique_ptr<ParsedExpression>>& projectionExpressions,
bool containsStar) {
expression_vector boundProjectionExpressions;
const parsed_expression_vector& projectionExpressions, bool star) {
expression_vector result;
for (auto& expression : projectionExpressions) {
boundProjectionExpressions.push_back(expressionBinder.bindExpression(*expression));
result.push_back(expressionBinder.bindExpression(*expression));
}
if (containsStar) {
if (star) {
if (variableScope->empty()) {
throw BinderException(
"RETURN or WITH * is not allowed when there are no variables in scope.");
}
for (auto& expression : variableScope->getExpressions()) {
boundProjectionExpressions.push_back(expression);
result.push_back(expression);
}
}
resolveAnyDataTypeWithDefaultType(boundProjectionExpressions);
validateProjectionColumnNamesAreUnique(boundProjectionExpressions);
return boundProjectionExpressions;
resolveAnyDataTypeWithDefaultType(result);
validateProjectionColumnNamesAreUnique(result);
return result;
}

expression_vector Binder::rewriteNodeOrRelExpression(const Expression& expression) {
Expand Down Expand Up @@ -99,41 +195,6 @@ expression_vector Binder::rewriteRelExpression(const Expression& expression) {
return result;
}

void Binder::bindOrderBySkipLimitIfNecessary(
BoundProjectionBody& boundProjectionBody, const ProjectionBody& projectionBody) {
auto projectionExpressions = boundProjectionBody.getProjectionExpressions();
if (projectionBody.hasOrderByExpressions()) {
addExpressionsToScope(projectionExpressions);
auto orderByExpressions = bindOrderByExpressions(projectionBody.getOrderByExpressions());
// Cypher rule of ORDER BY expression scope: if projection contains aggregation, only
// expressions in projection are available. Otherwise, expressions before projection are
// also available
if (boundProjectionBody.hasAggregationExpressions()) {
// TODO(Xiyang): abstract return/with clause as a temporary table and introduce
// reference expression to solve this. Our property expression should also be changed to
// reference expression.
auto projectionExpressionSet =
expression_set{projectionExpressions.begin(), projectionExpressions.end()};
for (auto& orderByExpression : orderByExpressions) {
if (!projectionExpressionSet.contains(orderByExpression)) {
throw BinderException("Order by expression " + orderByExpression->toString() +
" is not in RETURN or WITH clause.");
}
}
}
boundProjectionBody.setOrderByExpressions(
std::move(orderByExpressions), projectionBody.getSortOrders());
}
if (projectionBody.hasSkipExpression()) {
boundProjectionBody.setSkipNumber(
bindSkipLimitExpression(*projectionBody.getSkipExpression()));
}
if (projectionBody.hasLimitExpression()) {
boundProjectionBody.setLimitNumber(
bindSkipLimitExpression(*projectionBody.getLimitExpression()));
}
}

expression_vector Binder::bindOrderByExpressions(
const std::vector<std::unique_ptr<ParsedExpression>>& orderByExpressions) {
expression_vector boundOrderByExpressions;
Expand Down
1 change: 0 additions & 1 deletion src/binder/query/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ add_library(
OBJECT
bound_create_clause.cpp
bound_delete_clause.cpp
bound_projection_body.cpp
bound_set_clause.cpp
query_graph.cpp)

Expand Down
24 changes: 0 additions & 24 deletions src/binder/query/bound_projection_body.cpp

This file was deleted.

8 changes: 4 additions & 4 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,17 @@ class Binder {
/*** bind projection clause ***/
std::unique_ptr<BoundWithClause> bindWithClause(const parser::WithClause& withClause);
std::unique_ptr<BoundReturnClause> bindReturnClause(const parser::ReturnClause& returnClause);
std::unique_ptr<BoundProjectionBody> bindProjectionBody(
const parser::ProjectionBody& projectionBody,
const expression_vector& projectionExpressions);

expression_vector bindProjectionExpressions(
const std::vector<std::unique_ptr<parser::ParsedExpression>>& projectionExpressions,
bool containsStar);
const parser::parsed_expression_vector& parsedExpressions, bool star);
// Rewrite variable "v" as all properties of "v"
expression_vector rewriteNodeOrRelExpression(const Expression& expression);
expression_vector rewriteNodeExpression(const Expression& expression);
expression_vector rewriteRelExpression(const Expression& expression);

void bindOrderBySkipLimitIfNecessary(
BoundProjectionBody& boundProjectionBody, const parser::ProjectionBody& projectionBody);
expression_vector bindOrderByExpressions(
const std::vector<std::unique_ptr<parser::ParsedExpression>>& orderByExpressions);
uint64_t bindSkipLimitExpression(const parser::ParsedExpression& expression);
Expand Down
5 changes: 5 additions & 0 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ struct ExpressionUtil {

static expression_vector excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude);

inline static bool isNodeVariable(const Expression& expression) {
return expression.expressionType == common::ExpressionType::VARIABLE &&
expression.dataType.getLogicalTypeID() == common::LogicalTypeID::NODE;
}
};

} // namespace binder
Expand Down
37 changes: 23 additions & 14 deletions src/include/binder/query/return_with_clause/bound_projection_body.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,49 @@ namespace kuzu {
namespace binder {

class BoundProjectionBody {
static constexpr uint64_t INVALID_NUMBER = UINT64_MAX;

public:
explicit BoundProjectionBody(bool isDistinct, expression_vector projectionExpressions)
BoundProjectionBody(bool isDistinct, expression_vector projectionExpressions)
: isDistinct{isDistinct}, projectionExpressions{std::move(projectionExpressions)},
skipNumber{UINT64_MAX}, limitNumber{UINT64_MAX} {}

skipNumber{INVALID_NUMBER}, limitNumber{INVALID_NUMBER} {}
BoundProjectionBody(const BoundProjectionBody& other)
: isDistinct{other.isDistinct}, projectionExpressions{other.projectionExpressions},
groupByExpressions{other.groupByExpressions},
aggregateExpressions{other.aggregateExpressions},
orderByExpressions{other.orderByExpressions}, isAscOrders{other.isAscOrders},
skipNumber{other.skipNumber}, limitNumber{other.limitNumber} {}

~BoundProjectionBody() = default;

inline bool getIsDistinct() const { return isDistinct; }

inline expression_vector getProjectionExpressions() const { return projectionExpressions; }

bool hasAggregationExpressions() const;
inline void setGroupByExpressions(expression_vector expressions) {
groupByExpressions = std::move(expressions);
}
inline expression_vector getGroupByExpressions() const { return groupByExpressions; }

void setOrderByExpressions(expression_vector expressions, std::vector<bool> sortOrders);
inline void setAggregateExpressions(expression_vector expressions) {
aggregateExpressions = std::move(expressions);
}
inline bool hasAggregateExpressions() const { return !aggregateExpressions.empty(); }
inline expression_vector getAggregateExpressions() const { return aggregateExpressions; }

inline void setOrderByExpressions(expression_vector expressions, std::vector<bool> sortOrders) {
orderByExpressions = std::move(expressions);
isAscOrders = std::move(sortOrders);
}
inline bool hasOrderByExpressions() const { return !orderByExpressions.empty(); }

inline const expression_vector& getOrderByExpressions() const { return orderByExpressions; }

inline const std::vector<bool>& getSortingOrders() const { return isAscOrders; }

inline void setSkipNumber(uint64_t number) { skipNumber = number; }

inline bool hasSkip() const { return skipNumber != UINT64_MAX; }

inline bool hasSkip() const { return skipNumber != INVALID_NUMBER; }
inline uint64_t getSkipNumber() const { return skipNumber; }

inline void setLimitNumber(uint64_t number) { limitNumber = number; }

inline bool hasLimit() const { return limitNumber != UINT64_MAX; }

inline bool hasLimit() const { return limitNumber != INVALID_NUMBER; }
inline uint64_t getLimitNumber() const { return limitNumber; }

inline bool hasSkipOrLimit() const { return hasSkip() || hasLimit(); }
Expand All @@ -53,6 +60,8 @@ class BoundProjectionBody {
private:
bool isDistinct;
expression_vector projectionExpressions;
expression_vector groupByExpressions;
expression_vector aggregateExpressions;
expression_vector orderByExpressions;
std::vector<bool> isAscOrders;
uint64_t skipNumber;
Expand Down
5 changes: 4 additions & 1 deletion src/include/parser/expression/parsed_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
namespace kuzu {
namespace parser {

class ParsedExpression;
using parsed_expression_vector = std::vector<std::unique_ptr<ParsedExpression>>;

class ParsedExpression {
public:
ParsedExpression(
Expand Down Expand Up @@ -41,7 +44,7 @@ class ParsedExpression {
common::ExpressionType type;
std::string alias;
std::string rawName;
std::vector<std::unique_ptr<ParsedExpression>> children;
parsed_expression_vector children;
};

} // namespace parser
Expand Down
7 changes: 6 additions & 1 deletion src/include/planner/logical_plan/logical_operator/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,14 @@ class Schema {
f_group_pos createGroup();

void insertToScope(const std::shared_ptr<binder::Expression>& expression, uint32_t groupPos);

void insertToGroupAndScope(
const std::shared_ptr<binder::Expression>& expression, uint32_t groupPos);
// Use these unsafe insert functions only if the operator may work with duplicate expressions.
// E.g. group by a.age, a.age
void insertToScopeMayRepeat(
const std::shared_ptr<binder::Expression>& expression, uint32_t groupPos);
void insertToGroupAndScopeMayRepeat(
const std::shared_ptr<binder::Expression>& expression, uint32_t groupPos);

void insertToGroupAndScope(const binder::expression_vector& expressions, uint32_t groupPos);

Expand Down
Loading

0 comments on commit 424664d

Please sign in to comment.