Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unnest arbitrary subquery #1930

Merged
merged 1 commit into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
return result;
}

bool ExpressionUtil::allExpressionsHaveDataType(
expression_vector& expressions, LogicalTypeID dataTypeID) {
bool ExpressionUtil::isExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() != dataTypeID) {
return false;
Expand All @@ -29,6 +29,17 @@
return true;
}

expression_vector ExpressionUtil::getExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
expression_vector result;
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() == dataTypeID) {
result.push_back(expression);
}
}
return result;
}

Check warning on line 41 in src/binder/expression/expression.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression/expression.cpp#L41

Added line #L41 was not covered by tests

uint32_t ExpressionUtil::find(Expression* target, expression_vector expressions) {
for (auto i = 0u; i < expressions.size(); ++i) {
if (target->getUniqueName() == expressions[i]->getUniqueName()) {
Expand Down
6 changes: 4 additions & 2 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ struct ExpressionEquality {
};

struct ExpressionUtil {
static bool allExpressionsHaveDataType(
expression_vector& expressions, common::LogicalTypeID dataTypeID);
static bool isExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID);
static expression_vector getExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID);

static uint32_t find(Expression* target, expression_vector expressions);

Expand Down
30 changes: 30 additions & 0 deletions src/include/optimizer/correlated_subquery_unnest_solver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class CorrelatedSubqueryUnnestSolver : public LogicalOperatorVisitor {
public:
CorrelatedSubqueryUnnestSolver(planner::LogicalOperator* accumulateOp)
: accumulateOp{accumulateOp} {}
void solve(planner::LogicalOperator* root_);

private:
void visitOperator(planner::LogicalOperator* op);
void visitExpressionsScan(planner::LogicalOperator* op) final;

inline bool isAccHashJoin(planner::LogicalOperator* op) const {
return op->getOperatorType() == planner::LogicalOperatorType::HASH_JOIN &&
op->getChild(0)->getOperatorType() == planner::LogicalOperatorType::ACCUMULATE;
}
void solveAccHashJoin(planner::LogicalOperator* op) const;

private:
planner::LogicalOperator* accumulateOp;
};

} // namespace optimizer
} // namespace kuzu
1 change: 1 addition & 0 deletions src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class FactorizationRewriter : public LogicalOperatorVisitor {
void visitHashJoin(planner::LogicalOperator* op) override;
void visitIntersect(planner::LogicalOperator* op) override;
void visitProjection(planner::LogicalOperator* op) override;
void visitAccumulate(planner::LogicalOperator* op) override;
void visitAggregate(planner::LogicalOperator* op) override;
void visitOrderBy(planner::LogicalOperator* op) override;
void visitSkip(planner::LogicalOperator* op) override;
Expand Down
6 changes: 6 additions & 0 deletions src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitExpressionsScan(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitExpressionsScanReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitScanNode(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitScanNodeReplace(
std::shared_ptr<planner::LogicalOperator> op) {
Expand Down
2 changes: 1 addition & 1 deletion src/include/planner/join_order/cardinality_estimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CardinalityEstimator {
void initNodeIDDom(binder::QueryGraph* queryGraph);

uint64_t estimateScanNode(LogicalOperator* op);
uint64_t estimateHashJoin(const binder::expression_vector& joinNodeIDs,
uint64_t estimateHashJoin(const binder::expression_vector& joinKeys,
const LogicalPlan& probePlan, const LogicalPlan& buildPlan);
uint64_t estimateCrossProduct(const LogicalPlan& probePlan, const LogicalPlan& buildPlan);
uint64_t estimateIntersect(const binder::expression_vector& joinNodeIDs,
Expand Down
26 changes: 16 additions & 10 deletions src/include/planner/join_order_enumerator_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@
namespace kuzu {
namespace planner {

enum class SubqueryType : uint8_t {
NONE = 0,
INTERNAL_ID_CORRELATED = 1,
CORRELATED = 2,
};

class JoinOrderEnumeratorContext {
friend class QueryPlanner;

public:
JoinOrderEnumeratorContext()
: currentLevel{0}, maxLevel{0}, subPlansTable{std::make_unique<SubPlansTable>()},
queryGraph{nullptr} {}
queryGraph{nullptr}, subqueryType{SubqueryType::NONE}, correlatedExpressionsCardinality{
1} {}

void init(QueryGraph* queryGraph, const expression_vector& predicates);

Expand All @@ -35,15 +42,12 @@ class JoinOrderEnumeratorContext {

inline QueryGraph* getQueryGraph() { return queryGraph; }

inline bool nodeToScanFromInnerAndOuter(NodeExpression* node) {
for (auto& nodeID : nodeIDsToScanFromInnerAndOuter) {
if (nodeID->getUniqueName() == node->getInternalIDPropertyName()) {
return true;
}
}
return false;
inline binder::expression_vector getCorrelatedExpressions() const {
return correlatedExpressions;
}
inline binder::expression_set getCorrelatedExpressionsSet() const {
return binder::expression_set{correlatedExpressions.begin(), correlatedExpressions.end()};
}

void resetState();

private:
Expand All @@ -55,7 +59,9 @@ class JoinOrderEnumeratorContext {
std::unique_ptr<SubPlansTable> subPlansTable;
QueryGraph* queryGraph;

expression_vector nodeIDsToScanFromInnerAndOuter;
SubqueryType subqueryType;
expression_vector correlatedExpressions;
uint64_t correlatedExpressionsCardinality;
};

} // namespace planner
Expand Down
15 changes: 10 additions & 5 deletions src/include/planner/logical_plan/logical_accumulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,31 @@ namespace planner {

class LogicalAccumulate : public LogicalOperator {
public:
LogicalAccumulate(common::AccumulateType accumulateType, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)}, accumulateType{
accumulateType} {}
LogicalAccumulate(common::AccumulateType accumulateType,
binder::expression_vector expressionsToFlatten, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)},
accumulateType{accumulateType}, expressionsToFlatten{std::move(expressionsToFlatten)} {}

void computeFactorizedSchema() final;
void computeFlatSchema() final;

f_group_pos_set getGroupPositionsToFlatten() const;

inline std::string getExpressionsForPrinting() const final { return std::string{}; }

inline common::AccumulateType getAccumulateType() const { return accumulateType; }
inline binder::expression_vector getExpressions() const {
inline binder::expression_vector getExpressionsToAccumulate() const {
return children[0]->getSchema()->getExpressionsInScope();
}

inline std::unique_ptr<LogicalOperator> copy() final {
return make_unique<LogicalAccumulate>(accumulateType, children[0]->copy());
return make_unique<LogicalAccumulate>(
accumulateType, expressionsToFlatten, children[0]->copy());
}

private:
common::AccumulateType accumulateType;
binder::expression_vector expressionsToFlatten;
};

} // namespace planner
Expand Down
12 changes: 11 additions & 1 deletion src/include/planner/logical_plan/logical_hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ namespace planner {
// We only support equality comparison as join condition
using join_condition_t = binder::expression_pair;

enum class JoinSubPlanSolveOrder : uint8_t {
ANY = 0,
PROBE_BUILD = 1,
BUILD_PROBE = 2,
};

// Probe side on left, i.e. children[0]. Build side on right, i.e. children[1].
class LogicalHashJoin : public LogicalOperator {
public:
Expand All @@ -35,7 +41,7 @@ class LogicalHashJoin : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::HASH_JOIN, std::move(probeSideChild),
std::move(buildSideChild)},
joinConditions(std::move(joinConditions)), joinType{joinType}, mark{std::move(mark)},
sip{SidewaysInfoPassing::NONE} {}
sip{SidewaysInfoPassing::NONE}, order{JoinSubPlanSolveOrder::ANY} {}

f_group_pos_set getGroupsPosToFlattenOnProbeSide();
f_group_pos_set getGroupsPosToFlattenOnBuildSide();
Expand All @@ -61,6 +67,9 @@ class LogicalHashJoin : public LogicalOperator {
inline void setSIP(SidewaysInfoPassing sip_) { sip = sip_; }
inline SidewaysInfoPassing getSIP() const { return sip; }

inline void setJoinSubPlanSolveOrder(JoinSubPlanSolveOrder order_) { order = order_; }
inline JoinSubPlanSolveOrder getJoinSubPlanSolveOrder() const { return order; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalHashJoin>(
joinConditions, joinType, mark, children[0]->copy(), children[1]->copy());
Expand All @@ -84,6 +93,7 @@ class LogicalHashJoin : public LogicalOperator {
common::JoinType joinType;
std::shared_ptr<binder::Expression> mark; // when joinType is Mark
SidewaysInfoPassing sip;
JoinSubPlanSolveOrder order; // sip introduce join dependency
};

} // namespace planner
Expand Down
1 change: 1 addition & 0 deletions src/include/planner/logical_plan/logical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ enum class LogicalOperatorType : uint8_t {
DROP_TABLE,
DUMMY_SCAN,
EXPLAIN,
EXPRESSIONS_SCAN,
EXTEND,
FILTER,
FLATTEN,
Expand Down
39 changes: 39 additions & 0 deletions src/include/planner/logical_plan/scan/logical_expressions_scan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include "planner/logical_plan/logical_operator.h"

namespace kuzu {
namespace planner {

// LogicalExpressionsScan scans from an outer factorize table
class LogicalExpressionsScan : public LogicalOperator {
public:
LogicalExpressionsScan(binder::expression_vector expressions)
: LogicalOperator{LogicalOperatorType::EXPRESSIONS_SCAN}, expressions{
std::move(expressions)} {}

inline void computeFactorizedSchema() final { computeSchema(); }
inline void computeFlatSchema() final { computeSchema(); }

inline std::string getExpressionsForPrinting() const final {
return binder::ExpressionUtil::toString(expressions);
}

inline binder::expression_vector getExpressions() const { return expressions; }
inline void setOuterAccumulate(LogicalOperator* op) { outerAccumulate = op; }
inline LogicalOperator* getOuterAccumulate() const { return outerAccumulate; }

inline std::unique_ptr<LogicalOperator> copy() final {
return std::make_unique<LogicalExpressionsScan>(expressions);
}

private:
void computeSchema();

private:
binder::expression_vector expressions;
LogicalOperator* outerAccumulate;
};

} // namespace planner
} // namespace kuzu
27 changes: 19 additions & 8 deletions src/include/planner/query_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,25 @@ class QueryPlanner {
std::unique_ptr<LogicalPlan> planQueryGraphCollection(
const binder::QueryGraphCollection& queryGraphCollection,
const binder::expression_vector& predicates);
std::unique_ptr<LogicalPlan> planQueryGraphCollectionInNewContext(
const binder::expression_vector& expressionsToExcludeScan,
std::unique_ptr<LogicalPlan> planQueryGraphCollectionInNewContext(SubqueryType subqueryType,
const binder::expression_vector& correlatedExpressions, uint64_t cardinality,
const binder::QueryGraphCollection& queryGraphCollection,
const binder::expression_vector& predicates);
std::vector<std::unique_ptr<LogicalPlan>> enumerateQueryGraphCollection(
const binder::QueryGraphCollection& queryGraphCollection,
const binder::expression_vector& predicates);
std::vector<std::unique_ptr<LogicalPlan>> enumerateQueryGraph(
binder::QueryGraph* queryGraph, binder::expression_vector& predicates);
std::vector<std::unique_ptr<LogicalPlan>> enumerateQueryGraph(SubqueryType subqueryType,
const expression_vector& correlatedExpressions, binder::QueryGraph* queryGraph,
binder::expression_vector& predicates);

// Plan node/rel table scan
void planBaseTableScan();
void planBaseTableScans(
SubqueryType subqueryType, const expression_vector& correlatedExpressions);
void planCorrelatedExpressionsScan(const binder::expression_vector& correlatedExpressions);
std::unique_ptr<LogicalPlan> getCorrelatedExpressionScanPlan(
const binder::expression_vector& correlatedExpressions);
void planNodeScan(uint32_t nodePos);
void planNodeIDScan(uint32_t nodePos);
void planRelScan(uint32_t relPos);
void appendExtendAndFilter(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
Expand Down Expand Up @@ -164,6 +170,7 @@ class QueryPlanner {
void appendSkip(uint64_t skipNumber, LogicalPlan& plan);

// Append scan operators
void appendExpressionsScan(const expression_vector& expressions, LogicalPlan& plan);
void appendScanNodeID(std::shared_ptr<NodeExpression>& node, LogicalPlan& plan);
void appendScanNodeProperties(const expression_vector& propertyExpressions,
std::shared_ptr<NodeExpression> node, LogicalPlan& plan);
Expand Down Expand Up @@ -198,7 +205,11 @@ class QueryPlanner {
void appendCrossProduct(
common::AccumulateType accumulateType, LogicalPlan& probePlan, LogicalPlan& buildPlan);

void appendAccumulate(common::AccumulateType accumulateType, LogicalPlan& plan);
inline void appendAccumulate(common::AccumulateType accumulateType, LogicalPlan& plan) {
appendAccumulate(accumulateType, expression_vector{}, plan);
}
void appendAccumulate(common::AccumulateType accumulateType,
const binder::expression_vector& expressionsToFlatten, LogicalPlan& plan);

void appendDummyScan(LogicalPlan& plan);

Expand All @@ -221,8 +232,8 @@ class QueryPlanner {

expression_vector getProperties(const binder::Expression& nodeOrRel);

std::unique_ptr<JoinOrderEnumeratorContext> enterContext(
binder::expression_vector nodeIDsToScanFromInnerAndOuter);
std::unique_ptr<JoinOrderEnumeratorContext> enterContext(SubqueryType subqueryType,
const expression_vector& correlatedExpressions, uint64_t cardinality);
void exitContext(std::unique_ptr<JoinOrderEnumeratorContext> prevContext);

private:
Expand Down
9 changes: 8 additions & 1 deletion src/include/processor/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,21 @@ class PlanMapper {
std::unique_ptr<PhysicalOperator> mapStandaloneCall(planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapInQueryCall(planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapExplain(planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapExpressionsScan(planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapCreateMacro(planner::LogicalOperator* logicalOperator);

std::unique_ptr<ResultCollector> createResultCollector(common::AccumulateType accumulateType,
const binder::expression_vector& expressions, planner::Schema* schema,
std::unique_ptr<PhysicalOperator> prevOperator);
std::unique_ptr<PhysicalOperator> createFactorizedTableScan(
const binder::expression_vector& expressions, std::vector<ft_col_idx_t> colIndices,
planner::Schema* schema, std::shared_ptr<FactorizedTable> table, uint64_t maxMorselSize,
std::unique_ptr<PhysicalOperator> prevOperator);
// Assume scans all columns of table in the same order as given expressions.
std::unique_ptr<PhysicalOperator> createFactorizedTableScanAligned(
const binder::expression_vector& expressions, planner::Schema* schema,
std::shared_ptr<FactorizedTable> table, std::unique_ptr<PhysicalOperator> prevOperator);
std::shared_ptr<FactorizedTable> table, uint64_t maxMorselSize,
std::unique_ptr<PhysicalOperator> prevOperator);
std::unique_ptr<HashJoinBuildInfo> createHashBuildInfo(const planner::Schema& buildSideSchema,
const binder::expression_vector& keys, const binder::expression_vector& payloads);
std::unique_ptr<PhysicalOperator> createHashAggregate(
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_library(kuzu_optimizer
OBJECT
acc_hash_join_optimizer.cpp
agg_key_dependency_optimizer.cpp
correlated_subquery_unnest_solver.cpp
factorization_rewriter.cpp
filter_push_down_optimizer.cpp
logical_operator_collector.cpp
Expand Down
Loading