Skip to content

Commit

Permalink
Fix multi query with filter result error (#3323)
Browse files Browse the repository at this point in the history
* fix multi query with filter result error

* Simplify design

---------

Co-authored-by: xiyang <x74feng@uwaterloo.ca>
  • Loading branch information
ted-wq-x and andyfengHKU committed Apr 20, 2024
1 parent 4917262 commit 58b8222
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 15 deletions.
10 changes: 6 additions & 4 deletions src/include/optimizer/filter_push_down_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class FilterPushDownOptimizer {
// And heuristically reorder equality predicates first in the filter.
std::shared_ptr<planner::LogicalOperator> finishPushDown(
std::shared_ptr<planner::LogicalOperator> op);
std::shared_ptr<planner::LogicalOperator> appendFilters(
const binder::expression_vector& predicates,
std::shared_ptr<planner::LogicalOperator> child);

std::shared_ptr<planner::LogicalOperator> appendScanNodeProperty(
std::shared_ptr<binder::Expression> nodeID, std::vector<common::table_id_t> nodeTableIDs,
Expand All @@ -55,17 +58,16 @@ class FilterPushDownOptimizer {
binder::expression_vector equalityPredicates;
binder::expression_vector nonEqualityPredicates;

inline bool isEmpty() const {
return equalityPredicates.empty() && nonEqualityPredicates.empty();
}
inline void clear() {
bool isEmpty() const { return equalityPredicates.empty() && nonEqualityPredicates.empty(); }
void clear() {
equalityPredicates.clear();
nonEqualityPredicates.clear();
}

void addPredicate(std::shared_ptr<binder::Expression> predicate);
std::shared_ptr<binder::Expression> popNodePKEqualityComparison(
const binder::Expression& nodeID);
binder::expression_vector getAllPredicates();
};

private:
Expand Down
46 changes: 35 additions & 11 deletions src/optimizer/filter_push_down_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ std::shared_ptr<LogicalOperator> FilterPushDownOptimizer::visitCrossProductRepla
}
auto probeSchema = op->getChild(0)->getSchema();
auto buildSchema = op->getChild(1)->getSchema();
expression_vector predicates;
std::vector<join_condition_t> joinConditions;
for (auto& predicate : predicateSet->equalityPredicates) {
auto left = predicate->getChild(0);
Expand All @@ -81,16 +82,25 @@ std::shared_ptr<LogicalOperator> FilterPushDownOptimizer::visitCrossProductRepla
} else if (probeSchema->isExpressionInScope(*right) &&
buildSchema->isExpressionInScope(*left)) {
joinConditions.emplace_back(right, left);
} else {
// Collect predicates that cannot be rewritten as join conditions.
predicates.push_back(predicate);
}
}
if (joinConditions.empty()) {
if (joinConditions.empty()) { // Nothing to push down. Terminate.
return finishPushDown(op);
}
auto hashJoin = std::make_shared<LogicalHashJoin>(joinConditions, JoinType::INNER,
nullptr /* mark */, op->getChild(0), op->getChild(1));
hashJoin->setSIP(planner::SidewaysInfoPassing::PROHIBIT);
hashJoin->setSIP(SidewaysInfoPassing::PROHIBIT);
hashJoin->computeFlatSchema();
return hashJoin;
// Apply remaining predicates.
predicates.insert(predicates.end(), predicateSet->nonEqualityPredicates.begin(),
predicateSet->nonEqualityPredicates.end());
if (predicates.empty()) {
return hashJoin;
}
return appendFilters(predicates, hashJoin);
}

std::shared_ptr<planner::LogicalOperator> FilterPushDownOptimizer::visitScanNodePropertyReplace(
Expand Down Expand Up @@ -165,15 +175,10 @@ std::shared_ptr<planner::LogicalOperator> FilterPushDownOptimizer::finishPushDow
if (predicateSet->isEmpty()) {
return op;
}
auto currentRoot = op;
for (auto& predicate : predicateSet->equalityPredicates) {
currentRoot = appendFilter(predicate, currentRoot);
}
for (auto& predicate : predicateSet->nonEqualityPredicates) {
currentRoot = appendFilter(predicate, currentRoot);
}
auto predicates = predicateSet->getAllPredicates();
auto root = appendFilters(predicates, op);
predicateSet->clear();
return currentRoot;
return root;
}

std::shared_ptr<planner::LogicalOperator> FilterPushDownOptimizer::appendScanNodeProperty(
Expand All @@ -188,6 +193,18 @@ std::shared_ptr<planner::LogicalOperator> FilterPushDownOptimizer::appendScanNod
return scanNodeProperty;
}

std::shared_ptr<LogicalOperator> FilterPushDownOptimizer::appendFilters(
const expression_vector& predicates, std::shared_ptr<LogicalOperator> child) {
if (predicates.empty()) {
return child;
}
auto root = child;
for (auto& p : predicates) {
root = appendFilter(p, root);
}
return root;
}

std::shared_ptr<planner::LogicalOperator> FilterPushDownOptimizer::appendFilter(
std::shared_ptr<binder::Expression> predicate,
std::shared_ptr<planner::LogicalOperator> child) {
Expand Down Expand Up @@ -246,5 +263,12 @@ FilterPushDownOptimizer::PredicateSet::popNodePKEqualityComparison(
return nullptr;
}

binder::expression_vector FilterPushDownOptimizer::PredicateSet::getAllPredicates() {
expression_vector result;
result.insert(result.end(), equalityPredicates.begin(), equalityPredicates.end());
result.insert(result.end(), nonEqualityPredicates.begin(), nonEqualityPredicates.end());
return result;
}

} // namespace optimizer
} // namespace kuzu
23 changes: 23 additions & 0 deletions test/optimizer/optimizer_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "graph_test/graph_test.h"
#include "planner/operator/extend/logical_recursive_extend.h"
#include "planner/operator/logical_filter.h"
#include "planner/operator/logical_plan_util.h"
#include "planner/operator/scan/logical_scan_node_property.h"
#include "test_runner/test_runner.h"
Expand All @@ -22,6 +23,28 @@ class OptimizerTest : public DBTest {
}
};

