Skip to content

Commit

Permalink
Merge pull request #1930 from kuzudb/unnest-subquery
Browse files Browse the repository at this point in the history
Unnest arbitrary subquery
  • Loading branch information
andyfengHKU committed Aug 15, 2023
2 parents 9529003 + e4bb187 commit c90fb84
Show file tree
Hide file tree
Showing 45 changed files with 670 additions and 171 deletions.
15 changes: 13 additions & 2 deletions src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ expression_vector Expression::splitOnAND() {
return result;
}

bool ExpressionUtil::allExpressionsHaveDataType(
expression_vector& expressions, LogicalTypeID dataTypeID) {
bool ExpressionUtil::isExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() != dataTypeID) {
return false;
Expand All @@ -29,6 +29,17 @@ bool ExpressionUtil::allExpressionsHaveDataType(
return true;
}

expression_vector ExpressionUtil::getExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
expression_vector result;
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() == dataTypeID) {
result.push_back(expression);
}
}
return result;
}

uint32_t ExpressionUtil::find(Expression* target, expression_vector expressions) {
for (auto i = 0u; i < expressions.size(); ++i) {
if (target->getUniqueName() == expressions[i]->getUniqueName()) {
Expand Down
6 changes: 4 additions & 2 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ struct ExpressionEquality {
};

struct ExpressionUtil {
static bool allExpressionsHaveDataType(
expression_vector& expressions, common::LogicalTypeID dataTypeID);
static bool isExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID);
static expression_vector getExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID);

static uint32_t find(Expression* target, expression_vector expressions);

Expand Down
30 changes: 30 additions & 0 deletions src/include/optimizer/correlated_subquery_unnest_solver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class CorrelatedSubqueryUnnestSolver : public LogicalOperatorVisitor {
public:
CorrelatedSubqueryUnnestSolver(planner::LogicalOperator* accumulateOp)
: accumulateOp{accumulateOp} {}
void solve(planner::LogicalOperator* root_);

private:
void visitOperator(planner::LogicalOperator* op);
void visitExpressionsScan(planner::LogicalOperator* op) final;

inline bool isAccHashJoin(planner::LogicalOperator* op) const {
return op->getOperatorType() == planner::LogicalOperatorType::HASH_JOIN &&
op->getChild(0)->getOperatorType() == planner::LogicalOperatorType::ACCUMULATE;
}
void solveAccHashJoin(planner::LogicalOperator* op) const;

private:
planner::LogicalOperator* accumulateOp;
};

} // namespace optimizer
} // namespace kuzu
1 change: 1 addition & 0 deletions src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class FactorizationRewriter : public LogicalOperatorVisitor {
void visitHashJoin(planner::LogicalOperator* op) override;
void visitIntersect(planner::LogicalOperator* op) override;
void visitProjection(planner::LogicalOperator* op) override;
void visitAccumulate(planner::LogicalOperator* op) override;
void visitAggregate(planner::LogicalOperator* op) override;
void visitOrderBy(planner::LogicalOperator* op) override;
void visitSkip(planner::LogicalOperator* op) override;
Expand Down
6 changes: 6 additions & 0 deletions src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitExpressionsScan(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitExpressionsScanReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitScanNode(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitScanNodeReplace(
std::shared_ptr<planner::LogicalOperator> op) {
Expand Down
2 changes: 1 addition & 1 deletion src/include/planner/join_order/cardinality_estimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CardinalityEstimator {
void initNodeIDDom(binder::QueryGraph* queryGraph);

uint64_t estimateScanNode(LogicalOperator* op);
uint64_t estimateHashJoin(const binder::expression_vector& joinNodeIDs,
uint64_t estimateHashJoin(const binder::expression_vector& joinKeys,
const LogicalPlan& probePlan, const LogicalPlan& buildPlan);
uint64_t estimateCrossProduct(const LogicalPlan& probePlan, const LogicalPlan& buildPlan);
uint64_t estimateIntersect(const binder::expression_vector& joinNodeIDs,
Expand Down
26 changes: 16 additions & 10 deletions src/include/planner/join_order_enumerator_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@
namespace kuzu {
namespace planner {

enum class SubqueryType : uint8_t {
NONE = 0,
INTERNAL_ID_CORRELATED = 1,
CORRELATED = 2,
};

class JoinOrderEnumeratorContext {
friend class QueryPlanner;

public:
JoinOrderEnumeratorContext()
: currentLevel{0}, maxLevel{0}, subPlansTable{std::make_unique<SubPlansTable>()},
queryGraph{nullptr} {}
queryGraph{nullptr}, subqueryType{SubqueryType::NONE}, correlatedExpressionsCardinality{
1} {}

void init(QueryGraph* queryGraph, const expression_vector& predicates);

Expand All @@ -35,15 +42,12 @@ class JoinOrderEnumeratorContext {

inline QueryGraph* getQueryGraph() { return queryGraph; }

inline bool nodeToScanFromInnerAndOuter(NodeExpression* node) {
for (auto& nodeID : nodeIDsToScanFromInnerAndOuter) {
if (nodeID->getUniqueName() == node->getInternalIDPropertyName()) {
return true;
}
}
return false;
inline binder::expression_vector getCorrelatedExpressions() const {
return correlatedExpressions;
}
inline binder::expression_set getCorrelatedExpressionsSet() const {
return binder::expression_set{correlatedExpressions.begin(), correlatedExpressions.end()};
}

void resetState();

private:
Expand All @@ -55,7 +59,9 @@ class JoinOrderEnumeratorContext {
std::unique_ptr<SubPlansTable> subPlansTable;
QueryGraph* queryGraph;

expression_vector nodeIDsToScanFromInnerAndOuter;
SubqueryType subqueryType;
expression_vector correlatedExpressions;
uint64_t correlatedExpressionsCardinality;
};

} // namespace planner
Expand Down
15 changes: 10 additions & 5 deletions src/include/planner/logical_plan/logical_accumulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,31 @@ namespace planner {

class LogicalAccumulate : public LogicalOperator {
public:
LogicalAccumulate(common::AccumulateType accumulateType, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)}, accumulateType{
accumulateType} {}
LogicalAccumulate(common::AccumulateType accumulateType,
binder::expression_vector expressionsToFlatten, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)},
accumulateType{accumulateType}, expressionsToFlatten{std::move(expressionsToFlatten)} {}

void computeFactorizedSchema() final;
void computeFlatSchema() final;

f_group_pos_set getGroupPositionsToFlatten() const;

inline std::string getExpressionsForPrinting() const final { return std::string{}; }

inline common::AccumulateType getAccumulateType() const { return accumulateType; }
inline binder::expression_vector getExpressions() const {
inline binder::expression_vector getExpressionsToAccumulate() const {
return children[0]->getSchema()->getExpressionsInScope();
}

inline std::unique_ptr<LogicalOperator> copy() final {
return make_unique<LogicalAccumulate>(accumulateType, children[0]->copy());
return make_unique<LogicalAccumulate>(
accumulateType, expressionsToFlatten, children[0]->copy());
}

private:
common::AccumulateType accumulateType;
binder::expression_vector expressionsToFlatten;
};

} // namespace planner
Expand Down
12 changes: 11 additions & 1 deletion src/include/planner/logical_plan/logical_hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ namespace planner {
// We only support equality comparison as join condition
using join_condition_t = binder::expression_pair;

enum class JoinSubPlanSolveOrder : uint8_t {
ANY = 0,
PROBE_BUILD = 1,
BUILD_PROBE = 2,
};

// Probe side on left, i.e. children[0]. Build side on right, i.e. children[1].
class LogicalHashJoin : public LogicalOperator {
public:
Expand All @@ -35,7 +41,7 @@ class LogicalHashJoin : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::HASH_JOIN, std::move(probeSideChild),
std::move(buildSideChild)},
joinConditions(std::move(joinConditions)), joinType{joinType}, mark{std::move(mark)},
sip{SidewaysInfoPassing::NONE} {}
sip{SidewaysInfoPassing::NONE}, order{JoinSubPlanSolveOrder::ANY} {}

f_group_pos_set getGroupsPosToFlattenOnProbeSide();
f_group_pos_set getGroupsPosToFlattenOnBuildSide();
Expand All @@ -61,6 +67,9 @@ class LogicalHashJoin : public LogicalOperator {
inline void setSIP(SidewaysInfoPassing sip_) { sip = sip_; }
inline SidewaysInfoPassing getSIP() const { return sip; }

inline void setJoinSubPlanSolveOrder(JoinSubPlanSolveOrder order_) { order = order_; }
inline JoinSubPlanSolveOrder getJoinSubPlanSolveOrder() const { return order; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalHashJoin>(
joinConditions, joinType, mark, children[0]->copy(), children[1]->copy());
Expand All @@ -84,6 +93,7 @@ class LogicalHashJoin : public LogicalOperator {
common::JoinType joinType;
std::shared_ptr<binder::Expression> mark; // when joinType is Mark
SidewaysInfoPassing sip;
JoinSubPlanSolveOrder order; // sip introduce join dependency
};

} // namespace planner
Expand Down
1 change: 1 addition & 0 deletions src/include/planner/logical_plan/logical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ enum class LogicalOperatorType : uint8_t {
DROP_TABLE,
DUMMY_SCAN,
EXPLAIN,
EXPRESSIONS_SCAN,
EXTEND,
FILTER,
FLATTEN,
Expand Down
39 changes: 39 additions & 0 deletions src/include/planner/logical_plan/scan/logical_expressions_scan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include "planner/logical_plan/logical_operator.h"

namespace kuzu {
namespace planner {

// LogicalExpressionsScan scans from an outer factorize table
class LogicalExpressionsScan : public LogicalOperator {
public:
LogicalExpressionsScan(binder::expression_vector expressions)
: LogicalOperator{LogicalOperatorType::EXPRESSIONS_SCAN}, expressions{
std::move(expressions)} {}

inline void computeFactorizedSchema() final { computeSchema(); }
inline void computeFlatSchema() final { computeSchema(); }

inline std::string getExpressionsForPrinting() const final {
return binder::ExpressionUtil::toString(expressions);
}

inline binder::expression_vector getExpressions() const { return expressions; }
inline void setOuterAccumulate(LogicalOperator* op) { outerAccumulate = op; }
inline LogicalOperator* getOuterAccumulate() const { return outerAccumulate; }

inline std::unique_ptr<LogicalOperator> copy() final {
return std::make_unique<LogicalExpressionsScan>(expressions);
}

private:
void computeSchema();

private:
binder::expression_vector expressions;
LogicalOperator* outerAccumulate;
};

} // namespace planner
} // namespace kuzu
27 changes: 19 additions & 8 deletions src/include/planner/query_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,25 @@ class QueryPlanner {
std::unique_ptr<LogicalPlan> planQueryGraphCollection(
const binder::QueryGraphCollection& queryGraphCollection,
const binder::expression_vector& predicates);
std::unique_ptr<LogicalPlan> planQueryGraphCollectionInNewContext(
const binder::expression_vector& expressionsToExcludeScan,
std::unique_ptr<LogicalPlan> planQueryGraphCollectionInNewContext(SubqueryType subqueryType,
const binder::expression_vector& correlatedExpressions, uint64_t cardinality,
const binder::QueryGraphCollection& queryGraphCollection,
const binder::expression_vector& predicates);
std::vector<std::unique_ptr<LogicalPlan>> enumerateQueryGraphCollection(
const binder::QueryGraphCollection& queryGraphCollection,
const binder::expression_vector& predicates);
std::vector<std::unique_ptr<LogicalPlan>> enumerateQueryGraph(
binder::QueryGraph* queryGraph, binder::expression_vector& predicates);
std::vector<std::unique_ptr<LogicalPlan>> enumerateQueryGraph(SubqueryType subqueryType,
const expression_vector& correlatedExpressions, binder::QueryGraph* queryGraph,
binder::expression_vector& predicates);

// Plan node/rel table scan
void planBaseTableScan();
void planBaseTableScans(
SubqueryType subqueryType, const expression_vector& correlatedExpressions);
void planCorrelatedExpressionsScan(const binder::expression_vector& correlatedExpressions);
std::unique_ptr<LogicalPlan> getCorrelatedExpressionScanPlan(
const binder::expression_vector& correlatedExpressions);
void planNodeScan(uint32_t nodePos);
void planNodeIDScan(uint32_t nodePos);
void planRelScan(uint32_t relPos);
void appendExtendAndFilter(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
Expand Down Expand Up @@ -164,6 +170,7 @@ class QueryPlanner {
void appendSkip(uint64_t skipNumber, LogicalPlan& plan);

// Append scan operators
void appendExpressionsScan(const expression_vector& expressions, LogicalPlan& plan);
void appendScanNodeID(std::shared_ptr<NodeExpression>& node, LogicalPlan& plan);
void appendScanNodeProperties(const expression_vector& propertyExpressions,
std::shared_ptr<NodeExpression> node, LogicalPlan& plan);
Expand Down Expand Up @@ -198,7 +205,11 @@ class QueryPlanner {
void appendCrossProduct(
common::AccumulateType accumulateType, LogicalPlan& probePlan, LogicalPlan& buildPlan);

void appendAccumulate(common::AccumulateType accumulateType, LogicalPlan& plan);
inline void appendAccumulate(common::AccumulateType accumulateType, LogicalPlan& plan) {
appendAccumulate(accumulateType, expression_vector{}, plan);
}
void appendAccumulate(common::AccumulateType accumulateType,
const binder::expression_vector& expressionsToFlatten, LogicalPlan& plan);

void appendDummyScan(LogicalPlan& plan);

Expand All @@ -221,8 +232,8 @@ class QueryPlanner {

expression_vector getProperties(const binder::Expression& nodeOrRel);

std::unique_ptr<JoinOrderEnumeratorContext> enterContext(
binder::expression_vector nodeIDsToScanFromInnerAndOuter);
std::unique_ptr<JoinOrderEnumeratorContext> enterContext(SubqueryType subqueryType,
const expression_vector& correlatedExpressions, uint64_t cardinality);
void exitContext(std::unique_ptr<JoinOrderEnumeratorContext> prevContext);

private:
Expand Down
9 changes: 8 additions & 1 deletion src/include/processor/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,21 @@ class PlanMapper {
std::unique_ptr<PhysicalOperator> mapStandaloneCall(planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapInQueryCall(planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapExplain(planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapExpressionsScan(planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapCreateMacro(planner::LogicalOperator* logicalOperator);

std::unique_ptr<ResultCollector> createResultCollector(common::AccumulateType accumulateType,
const binder::expression_vector& expressions, planner::Schema* schema,
std::unique_ptr<PhysicalOperator> prevOperator);
std::unique_ptr<PhysicalOperator> createFactorizedTableScan(
const binder::expression_vector& expressions, std::vector<ft_col_idx_t> colIndices,
planner::Schema* schema, std::shared_ptr<FactorizedTable> table, uint64_t maxMorselSize,
std::unique_ptr<PhysicalOperator> prevOperator);
// Assume scans all columns of table in the same order as given expressions.
std::unique_ptr<PhysicalOperator> createFactorizedTableScanAligned(
const binder::expression_vector& expressions, planner::Schema* schema,
std::shared_ptr<FactorizedTable> table, std::unique_ptr<PhysicalOperator> prevOperator);
std::shared_ptr<FactorizedTable> table, uint64_t maxMorselSize,
std::unique_ptr<PhysicalOperator> prevOperator);
std::unique_ptr<HashJoinBuildInfo> createHashBuildInfo(const planner::Schema& buildSideSchema,
const binder::expression_vector& keys, const binder::expression_vector& payloads);
std::unique_ptr<PhysicalOperator> createHashAggregate(
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_library(kuzu_optimizer
OBJECT
acc_hash_join_optimizer.cpp
agg_key_dependency_optimizer.cpp
correlated_subquery_unnest_solver.cpp
factorization_rewriter.cpp
filter_push_down_optimizer.cpp
logical_operator_collector.cpp
Expand Down
Loading

0 comments on commit c90fb84

Please sign in to comment.