Skip to content

Commit

Permalink
Improve cardinality estimation and cost model
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 28, 2023
1 parent 98a3349 commit 4792ec3
Show file tree
Hide file tree
Showing 18 changed files with 367 additions and 129 deletions.
2 changes: 1 addition & 1 deletion src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ struct LoggerConstants {
struct EnumeratorKnobs {
static constexpr double NON_EQUALITY_PREDICATE_SELECTIVITY = 0.1;
static constexpr double EQUALITY_PREDICATE_SELECTIVITY = 0.01;
static constexpr double FLAT_PROBE_PENALTY = 10;
static constexpr uint64_t BUILD_PENALTY = 2;
};

struct ClientContextConstants {
Expand Down
48 changes: 48 additions & 0 deletions src/include/planner/join_order/cardinality_estimator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "binder/query/reading_clause/query_graph.h"
#include "planner/logical_plan/logical_plan.h"
#include "storage/store/nodes_statistics_and_deleted_ids.h"
#include "storage/store/rels_statistics.h"

namespace kuzu {
namespace planner {

class CardinalityEstimator {
public:
CardinalityEstimator(const storage::NodesStatisticsAndDeletedIDs& nodesStatistics,
const storage::RelsStatistics& relsStatistics)
: nodesStatistics{nodesStatistics}, relsStatistics{relsStatistics} {}

void initNodeIDDom(binder::QueryGraph* queryGraph);

uint64_t estimateScanNode(LogicalOperator* op);
uint64_t estimateHashJoin(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probePlan, const LogicalPlan& buildPlan);
uint64_t estimateCrossProduct(const LogicalPlan& probePlan, const LogicalPlan& buildPlan);
uint64_t estimateIntersect(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probePlan, const std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);
uint64_t estimateFlatten(const LogicalPlan& childPlan, f_group_pos groupPosToFlatten);
uint64_t estimateFilter(const LogicalPlan& childPlan, const binder::Expression& predicate);

double getExtensionRate(
const binder::RelExpression& rel, const binder::NodeExpression& boundNode);

private:
inline uint64_t atLeastOne(uint64_t x) { return x == 0 ? 1 : x; }

uint64_t getNodeIDDom(const std::string& nodeIDName) {
assert(nodeIDName2dom.contains(nodeIDName));
return nodeIDName2dom.at(nodeIDName);
}
uint64_t getNumNodes(const binder::NodeExpression& node);

uint64_t getNumRels(const binder::RelExpression& rel);

private:
const storage::NodesStatisticsAndDeletedIDs& nodesStatistics;
const storage::RelsStatistics& relsStatistics;
// The domain of nodeID is defined as the number of unique value of nodeID, i.e. num nodes.
std::unordered_map<std::string, uint64_t> nodeIDName2dom;
};

} // namespace planner
} // namespace kuzu
18 changes: 18 additions & 0 deletions src/include/planner/join_order/cost_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace planner {

class CostModel {
public:
static uint64_t computeExtendCost(const LogicalPlan& childPlan);
static uint64_t computeHashJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeMarkJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeIntersectCost(
const LogicalPlan& probePlan, const std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);
};

} // namespace planner
} // namespace kuzu
36 changes: 14 additions & 22 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ class JoinOrderEnumeratorContext;
*/
class JoinOrderEnumerator {
public:
JoinOrderEnumerator(const catalog::Catalog& catalog,
const storage::NodesStatisticsAndDeletedIDs& nodesStatistics,
const storage::RelsStatistics& relsStatistics, QueryPlanner* queryPlanner)
: catalog{catalog}, nodesStatistics{nodesStatistics}, relsStatistics{relsStatistics},
JoinOrderEnumerator(const catalog::Catalog& catalog, QueryPlanner* queryPlanner)
: catalog{catalog},
queryPlanner{queryPlanner}, context{std::make_unique<JoinOrderEnumeratorContext>()} {};

std::vector<std::unique_ptr<LogicalPlan>> enumerate(
Expand All @@ -34,28 +32,27 @@ class JoinOrderEnumerator {
binder::expression_vector nodeIDsToScanFromInnerAndOuter);
void exitSubquery(std::unique_ptr<JoinOrderEnumeratorContext> prevContext);

static inline void planMarkJoin(const binder::expression_vector& joinNodeIDs,
inline void planMarkJoin(const binder::expression_vector& joinNodeIDs,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan) {
planJoin(joinNodeIDs, common::JoinType::MARK, mark, probePlan, buildPlan);
}
static inline void planInnerHashJoin(
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes, LogicalPlan& probePlan,
LogicalPlan& buildPlan) {
inline void planInnerHashJoin(const std::vector<std::shared_ptr<NodeExpression>>& joinNodes,
LogicalPlan& probePlan, LogicalPlan& buildPlan) {
binder::expression_vector joinNodeIDs;
for (auto& joinNode : joinNodes) {
joinNodeIDs.push_back(joinNode->getInternalIDProperty());
}
planJoin(joinNodeIDs, common::JoinType::INNER, nullptr /* mark */, probePlan, buildPlan);
}
static inline void planInnerHashJoin(const binder::expression_vector& joinNodeIDs,
inline void planInnerHashJoin(const binder::expression_vector& joinNodeIDs,
LogicalPlan& probePlan, LogicalPlan& buildPlan) {
planJoin(joinNodeIDs, common::JoinType::INNER, nullptr /* mark */, probePlan, buildPlan);
}
static inline void planLeftHashJoin(const binder::expression_vector& joinNodeIDs,
inline void planLeftHashJoin(const binder::expression_vector& joinNodeIDs,
LogicalPlan& probePlan, LogicalPlan& buildPlan) {
planJoin(joinNodeIDs, common::JoinType::LEFT, nullptr /* mark */, probePlan, buildPlan);
}
static inline void planCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan) {
inline void planCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan) {
appendCrossProduct(probePlan, buildPlan);
}

Expand Down Expand Up @@ -100,19 +97,16 @@ class JoinOrderEnumerator {
common::RelDirection direction, const binder::expression_vector& properties,
LogicalPlan& plan);

static void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan);
static void appendHashJoin(const binder::expression_vector& joinNodeIDs,
common::JoinType joinType, LogicalPlan& probePlan, LogicalPlan& buildPlan);
static void appendMarkJoin(const binder::expression_vector& joinNodeIDs,
void appendHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
LogicalPlan& probePlan, LogicalPlan& buildPlan);
void appendMarkJoin(const binder::expression_vector& joinNodeIDs,
const std::shared_ptr<Expression>& mark, LogicalPlan& probePlan, LogicalPlan& buildPlan);
static void appendIntersect(const std::shared_ptr<Expression>& intersectNodeID,
void appendIntersect(const std::shared_ptr<Expression>& intersectNodeID,
binder::expression_vector& boundNodeIDs, LogicalPlan& probePlan,
std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);
static void appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan);

uint64_t getExtensionRate(
const RelExpression& rel, const NodeExpression& boundNode, common::RelDirection direction);
void appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan);

static binder::expression_vector getNewlyMatchedExpressions(const SubqueryGraph& prevSubgraph,
const SubqueryGraph& newSubgraph, const binder::expression_vector& expressions) {
Expand All @@ -127,8 +121,6 @@ class JoinOrderEnumerator {

private:
const catalog::Catalog& catalog;
const storage::NodesStatisticsAndDeletedIDs& nodesStatistics;
const storage::RelsStatistics& relsStatistics;
QueryPlanner* queryPlanner;
std::unique_ptr<JoinOrderEnumeratorContext> context;
};
Expand Down
7 changes: 4 additions & 3 deletions src/include/planner/logical_plan/logical_operator/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ constexpr f_group_pos INVALID_F_GROUP_POS = UINT32_MAX;

class FactorizationGroup {
friend class Schema;
friend class CardinalityEstimator;

public:
FactorizationGroup() : flat{false}, singleState{false}, cardinalityMultiplier{1} {}
Expand All @@ -33,8 +34,8 @@ class FactorizationGroup {
}
inline bool isSingleState() const { return singleState; }

inline void setMultiplier(uint64_t multiplier) { cardinalityMultiplier = multiplier; }
inline uint64_t getMultiplier() const { return cardinalityMultiplier; }
inline void setMultiplier(double multiplier) { cardinalityMultiplier = multiplier; }
inline double getMultiplier() const { return cardinalityMultiplier; }

inline void insertExpression(const std::shared_ptr<binder::Expression>& expression) {
assert(!expressionNameToPos.contains(expression->getUniqueName()));
Expand All @@ -50,7 +51,7 @@ class FactorizationGroup {
private:
bool flat;
bool singleState;
uint64_t cardinalityMultiplier;
double cardinalityMultiplier;
binder::expression_vector expressions;
std::unordered_map<std::string, uint32_t> expressionNameToPos;
};
Expand Down
9 changes: 4 additions & 5 deletions src/include/planner/logical_plan/logical_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
namespace kuzu {
namespace planner {

class LogicalPlan;

class LogicalPlan {
friend class CardinalityEstimator;
friend class CostModel;

public:
LogicalPlan() : estCardinality{1}, cost{0} {}

Expand All @@ -20,12 +21,10 @@ class LogicalPlan {
inline std::shared_ptr<LogicalOperator> getLastOperator() const { return lastOperator; }
inline Schema* getSchema() const { return lastOperator->getSchema(); }

inline void multiplyCardinality(uint64_t factor) { estCardinality *= factor; }
inline void setCardinality(uint64_t cardinality) { estCardinality = cardinality; }
inline uint64_t getCardinality() const { return estCardinality; }

inline void multiplyCost(uint64_t factor) { cost *= factor; }
inline void increaseCost(uint64_t costToIncrease) { cost += costToIncrease; }
inline void setCost(uint64_t cost_) { cost = cost_; }
inline uint64_t getCost() const { return cost; }

inline std::string toString() const { return lastOperator->toString(); }
Expand Down
20 changes: 11 additions & 9 deletions src/include/planner/query_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "binder/bound_statement.h"
#include "binder/expression/existential_subquery_expression.h"
#include "join_order_enumerator.h"
#include "planner/join_order/cardinality_estimator.h"
#include "projection_planner.h"
#include "update_planner.h"

Expand All @@ -13,14 +14,14 @@ class QueryPlanner {
friend class JoinOrderEnumerator;
friend class ProjectionPlanner;
friend class UpdatePlanner;
friend class ASPOptimizer;

public:
explicit QueryPlanner(const catalog::Catalog& catalog,
const storage::NodesStatisticsAndDeletedIDs& nodesStatistics,
const storage::RelsStatistics& relsStatistics)
: catalog{catalog}, joinOrderEnumerator{catalog, nodesStatistics, relsStatistics, this},
projectionPlanner{this}, updatePlanner{} {}
: catalog{catalog}, cardinalityEstimator{std::make_unique<CardinalityEstimator>(
nodesStatistics, relsStatistics)},
joinOrderEnumerator{catalog, this}, projectionPlanner{this}, updatePlanner{this} {}

std::vector<std::unique_ptr<LogicalPlan>> getAllPlans(const BoundStatement& boundStatement);

Expand Down Expand Up @@ -52,16 +53,16 @@ class QueryPlanner {
void planExistsSubquery(std::shared_ptr<Expression>& subquery, LogicalPlan& outerPlan);
void planSubqueryIfNecessary(const std::shared_ptr<Expression>& expression, LogicalPlan& plan);

static void appendAccumulate(LogicalPlan& plan);
void appendAccumulate(LogicalPlan& plan);

static void appendExpressionsScan(const expression_vector& expressions, LogicalPlan& plan);
void appendExpressionsScan(const expression_vector& expressions, LogicalPlan& plan);

static void appendDistinct(const expression_vector& expressionsToDistinct, LogicalPlan& plan);
void appendDistinct(const expression_vector& expressionsToDistinct, LogicalPlan& plan);

static void appendUnwind(BoundUnwindClause& boundUnwindClause, LogicalPlan& plan);
void appendUnwind(BoundUnwindClause& boundUnwindClause, LogicalPlan& plan);

static void appendFlattens(const f_group_pos_set& groupsPos, LogicalPlan& plan);
static void appendFlattenIfNecessary(f_group_pos groupPos, LogicalPlan& plan);
void appendFlattens(const f_group_pos_set& groupsPos, LogicalPlan& plan);
void appendFlattenIfNecessary(f_group_pos groupPos, LogicalPlan& plan);

void appendFilters(const binder::expression_vector& predicates, LogicalPlan& plan);
void appendFilter(const std::shared_ptr<Expression>& predicate, LogicalPlan& plan);
Expand All @@ -82,6 +83,7 @@ class QueryPlanner {

private:
const catalog::Catalog& catalog;
std::unique_ptr<CardinalityEstimator> cardinalityEstimator;
expression_vector propertiesToScan;
JoinOrderEnumerator joinOrderEnumerator;
ProjectionPlanner projectionPlanner;
Expand Down
7 changes: 6 additions & 1 deletion src/include/planner/update_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
namespace kuzu {
namespace planner {

class QueryPlanner;

class UpdatePlanner {
public:
UpdatePlanner() = default;
UpdatePlanner(QueryPlanner* queryPlanner) : queryPlanner{queryPlanner} {};

inline void planUpdatingClause(binder::BoundUpdatingClause& updatingClause,
std::vector<std::unique_ptr<LogicalPlan>>& plans) {
Expand Down Expand Up @@ -43,6 +45,9 @@ class UpdatePlanner {
LogicalPlan& plan);
void appendDeleteRel(
const std::vector<std::shared_ptr<binder::RelExpression>>& deleteRels, LogicalPlan& plan);

private:
QueryPlanner* queryPlanner;
};

} // namespace planner
Expand Down
1 change: 1 addition & 0 deletions src/planner/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(join_order)
add_subdirectory(operator)

add_library(kuzu_planner
Expand Down
8 changes: 8 additions & 0 deletions src/planner/join_order/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_library(kuzu_planner_join_order
OBJECT
cardinality_estimator.cpp
cost_model.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_planner_join_order>
PARENT_SCOPE)
Loading

0 comments on commit 4792ec3

Please sign in to comment.