Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add top K optimizer #1949

Merged
merged 1 commit into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/include/optimizer/top_k_optimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class TopKOptimizer : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

std::shared_ptr<planner::LogicalOperator> visitOperator(
std::shared_ptr<planner::LogicalOperator> op);

private:
std::shared_ptr<planner::LogicalOperator> visitLimitReplace(
std::shared_ptr<planner::LogicalOperator> op) override;
};

} // namespace optimizer
} // namespace kuzu
22 changes: 13 additions & 9 deletions src/include/planner/logical_plan/logical_order_by.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,32 @@ class LogicalOrderBy : public LogicalOperator {

f_group_pos_set getGroupsPosToFlatten();

void computeFactorizedSchema() override;
void computeFlatSchema() override;
void computeFactorizedSchema() final;
void computeFlatSchema() final;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(expressionsToOrderBy);
}
std::string getExpressionsForPrinting() const final;

inline binder::expression_vector getExpressionsToOrderBy() const {
return expressionsToOrderBy;
}
inline std::vector<bool> getIsAscOrders() const { return isAscOrders; }
inline binder::expression_vector getExpressionsToMaterialize() const {
return children[0]->getSchema()->getExpressionsInScope();
}

inline std::unique_ptr<LogicalOperator> copy() override {
inline bool isTopK() const { return hasLimitNum(); }
inline void setSkipNum(uint64_t num) { skipNum = num; }
inline uint64_t getSkipNum() const { return skipNum; }
inline void setLimitNum(uint64_t num) { limitNum = num; }
inline bool hasLimitNum() const { return limitNum != UINT64_MAX; }
inline uint64_t getLimitNum() const { return limitNum; }

inline std::unique_ptr<LogicalOperator> copy() final {
return make_unique<LogicalOrderBy>(expressionsToOrderBy, isAscOrders, children[0]->copy());
}

private:
binder::expression_vector expressionsToOrderBy;
std::vector<bool> isAscOrders;
uint64_t skipNum = UINT64_MAX;
uint64_t limitNum = UINT64_MAX;
};

} // namespace planner
Expand Down
3 changes: 2 additions & 1 deletion src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ add_library(kuzu_optimizer
optimizer.cpp
projection_push_down_optimizer.cpp
remove_factorization_rewriter.cpp
remove_unnecessary_join_optimizer.cpp)
remove_unnecessary_join_optimizer.cpp
top_k_optimizer.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_optimizer>
Expand Down
4 changes: 4 additions & 0 deletions src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "optimizer/projection_push_down_optimizer.h"
#include "optimizer/remove_factorization_rewriter.h"
#include "optimizer/remove_unnecessary_join_optimizer.h"
#include "optimizer/top_k_optimizer.h"

