From fae01341790649fc412be155a4335fd27db92467 Mon Sep 17 00:00:00 2001 From: xiyang Date: Thu, 24 Aug 2023 12:09:16 -0400 Subject: [PATCH] Add with clause projection list rewriter --- src/binder/CMakeLists.txt | 2 + src/binder/bind/bind_projection_clause.cpp | 47 ++++++++------ src/binder/binder.cpp | 56 +++++++++-------- src/binder/bound_statement_rewriter.cpp | 14 +++++ src/binder/rewriter/CMakeLists.txt | 8 +++ .../with_clause_projection_rewriter.cpp | 62 +++++++++++++++++++ src/include/binder/bound_statement_rewriter.h | 12 ++++ .../bound_projection_body.h | 3 + .../with_clause_projection_rewriter.h | 19 ++++++ test/optimizer/optimizer_test.cpp | 12 ++++ 10 files changed, 191 insertions(+), 44 deletions(-) create mode 100644 src/binder/bound_statement_rewriter.cpp create mode 100644 src/binder/rewriter/CMakeLists.txt create mode 100644 src/binder/rewriter/with_clause_projection_rewriter.cpp create mode 100644 src/include/binder/bound_statement_rewriter.h create mode 100644 src/include/binder/rewriter/with_clause_projection_rewriter.h diff --git a/src/binder/CMakeLists.txt b/src/binder/CMakeLists.txt index 9861c29e82..5cc5bd577e 100644 --- a/src/binder/CMakeLists.txt +++ b/src/binder/CMakeLists.txt @@ -2,12 +2,14 @@ add_subdirectory(bind) add_subdirectory(bind_expression) add_subdirectory(expression) add_subdirectory(query) +add_subdirectory(rewriter) add_subdirectory(visitor) add_library(kuzu_binder OBJECT binder.cpp bound_statement_result.cpp + bound_statement_rewriter.cpp bound_statement_visitor.cpp expression_binder.cpp expression_visitor.cpp) diff --git a/src/binder/bind/bind_projection_clause.cpp b/src/binder/bind/bind_projection_clause.cpp index af2ebe933e..65d92c50ff 100644 --- a/src/binder/bind/bind_projection_clause.cpp +++ b/src/binder/bind/bind_projection_clause.cpp @@ -11,33 +11,44 @@ using namespace kuzu::parser; namespace kuzu { namespace binder { -std::unique_ptr Binder::bindWithClause(const WithClause& withClause) { - auto projectionBody = withClause.getProjectionBody(); - auto projectionExpressions = - bindProjectionExpressions(projectionBody->getProjectionExpressions()); - validateProjectionColumnsInWithClauseAreAliased(projectionExpressions); - expression_vector newProjectionExpressions; - for (auto& expression : projectionExpressions) { +// WITH clause is like SQL CTE. So the projection list of WITH clause should be explicitly +// evaluated. This, however, creates problem in the following case +// MATCH (a) WITH a RETURN a.age; +// Although only a.age is needed for further processing. The CTE "MATCH (a) WITH a" require us to +// fully materialize all columns of "a". Note that we cannot rely on projection push down to +// optimize this because projection pushdown assumes all columns in WITH/RETURN are needed. +// Our solution is: +// First rewrite node and rel as their INTERNAL ID property in WITH clause. So +// MATCH (a) WITH a._id RETURN a.age; +// And then apply WithClauseProjectionRewriter after binding to rewrite as +// MATCH (a) WITH a._id, a.age RETURN a.age +static expression_vector rewriteProjectionInWithClause(const expression_vector& expressions) { + expression_vector result; + for (auto& expression : expressions) { if (ExpressionUtil::isNodeVariable(*expression)) { auto node = (NodeExpression*)expression.get(); - newProjectionExpressions.push_back(node->getInternalIDProperty()); - for (auto& property : node->getPropertyExpressions()) { - newProjectionExpressions.push_back(property->copy()); - } + result.push_back(node->getInternalIDProperty()); } else if (ExpressionUtil::isRelVariable(*expression)) { auto rel = (RelExpression*)expression.get(); - for (auto& property : rel->getPropertyExpressions()) { - newProjectionExpressions.push_back(property->copy()); - } + result.push_back(rel->getInternalIDProperty()); } else if (ExpressionUtil::isRecursiveRelVariable(*expression)) { auto rel = (RelExpression*)expression.get(); - newProjectionExpressions.push_back(expression); - newProjectionExpressions.push_back(rel->getLengthExpression()); + result.push_back(expression); + result.push_back(rel->getLengthExpression()); } else { - newProjectionExpressions.push_back(expression); + result.push_back(expression); } } - auto boundProjectionBody = bindProjectionBody(*projectionBody, newProjectionExpressions); + return result; +} + +std::unique_ptr Binder::bindWithClause(const WithClause& withClause) { + auto projectionBody = withClause.getProjectionBody(); + auto projectionExpressions = + bindProjectionExpressions(projectionBody->getProjectionExpressions()); + validateProjectionColumnsInWithClauseAreAliased(projectionExpressions); + auto boundProjectionBody = + bindProjectionBody(*projectionBody, rewriteProjectionInWithClause(projectionExpressions)); validateOrderByFollowedBySkipOrLimitInWithClause(*boundProjectionBody); scope->clear(); addExpressionsToScope(projectionExpressions); diff --git a/src/binder/binder.cpp b/src/binder/binder.cpp index ed1882a027..b3fade8380 100644 --- a/src/binder/binder.cpp +++ b/src/binder/binder.cpp @@ -1,5 +1,6 @@ #include "binder/binder.h" +#include "binder/bound_statement_rewriter.h" #include "binder/expression/variable_expression.h" #include "common/string_utils.h" @@ -11,49 +12,52 @@ namespace kuzu { namespace binder { std::unique_ptr Binder::bind(const Statement& statement) { + std::unique_ptr boundStatement; switch (statement.getStatementType()) { case StatementType::CREATE_NODE_TABLE: { - return bindCreateNodeTableClause(statement); - } + boundStatement = bindCreateNodeTableClause(statement); + } break; case StatementType::CREATE_REL_TABLE: { - return bindCreateRelTableClause(statement); - } + boundStatement = bindCreateRelTableClause(statement); + } break; case StatementType::COPY_FROM: { - return bindCopyFromClause(statement); - } + boundStatement = bindCopyFromClause(statement); + } break; case StatementType::COPY_TO: { - return bindCopyToClause(statement); - } + boundStatement = bindCopyToClause(statement); + } break; case StatementType::DROP_TABLE: { - return bindDropTableClause(statement); - } + boundStatement = bindDropTableClause(statement); + } break; case StatementType::RENAME_TABLE: { - return bindRenameTableClause(statement); - } + boundStatement = bindRenameTableClause(statement); + } break; case StatementType::ADD_PROPERTY: { - return bindAddPropertyClause(statement); - } + boundStatement = bindAddPropertyClause(statement); + } break; case StatementType::DROP_PROPERTY: { - return bindDropPropertyClause(statement); - } + boundStatement = bindDropPropertyClause(statement); + } break; case StatementType::RENAME_PROPERTY: { - return bindRenamePropertyClause(statement); - } + boundStatement = bindRenamePropertyClause(statement); + } break; case StatementType::QUERY: { - return bindQuery((const RegularQuery&)statement); - } + boundStatement = bindQuery((const RegularQuery&)statement); + } break; case StatementType::STANDALONE_CALL: { - return bindStandaloneCall(statement); - } + boundStatement = bindStandaloneCall(statement); + } break; case StatementType::EXPLAIN: { - return bindExplain(statement); - } + boundStatement = bindExplain(statement); + } break; case StatementType::CREATE_MACRO: { - return bindCreateMacro(statement); - } + boundStatement = bindCreateMacro(statement); + } break; default: throw NotImplementedException("Binder::bind"); } + BoundStatementRewriter::rewrite(*boundStatement); + return boundStatement; } std::shared_ptr Binder::bindWhereExpression(const ParsedExpression& parsedExpression) { diff --git a/src/binder/bound_statement_rewriter.cpp b/src/binder/bound_statement_rewriter.cpp new file mode 100644 index 0000000000..ec602f8ecc --- /dev/null +++ b/src/binder/bound_statement_rewriter.cpp @@ -0,0 +1,14 @@ +#include "binder/bound_statement_rewriter.h" + +#include "binder/rewriter/with_clause_projection_rewriter.h" + +namespace kuzu { +namespace binder { + +void BoundStatementRewriter::rewrite(BoundStatement& boundStatement) { + auto withClauseProjectionRewriter = WithClauseProjectionRewriter(); + withClauseProjectionRewriter.visit(boundStatement); +} + +} // namespace binder +} // namespace kuzu diff --git a/src/binder/rewriter/CMakeLists.txt b/src/binder/rewriter/CMakeLists.txt new file mode 100644 index 0000000000..382ac8901f --- /dev/null +++ b/src/binder/rewriter/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library( + kuzu_binder_rewriter + OBJECT + with_clause_projection_rewriter.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/src/binder/rewriter/with_clause_projection_rewriter.cpp b/src/binder/rewriter/with_clause_projection_rewriter.cpp new file mode 100644 index 0000000000..d90261c939 --- /dev/null +++ b/src/binder/rewriter/with_clause_projection_rewriter.cpp @@ -0,0 +1,62 @@ +#include "binder/rewriter/with_clause_projection_rewriter.h" + +#include "binder/visitor/property_collector.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace binder { + +static expression_vector getPropertiesOfSameVariable( + const expression_vector& expressions, const std::string& variableName) { + expression_vector result; + for (auto& expression : expressions) { + auto propertyExpression = (PropertyExpression*)expression.get(); + if (propertyExpression->getVariableName() != variableName) { + continue; + } + result.push_back(expression); + } + return result; +} + +static expression_vector rewriteExpressions( + const expression_vector& expressions, const expression_vector& properties) { + expression_set distinctResult; + for (auto& expression : expressions) { + if (expression->expressionType != common::PROPERTY) { + distinctResult.insert(expression); + continue; + } + auto propertyExpression = (PropertyExpression*)expression.get(); + if (!propertyExpression->isInternalID()) { + distinctResult.insert(expression); + continue; + } + // Expression is internal ID. Perform rewrite as all properties with the same variable. + auto variableName = propertyExpression->getVariableName(); + for (auto& property : getPropertiesOfSameVariable(properties, variableName)) { + distinctResult.insert(property); + } + } + return expression_vector{distinctResult.begin(), distinctResult.end()}; +} + +void WithClauseProjectionRewriter::visitSingleQuery(const NormalizedSingleQuery& singleQuery) { + auto propertyCollector = PropertyCollector(); + propertyCollector.visitSingleQuery(singleQuery); + auto properties = propertyCollector.getProperties(); + for (auto i = 0; i < singleQuery.getNumQueryParts() - 1; ++i) { + auto queryPart = singleQuery.getQueryPart(i); + auto projectionBody = queryPart->getProjectionBody(); + auto newProjectionExpressions = + rewriteExpressions(projectionBody->getProjectionExpressions(), properties); + projectionBody->setProjectionExpressions(std::move(newProjectionExpressions)); + auto newGroupByExpressions = + rewriteExpressions(projectionBody->getGroupByExpressions(), properties); + projectionBody->setGroupByExpressions(std::move(newGroupByExpressions)); + } +} + +} // namespace binder +} // namespace kuzu diff --git a/src/include/binder/bound_statement_rewriter.h b/src/include/binder/bound_statement_rewriter.h new file mode 100644 index 0000000000..8446275fc3 --- /dev/null +++ b/src/include/binder/bound_statement_rewriter.h @@ -0,0 +1,12 @@ +#include "bound_statement.h" + +namespace kuzu { +namespace binder { + +class BoundStatementRewriter { +public: + static void rewrite(BoundStatement& boundStatement); +}; + +} // namespace binder +} // namespace kuzu diff --git a/src/include/binder/query/return_with_clause/bound_projection_body.h b/src/include/binder/query/return_with_clause/bound_projection_body.h index cbab98bf78..a56b041be4 100644 --- a/src/include/binder/query/return_with_clause/bound_projection_body.h +++ b/src/include/binder/query/return_with_clause/bound_projection_body.h @@ -22,6 +22,9 @@ class BoundProjectionBody { inline bool getIsDistinct() const { return isDistinct; } + inline void setProjectionExpressions(expression_vector expressions) { + projectionExpressions = std::move(expressions); + } inline expression_vector getProjectionExpressions() const { return projectionExpressions; } inline void setGroupByExpressions(expression_vector expressions) { diff --git a/src/include/binder/rewriter/with_clause_projection_rewriter.h b/src/include/binder/rewriter/with_clause_projection_rewriter.h new file mode 100644 index 0000000000..0b03ed76dd --- /dev/null +++ b/src/include/binder/rewriter/with_clause_projection_rewriter.h @@ -0,0 +1,19 @@ +#include "binder/bound_statement_visitor.h" + +namespace kuzu { +namespace binder { + +// WithClauseProjectionRewriter first analyze the properties need to be scanned for each query. And +// then rewrite node/rel expression in WITH clause as their properties. So We avoid eagerly evaluate +// node/rel in WITH clause projection. E.g. +// MATCH (a) WITH a MATCH (a)->(b); +// will be rewritten as +// MATCH (a) WITH a._id MATCH (a)->(b); +// See bind_projection_clause.cpp for more details. +class WithClauseProjectionRewriter : public BoundStatementVisitor { +public: + void visitSingleQuery(const NormalizedSingleQuery& singleQuery) override; +}; + +} // namespace binder +} // namespace kuzu diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index 46435741f1..6bfe2743f7 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -1,6 +1,7 @@ #include "graph_test/graph_test.h" #include "planner/logical_plan/extend/logical_recursive_extend.h" #include "planner/logical_plan/logical_plan_util.h" +#include "planner/logical_plan/scan/logical_scan_node_property.h" namespace kuzu { namespace testing { @@ -20,6 +21,17 @@ class OptimizerTest : public DBTest { } }; +TEST_F(OptimizerTest, WithClauseProjectionListRewriterTest) { + auto op = getRoot("MATCH (a:person) WITH a RETURN a.gender;"); + ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION); + op = op->getChild(0); + ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION); + op = op->getChild(0); + ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::SCAN_NODE_PROPERTY); + auto scanNodeProperty = (planner::LogicalScanNodeProperty*)op.get(); + ASSERT_EQ(scanNodeProperty->getProperties().size(), 1); +} + TEST_F(OptimizerTest, FilterPushDownTest) { auto op = getRoot("MATCH (a:person) WHERE a.ID < 0 AND a.fName='Alice' RETURN a.gender;"); ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);