TEST_F(OptimizerTest, CrossJoinWithFilterWithoutPushDownTest) {
auto op = getRoot("MATCH (a:person) MATCH (b:person) where a.fName=b.fName and a.gender <> b.gender RETURN a.gender;");
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::FILTER);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::HASH_JOIN);
}

TEST_F(OptimizerTest, CrossJoinWithFilterPushDownTest) {
auto op = getRoot("MATCH (a:person) , (b:person) where a.fName=b.fName and a.fName is null RETURN a.gender;");
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::HASH_JOIN);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::FLATTEN);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::SCAN_NODE_PROPERTY);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::FILTER);
}

TEST_F(OptimizerTest, WithClauseProjectionListRewriterTest) {
auto op = getRoot("MATCH (a:person) WITH a RETURN a.gender;");
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);
Expand Down
18 changes: 18 additions & 0 deletions test/test_files/tinysnb/projection/multi_query_part.test
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,21 @@ Elizabeth|Elizabeth|20
Farooq|Farooq|25
Greg|Greg|40
Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|83

-LOG CrossProductWithFilterPushDownTest
-STATEMENT MATCH (a:person) WITH a.age AS s, a.fName as n MATCH (b:person) WHERE s=b.age and n='Dan' RETURN n,b.fName,s
-ENUMERATE
---- 2
Dan|Dan|20
Dan|Elizabeth|20
-STATEMENT MATCH (a:person) WITH a.age AS s, a.fName as n MATCH (b:person) WHERE s=b.age and n=b.fName RETURN n,b.fName,s
-ENUMERATE
---- 8
Alice|Alice|35
Bob|Bob|30
Carol|Carol|45
Dan|Dan|20
Elizabeth|Elizabeth|20
Farooq|Farooq|25
Greg|Greg|40
Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|83

0 comments on commit 58b8222

Please sign in to comment.