Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bind group by agg rework #1748

Merged
merged 1 commit into from
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 113 additions & 52 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/expression_visitor.h"
#include "binder/expression/literal_expression.h"

using namespace kuzu::common;
Expand All @@ -9,12 +10,10 @@ namespace binder {

std::unique_ptr<BoundWithClause> Binder::bindWithClause(const WithClause& withClause) {
auto projectionBody = withClause.getProjectionBody();
auto boundProjectionExpressions = bindProjectionExpressions(
auto projectionExpressions = bindProjectionExpressions(
projectionBody->getProjectionExpressions(), projectionBody->containsStar());
validateProjectionColumnsInWithClauseAreAliased(boundProjectionExpressions);
auto boundProjectionBody = std::make_unique<BoundProjectionBody>(
projectionBody->getIsDistinct(), std::move(boundProjectionExpressions));
bindOrderBySkipLimitIfNecessary(*boundProjectionBody, *projectionBody);
validateProjectionColumnsInWithClauseAreAliased(projectionExpressions);
auto boundProjectionBody = bindProjectionBody(*projectionBody, projectionExpressions);
validateOrderByFollowedBySkipOrLimitInWithClause(*boundProjectionBody);
variableScope->clear();
addExpressionsToScope(boundProjectionBody->getProjectionExpressions());
Expand All @@ -39,32 +38,129 @@ std::unique_ptr<BoundReturnClause> Binder::bindReturnClause(const ReturnClause&
statementResult->addColumn(expression, expression_vector{expression});
}
}
auto boundProjectionBody = std::make_unique<BoundProjectionBody>(
projectionBody->getIsDistinct(), statementResult->getExpressionsToCollect());
bindOrderBySkipLimitIfNecessary(*boundProjectionBody, *projectionBody);
auto boundProjectionBody =
bindProjectionBody(*projectionBody, statementResult->getExpressionsToCollect());
return std::make_unique<BoundReturnClause>(
std::move(boundProjectionBody), std::move(statementResult));
}

static bool isAggregateExpression(
const std::shared_ptr<Expression>& expression, VariableScope* scope) {
if (expression->hasAlias() && scope->contains(expression->getAlias())) {
return false;
}
if (expression->expressionType == common::AGGREGATE_FUNCTION) {
return true;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(*expression)) {
if (isAggregateExpression(child, scope)) {
return true;
}
}
return false;
}

static expression_vector getAggregateExpressions(
const std::shared_ptr<Expression>& expression, VariableScope* scope) {
expression_vector result;
if (expression->hasAlias() && scope->contains(expression->getAlias())) {
return result;
}
if (expression->expressionType == common::AGGREGATE_FUNCTION) {
result.push_back(expression);
return result;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(*expression)) {
for (auto& expr : getAggregateExpressions(child, scope)) {
result.push_back(expr);
}
}
return result;
}

std::unique_ptr<BoundProjectionBody> Binder::bindProjectionBody(
const parser::ProjectionBody& projectionBody, const expression_vector& projectionExpressions) {
auto boundProjectionBody = std::make_unique<BoundProjectionBody>(
projectionBody.getIsDistinct(), projectionExpressions);
// Bind group by & aggregate.
expression_vector groupByExpressions;
expression_vector aggregateExpressions;
for (auto& expression : projectionExpressions) {
if (isAggregateExpression(expression, variableScope.get())) {
for (auto& agg : getAggregateExpressions(expression, variableScope.get())) {
aggregateExpressions.push_back(agg);
}
} else {
groupByExpressions.push_back(expression);
}
}
if (!groupByExpressions.empty()) {
expression_vector augmentedGroupByExpressions = groupByExpressions;
for (auto& expression : groupByExpressions) {
if (ExpressionUtil::isNodeVariable(*expression)) {
auto node = (NodeExpression*)expression.get();
augmentedGroupByExpressions.push_back(node->getInternalIDProperty());
}
}
boundProjectionBody->setGroupByExpressions(std::move(augmentedGroupByExpressions));
}
if (!aggregateExpressions.empty()) {
boundProjectionBody->setAggregateExpressions(std::move(aggregateExpressions));
}
// Bind order by
if (projectionBody.hasOrderByExpressions()) {
addExpressionsToScope(projectionExpressions);
auto orderByExpressions = bindOrderByExpressions(projectionBody.getOrderByExpressions());
// Cypher rule of ORDER BY expression scope: if projection contains aggregation, only
// expressions in projection are available. Otherwise, expressions before projection are
// also available
if (boundProjectionBody->hasAggregateExpressions()) {
// TODO(Xiyang): abstract return/with clause as a temporary table and introduce
// reference expression to solve this. Our property expression should also be changed to
// reference expression.
auto projectionExpressionSet =
expression_set{projectionExpressions.begin(), projectionExpressions.end()};
for (auto& orderByExpression : orderByExpressions) {
if (!projectionExpressionSet.contains(orderByExpression)) {
throw BinderException("Order by expression " + orderByExpression->toString() +
" is not in RETURN or WITH clause.");
}
}
}
boundProjectionBody->setOrderByExpressions(
std::move(orderByExpressions), projectionBody.getSortOrders());
}
// Bind skip
if (projectionBody.hasSkipExpression()) {
boundProjectionBody->setSkipNumber(
bindSkipLimitExpression(*projectionBody.getSkipExpression()));
}
// Bind limit
if (projectionBody.hasLimitExpression()) {
boundProjectionBody->setLimitNumber(
bindSkipLimitExpression(*projectionBody.getLimitExpression()));
}
return boundProjectionBody;
}

expression_vector Binder::bindProjectionExpressions(
const std::vector<std::unique_ptr<ParsedExpression>>& projectionExpressions,
bool containsStar) {
expression_vector boundProjectionExpressions;
const parsed_expression_vector& projectionExpressions, bool star) {
expression_vector result;
for (auto& expression : projectionExpressions) {
boundProjectionExpressions.push_back(expressionBinder.bindExpression(*expression));
result.push_back(expressionBinder.bindExpression(*expression));
}
if (containsStar) {
if (star) {
if (variableScope->empty()) {
throw BinderException(
"RETURN or WITH * is not allowed when there are no variables in scope.");
}
for (auto& expression : variableScope->getExpressions()) {
boundProjectionExpressions.push_back(expression);
result.push_back(expression);
}
}
resolveAnyDataTypeWithDefaultType(boundProjectionExpressions);
validateProjectionColumnNamesAreUnique(boundProjectionExpressions);
return boundProjectionExpressions;
resolveAnyDataTypeWithDefaultType(result);
validateProjectionColumnNamesAreUnique(result);
return result;
}

expression_vector Binder::rewriteNodeOrRelExpression(const Expression& expression) {
Expand Down Expand Up @@ -99,41 +195,6 @@ expression_vector Binder::rewriteRelExpression(const Expression& expression) {
return result;
}

void Binder::bindOrderBySkipLimitIfNecessary(
BoundProjectionBody& boundProjectionBody, const ProjectionBody& projectionBody) {
auto projectionExpressions = boundProjectionBody.getProjectionExpressions();
if (projectionBody.hasOrderByExpressions()) {
addExpressionsToScope(projectionExpressions);
auto orderByExpressions = bindOrderByExpressions(projectionBody.getOrderByExpressions());
// Cypher rule of ORDER BY expression scope: if projection contains aggregation, only
// expressions in projection are available. Otherwise, expressions before projection are
// also available
if (boundProjectionBody.hasAggregationExpressions()) {
// TODO(Xiyang): abstract return/with clause as a temporary table and introduce
// reference expression to solve this. Our property expression should also be changed to
// reference expression.
auto projectionExpressionSet =
expression_set{projectionExpressions.begin(), projectionExpressions.end()};
for (auto& orderByExpression : orderByExpressions) {
if (!projectionExpressionSet.contains(orderByExpression)) {
throw BinderException("Order by expression " + orderByExpression->toString() +
" is not in RETURN or WITH clause.");
}
}
}
boundProjectionBody.setOrderByExpressions(
std::move(orderByExpressions), projectionBody.getSortOrders());
}
if (projectionBody.hasSkipExpression()) {
boundProjectionBody.setSkipNumber(
bindSkipLimitExpression(*projectionBody.getSkipExpression()));
}
if (projectionBody.hasLimitExpression()) {
boundProjectionBody.setLimitNumber(
bindSkipLimitExpression(*projectionBody.getLimitExpression()));
}
}

expression_vector Binder::bindOrderByExpressions(
const std::vector<std::unique_ptr<ParsedExpression>>& orderByExpressions) {
expression_vector boundOrderByExpressions;
Expand Down
1 change: 0 additions & 1 deletion src/binder/query/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ add_library(
OBJECT
bound_create_clause.cpp
bound_delete_clause.cpp
bound_projection_body.cpp
bound_set_clause.cpp
query_graph.cpp)

Expand Down
24 changes: 0 additions & 24 deletions src/binder/query/bound_projection_body.cpp

This file was deleted.

8 changes: 4 additions & 4 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,17 @@ class Binder {
/*** bind projection clause ***/
std::unique_ptr<BoundWithClause> bindWithClause(const parser::WithClause& withClause);
std::unique_ptr<BoundReturnClause> bindReturnClause(const parser::ReturnClause& returnClause);
std::unique_ptr<BoundProjectionBody> bindProjectionBody(
const parser::ProjectionBody& projectionBody,
const expression_vector& projectionExpressions);

expression_vector bindProjectionExpressions(
const std::vector<std::unique_ptr<parser::ParsedExpression>>& projectionExpressions,
bool containsStar);
const parser::parsed_expression_vector& parsedExpressions, bool star);
// Rewrite variable "v" as all properties of "v"
expression_vector rewriteNodeOrRelExpression(const Expression& expression);
expression_vector rewriteNodeExpression(const Expression& expression);
expression_vector rewriteRelExpression(const Expression& expression);

void bindOrderBySkipLimitIfNecessary(
BoundProjectionBody& boundProjectionBody, const parser::ProjectionBody& projectionBody);
expression_vector bindOrderByExpressions(
const std::vector<std::unique_ptr<parser::ParsedExpression>>& orderByExpressions);
uint64_t bindSkipLimitExpression(const parser::ParsedExpression& expression);
Expand Down
5 changes: 5 additions & 0 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ struct ExpressionUtil {

static expression_vector excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude);

inline static bool isNodeVariable(const Expression& expression) {
return expression.expressionType == common::ExpressionType::VARIABLE &&
expression.dataType.getLogicalTypeID() == common::LogicalTypeID::NODE;
}
};

} // namespace binder
Expand Down
37 changes: 23 additions & 14 deletions src/include/binder/query/return_with_clause/bound_projection_body.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,49 @@ namespace kuzu {
namespace binder {

class BoundProjectionBody {
static constexpr uint64_t INVALID_NUMBER = UINT64_MAX;

public:
explicit BoundProjectionBody(bool isDistinct, expression_vector projectionExpressions)
BoundProjectionBody(bool isDistinct, expression_vector projectionExpressions)
: isDistinct{isDistinct}, projectionExpressions{std::move(projectionExpressions)},
skipNumber{UINT64_MAX}, limitNumber{UINT64_MAX} {}

skipNumber{INVALID_NUMBER}, limitNumber{INVALID_NUMBER} {}
BoundProjectionBody(const BoundProjectionBody& other)
: isDistinct{other.isDistinct}, projectionExpressions{other.projectionExpressions},
groupByExpressions{other.groupByExpressions},
aggregateExpressions{other.aggregateExpressions},
orderByExpressions{other.orderByExpressions}, isAscOrders{other.isAscOrders},
skipNumber{other.skipNumber}, limitNumber{other.limitNumber} {}

~BoundProjectionBody() = default;

inline bool getIsDistinct() const { return isDistinct; }

inline expression_vector getProjectionExpressions() const { return projectionExpressions; }

bool hasAggregationExpressions() const;
inline void setGroupByExpressions(expression_vector expressions) {
groupByExpressions = std::move(expressions);
}
inline expression_vector getGroupByExpressions() const { return groupByExpressions; }

void setOrderByExpressions(expression_vector expressions, std::vector<bool> sortOrders);
inline void setAggregateExpressions(expression_vector expressions) {
aggregateExpressions = std::move(expressions);
}
inline bool hasAggregateExpressions() const { return !aggregateExpressions.empty(); }
inline expression_vector getAggregateExpressions() const { return aggregateExpressions; }

inline void setOrderByExpressions(expression_vector expressions, std::vector<bool> sortOrders) {
orderByExpressions = std::move(expressions);
isAscOrders = std::move(sortOrders);
}
inline bool hasOrderByExpressions() const { return !orderByExpressions.empty(); }

inline const expression_vector& getOrderByExpressions() const { return orderByExpressions; }

inline const std::vector<bool>& getSortingOrders() const { return isAscOrders; }

inline void setSkipNumber(uint64_t number) { skipNumber = number; }

inline bool hasSkip() const { return skipNumber != UINT64_MAX; }

inline bool hasSkip() const { return skipNumber != INVALID_NUMBER; }
inline uint64_t getSkipNumber() const { return skipNumber; }

inline void setLimitNumber(uint64_t number) { limitNumber = number; }

inline bool hasLimit() const { return limitNumber != UINT64_MAX; }

inline bool hasLimit() const { return limitNumber != INVALID_NUMBER; }
inline uint64_t getLimitNumber() const { return limitNumber; }

inline bool hasSkipOrLimit() const { return hasSkip() || hasLimit(); }
Expand All @@ -53,6 +60,8 @@ class BoundProjectionBody {
private:
bool isDistinct;
expression_vector projectionExpressions;
expression_vector groupByExpressions;
expression_vector aggregateExpressions;
expression_vector orderByExpressions;
std::vector<bool> isAscOrders;
uint64_t skipNumber;
Expand Down
5 changes: 4 additions & 1 deletion src/include/parser/expression/parsed_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
namespace kuzu {
namespace parser {

class ParsedExpression;
using parsed_expression_vector = std::vector<std::unique_ptr<ParsedExpression>>;

class ParsedExpression {
public:
ParsedExpression(
Expand Down Expand Up @@ -41,7 +44,7 @@ class ParsedExpression {
common::ExpressionType type;
std::string alias;
std::string rawName;
std::vector<std::unique_ptr<ParsedExpression>> children;
parsed_expression_vector children;
};

} // namespace parser
Expand Down
7 changes: 6 additions & 1 deletion src/include/planner/logical_plan/logical_operator/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,14 @@ class Schema {
f_group_pos createGroup();

void insertToScope(const std::shared_ptr<binder::Expression>& expression, uint32_t groupPos);

void insertToGroupAndScope(
const std::shared_ptr<binder::Expression>& expression, uint32_t groupPos);
// Use these unsafe insert functions only if the operator may work with duplicate expressions.
// E.g. group by a.age, a.age
void insertToScopeMayRepeat(
const std::shared_ptr<binder::Expression>& expression, uint32_t groupPos);
void insertToGroupAndScopeMayRepeat(
const std::shared_ptr<binder::Expression>& expression, uint32_t groupPos);

void insertToGroupAndScope(const binder::expression_vector& expressions, uint32_t groupPos);

Expand Down
Loading
Loading