Skip to content

Commit

Permalink
Merge pull request #1956 from kuzudb/rewrite-with-clause-projection-list
Browse files Browse the repository at this point in the history
Add with clause projection list rewriter
  • Loading branch information
andyfengHKU committed Aug 25, 2023
2 parents dc4266a + fae0134 commit 9526e5e
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 44 deletions.
2 changes: 2 additions & 0 deletions src/binder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ add_subdirectory(bind)
add_subdirectory(bind_expression)
add_subdirectory(expression)
add_subdirectory(query)
add_subdirectory(rewriter)
add_subdirectory(visitor)

add_library(kuzu_binder
OBJECT
binder.cpp
bound_statement_result.cpp
bound_statement_rewriter.cpp
bound_statement_visitor.cpp
expression_binder.cpp
expression_visitor.cpp)
Expand Down
47 changes: 29 additions & 18 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,44 @@ using namespace kuzu::parser;
namespace kuzu {
namespace binder {

std::unique_ptr<BoundWithClause> Binder::bindWithClause(const WithClause& withClause) {
auto projectionBody = withClause.getProjectionBody();
auto projectionExpressions =
bindProjectionExpressions(projectionBody->getProjectionExpressions());
validateProjectionColumnsInWithClauseAreAliased(projectionExpressions);
expression_vector newProjectionExpressions;
for (auto& expression : projectionExpressions) {
// WITH clause is like SQL CTE. So the projection list of WITH clause should be explicitly
// evaluated. This, however, creates problem in the following case
// MATCH (a) WITH a RETURN a.age;
// Although only a.age is needed for further processing. The CTE "MATCH (a) WITH a" require us to
// fully materialize all columns of "a". Note that we cannot rely on projection push down to
// optimize this because projection pushdown assumes all columns in WITH/RETURN are needed.
// Our solution is:
// First rewrite node and rel as their INTERNAL ID property in WITH clause. So
// MATCH (a) WITH a._id RETURN a.age;
// And then apply WithClauseProjectionRewriter after binding to rewrite as
// MATCH (a) WITH a._id, a.age RETURN a.age
static expression_vector rewriteProjectionInWithClause(const expression_vector& expressions) {
expression_vector result;
for (auto& expression : expressions) {
if (ExpressionUtil::isNodeVariable(*expression)) {
auto node = (NodeExpression*)expression.get();
newProjectionExpressions.push_back(node->getInternalIDProperty());
for (auto& property : node->getPropertyExpressions()) {
newProjectionExpressions.push_back(property->copy());
}
result.push_back(node->getInternalIDProperty());
} else if (ExpressionUtil::isRelVariable(*expression)) {
auto rel = (RelExpression*)expression.get();
for (auto& property : rel->getPropertyExpressions()) {
newProjectionExpressions.push_back(property->copy());
}
result.push_back(rel->getInternalIDProperty());
} else if (ExpressionUtil::isRecursiveRelVariable(*expression)) {
auto rel = (RelExpression*)expression.get();
newProjectionExpressions.push_back(expression);
newProjectionExpressions.push_back(rel->getLengthExpression());
result.push_back(expression);
result.push_back(rel->getLengthExpression());
} else {
newProjectionExpressions.push_back(expression);
result.push_back(expression);
}
}
auto boundProjectionBody = bindProjectionBody(*projectionBody, newProjectionExpressions);
return result;
}

std::unique_ptr<BoundWithClause> Binder::bindWithClause(const WithClause& withClause) {
auto projectionBody = withClause.getProjectionBody();
auto projectionExpressions =
bindProjectionExpressions(projectionBody->getProjectionExpressions());
validateProjectionColumnsInWithClauseAreAliased(projectionExpressions);
auto boundProjectionBody =
bindProjectionBody(*projectionBody, rewriteProjectionInWithClause(projectionExpressions));
validateOrderByFollowedBySkipOrLimitInWithClause(*boundProjectionBody);
scope->clear();
addExpressionsToScope(projectionExpressions);
Expand Down
56 changes: 30 additions & 26 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "binder/binder.h"

#include "binder/bound_statement_rewriter.h"
#include "binder/expression/variable_expression.h"
#include "common/string_utils.h"

Expand All @@ -11,49 +12,52 @@ namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
std::unique_ptr<BoundStatement> boundStatement;
switch (statement.getStatementType()) {
case StatementType::CREATE_NODE_TABLE: {
return bindCreateNodeTableClause(statement);
}
boundStatement = bindCreateNodeTableClause(statement);
} break;
case StatementType::CREATE_REL_TABLE: {
return bindCreateRelTableClause(statement);
}
boundStatement = bindCreateRelTableClause(statement);
} break;
case StatementType::COPY_FROM: {
return bindCopyFromClause(statement);
}
boundStatement = bindCopyFromClause(statement);
} break;
case StatementType::COPY_TO: {
return bindCopyToClause(statement);
}
boundStatement = bindCopyToClause(statement);
} break;
case StatementType::DROP_TABLE: {
return bindDropTableClause(statement);
}
boundStatement = bindDropTableClause(statement);
} break;
case StatementType::RENAME_TABLE: {
return bindRenameTableClause(statement);
}
boundStatement = bindRenameTableClause(statement);
} break;
case StatementType::ADD_PROPERTY: {
return bindAddPropertyClause(statement);
}
boundStatement = bindAddPropertyClause(statement);
} break;
case StatementType::DROP_PROPERTY: {
return bindDropPropertyClause(statement);
}
boundStatement = bindDropPropertyClause(statement);
} break;
case StatementType::RENAME_PROPERTY: {
return bindRenamePropertyClause(statement);
}
boundStatement = bindRenamePropertyClause(statement);
} break;
case StatementType::QUERY: {
return bindQuery((const RegularQuery&)statement);
}
boundStatement = bindQuery((const RegularQuery&)statement);
} break;
case StatementType::STANDALONE_CALL: {
return bindStandaloneCall(statement);
}
boundStatement = bindStandaloneCall(statement);
} break;
case StatementType::EXPLAIN: {
return bindExplain(statement);
}
boundStatement = bindExplain(statement);
} break;
case StatementType::CREATE_MACRO: {
return bindCreateMacro(statement);
}
boundStatement = bindCreateMacro(statement);
} break;
default:
throw NotImplementedException("Binder::bind");
}
BoundStatementRewriter::rewrite(*boundStatement);
return boundStatement;
}