namespace kuzu {
namespace optimizer {
Expand All @@ -33,6 +34,9 @@ void Optimizer::optimize(planner::LogicalPlan* plan) {
auto hashJoinSIPOptimizer = HashJoinSIPOptimizer();
hashJoinSIPOptimizer.rewrite(plan);

// auto topKOptimizer = TopKOptimizer();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

// topKOptimizer.rewrite(plan);

auto factorizationRewriter = FactorizationRewriter();
factorizationRewriter.rewrite(plan);

Expand Down
2 changes: 1 addition & 1 deletion src/optimizer/projection_push_down_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void ProjectionPushDownOptimizer::visitOrderBy(planner::LogicalOperator* op) {
for (auto& expression : orderBy->getExpressionsToOrderBy()) {
collectExpressionsInUse(expression);
}
auto expressionsBeforePruning = orderBy->getExpressionsToMaterialize();
auto expressionsBeforePruning = orderBy->getChild(0)->getSchema()->getExpressionsInScope();
auto expressionsAfterPruning = pruneExpressions(expressionsBeforePruning);
if (expressionsBeforePruning.size() == expressionsAfterPruning.size()) {
return;
Expand Down
52 changes: 52 additions & 0 deletions src/optimizer/top_k_optimizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "optimizer/top_k_optimizer.h"

#include "planner/logical_plan/logical_limit.h"
#include "planner/logical_plan/logical_order_by.h"

using namespace kuzu::planner;

namespace kuzu {
namespace optimizer {

void TopKOptimizer::rewrite(planner::LogicalPlan* plan) {
visitOperator(plan->getLastOperator());
}

std::shared_ptr<LogicalOperator> TopKOptimizer::visitOperator(std::shared_ptr<LogicalOperator> op) {
// bottom-up traversal
for (auto i = 0; i < op->getNumChildren(); ++i) {
op->setChild(i, visitOperator(op->getChild(i)));
}
auto result = visitOperatorReplaceSwitch(op);
result->computeFlatSchema();
return result;
}

// TODO(Xiyang): we should probably remove the projection between ORDER BY and MULTIPLICITY REDUCER
// We search for pattern
// ORDER BY -> PROJECTION -> MULTIPLICITY REDUCER -> LIMIT
// and rewrite as TOP_K
std::shared_ptr<LogicalOperator> TopKOptimizer::visitLimitReplace(
std::shared_ptr<LogicalOperator> op) {
auto limit = (LogicalLimit*)op.get();
if (!limit->hasLimitNum()) {
return op; // only skip no limit. No need to rewrite
}
auto multiplicityReducer = limit->getChild(0);
assert(multiplicityReducer->getOperatorType() == LogicalOperatorType::MULTIPLICITY_REDUCER);
if (multiplicityReducer->getChild(0)->getOperatorType() != LogicalOperatorType::PROJECTION) {
return op;
}
auto projection = multiplicityReducer->getChild(0);
if (projection->getChild(0)->getOperatorType() != LogicalOperatorType::ORDER_BY) {
return op;
}
auto orderBy = std::static_pointer_cast<LogicalOrderBy>(projection->getChild(0));
orderBy->setLimitNum(limit->getLimitNum());
auto skipNum = limit->hasSkipNum() ? limit->getSkipNum() : 0;
orderBy->setSkipNum(skipNum);
return projection;
}

} // namespace optimizer
} // namespace kuzu
4 changes: 2 additions & 2 deletions src/planner/operator/logical_limit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace planner {
std::string LogicalLimit::getExpressionsForPrinting() const {
std::string result;
if (hasSkipNum()) {
result += "SKIP " + std::to_string(skipNum) + " ";
result += "SKIP " + std::to_string(skipNum) + "\n";
}
if (hasLimitNum()) {
result += "LIMIT " + std::to_string(limitNum);
result += "LIMIT " + std::to_string(limitNum) + "\n";
}
return result;
}
Expand Down
13 changes: 11 additions & 2 deletions src/planner/operator/logical_order_by.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@

void LogicalOrderBy::computeFactorizedSchema() {
createEmptySchema();
SinkOperatorUtil::recomputeSchema(
*children[0]->getSchema(), getExpressionsToMaterialize(), *schema);
auto childSchema = children[0]->getSchema();
SinkOperatorUtil::recomputeSchema(*childSchema, childSchema->getExpressionsInScope(), *schema);
}

void LogicalOrderBy::computeFlatSchema() {
Expand All @@ -47,5 +47,14 @@
}
}

std::string LogicalOrderBy::getExpressionsForPrinting() const {
auto result = binder::ExpressionUtil::toString(expressionsToOrderBy) + "\n";
if (hasLimitNum()) {
result += "SKIP " + std::to_string(skipNum) + "\n";
result += "LIMIT " + std::to_string(limitNum);

Check warning on line 54 in src/planner/operator/logical_order_by.cpp

View check run for this annotation

Codecov / codecov/patch

src/planner/operator/logical_order_by.cpp#L53-L54

Added lines #L53 - L54 were not covered by tests
}
return result;
}

} // namespace planner
} // namespace kuzu
4 changes: 0 additions & 4 deletions src/planner/plan/plan_projection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ void QueryPlanner::planProjectionBody(
if (plan.isEmpty()) { // e.g. RETURN 1, COUNT(2)
appendDummyScan(plan);
}
// NOTE: As a temporary solution, we rewrite variables in WITH clause as all properties in scope
// during planning stage. The purpose is to avoid reading unnecessary properties for WITH.
// E.g. MATCH (a) WITH a RETURN a.age -> MATCH (a) WITH a.age RETURN a.age
// This rewrite should be removed once we add an optimizer that can remove unnecessary columns.
auto expressionsToProject = projectionBody.getProjectionExpressions();
auto expressionsToAggregate = projectionBody.getAggregateExpressions();
auto expressionsToGroupBy = projectionBody.getGroupByExpressions();
Expand Down
20 changes: 12 additions & 8 deletions src/processor/map/map_order_by.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,23 @@ namespace kuzu {
namespace processor {

std::unique_ptr<PhysicalOperator> PlanMapper::mapOrderBy(LogicalOperator* logicalOperator) {
auto& logicalOrderBy = (LogicalOrderBy&)*logicalOperator;
auto outSchema = logicalOrderBy.getSchema();
auto inSchema = logicalOrderBy.getChild(0)->getSchema();
auto prevOperator = mapOperator(logicalOrderBy.getChild(0).get());
auto paramsString = logicalOrderBy.getExpressionsForPrinting();
auto logicalOrderBy = (LogicalOrderBy*)logicalOperator;
if (logicalOrderBy->isTopK()) {
// TODO(Ziyi): fill
assert(false);
}
auto outSchema = logicalOrderBy->getSchema();
auto inSchema = logicalOrderBy->getChild(0)->getSchema();
auto prevOperator = mapOperator(logicalOrderBy->getChild(0).get());
auto paramsString = logicalOrderBy->getExpressionsForPrinting();
std::vector<std::pair<DataPos, LogicalType>> keysPosAndType;
for (auto& expression : logicalOrderBy.getExpressionsToOrderBy()) {
for (auto& expression : logicalOrderBy->getExpressionsToOrderBy()) {
keysPosAndType.emplace_back(inSchema->getExpressionPos(*expression), expression->dataType);
}
std::vector<std::pair<DataPos, LogicalType>> payloadsPosAndType;
std::vector<bool> isPayloadFlat;
std::vector<DataPos> outVectorPos;
for (auto& expression : logicalOrderBy.getExpressionsToMaterialize()) {
for (auto& expression : inSchema->getExpressionsInScope()) {
auto expressionName = expression->getUniqueName();
payloadsPosAndType.emplace_back(
inSchema->getExpressionPos(*expression), expression->dataType);
Expand All @@ -33,7 +37,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapOrderBy(LogicalOperator* logica
// See comment in planOrderBy in projectionPlanner.cpp
auto mayContainUnflatKey = inSchema->getNumGroups() == 1;
auto orderByDataInfo = OrderByDataInfo(keysPosAndType, payloadsPosAndType, isPayloadFlat,
logicalOrderBy.getIsAscOrders(), mayContainUnflatKey);
logicalOrderBy->getIsAscOrders(), mayContainUnflatKey);
auto orderBySharedState = std::make_shared<SharedFactorizedTablesAndSortedKeyBlocks>();

auto orderBy =
Expand Down
2 changes: 1 addition & 1 deletion src/storage/local_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ void VarListLocalColumn::prepareCommitForChunk(node_group_idx_t nodeGroupIdx) {
}

StructLocalColumn::StructLocalColumn(NodeColumn* column) : LocalColumn{column} {
assert(column->getDataType().getLogicalTypeID() == LogicalTypeID::STRUCT);
assert(column->getDataType().getPhysicalType() == PhysicalTypeID::STRUCT);
auto dataType = column->getDataType();
auto structFields = StructType::getFields(&dataType);
fields.resize(structFields.size());
Expand Down