Skip to content

Commit

Permalink
Merge pull request #1319 from kuzudb/sink-projection-push-down
Browse files Browse the repository at this point in the history
Hash join build projection push down
  • Loading branch information
andyfengHKU committed Feb 25, 2023
2 parents 73c344f + f83841e commit b618ea4
Show file tree
Hide file tree
Showing 13 changed files with 220 additions and 7 deletions.
18 changes: 18 additions & 0 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ class Expression;
using expression_vector = std::vector<std::shared_ptr<Expression>>;
using expression_pair = std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>;

struct ExpressionHasher;
struct ExpressionEquality;
using expression_set =
std::unordered_set<std::shared_ptr<Expression>, ExpressionHasher, ExpressionEquality>;

class Expression : public std::enable_shared_from_this<Expression> {
public:
Expression(common::ExpressionType expressionType, common::DataType dataType,
Expand Down Expand Up @@ -112,6 +117,19 @@ class Expression : public std::enable_shared_from_this<Expression> {
expression_vector children;
};

struct ExpressionHasher {
std::size_t operator()(const std::shared_ptr<Expression>& expression) const {
return std::hash<std::string>{}(expression->getUniqueName());
}
};

struct ExpressionEquality {
bool operator()(
const std::shared_ptr<Expression>& left, const std::shared_ptr<Expression>& right) const {
return left->getUniqueName() == right->getUniqueName();
}
};

class ExpressionUtil {
public:
static bool allExpressionsHaveDataType(
Expand Down
6 changes: 4 additions & 2 deletions src/include/optimizer/index_nested_loop_join_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include <vector>

#include "planner/logical_plan/logical_operator/base_logical_operator.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {
Expand All @@ -16,10 +16,12 @@ namespace optimizer {
// implemented.
class IndexNestedLoopJoinOptimizer {
public:
static void rewrite(planner::LogicalPlan* plan);

private:
static std::shared_ptr<planner::LogicalOperator> rewrite(
std::shared_ptr<planner::LogicalOperator> op);

private:
static std::shared_ptr<planner::LogicalOperator> rewriteFilter(
std::shared_ptr<planner::LogicalOperator> op);

Expand Down
36 changes: 36 additions & 0 deletions src/include/optimizer/projection_push_down_optimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include "planner/logical_plan/logical_operator/base_logical_operator.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

// ProjectionPushDownOptimizer implements the logic to avoid materializing unnecessary properties
// for hash join build.
// Note the optimization is for properties only but not for general expressions. This is because
// it's hard to figure out what expression is in-use, e.g. COUNT(a.age) + 1, it could be either the
// whole expression was evaluated in a WITH clause or only COUNT(a.age) was evaluated or only a.age
// is evaluate. For simplicity, we only consider the push down for property.
class ProjectionPushDownOptimizer {
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);

void visitAccumulate(planner::LogicalOperator* op);
void visitFilter(planner::LogicalOperator* op);
void visitHashJoin(planner::LogicalOperator* op);
void visitIntersect(planner::LogicalOperator* op);
void visitProjection(planner::LogicalOperator* op);
void visitOrderBy(planner::LogicalOperator* op);

void collectPropertiesInUse(std::shared_ptr<binder::Expression> expression);

private:
binder::expression_set propertiesInUse;
};

} // namespace optimizer
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LogicalFilter : public LogicalOperator {
}

inline std::shared_ptr<binder::Expression> getPredicate() const { return expression; }

f_group_pos getGroupPosToSelect() const;

inline std::unique_ptr<LogicalOperator> copy() override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ class LogicalHashJoin : public LogicalOperator {
return binder::ExpressionUtil::toString(joinNodeIDs);
}

inline void setExpressionsToMaterialize(binder::expression_set expressions) {
expressionsToMaterialize.clear();
for (auto& expression : expressions) {
expressionsToMaterialize.push_back(expression);
}
}
inline binder::expression_vector getExpressionsToMaterialize() const {
return expressionsToMaterialize;
}
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_library(kuzu_optimizer
factorization_rewriter.cpp
index_nested_loop_join_optimizer.cpp
optimizer.cpp
projection_push_down_optimizer.cpp
remove_factorization_rewriter.cpp)

set(ALL_OBJECT_FILES
Expand Down
4 changes: 4 additions & 0 deletions src/optimizer/index_nested_loop_join_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ using namespace kuzu::planner;
namespace kuzu {
namespace optimizer {

void IndexNestedLoopJoinOptimizer::rewrite(planner::LogicalPlan* plan) {
rewrite(plan->getLastOperator());
}

std::shared_ptr<planner::LogicalOperator> IndexNestedLoopJoinOptimizer::rewrite(
std::shared_ptr<planner::LogicalOperator> op) {
if (op->getOperatorType() == LogicalOperatorType::FILTER) {
Expand Down
6 changes: 5 additions & 1 deletion src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "optimizer/factorization_rewriter.h"
#include "optimizer/index_nested_loop_join_optimizer.h"
#include "optimizer/projection_push_down_optimizer.h"
#include "optimizer/remove_factorization_rewriter.h"

namespace kuzu {
Expand All @@ -11,7 +12,10 @@ void Optimizer::optimize(planner::LogicalPlan* plan) {
auto removeFactorizationRewriter = RemoveFactorizationRewriter();
removeFactorizationRewriter.rewrite(plan);

IndexNestedLoopJoinOptimizer::rewrite(plan->getLastOperator());
IndexNestedLoopJoinOptimizer::rewrite(plan);

auto projectionPushDownOptimizer = ProjectionPushDownOptimizer();
projectionPushDownOptimizer.rewrite(plan);

auto factorizationRewriter = FactorizationRewriter();
factorizationRewriter.rewrite(plan);
Expand Down
135 changes: 135 additions & 0 deletions src/optimizer/projection_push_down_optimizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#include "optimizer/projection_push_down_optimizer.h"

#include "planner/logical_plan/logical_operator/logical_accumulate.h"
#include "planner/logical_plan/logical_operator/logical_filter.h"
#include "planner/logical_plan/logical_operator/logical_hash_join.h"
#include "planner/logical_plan/logical_operator/logical_intersect.h"
#include "planner/logical_plan/logical_operator/logical_order_by.h"
#include "planner/logical_plan/logical_operator/logical_projection.h"

using namespace kuzu::common;
using namespace kuzu::planner;
using namespace kuzu::binder;

namespace kuzu {
namespace optimizer {

void ProjectionPushDownOptimizer::rewrite(planner::LogicalPlan* plan) {
visitOperator(plan->getLastOperator().get());
}

void ProjectionPushDownOptimizer::visitOperator(LogicalOperator* op) {
switch (op->getOperatorType()) {
case LogicalOperatorType::ACCUMULATE: {
visitAccumulate(op);
} break;
case LogicalOperatorType::FILTER: {
visitFilter(op);
} break;
case LogicalOperatorType::HASH_JOIN: {
visitHashJoin(op);
} break;
case LogicalOperatorType::PROJECTION: {
visitProjection(op);
return;
}
case LogicalOperatorType::INTERSECT: {
visitIntersect(op);
} break;
case LogicalOperatorType::ORDER_BY: {
visitOrderBy(op);
} break;
default:
break;
}
for (auto i = 0; i < op->getNumChildren(); ++i) {
visitOperator(op->getChild(i).get());
}
}

void ProjectionPushDownOptimizer::visitAccumulate(planner::LogicalOperator* op) {
auto accumulate = (LogicalAccumulate*)op;
for (auto& expression : accumulate->getExpressions()) {
collectPropertiesInUse(expression);
}
}

void ProjectionPushDownOptimizer::visitFilter(planner::LogicalOperator* op) {
auto filter = (LogicalFilter*)op;
collectPropertiesInUse(filter->getPredicate());
}

void ProjectionPushDownOptimizer::visitHashJoin(planner::LogicalOperator* op) {
auto hashJoin = (LogicalHashJoin*)op;
for (auto& joinNodeID : hashJoin->getJoinNodeIDs()) {
collectPropertiesInUse(joinNodeID);
}
if (hashJoin->getJoinType() == JoinType::MARK) { // no need to perform push down for mark join.
return;
}
auto expressionsBeforePruning = hashJoin->getExpressionsToMaterialize();
expression_set expressionsAfterPruning;
for (auto& expression : expressionsBeforePruning) {
if (expression->expressionType != common::PROPERTY ||
propertiesInUse.contains(expression)) {
expressionsAfterPruning.insert(expression);
}
}
if (expressionsBeforePruning.size() == expressionsAfterPruning.size()) {
// TODO(Xiyang): replace this with a separate optimizer.
return;
}
hashJoin->setExpressionsToMaterialize(expressionsAfterPruning);
auto projectionExpressions =
expression_vector{expressionsAfterPruning.begin(), expressionsAfterPruning.end()};
auto projection = std::make_shared<LogicalProjection>(
std::move(projectionExpressions), hashJoin->getChild(1));
hashJoin->setChild(1, std::move(projection));
}

void ProjectionPushDownOptimizer::visitIntersect(planner::LogicalOperator* op) {
auto intersect = (LogicalIntersect*)op;
collectPropertiesInUse(intersect->getIntersectNodeID());
for (auto i = 0u; i < intersect->getNumBuilds(); ++i) {
auto buildInfo = intersect->getBuildInfo(i);
collectPropertiesInUse(buildInfo->keyNodeID);
for (auto& expression : buildInfo->expressionsToMaterialize) {
collectPropertiesInUse(expression);
}
}
}

void ProjectionPushDownOptimizer::visitProjection(LogicalOperator* op) {
// Projection operator defines the start of a projection push down until the next projection
// operator is seen.
ProjectionPushDownOptimizer optimizer;
auto projection = (LogicalProjection*)op;
for (auto& expression : projection->getExpressionsToProject()) {
optimizer.collectPropertiesInUse(expression);
}
optimizer.visitOperator(op->getChild(0).get());
}

void ProjectionPushDownOptimizer::visitOrderBy(planner::LogicalOperator* op) {
auto orderBy = (LogicalOrderBy*)op;
for (auto& expression : orderBy->getExpressionsToOrderBy()) {
collectPropertiesInUse(expression);
}
for (auto& expression : orderBy->getExpressionsToMaterialize()) {
collectPropertiesInUse(expression);
}
}

void ProjectionPushDownOptimizer::collectPropertiesInUse(
std::shared_ptr<binder::Expression> expression) {
if (expression->expressionType == common::PROPERTY) {
propertiesInUse.insert(std::move(expression));
return;
}
for (auto& child : expression->getChildren()) {
collectPropertiesInUse(child);
}
}

} // namespace optimizer
} // namespace kuzu
1 change: 1 addition & 0 deletions src/optimizer/remove_factorization_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ std::shared_ptr<planner::LogicalOperator> RemoveFactorizationRewriter::rewriteOp
if (op->getOperatorType() == planner::LogicalOperatorType::FLATTEN) {
return op->getChild(0);
}
op->getSchema()->clear();
return op;
}

Expand Down
6 changes: 6 additions & 0 deletions src/planner/operator/base_logical_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
case LogicalOperatorType::PROJECTION: {
return "PROJECTION";
}
case LogicalOperatorType::RENAME_TABLE: {
return "RENAME_TABLE";
}
case LogicalOperatorType::RENAME_PROPERTY: {
return "RENAME_PROPERTY";
}
case LogicalOperatorType::SCAN_NODE: {
return "SCAN_NODE";
}
Expand Down
5 changes: 2 additions & 3 deletions src/processor/mapper/map_accumulate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalAccumulateToPhysical(
auto inSchema = logicalAccumulate->getSchemaBeforeSink();
// append result collector
auto prevOperator = mapLogicalOperatorToPhysical(logicalAccumulate->getChild(0));
auto resultCollector = appendResultCollector(
inSchema->getExpressionsInScope(), *inSchema, std::move(prevOperator));
auto expressions = logicalAccumulate->getExpressions();
auto resultCollector = appendResultCollector(expressions, *inSchema, std::move(prevOperator));
// append factorized table scan
std::vector<DataPos> outDataPoses;
std::vector<uint32_t> colIndicesToScan;
auto expressions = logicalAccumulate->getExpressions();
for (auto i = 0u; i < expressions.size(); ++i) {
auto expression = expressions[i];
outDataPoses.emplace_back(outSchema->getExpressionPos(*expression));
Expand Down
2 changes: 1 addition & 1 deletion test/test_files/tinysnb/filter/two_hop.test
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
---- 1
12

-NAME TwoHopKnowsFilteredTest
-NAME TwoHopKnowsFilteredTest2
-QUERY MATCH (a:person)-[e1:knows]->(b:person), (a:person)-[e2:knows {date:e1.date}]->(c:person) WHERE e1.date = e2.date RETURN COUNT(*)
-ENUMERATE
---- 1
Expand Down

0 comments on commit b618ea4

Please sign in to comment.