Skip to content

Commit

Permalink
Merge pull request #1949 from kuzudb/top-k-optmizer
Browse files Browse the repository at this point in the history
Add top K optimizer
  • Loading branch information
andyfengHKU committed Aug 22, 2023
2 parents e69f0a7 + beb15e6 commit 2c45803
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 28 deletions.
22 changes: 22 additions & 0 deletions src/include/optimizer/top_k_optimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class TopKOptimizer : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

std::shared_ptr<planner::LogicalOperator> visitOperator(
std::shared_ptr<planner::LogicalOperator> op);

private:
std::shared_ptr<planner::LogicalOperator> visitLimitReplace(
std::shared_ptr<planner::LogicalOperator> op) override;
};

} // namespace optimizer
} // namespace kuzu
22 changes: 13 additions & 9 deletions src/include/planner/logical_plan/logical_order_by.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,32 @@ class LogicalOrderBy : public LogicalOperator {

f_group_pos_set getGroupsPosToFlatten();

void computeFactorizedSchema() override;
void computeFlatSchema() override;
void computeFactorizedSchema() final;
void computeFlatSchema() final;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(expressionsToOrderBy);
}
std::string getExpressionsForPrinting() const final;

inline binder::expression_vector getExpressionsToOrderBy() const {
return expressionsToOrderBy;
}
inline std::vector<bool> getIsAscOrders() const { return isAscOrders; }
inline binder::expression_vector getExpressionsToMaterialize() const {
return children[0]->getSchema()->getExpressionsInScope();
}

inline std::unique_ptr<LogicalOperator> copy() override {
inline bool isTopK() const { return hasLimitNum(); }
inline void setSkipNum(uint64_t num) { skipNum = num; }
inline uint64_t getSkipNum() const { return skipNum; }
inline void setLimitNum(uint64_t num) { limitNum = num; }
inline bool hasLimitNum() const { return limitNum != UINT64_MAX; }
inline uint64_t getLimitNum() const { return limitNum; }

inline std::unique_ptr<LogicalOperator> copy() final {
return make_unique<LogicalOrderBy>(expressionsToOrderBy, isAscOrders, children[0]->copy());
}

private:
binder::expression_vector expressionsToOrderBy;
std::vector<bool> isAscOrders;
uint64_t skipNum = UINT64_MAX;
uint64_t limitNum = UINT64_MAX;
};

} // namespace planner
Expand Down
3 changes: 2 additions & 1 deletion src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ add_library(kuzu_optimizer
optimizer.cpp
projection_push_down_optimizer.cpp
remove_factorization_rewriter.cpp
remove_unnecessary_join_optimizer.cpp)
remove_unnecessary_join_optimizer.cpp
top_k_optimizer.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_optimizer>
Expand Down
4 changes: 4 additions & 0 deletions src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "optimizer/projection_push_down_optimizer.h"
#include "optimizer/remove_factorization_rewriter.h"
#include "optimizer/remove_unnecessary_join_optimizer.h"
#include "optimizer/top_k_optimizer.h"

