Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 28, 2023
1 parent 4b4b43d commit 169aff7
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 59 deletions.
3 changes: 2 additions & 1 deletion src/include/planner/join_order/cardinality_estimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class CardinalityEstimator {
private:
const storage::NodesStatisticsAndDeletedIDs& nodesStatistics;
const storage::RelsStatistics& relsStatistics;
// The domain of nodeID is defined as the number of unique value of nodeID, i.e. num nodes.
std::unordered_map<std::string, uint64_t> nodeIDName2domCardinality;
};

} // namespace planner
} // namespace kuzu
} // namespace kuzu
3 changes: 1 addition & 2 deletions src/include/planner/join_order/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ namespace planner {
class CostModel {
public:
static uint64_t computeExtendCost(const LogicalPlan& childPlan);

static uint64_t computeHashJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeMarkJoinCost(const binder::expression_vector& joinNodeIDs,
Expand All @@ -16,4 +15,4 @@ class CostModel {
};

} // namespace planner
} // namespace kuzu
} // namespace kuzu
12 changes: 5 additions & 7 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ class JoinOrderEnumeratorContext;
*/
class JoinOrderEnumerator {
public:
JoinOrderEnumerator(const catalog::Catalog& catalog,
QueryPlanner* queryPlanner)
JoinOrderEnumerator(const catalog::Catalog& catalog, QueryPlanner* queryPlanner)
: catalog{catalog},
queryPlanner{queryPlanner}, context{std::make_unique<JoinOrderEnumeratorContext>()} {};

Expand All @@ -37,9 +36,8 @@ class JoinOrderEnumerator {
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan) {
planJoin(joinNodeIDs, common::JoinType::MARK, mark, probePlan, buildPlan);
}
inline void planInnerHashJoin(
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes, LogicalPlan& probePlan,
LogicalPlan& buildPlan) {
inline void planInnerHashJoin(const std::vector<std::shared_ptr<NodeExpression>>& joinNodes,
LogicalPlan& probePlan, LogicalPlan& buildPlan) {
binder::expression_vector joinNodeIDs;
for (auto& joinNode : joinNodes) {
joinNodeIDs.push_back(joinNode->getInternalIDProperty());
Expand Down Expand Up @@ -101,8 +99,8 @@ class JoinOrderEnumerator {

void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan);
void appendHashJoin(const binder::expression_vector& joinNodeIDs,
common::JoinType joinType, LogicalPlan& probePlan, LogicalPlan& buildPlan);
void appendHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
LogicalPlan& probePlan, LogicalPlan& buildPlan);
void appendMarkJoin(const binder::expression_vector& joinNodeIDs,
const std::shared_ptr<Expression>& mark, LogicalPlan& probePlan, LogicalPlan& buildPlan);
void appendIntersect(const std::shared_ptr<Expression>& intersectNodeID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class LogicalHashJoin : public LogicalOperator {
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(joinNodeIDs) ;
// + std::to_string(cost) + "|" + std::to_string(lc) + "|" + std::to_string(rc)
return binder::ExpressionUtil::toString(joinNodeIDs);
}

binder::expression_vector getExpressionsToMaterialize() const;
Expand All @@ -56,19 +55,10 @@ class LogicalHashJoin : public LogicalOperator {
inline SidewaysInfoPassing getSIP() const { return sip; }

inline std::unique_ptr<LogicalOperator> copy() override {
auto c = make_unique<LogicalHashJoin>(
return make_unique<LogicalHashJoin>(
joinNodeIDs, joinType, mark, children[0]->copy(), children[1]->copy());
c->cost = cost;
c->lc = lc;
c->rc = rc;
return c;
}

public:
uint64_t cost;
uint64_t lc;
uint64_t rc;

private:
// Flat probe side key group in either of the following two cases:
// 1. there are multiple join nodes;
Expand Down
1 change: 0 additions & 1 deletion src/include/planner/subplans_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "binder/query/reading_clause/query_graph.h"
#include "planner/logical_plan/logical_plan.h"

// TODO: remove
using namespace kuzu::binder;

namespace kuzu {
Expand Down
2 changes: 1 addition & 1 deletion src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void Optimizer::optimize(planner::LogicalPlan* plan) {

// ASP optimizer should be applied after optimizers that manipulate hash join.
auto aspOptimizer = ASPOptimizer();
// aspOptimizer.rewrite(plan);
aspOptimizer.rewrite(plan);

auto projectionPushDownOptimizer = ProjectionPushDownOptimizer();
projectionPushDownOptimizer.rewrite(plan);
Expand Down
20 changes: 12 additions & 8 deletions src/planner/join_order/cardinality_estimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,23 @@ uint64_t CardinalityEstimator::estimateScanNode(LogicalOperator* op) {
return atLeastOne(getNumNodes(scanNode->getNode()->getInternalIDPropertyName()));
}

static uint64_t getHashCardinality(const binder::expression_vector& joinNodeIDs, const LogicalPlan& plan) {
auto schema = plan.getSchema();
// Although we may not flatten join key in Build operator computation. We do need to calculate join
// cardinality based on flat join key cardinality.
static uint64_t getJoinKeysFlatCardinality(
const binder::expression_vector& joinNodeIDs, const LogicalPlan& buildPlan) {
auto schema = buildPlan.getSchema();
f_group_pos_set unFlatGroupsPos;
for (auto& joinID : joinNodeIDs) {
auto groupPos = schema->getGroupPos(*joinID);
if (!schema->getGroup(groupPos)->isFlat()) {
unFlatGroupsPos.insert(groupPos);
}
}
auto cost = plan.getCardinality();
auto cardinality = buildPlan.getCardinality();
for (auto groupPos : unFlatGroupsPos) {
cost *= schema->getGroup(groupPos)->getMultiplier();
cardinality *= schema->getGroup(groupPos)->getMultiplier();
}
return cost;
return cardinality;
}

uint64_t CardinalityEstimator::estimateHashJoin(const binder::expression_vector& joinNodeIDs,
Expand All @@ -43,7 +46,8 @@ uint64_t CardinalityEstimator::estimateHashJoin(const binder::expression_vector&
for (auto& joinNodeID : joinNodeIDs) {
denominator *= getNumNodes(joinNodeID->getUniqueName());
}
return atLeastOne(probePlan.estCardinality * getHashCardinality(joinNodeIDs, buildPlan) / denominator);
return atLeastOne(probePlan.estCardinality *
getJoinKeysFlatCardinality(joinNodeIDs, buildPlan) / denominator);
}

uint64_t CardinalityEstimator::estimateCrossProduct(
Expand Down Expand Up @@ -113,8 +117,8 @@ uint64_t CardinalityEstimator::getNumRels(const binder::RelExpression& rel) {

double CardinalityEstimator::getExtensionRate(
const binder::RelExpression& rel, const binder::NodeExpression& boundNode) {
auto numBoundNodes = (double )getNumNodes(boundNode);
auto numRels = (double )getNumRels(rel);
auto numBoundNodes = (double)getNumNodes(boundNode);
auto numRels = (double)getNumRels(rel);
return numRels / numBoundNodes;
}

Expand Down
26 changes: 4 additions & 22 deletions src/planner/join_order/cost_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,13 @@ uint64_t CostModel::computeExtendCost(const LogicalPlan& childPlan) {
return childPlan.estCardinality;
}

// Although we may not flatten join key in Build operator. The hash is computed on every flat tuple,
// so the cost should also be measured on flat cardinality.
static uint64_t getHashCost(const binder::expression_vector& joinNodeIDs, const LogicalPlan& plan) {
auto schema = plan.getSchema();
f_group_pos_set unFlatGroupsPos;
for (auto& joinID : joinNodeIDs) {
auto groupPos = schema->getGroupPos(*joinID);
if (!schema->getGroup(groupPos)->isFlat()) {
unFlatGroupsPos.insert(groupPos);
}
}
auto cost = plan.getCardinality();
// for (auto groupPos : unFlatGroupsPos) {
// cost *= schema->getGroup(groupPos)->getMultiplier();
// }
return cost;
}

uint64_t CostModel::computeHashJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build) {
auto cost = 0u;
cost += probe.getCost();
cost += build.getCost();
cost += probe.getCardinality();
cost += 2 * build.getCardinality();
cost += common::EnumeratorKnobs::BUILD_PENALTY * build.getCardinality();
return cost;
}

Expand All @@ -46,14 +28,14 @@ uint64_t CostModel::computeIntersectCost(const kuzu::planner::LogicalPlan& probe
const std::vector<std::unique_ptr<LogicalPlan>>& buildPlans) {
auto cost = 0u;
cost += probePlan.getCost();
// TODO(Xiyang): think of how to calculate intersect cost such that it will be picked in worst
// case.
cost += probePlan.getCardinality();
for (auto& buildPlan : buildPlans) {
cost += buildPlan->getCost();
// During planning we guarantee keys are already flatten on build side.
// cost += buildPlan->getCardinality();
}
return cost;
}

} // namespace planner
} // namespace kuzu
} // namespace kuzu
5 changes: 1 addition & 4 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "planner/join_order_enumerator.h"

#include "planner/join_order/cost_model.h"
#include "planner/logical_plan/logical_operator/logical_cross_product.h"
#include "planner/logical_plan/logical_operator/logical_extend.h"
#include "planner/logical_plan/logical_operator/logical_ftable_scan.h"
Expand All @@ -8,7 +9,6 @@
#include "planner/logical_plan/logical_operator/logical_scan_node.h"
#include "planner/projection_planner.h"
#include "planner/query_planner.h"
#include "planner/join_order/cost_model.h"

using namespace kuzu::common;

Expand Down Expand Up @@ -495,9 +495,6 @@ void JoinOrderEnumerator::appendHashJoin(const expression_vector& joinNodeIDs, J
hashJoin->computeFactorizedSchema();
// update cost
probePlan.setCost(CostModel::computeHashJoinCost(joinNodeIDs, probePlan, buildPlan));
hashJoin->cost = probePlan.getCost();
hashJoin->lc = probePlan.getCardinality();
hashJoin->rc = buildPlan.getCardinality();
// update cardinality
probePlan.setCardinality(
queryPlanner->cardinalityEstimator->estimateHashJoin(joinNodeIDs, probePlan, buildPlan));
Expand Down
2 changes: 1 addition & 1 deletion test/test_files/tinysnb/asp/asp.test
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-NAME AspBasic
-QUERY MATCH (a:person)-[e1:knows]->(b:person) WHERE a.age > 35 AND b.age > 0 RETURN b.fName
-QUERY MATCH (a:person)-[e1:knows]->(b:person) WHERE a.age > 35 RETURN b.fName
-ENCODED_JOIN HJ(b._id){E(b)S(a)}{S(b)}
---- 3
Alice
Expand Down

0 comments on commit 169aff7

Please sign in to comment.