Skip to content

Commit

Permalink
Merge pull request #1858 from kuzudb/enable-optional-match-cross-product
Browse files Browse the repository at this point in the history
Enable optional match cross product
  • Loading branch information
andyfengHKU committed Jul 25, 2023
2 parents 8516e87 + 14c0cee commit c1896a3
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 39 deletions.
6 changes: 2 additions & 4 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ class JoinOrderEnumerator {
LogicalPlan& probePlan, LogicalPlan& buildPlan) {
planJoin(joinNodeIDs, common::JoinType::LEFT, nullptr /* mark */, probePlan, buildPlan);
}
inline void planCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan) {
appendCrossProduct(probePlan, buildPlan);
}
void appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan);
void appendCrossProduct(
common::AccumulateType accumulateType, LogicalPlan& probePlan, LogicalPlan& buildPlan);

private:
std::vector<std::unique_ptr<LogicalPlan>> planCrossProduct(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
#pragma once

#include "base_logical_operator.h"
#include "common/join_type.h"
#include "sink_util.h"

namespace kuzu {
namespace planner {

class LogicalCrossProduct : public LogicalOperator {
public:
LogicalCrossProduct(std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalOperator{LogicalOperatorType::CROSS_PRODUCT, std::move(probeSideChild),
std::move(buildSideChild)} {}
LogicalCrossProduct(common::AccumulateType accumulateType,
std::shared_ptr<LogicalOperator> probeChild, std::shared_ptr<LogicalOperator> buildChild)
: LogicalOperator{LogicalOperatorType::CROSS_PRODUCT, std::move(probeChild),
std::move(buildChild)},
accumulateType{accumulateType} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override { return std::string(); }

inline Schema* getBuildSideSchema() const { return children[1]->getSchema(); }
inline common::AccumulateType getAccumulateType() const { return accumulateType; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalCrossProduct>(children[0]->copy(), children[1]->copy());
return make_unique<LogicalCrossProduct>(
accumulateType, children[0]->copy(), children[1]->copy());
}

private:
common::AccumulateType accumulateType;
};

} // namespace planner
Expand Down
2 changes: 0 additions & 2 deletions src/include/planner/query_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,12 @@ class QueryPlanner {
LogicalPlan& plan);
void appendCreateRel(
const std::vector<std::unique_ptr<binder::BoundCreateRel>>& createRels, LogicalPlan& plan);

void appendSetNodeProperty(
const std::vector<std::unique_ptr<binder::BoundSetNodeProperty>>& setNodeProperties,
LogicalPlan& plan);
void appendSetRelProperty(
const std::vector<std::unique_ptr<binder::BoundSetRelProperty>>& setRelProperties,
LogicalPlan& plan);

void appendDeleteNode(const std::vector<std::unique_ptr<binder::BoundDeleteNode>>& deleteNodes,
LogicalPlan& plan);
void appendDeleteRel(
Expand Down
15 changes: 1 addition & 14 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include "binder/expression/expression_visitor.h"
#include "planner/join_order/cost_model.h"
#include "planner/logical_plan/logical_operator/logical_cross_product.h"
#include "planner/logical_plan/logical_operator/logical_scan_node.h"
#include "planner/query_planner.h"

Expand Down Expand Up @@ -59,7 +58,7 @@ std::vector<std::unique_ptr<LogicalPlan>> JoinOrderEnumerator::planCrossProduct(
for (auto& rightPlan : rightPlans) {
auto leftPlanCopy = leftPlan->shallowCopy();
auto rightPlanCopy = rightPlan->shallowCopy();
appendCrossProduct(*leftPlanCopy, *rightPlanCopy);
appendCrossProduct(common::AccumulateType::REGULAR, *leftPlanCopy, *rightPlanCopy);
result.push_back(std::move(leftPlanCopy));
}
}
Expand Down Expand Up @@ -468,18 +467,6 @@ void JoinOrderEnumerator::planJoin(const expression_vector& joinNodeIDs, JoinTyp
}
}

void JoinOrderEnumerator::appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan) {
auto crossProduct =
make_shared<LogicalCrossProduct>(probePlan.getLastOperator(), buildPlan.getLastOperator());
crossProduct->computeFactorizedSchema();
// update cost
probePlan.setCost(probePlan.getCardinality() + buildPlan.getCardinality());
// update cardinality
probePlan.setCardinality(
queryPlanner->cardinalityEstimator->estimateCrossProduct(probePlan, buildPlan));
probePlan.setLastOperator(std::move(crossProduct));
}

expression_vector JoinOrderEnumerator::getNewlyMatchedExpressions(
const std::vector<SubqueryGraph>& prevSubgraphs, const SubqueryGraph& newSubgraph,
const expression_vector& expressions) {
Expand Down
1 change: 1 addition & 0 deletions src/planner/plan/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_library(kuzu_planner_plan_operator
OBJECT
append_accumulate.cpp
append_create.cpp
append_cross_product.cpp
append_delete.cpp
append_set.cpp
plan_update.cpp)
Expand Down
22 changes: 22 additions & 0 deletions src/planner/plan/append_cross_product.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "planner/join_order_enumerator.h"
#include "planner/logical_plan/logical_operator/logical_cross_product.h"
#include "planner/query_planner.h"

namespace kuzu {
namespace planner {

void JoinOrderEnumerator::appendCrossProduct(
common::AccumulateType accumulateType, LogicalPlan& probePlan, LogicalPlan& buildPlan) {
auto crossProduct = make_shared<LogicalCrossProduct>(
accumulateType, probePlan.getLastOperator(), buildPlan.getLastOperator());
crossProduct->computeFactorizedSchema();
// update cost
probePlan.setCost(probePlan.getCardinality() + buildPlan.getCardinality());
// update cardinality
probePlan.setCardinality(
queryPlanner->cardinalityEstimator->estimateCrossProduct(probePlan, buildPlan));
probePlan.setLastOperator(std::move(crossProduct));
}

} // namespace planner
} // namespace kuzu
18 changes: 12 additions & 6 deletions src/planner/query_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ void QueryPlanner::planInQueryCall(
if (!plan->isEmpty()) {
auto inQueryCallPlan = std::make_shared<LogicalPlan>();
appendInQueryCall(*boundInQueryCall, *inQueryCallPlan);
joinOrderEnumerator.appendCrossProduct(*plan, *inQueryCallPlan);
joinOrderEnumerator.appendCrossProduct(
common::AccumulateType::REGULAR, *plan, *inQueryCallPlan);
} else {
appendInQueryCall(*boundInQueryCall, *plan);
}
Expand Down Expand Up @@ -187,6 +188,7 @@ static expression_vector getJoinNodeIDs(expression_vector& expressions) {
void QueryPlanner::planOptionalMatch(const QueryGraphCollection& queryGraphCollection,
const expression_vector& predicates, LogicalPlan& leftPlan) {
if (leftPlan.isEmpty()) {
// Optional match is the first clause. No left plan to join.
auto plan = planJoins(queryGraphCollection, predicates);
leftPlan.setLastOperator(plan->getLastOperator());
appendAccumulate(AccumulateType::OPTIONAL_, leftPlan);
Expand All @@ -195,12 +197,16 @@ void QueryPlanner::planOptionalMatch(const QueryGraphCollection& queryGraphColle
auto correlatedExpressions =
getCorrelatedExpressions(queryGraphCollection, predicates, leftPlan.getSchema());
if (correlatedExpressions.empty()) {
throw NotImplementedException("Optional match is disconnected with previous MATCH clause.");
// No join condition, apply cross product.
auto rightPlan = planJoins(queryGraphCollection, predicates);
joinOrderEnumerator.appendCrossProduct(
common::AccumulateType::OPTIONAL_, leftPlan, *rightPlan);
return;
}
if (ExpressionUtil::allExpressionsHaveDataType(
correlatedExpressions, LogicalTypeID::INTERNAL_ID)) {
// All join conditions are internal IDs, unnest as left hash join.
auto joinNodeIDs = getJoinNodeIDs(correlatedExpressions);
// When correlated variables are all NODE IDs, the subquery can be un-nested as left join.
// Join nodes are scanned twice in both outer and inner. However, we make sure inner table
// scan only scans node ID and does not scan from storage (i.e. no property scan).
auto rightPlan = planJoinsInNewContext(joinNodeIDs, queryGraphCollection, predicates);
Expand Down Expand Up @@ -230,7 +236,8 @@ void QueryPlanner::planRegularMatch(const QueryGraphCollection& queryGraphCollec
// planning an un-nest subquery.
auto rightPlan = planJoinsInNewContext(joinNodeIDs, queryGraphCollection, predicatesToPushDown);
if (joinNodeIDs.empty()) {
joinOrderEnumerator.planCrossProduct(leftPlan, *rightPlan);
joinOrderEnumerator.appendCrossProduct(
common::AccumulateType::REGULAR, leftPlan, *rightPlan);
} else {
joinOrderEnumerator.planInnerHashJoin(joinNodeIDs, leftPlan, *rightPlan);
}
Expand All @@ -241,9 +248,9 @@ void QueryPlanner::planRegularMatch(const QueryGraphCollection& queryGraphCollec

void QueryPlanner::planExistsSubquery(
std::shared_ptr<Expression> expression, LogicalPlan& outerPlan) {

assert(expression->expressionType == EXISTENTIAL_SUBQUERY);
auto subquery = static_pointer_cast<ExistentialSubqueryExpression>(expression);
auto predicates = subquery->getPredicatesSplitOnAnd();
auto correlatedExpressions = outerPlan.getSchema()->getSubExpressionsInScope(subquery);
if (correlatedExpressions.empty()) {
throw NotImplementedException("Subquery is disconnected with outer query.");
Expand All @@ -252,7 +259,6 @@ void QueryPlanner::planExistsSubquery(
correlatedExpressions, LogicalTypeID::INTERNAL_ID)) {
auto joinNodeIDs = getJoinNodeIDs(correlatedExpressions);
// Unnest as mark join. See planOptionalMatch for unnesting logic.
auto predicates = subquery->getPredicatesSplitOnAnd();
auto innerPlan =
planJoinsInNewContext(joinNodeIDs, *subquery->getQueryGraphCollection(), predicates);
joinOrderEnumerator.planMarkJoin(joinNodeIDs, expression, outerPlan, *innerPlan);
Expand Down
12 changes: 6 additions & 6 deletions src/processor/mapper/map_cross_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ namespace processor {
std::unique_ptr<PhysicalOperator> PlanMapper::mapCrossProduct(LogicalOperator* logicalOperator) {
auto logicalCrossProduct = (LogicalCrossProduct*)logicalOperator;
auto outSchema = logicalCrossProduct->getSchema();
auto buildChild = logicalCrossProduct->getChild(1);
// map build side
auto buildSideSchema = logicalCrossProduct->getBuildSideSchema();
auto buildSidePrevOperator = mapOperator(logicalCrossProduct->getChild(1).get());
auto resultCollector = createResultCollector(common::AccumulateType::REGULAR,
buildSideSchema->getExpressionsInScope(), buildSideSchema,
std::move(buildSidePrevOperator));
auto buildSchema = buildChild->getSchema();
auto buildSidePrevOperator = mapOperator(buildChild.get());
auto expressions = buildSchema->getExpressionsInScope();
auto resultCollector = createResultCollector(logicalCrossProduct->getAccumulateType(),
expressions, buildSchema, std::move(buildSidePrevOperator));
// map probe side
auto probeSidePrevOperator = mapOperator(logicalCrossProduct->getChild(0).get());
std::vector<DataPos> outVecPos;
std::vector<uint32_t> colIndicesToScan;
auto expressions = buildSideSchema->getExpressionsInScope();
for (auto i = 0u; i < expressions.size(); ++i) {
auto expression = expressions[i];
outVecPos.emplace_back(outSchema->getExpressionPos(*expression));
Expand Down
11 changes: 10 additions & 1 deletion test/test_files/tinysnb/optional_match/optional_match.test
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
#---- 1
#26

-LOG CrossProductOptionalMatch1
-STATEMENT MATCH (a:person) WHERE a.fName = 'Alice' OPTIONAL MATCH (b:person) WHERE b.fName = 'A' RETURN a.ID, b.ID
---- 1
0|

-LOG CrossProductOptionalMatch2
-STATEMENT OPTIONAL MATCH (a:person) WHERE a.fName = 'a' MATCH (b:person) WHERE b.fName = 'Alice' RETURN a.ID, b.ID
---- 1
|0

-LOG InitOptionalMatch
-STATEMENT OPTIONAL MATCH (a:person) WHERE a.ID<0 RETURN COUNT(*)
---- 1
Expand All @@ -33,7 +43,6 @@
-STATEMENT OPTIONAL MATCH (a:person) WHERE a.fName = 'a' RETURN a.age
---- 1


-LOG InitOptionalMatch2
-STATEMENT OPTIONAL MATCH (a:person) WHERE a.ID < 6 RETURN a.fName
---- 4
Expand Down

0 comments on commit c1896a3

Please sign in to comment.