std::shared_ptr<Expression> Binder::bindWhereExpression(const ParsedExpression& parsedExpression) {
Expand Down
14 changes: 14 additions & 0 deletions src/binder/bound_statement_rewriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "binder/bound_statement_rewriter.h"

#include "binder/rewriter/with_clause_projection_rewriter.h"

namespace kuzu {
namespace binder {

void BoundStatementRewriter::rewrite(BoundStatement& boundStatement) {
auto withClauseProjectionRewriter = WithClauseProjectionRewriter();
withClauseProjectionRewriter.visit(boundStatement);
}

} // namespace binder
} // namespace kuzu
8 changes: 8 additions & 0 deletions src/binder/rewriter/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_library(
kuzu_binder_rewriter
OBJECT
with_clause_projection_rewriter.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_binder_rewriter>
PARENT_SCOPE)
62 changes: 62 additions & 0 deletions src/binder/rewriter/with_clause_projection_rewriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "binder/rewriter/with_clause_projection_rewriter.h"

#include "binder/visitor/property_collector.h"

using namespace kuzu::common;

namespace kuzu {
namespace binder {

static expression_vector getPropertiesOfSameVariable(
const expression_vector& expressions, const std::string& variableName) {
expression_vector result;
for (auto& expression : expressions) {
auto propertyExpression = (PropertyExpression*)expression.get();
if (propertyExpression->getVariableName() != variableName) {
continue;
}
result.push_back(expression);
}
return result;
}

static expression_vector rewriteExpressions(
const expression_vector& expressions, const expression_vector& properties) {
expression_set distinctResult;
for (auto& expression : expressions) {
if (expression->expressionType != common::PROPERTY) {
distinctResult.insert(expression);
continue;
}
auto propertyExpression = (PropertyExpression*)expression.get();
if (!propertyExpression->isInternalID()) {
distinctResult.insert(expression);
continue;
}
// Expression is internal ID. Perform rewrite as all properties with the same variable.
auto variableName = propertyExpression->getVariableName();
for (auto& property : getPropertiesOfSameVariable(properties, variableName)) {
distinctResult.insert(property);
}
}
return expression_vector{distinctResult.begin(), distinctResult.end()};
}

void WithClauseProjectionRewriter::visitSingleQuery(const NormalizedSingleQuery& singleQuery) {
auto propertyCollector = PropertyCollector();
propertyCollector.visitSingleQuery(singleQuery);
auto properties = propertyCollector.getProperties();
for (auto i = 0; i < singleQuery.getNumQueryParts() - 1; ++i) {
auto queryPart = singleQuery.getQueryPart(i);
auto projectionBody = queryPart->getProjectionBody();
auto newProjectionExpressions =
rewriteExpressions(projectionBody->getProjectionExpressions(), properties);
projectionBody->setProjectionExpressions(std::move(newProjectionExpressions));
auto newGroupByExpressions =
rewriteExpressions(projectionBody->getGroupByExpressions(), properties);
projectionBody->setGroupByExpressions(std::move(newGroupByExpressions));
}
}

} // namespace binder
} // namespace kuzu
12 changes: 12 additions & 0 deletions src/include/binder/bound_statement_rewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include "bound_statement.h"

namespace kuzu {
namespace binder {

class BoundStatementRewriter {
public:
static void rewrite(BoundStatement& boundStatement);
};

} // namespace binder
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class BoundProjectionBody {

inline bool getIsDistinct() const { return isDistinct; }

inline void setProjectionExpressions(expression_vector expressions) {
projectionExpressions = std::move(expressions);
}
inline expression_vector getProjectionExpressions() const { return projectionExpressions; }

inline void setGroupByExpressions(expression_vector expressions) {
Expand Down
19 changes: 19 additions & 0 deletions src/include/binder/rewriter/with_clause_projection_rewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "binder/bound_statement_visitor.h"

namespace kuzu {
namespace binder {

// WithClauseProjectionRewriter first analyze the properties need to be scanned for each query. And
// then rewrite node/rel expression in WITH clause as their properties. So We avoid eagerly evaluate
// node/rel in WITH clause projection. E.g.
// MATCH (a) WITH a MATCH (a)->(b);
// will be rewritten as
// MATCH (a) WITH a._id MATCH (a)->(b);
// See bind_projection_clause.cpp for more details.
class WithClauseProjectionRewriter : public BoundStatementVisitor {
public:
void visitSingleQuery(const NormalizedSingleQuery& singleQuery) override;
};

} // namespace binder
} // namespace kuzu
12 changes: 12 additions & 0 deletions test/optimizer/optimizer_test.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "graph_test/graph_test.h"
#include "planner/logical_plan/extend/logical_recursive_extend.h"
#include "planner/logical_plan/logical_plan_util.h"
#include "planner/logical_plan/scan/logical_scan_node_property.h"

namespace kuzu {
namespace testing {
Expand All @@ -20,6 +21,17 @@ class OptimizerTest : public DBTest {
}
};

TEST_F(OptimizerTest, WithClauseProjectionListRewriterTest) {
auto op = getRoot("MATCH (a:person) WITH a RETURN a.gender;");
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::SCAN_NODE_PROPERTY);
auto scanNodeProperty = (planner::LogicalScanNodeProperty*)op.get();
ASSERT_EQ(scanNodeProperty->getProperties().size(), 1);
}

TEST_F(OptimizerTest, FilterPushDownTest) {
auto op = getRoot("MATCH (a:person) WHERE a.ID < 0 AND a.fName='Alice' RETURN a.gender;");
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);
Expand Down

0 comments on commit 9526e5e

Please sign in to comment.