Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable optional match cross product #1858

Merged
merged 1 commit into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ class JoinOrderEnumerator {
LogicalPlan& probePlan, LogicalPlan& buildPlan) {
planJoin(joinNodeIDs, common::JoinType::LEFT, nullptr /* mark */, probePlan, buildPlan);
}
inline void planCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan) {
appendCrossProduct(probePlan, buildPlan);
}
void appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan);
void appendCrossProduct(
common::AccumulateType accumulateType, LogicalPlan& probePlan, LogicalPlan& buildPlan);

private:
std::vector<std::unique_ptr<LogicalPlan>> planCrossProduct(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
#pragma once

#include "base_logical_operator.h"
#include "common/join_type.h"
#include "sink_util.h"

namespace kuzu {
namespace planner {

class LogicalCrossProduct : public LogicalOperator {
public:
LogicalCrossProduct(std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalOperator{LogicalOperatorType::CROSS_PRODUCT, std::move(probeSideChild),
std::move(buildSideChild)} {}
LogicalCrossProduct(common::AccumulateType accumulateType,
std::shared_ptr<LogicalOperator> probeChild, std::shared_ptr<LogicalOperator> buildChild)
: LogicalOperator{LogicalOperatorType::CROSS_PRODUCT, std::move(probeChild),
std::move(buildChild)},
accumulateType{accumulateType} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override { return std::string(); }

inline Schema* getBuildSideSchema() const { return children[1]->getSchema(); }
inline common::AccumulateType getAccumulateType() const { return accumulateType; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalCrossProduct>(children[0]->copy(), children[1]->copy());
return make_unique<LogicalCrossProduct>(
accumulateType, children[0]->copy(), children[1]->copy());
}

private:
common::AccumulateType accumulateType;
};

} // namespace planner
Expand Down
2 changes: 0 additions & 2 deletions src/include/planner/query_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,12 @@ class QueryPlanner {
LogicalPlan& plan);
void appendCreateRel(
const std::vector<std::unique_ptr<binder::BoundCreateRel>>& createRels, LogicalPlan& plan);

void appendSetNodeProperty(
const std::vector<std::unique_ptr<binder::BoundSetNodeProperty>>& setNodeProperties,
LogicalPlan& plan);
void appendSetRelProperty(
const std::vector<std::unique_ptr<binder::BoundSetRelProperty>>& setRelProperties,
LogicalPlan& plan);

void appendDeleteNode(const std::vector<std::unique_ptr<binder::BoundDeleteNode>>& deleteNodes,
LogicalPlan& plan);
void appendDeleteRel(
Expand Down
15 changes: 1 addition & 14 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include "binder/expression/expression_visitor.h"
#include "planner/join_order/cost_model.h"
#include "planner/logical_plan/logical_operator/logical_cross_product.h"
#include "planner/logical_plan/logical_operator/logical_scan_node.h"
#include "planner/query_planner.h"

Expand Down Expand Up @@ -59,7 +58,7 @@ std::vector<std::unique_ptr<LogicalPlan>> JoinOrderEnumerator::planCrossProduct(
for (auto& rightPlan : rightPlans) {
auto leftPlanCopy = leftPlan->shallowCopy();
auto rightPlanCopy = rightPlan->shallowCopy();
appendCrossProduct(*leftPlanCopy, *rightPlanCopy);
appendCrossProduct(common::AccumulateType::REGULAR, *leftPlanCopy, *rightPlanCopy);
result.push_back(std::move(leftPlanCopy));
}
}
Expand Down Expand Up @@ -468,18 +467,6 @@ void JoinOrderEnumerator::planJoin(const expression_vector& joinNodeIDs, JoinTyp
}
}

void JoinOrderEnumerator::appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan) {
auto crossProduct =
make_shared<LogicalCrossProduct>(probePlan.getLastOperator(), buildPlan.getLastOperator());
crossProduct->computeFactorizedSchema();
// update cost
probePlan.setCost(probePlan.getCardinality() + buildPlan.getCardinality());
// update cardinality
probePlan.setCardinality(
queryPlanner->cardinalityEstimator->estimateCrossProduct(probePlan, buildPlan));
probePlan.setLastOperator(std::move(crossProduct));
}

expression_vector JoinOrderEnumerator::getNewlyMatchedExpressions(
const std::vector<SubqueryGraph>& prevSubgraphs, const SubqueryGraph& newSubgraph,
const expression_vector& expressions) {
Expand Down
1 change: 1 addition & 0 deletions src/planner/plan/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_library(kuzu_planner_plan_operator
OBJECT
append_accumulate.cpp
append_create.cpp
append_cross_product.cpp
append_delete.cpp
append_set.cpp
plan_update.cpp)
Expand Down
22 changes: 22 additions & 0 deletions src/planner/plan/append_cross_product.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "planner/join_order_enumerator.h"
#include "planner/logical_plan/logical_operator/logical_cross_product.h"
#include "planner/query_planner.h"

namespace kuzu {
namespace planner {

void JoinOrderEnumerator::appendCrossProduct(
common::AccumulateType accumulateType, LogicalPlan& probePlan, LogicalPlan& buildPlan) {
auto crossProduct = make_shared<LogicalCrossProduct>(
accumulateType, probePlan.getLastOperator(), buildPlan.getLastOperator());
crossProduct->computeFactorizedSchema();
// update cost
probePlan.setCost(probePlan.getCardinality() + buildPlan.getCardinality());
// update cardinality
probePlan.setCardinality(
queryPlanner->cardinalityEstimator->estimateCrossProduct(probePlan, buildPlan));
probePlan.setLastOperator(std::move(crossProduct));
}

} // namespace planner
} // namespace kuzu
18 changes: 12 additions & 6 deletions src/planner/query_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ void QueryPlanner::planInQueryCall(
if (!plan->isEmpty()) {
auto inQueryCallPlan = std::make_shared<LogicalPlan>();
appendInQueryCall(*boundInQueryCall, *inQueryCallPlan);
joinOrderEnumerator.appendCrossProduct(*plan, *inQueryCallPlan);
joinOrderEnumerator.appendCrossProduct(
common::AccumulateType::REGULAR, *plan, *inQueryCallPlan);
} else {
appendInQueryCall(*boundInQueryCall, *plan);
}
Expand Down Expand Up @@ -187,6 +188,7 @@ static expression_vector getJoinNodeIDs(expression_vector& expressions) {
void QueryPlanner::planOptionalMatch(const QueryGraphCollection& queryGraphCollection,
const expression_vector& predicates, LogicalPlan& leftPlan) {
if (leftPlan.isEmpty()) {
// Optional match is the first clause. No left plan to join.
auto plan = planJoins(queryGraphCollection, predicates);
leftPlan.setLastOperator(plan->getLastOperator());
appendAccumulate(AccumulateType::OPTIONAL_, leftPlan);
Expand All @@ -195,12 +197,16 @@ void QueryPlanner::planOptionalMatch(const QueryGraphCollection& queryGraphColle
auto correlatedExpressions =
getCorrelatedExpressions(queryGraphCollection, predicates, leftPlan.getSchema());
if (correlatedExpressions.empty()) {
throw NotImplementedException("Optional match is disconnected with previous MATCH clause.");
// No join condition, apply cross product.
auto rightPlan = planJoins(queryGraphCollection, predicates);
joinOrderEnumerator.appendCrossProduct(
common::AccumulateType::OPTIONAL_, leftPlan, *rightPlan);
return;
}
if (ExpressionUtil::allExpressionsHaveDataType(
correlatedExpressions, LogicalTypeID::INTERNAL_ID)) {
// All join conditions are internal IDs, unnest as left hash join.
auto joinNodeIDs = getJoinNodeIDs(correlatedExpressions);
// When correlated variables are all NODE IDs, the subquery can be un-nested as left join.
// Join nodes are scanned twice in both outer and inner. However, we make sure inner table
// scan only scans node ID and does not scan from storage (i.e. no property scan).
auto rightPlan = planJoinsInNewContext(joinNodeIDs, queryGraphCollection, predicates);
Expand Down Expand Up @@ -230,7 +236,8 @@ void QueryPlanner::planRegularMatch(const QueryGraphCollection& queryGraphCollec
// planning an un-nest subquery.
auto rightPlan = planJoinsInNewContext(joinNodeIDs, queryGraphCollection, predicatesToPushDown);
if (joinNodeIDs.empty()) {
joinOrderEnumerator.planCrossProduct(leftPlan, *rightPlan);
joinOrderEnumerator.appendCrossProduct(
common::AccumulateType::REGULAR, leftPlan, *rightPlan);
} else {
joinOrderEnumerator.planInnerHashJoin(joinNodeIDs, leftPlan, *rightPlan);
}
Expand All @@ -241,9 +248,9 @@ void QueryPlanner::planRegularMatch(const QueryGraphCollection& queryGraphCollec

void QueryPlanner::planExistsSubquery(
std::shared_ptr<Expression> expression, LogicalPlan& outerPlan) {

assert(expression->expressionType == EXISTENTIAL_SUBQUERY);
auto subquery = static_pointer_cast<ExistentialSubqueryExpression>(expression);
auto predicates = subquery->getPredicatesSplitOnAnd();
auto correlatedExpressions = outerPlan.getSchema()->getSubExpressionsInScope(subquery);
if (correlatedExpressions.empty()) {
throw NotImplementedException("Subquery is disconnected with outer query.");
Expand All @@ -252,7 +259,6 @@ void QueryPlanner::planExistsSubquery(
correlatedExpressions, LogicalTypeID::INTERNAL_ID)) {
auto joinNodeIDs = getJoinNodeIDs(correlatedExpressions);
// Unnest as mark join. See planOptionalMatch for unnesting logic.
auto predicates = subquery->getPredicatesSplitOnAnd();
auto innerPlan =
planJoinsInNewContext(joinNodeIDs, *subquery->getQueryGraphCollection(), predicates);
joinOrderEnumerator.planMarkJoin(joinNodeIDs, expression, outerPlan, *innerPlan);
Expand Down
12 changes: 6 additions & 6 deletions src/processor/mapper/map_cross_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ namespace processor {
std::unique_ptr<PhysicalOperator> PlanMapper::mapCrossProduct(LogicalOperator* logicalOperator) {
auto logicalCrossProduct = (LogicalCrossProduct*)logicalOperator;
auto outSchema = logicalCrossProduct->getSchema();
auto buildChild = logicalCrossProduct->getChild(1);
// map build side
auto buildSideSchema = logicalCrossProduct->getBuildSideSchema();
auto buildSidePrevOperator = mapOperator(logicalCrossProduct->getChild(1).get());
auto resultCollector = createResultCollector(common::AccumulateType::REGULAR,
buildSideSchema->getExpressionsInScope(), buildSideSchema,
std::move(buildSidePrevOperator));
auto buildSchema = buildChild->getSchema();
auto buildSidePrevOperator = mapOperator(buildChild.get());
auto expressions = buildSchema->getExpressionsInScope();
auto resultCollector = createResultCollector(logicalCrossProduct->getAccumulateType(),
expressions, buildSchema, std::move(buildSidePrevOperator));
// map probe side
auto probeSidePrevOperator = mapOperator(logicalCrossProduct->getChild(0).get());
std::vector<DataPos> outVecPos;
std::vector<uint32_t> colIndicesToScan;
auto expressions = buildSideSchema->getExpressionsInScope();
for (auto i = 0u; i < expressions.size(); ++i) {
auto expression = expressions[i];
outVecPos.emplace_back(outSchema->getExpressionPos(*expression));
Expand Down
11 changes: 10 additions & 1 deletion test/test_files/tinysnb/optional_match/optional_match.test
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
#---- 1
#26

-LOG CrossProductOptionalMatch1
-STATEMENT MATCH (a:person) WHERE a.fName = 'Alice' OPTIONAL MATCH (b:person) WHERE b.fName = 'A' RETURN a.ID, b.ID
---- 1
0|

-LOG CrossProductOptionalMatch2
-STATEMENT OPTIONAL MATCH (a:person) WHERE a.fName = 'a' MATCH (b:person) WHERE b.fName = 'Alice' RETURN a.ID, b.ID
---- 1
|0

-LOG InitOptionalMatch
-STATEMENT OPTIONAL MATCH (a:person) WHERE a.ID<0 RETURN COUNT(*)
---- 1
Expand All @@ -33,7 +43,6 @@
-STATEMENT OPTIONAL MATCH (a:person) WHERE a.fName = 'a' RETURN a.age
---- 1


-LOG InitOptionalMatch2
-STATEMENT OPTIONAL MATCH (a:person) WHERE a.ID < 6 RETURN a.fName
---- 4
Expand Down