namespace kuzu {
namespace optimizer {
Expand All @@ -33,6 +34,9 @@ void Optimizer::optimize(planner::LogicalPlan* plan) {
auto hashJoinSIPOptimizer = HashJoinSIPOptimizer();
hashJoinSIPOptimizer.rewrite(plan);

// auto topKOptimizer = TopKOptimizer();
// topKOptimizer.rewrite(plan);

auto factorizationRewriter = FactorizationRewriter();
factorizationRewriter.rewrite(plan);

Expand Down
2 changes: 1 addition & 1 deletion src/optimizer/projection_push_down_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void ProjectionPushDownOptimizer::visitOrderBy(planner::LogicalOperator* op) {
for (auto& expression : orderBy->getExpressionsToOrderBy()) {
collectExpressionsInUse(expression);
}
auto expressionsBeforePruning = orderBy->getExpressionsToMaterialize();
auto expressionsBeforePruning = orderBy->getChild(0)->getSchema()->getExpressionsInScope();
auto expressionsAfterPruning = pruneExpressions(expressionsBeforePruning);
if (expressionsBeforePruning.size() == expressionsAfterPruning.size()) {
return;
Expand Down
52 changes: 52 additions & 0 deletions src/optimizer/top_k_optimizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "optimizer/top_k_optimizer.h"

#include "planner/logical_plan/logical_limit.h"
#include "planner/logical_plan/logical_order_by.h"

using namespace kuzu::planner;

namespace kuzu {
namespace optimizer {

void TopKOptimizer::rewrite(planner::LogicalPlan* plan) {
visitOperator(plan->getLastOperator());
}

std::shared_ptr<LogicalOperator> TopKOptimizer::visitOperator(std::shared_ptr<LogicalOperator> op) {
// bottom-up traversal
for (auto i = 0; i < op->getNumChildren(); ++i) {
op->setChild(i, visitOperator(op->getChild(i)));
}
auto result = visitOperatorReplaceSwitch(op);
result->computeFlatSchema();
return result;
}

// TODO(Xiyang): we should probably remove the projection between ORDER BY and MULTIPLICITY REDUCER
// We search for pattern
// ORDER BY -> PROJECTION -> MULTIPLICITY REDUCER -> LIMIT
// and rewrite as TOP_K
std::shared_ptr<LogicalOperator> TopKOptimizer::visitLimitReplace(
std::shared_ptr<LogicalOperator> op) {
auto limit = (LogicalLimit*)op.get();
if (!limit->hasLimitNum()) {
return op; // only skip no limit. No need to rewrite
}
auto multiplicityReducer = limit->getChild(0);
assert(multiplicityReducer->getOperatorType() == LogicalOperatorType::MULTIPLICITY_REDUCER);
if (multiplicityReducer->getChild(0)->getOperatorType() != LogicalOperatorType::PROJECTION) {
return op;
}
auto projection = multiplicityReducer->getChild(0);
if (projection->getChild(0)->getOperatorType() != LogicalOperatorType::ORDER_BY) {
return op;
}
auto orderBy = std::static_pointer_cast<LogicalOrderBy>(projection->getChild(0));
orderBy->setLimitNum(limit->getLimitNum());
auto skipNum = limit->hasSkipNum() ? limit->getSkipNum() : 0;
orderBy->setSkipNum(skipNum);
return projection;
}

} // namespace optimizer
} // namespace kuzu
4 changes: 2 additions & 2 deletions src/planner/operator/logical_limit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace planner {
std::string LogicalLimit::getExpressionsForPrinting() const {
std::string result;
if (hasSkipNum()) {
result += "SKIP " + std::to_string(skipNum) + " ";
result += "SKIP " + std::to_string(skipNum) + "\n";
}
if (hasLimitNum()) {
result += "LIMIT " + std::to_string(limitNum);
result += "LIMIT " + std::to_string(limitNum) + "\n";
}
return result;
}
Expand Down
13 changes: 11 additions & 2 deletions src/planner/operator/logical_order_by.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ f_group_pos_set LogicalOrderBy::getGroupsPosToFlatten() {

void LogicalOrderBy::computeFactorizedSchema() {
createEmptySchema();
SinkOperatorUtil::recomputeSchema(
*children[0]->getSchema(), getExpressionsToMaterialize(), *schema);
auto childSchema = children[0]->getSchema();
SinkOperatorUtil::recomputeSchema(*childSchema, childSchema->getExpressionsInScope(), *schema);
}

void LogicalOrderBy::computeFlatSchema() {
Expand All @@ -47,5 +47,14 @@ void LogicalOrderBy::computeFlatSchema() {
}
}

std::string LogicalOrderBy::getExpressionsForPrinting() const {
auto result = binder::ExpressionUtil::toString(expressionsToOrderBy) + "\n";
if (hasLimitNum()) {
result += "SKIP " + std::to_string(skipNum) + "\n";
result += "LIMIT " + std::to_string(limitNum);
}
return result;
}

} // namespace planner
} // namespace kuzu
4 changes: 0 additions & 4 deletions src/planner/plan/plan_projection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ void QueryPlanner::planProjectionBody(
if (plan.isEmpty()) { // e.g. RETURN 1, COUNT(2)
appendDummyScan(plan);
}
// NOTE: As a temporary solution, we rewrite variables in WITH clause as all properties in scope
// during planning stage. The purpose is to avoid reading unnecessary properties for WITH.
// E.g. MATCH (a) WITH a RETURN a.age -> MATCH (a) WITH a.age RETURN a.age
// This rewrite should be removed once we add an optimizer that can remove unnecessary columns.
auto expressionsToProject = projectionBody.getProjectionExpressions();
auto expressionsToAggregate = projectionBody.getAggregateExpressions();
auto expressionsToGroupBy = projectionBody.getGroupByExpressions();
Expand Down
20 changes: 12 additions & 8 deletions src/processor/map/map_order_by.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,23 @@ namespace kuzu {
namespace processor {

std::unique_ptr<PhysicalOperator> PlanMapper::mapOrderBy(LogicalOperator* logicalOperator) {
auto& logicalOrderBy = (LogicalOrderBy&)*logicalOperator;
auto outSchema = logicalOrderBy.getSchema();
auto inSchema = logicalOrderBy.getChild(0)->getSchema();
auto prevOperator = mapOperator(logicalOrderBy.getChild(0).get());
auto paramsString = logicalOrderBy.getExpressionsForPrinting();
auto logicalOrderBy = (LogicalOrderBy*)logicalOperator;
if (logicalOrderBy->isTopK()) {
// TODO(Ziyi): fill
assert(false);
}
auto outSchema = logicalOrderBy->getSchema();
auto inSchema = logicalOrderBy->getChild(0)->getSchema();
auto prevOperator = mapOperator(logicalOrderBy->getChild(0).get());
auto paramsString = logicalOrderBy->getExpressionsForPrinting();
std::vector<std::pair<DataPos, LogicalType>> keysPosAndType;
for (auto& expression : logicalOrderBy.getExpressionsToOrderBy()) {
for (auto& expression : logicalOrderBy->getExpressionsToOrderBy()) {
keysPosAndType.emplace_back(inSchema->getExpressionPos(*expression), expression->dataType);
}
std::vector<std::pair<DataPos, LogicalType>> payloadsPosAndType;
std::vector<bool> isPayloadFlat;
std::vector<DataPos> outVectorPos;
for (auto& expression : logicalOrderBy.getExpressionsToMaterialize()) {
for (auto& expression : inSchema->getExpressionsInScope()) {
auto expressionName = expression->getUniqueName();
payloadsPosAndType.emplace_back(
inSchema->getExpressionPos(*expression), expression->dataType);
Expand All @@ -33,7 +37,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapOrderBy(LogicalOperator* logica
// See comment in planOrderBy in projectionPlanner.cpp
auto mayContainUnflatKey = inSchema->getNumGroups() == 1;
auto orderByDataInfo = OrderByDataInfo(keysPosAndType, payloadsPosAndType, isPayloadFlat,
logicalOrderBy.getIsAscOrders(), mayContainUnflatKey);
logicalOrderBy->getIsAscOrders(), mayContainUnflatKey);
auto orderBySharedState = std::make_shared<SharedFactorizedTablesAndSortedKeyBlocks>();

auto orderBy =
Expand Down
2 changes: 1 addition & 1 deletion src/storage/local_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ void VarListLocalColumn::prepareCommitForChunk(node_group_idx_t nodeGroupIdx) {
}

StructLocalColumn::StructLocalColumn(NodeColumn* column) : LocalColumn{column} {
assert(column->getDataType().getLogicalTypeID() == LogicalTypeID::STRUCT);
assert(column->getDataType().getPhysicalType() == PhysicalTypeID::STRUCT);
auto dataType = column->getDataType();
auto structFields = StructType::getFields(&dataType);
fields.resize(structFields.size());
Expand Down

0 comments on commit 2c45803

Please sign in to comment.