diff --git a/src/common/copier_config/copier_config.cpp b/src/common/copier_config/copier_config.cpp index e746a3d4d62..a138338bb69 100644 --- a/src/common/copier_config/copier_config.cpp +++ b/src/common/copier_config/copier_config.cpp @@ -36,6 +36,8 @@ std::string CopyDescription::getFileTypeName(FileType fileType) { case FileType::NPY: { return "npy"; } + default: + throw InternalException("Unimplemented getFileTypeName()."); } } diff --git a/src/include/common/constants.h b/src/include/common/constants.h index 7438d85bf05..082789e123d 100644 --- a/src/include/common/constants.h +++ b/src/include/common/constants.h @@ -130,10 +130,13 @@ struct LoggerConstants { }; }; -struct EnumeratorKnobs { +struct PlannerKnobs { static constexpr double NON_EQUALITY_PREDICATE_SELECTIVITY = 0.1; static constexpr double EQUALITY_PREDICATE_SELECTIVITY = 0.01; static constexpr uint64_t BUILD_PENALTY = 2; + // Avoid doing probe to build SIP if we have to accumulate a probe side that is much bigger than + // build side. Also avoid doing build to probe SIP if probe side is not much bigger than build. + static constexpr uint64_t ACC_HASH_JOIN_RATIO = 5; }; struct ClientContextConstants { diff --git a/src/include/optimizer/asp_optimizer.h b/src/include/optimizer/acc_hash_join_optimizer.h similarity index 60% rename from src/include/optimizer/asp_optimizer.h rename to src/include/optimizer/acc_hash_join_optimizer.h index 9dfac603c63..ecc446b90e8 100644 --- a/src/include/optimizer/asp_optimizer.h +++ b/src/include/optimizer/acc_hash_join_optimizer.h @@ -4,9 +4,9 @@ namespace kuzu { namespace optimizer { -// This optimizer enables the ASP join algorithm as introduced in paper "Kuzu Graph Database -// Management System". -class ASPOptimizer : public LogicalOperatorVisitor { +// This optimizer enables the Accumulated hash join algorithm as introduced in paper "Kuzu Graph +// Database Management System". +class AccHashJoinOptimizer : public LogicalOperatorVisitor { public: void rewrite(planner::LogicalPlan* plan); @@ -14,6 +14,10 @@ class ASPOptimizer : public LogicalOperatorVisitor { void visitOperator(planner::LogicalOperator* op); void visitHashJoin(planner::LogicalOperator* op) override; + + bool tryLeftToRightHashJoinSIP(planner::LogicalOperator* op); + bool tryRightToLeftHashJoinSIP(planner::LogicalOperator* op); + void visitIntersect(planner::LogicalOperator* op) override; bool isProbeSideQualified(planner::LogicalOperator* probeRoot); @@ -22,7 +26,10 @@ class ASPOptimizer : public LogicalOperatorVisitor { const binder::expression_vector& nodeIDCandidates, const std::vector& buildRoots); - void applyASP( + std::shared_ptr applySemiMasks( + const binder::expression_map>& nodeIDToScanNodes, + std::shared_ptr root); + void applyAccHashJoin( const binder::expression_map>& nodeIDToScanNodes, planner::LogicalOperator* op); }; diff --git a/src/include/optimizer/logical_operator_collector.h b/src/include/optimizer/logical_operator_collector.h index 0c34250e269..29b2e21437e 100644 --- a/src/include/optimizer/logical_operator_collector.h +++ b/src/include/optimizer/logical_operator_collector.h @@ -31,5 +31,10 @@ class LogicalScanNodeCollector : public LogicalOperatorCollector { void visitScanNode(planner::LogicalOperator* op) override { ops.push_back(op); } }; +class LogicalIndexScanNodeCollector : public LogicalOperatorCollector { +protected: + void visitIndexScanNode(planner::LogicalOperator* op) override { ops.push_back(op); } +}; + } // namespace optimizer } // namespace kuzu diff --git a/src/include/optimizer/logical_operator_visitor.h b/src/include/optimizer/logical_operator_visitor.h index 189c09234e0..d71d042a14f 100644 --- a/src/include/optimizer/logical_operator_visitor.h +++ b/src/include/optimizer/logical_operator_visitor.h @@ -27,6 +27,12 @@ class LogicalOperatorVisitor { return op; } + virtual void visitIndexScanNode(planner::LogicalOperator* op) {} + virtual std::shared_ptr visitIndexScanNodeReplace( + std::shared_ptr op) { + return op; + } + virtual void visitExtend(planner::LogicalOperator* op) {} virtual std::shared_ptr visitExtendReplace( std::shared_ptr op) { diff --git a/src/include/planner/join_order_enumerator.h b/src/include/planner/join_order_enumerator.h index 002e0b8f050..3af980a77f8 100644 --- a/src/include/planner/join_order_enumerator.h +++ b/src/include/planner/join_order_enumerator.h @@ -36,14 +36,6 @@ class JoinOrderEnumerator { std::shared_ptr mark, LogicalPlan& probePlan, LogicalPlan& buildPlan) { planJoin(joinNodeIDs, common::JoinType::MARK, mark, probePlan, buildPlan); } - inline void planInnerHashJoin(const std::vector>& joinNodes, - LogicalPlan& probePlan, LogicalPlan& buildPlan) { - binder::expression_vector joinNodeIDs; - for (auto& joinNode : joinNodes) { - joinNodeIDs.push_back(joinNode->getInternalIDProperty()); - } - planJoin(joinNodeIDs, common::JoinType::INNER, nullptr /* mark */, probePlan, buildPlan); - } inline void planInnerHashJoin(const binder::expression_vector& joinNodeIDs, LogicalPlan& probePlan, LogicalPlan& buildPlan) { planJoin(joinNodeIDs, common::JoinType::INNER, nullptr /* mark */, probePlan, buildPlan); diff --git a/src/include/planner/logical_plan/logical_operator/side_way_info_passing.h b/src/include/planner/logical_plan/logical_operator/side_way_info_passing.h index 66f28ae2238..6e3f188a85b 100644 --- a/src/include/planner/logical_plan/logical_operator/side_way_info_passing.h +++ b/src/include/planner/logical_plan/logical_operator/side_way_info_passing.h @@ -8,6 +8,9 @@ namespace planner { enum class SidewaysInfoPassing : uint8_t { NONE = 0, LEFT_TO_RIGHT = 1, + PROHIBIT_LEFT_TO_RIGHT = 2, + RIGHT_TO_LEFT = 3, + PROHIBIT_RIGHT_TO_LEFT = 4, }; } // namespace planner diff --git a/src/include/planner/subplans_table.h b/src/include/planner/subplans_table.h index a5d47dc5f97..3a264b15a86 100644 --- a/src/include/planner/subplans_table.h +++ b/src/include/planner/subplans_table.h @@ -29,6 +29,8 @@ class SubgraphPlans { public: SubgraphPlans(const SubqueryGraph& subqueryGraph); + inline uint64_t getMaxCost() const { return maxCost; } + void addPlan(std::unique_ptr plan); std::vector>& getPlans() { return plans; } @@ -42,6 +44,7 @@ class SubgraphPlans { constexpr static uint32_t MAX_NUM_PLANS = 10; private: + uint64_t maxCost = UINT64_MAX; binder::expression_vector nodeIDsToEncode; std::vector> plans; std::unordered_map, common::vector_idx_t> @@ -77,6 +80,8 @@ class SubPlansTable { public: void resize(uint32_t newSize); + uint64_t getMaxCost(const SubqueryGraph& subqueryGraph) const; + bool containSubgraphPlans(const SubqueryGraph& subqueryGraph) const; std::vector>& getSubgraphPlans(const SubqueryGraph& subqueryGraph); diff --git a/src/include/processor/mapper/plan_mapper.h b/src/include/processor/mapper/plan_mapper.h index 33ea42f535b..39074fc7a94 100644 --- a/src/include/processor/mapper/plan_mapper.h +++ b/src/include/processor/mapper/plan_mapper.h @@ -126,7 +126,7 @@ class PlanMapper { BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema, const binder::expression_vector& keys, const binder::expression_vector& payloads); - void mapASP(PhysicalOperator* probe); + void mapAccHashJoin(PhysicalOperator* probe); public: storage::StorageManager& storageManager; diff --git a/src/optimizer/CMakeLists.txt b/src/optimizer/CMakeLists.txt index 535686f93dc..ce4ffd08b9c 100644 --- a/src/optimizer/CMakeLists.txt +++ b/src/optimizer/CMakeLists.txt @@ -1,6 +1,6 @@ add_library(kuzu_optimizer OBJECT - asp_optimizer.cpp + acc_hash_join_optimizer.cpp factorization_rewriter.cpp filter_push_down_optimizer.cpp logical_operator_collector.cpp diff --git a/src/optimizer/asp_optimizer.cpp b/src/optimizer/acc_hash_join_optimizer.cpp similarity index 55% rename from src/optimizer/asp_optimizer.cpp rename to src/optimizer/acc_hash_join_optimizer.cpp index 420e9f072ea..764e144dfca 100644 --- a/src/optimizer/asp_optimizer.cpp +++ b/src/optimizer/acc_hash_join_optimizer.cpp @@ -1,4 +1,4 @@ -#include "optimizer/asp_optimizer.h" +#include "optimizer/acc_hash_join_optimizer.h" #include "optimizer/logical_operator_collector.h" #include "planner/logical_plan/logical_operator/logical_accumulate.h" @@ -12,11 +12,11 @@ using namespace kuzu::planner; namespace kuzu { namespace optimizer { -void ASPOptimizer::rewrite(planner::LogicalPlan* plan) { +void AccHashJoinOptimizer::rewrite(planner::LogicalPlan* plan) { visitOperator(plan->getLastOperator().get()); } -void ASPOptimizer::visitOperator(planner::LogicalOperator* op) { +void AccHashJoinOptimizer::visitOperator(planner::LogicalOperator* op) { // bottom up traversal for (auto i = 0u; i < op->getNumChildren(); ++i) { visitOperator(op->getChild(i).get()); @@ -24,24 +24,71 @@ void ASPOptimizer::visitOperator(planner::LogicalOperator* op) { visitOperatorSwitch(op); } -void ASPOptimizer::visitHashJoin(planner::LogicalOperator* op) { +void AccHashJoinOptimizer::visitHashJoin(planner::LogicalOperator* op) { + if (tryLeftToRightHashJoinSIP(op)) { + return; + } + tryRightToLeftHashJoinSIP(op); +} + +bool AccHashJoinOptimizer::tryLeftToRightHashJoinSIP(planner::LogicalOperator* op) { auto hashJoin = (LogicalHashJoin*)op; + if (hashJoin->getSIP() == planner::SidewaysInfoPassing::PROHIBIT_LEFT_TO_RIGHT) { + return false; + } if (!isProbeSideQualified(op->getChild(0).get())) { - return; + return false; } std::vector buildRoots; buildRoots.push_back(hashJoin->getChild(1).get()); auto scanNodes = resolveScanNodesToApplySemiMask(hashJoin->getJoinNodeIDs(), buildRoots); if (scanNodes.empty()) { - return; + return false; } - // apply ASP + // apply accumulated hash join hashJoin->setSIP(SidewaysInfoPassing::LEFT_TO_RIGHT); - applyASP(scanNodes, op); + applyAccHashJoin(scanNodes, op); + return true; } -void ASPOptimizer::visitIntersect(planner::LogicalOperator* op) { +static bool subPlanContainsFilter(LogicalOperator* root) { + auto filterCollector = LogicalFilterCollector(); + filterCollector.collect(root); + auto indexScanNodeCollector = LogicalIndexScanNodeCollector(); + indexScanNodeCollector.collect(root); + if (!filterCollector.hasOperators() && !indexScanNodeCollector.hasOperators()) { + return false; + } + return true; +} + +bool AccHashJoinOptimizer::tryRightToLeftHashJoinSIP(planner::LogicalOperator* op) { + auto hashJoin = (LogicalHashJoin*)op; + if (hashJoin->getSIP() == planner::SidewaysInfoPassing::PROHIBIT_RIGHT_TO_LEFT) { + return false; + } + if (hashJoin->getJoinType() != common::JoinType::INNER) { + return false; + } + if (!subPlanContainsFilter(hashJoin->getChild(1).get())) { + return false; + } + std::vector roots; + roots.push_back(hashJoin->getChild(0).get()); + auto scanNodes = resolveScanNodesToApplySemiMask(hashJoin->getJoinNodeIDs(), roots); + if (scanNodes.empty()) { + return false; + } + hashJoin->setSIP(planner::SidewaysInfoPassing::RIGHT_TO_LEFT); + hashJoin->setChild(1, applySemiMasks(scanNodes, op->getChild(1))); + return true; +} + +void AccHashJoinOptimizer::visitIntersect(planner::LogicalOperator* op) { auto intersect = (LogicalIntersect*)op; + if (intersect->getSIP() == planner::SidewaysInfoPassing::PROHIBIT_LEFT_TO_RIGHT) { + return; + } if (!isProbeSideQualified(op->getChild(0).get())) { return; } @@ -54,26 +101,22 @@ void ASPOptimizer::visitIntersect(planner::LogicalOperator* op) { return; } intersect->setSIP(SidewaysInfoPassing::LEFT_TO_RIGHT); - applyASP(scanNodes, op); + applyAccHashJoin(scanNodes, op); } // Probe side is qualified if it is selective. -bool ASPOptimizer::isProbeSideQualified(planner::LogicalOperator* probeRoot) { +bool AccHashJoinOptimizer::isProbeSideQualified(planner::LogicalOperator* probeRoot) { if (probeRoot->getOperatorType() == LogicalOperatorType::ACCUMULATE) { - // No ASP if probe side has already been accumulated. This can be solved. + // No Acc hash join if probe side has already been accumulated. This can be solved. return false; } - auto filterCollector = LogicalFilterCollector(); - filterCollector.collect(probeRoot); - if (!filterCollector.hasOperators()) { - // Probe side is not selective. So we don't apply ASP. - return false; - } - return true; + // Probe side is not selective. So we don't apply acc hash join. + return subPlanContainsFilter(probeRoot); } binder::expression_map> -ASPOptimizer::resolveScanNodesToApplySemiMask(const binder::expression_vector& nodeIDCandidates, +AccHashJoinOptimizer::resolveScanNodesToApplySemiMask( + const binder::expression_vector& nodeIDCandidates, const std::vector& buildRoots) { binder::expression_map> nodeIDToScanOperatorsMap; for (auto& buildRoot : buildRoots) { @@ -100,16 +143,23 @@ ASPOptimizer::resolveScanNodesToApplySemiMask(const binder::expression_vector& n return result; } -void ASPOptimizer::applyASP( +std::shared_ptr AccHashJoinOptimizer::applySemiMasks( const binder::expression_map>& nodeIDToScanNodes, - planner::LogicalOperator* op) { - auto currentChild = op->getChild(0); + std::shared_ptr root) { + auto currentRoot = root; for (auto& [nodeID, scanNodes] : nodeIDToScanNodes) { - auto semiMasker = std::make_shared(nodeID, scanNodes, currentChild); + auto semiMasker = std::make_shared(nodeID, scanNodes, currentRoot); semiMasker->computeFlatSchema(); - currentChild = semiMasker; + currentRoot = semiMasker; } - auto accumulate = std::make_shared(std::move(currentChild)); + return currentRoot; +} + +void AccHashJoinOptimizer::applyAccHashJoin( + const binder::expression_map>& nodeIDToScanNodes, + planner::LogicalOperator* op) { + auto currentRoot = applySemiMasks(nodeIDToScanNodes, op->getChild(0)); + auto accumulate = std::make_shared(std::move(currentRoot)); accumulate->computeFlatSchema(); op->setChild(0, std::move(accumulate)); } diff --git a/src/optimizer/logical_operator_visitor.cpp b/src/optimizer/logical_operator_visitor.cpp index e08eab69a9c..39e648a697f 100644 --- a/src/optimizer/logical_operator_visitor.cpp +++ b/src/optimizer/logical_operator_visitor.cpp @@ -13,6 +13,9 @@ void LogicalOperatorVisitor::visitOperatorSwitch(planner::LogicalOperator* op) { case LogicalOperatorType::SCAN_NODE: { visitScanNode(op); } break; + case LogicalOperatorType::INDEX_SCAN_NODE: { + visitIndexScanNode(op); + } break; case LogicalOperatorType::EXTEND: { visitExtend(op); } break; @@ -84,6 +87,9 @@ std::shared_ptr LogicalOperatorVisitor::visitOperatorR case LogicalOperatorType::SCAN_NODE: { return visitScanNodeReplace(op); } + case LogicalOperatorType::INDEX_SCAN_NODE: { + return visitIndexScanNodeReplace(op); + } case LogicalOperatorType::EXTEND: { return visitExtendReplace(op); } diff --git a/src/optimizer/optimizer.cpp b/src/optimizer/optimizer.cpp index 7433635b88d..3aa46e62e2e 100644 --- a/src/optimizer/optimizer.cpp +++ b/src/optimizer/optimizer.cpp @@ -1,6 +1,6 @@ #include "optimizer/optimizer.h" -#include "optimizer/asp_optimizer.h" +#include "optimizer/acc_hash_join_optimizer.h" #include "optimizer/factorization_rewriter.h" #include "optimizer/filter_push_down_optimizer.h" #include "optimizer/projection_push_down_optimizer.h" @@ -22,8 +22,8 @@ void Optimizer::optimize(planner::LogicalPlan* plan) { filterPushDownOptimizer.rewrite(plan); // ASP optimizer should be applied after optimizers that manipulate hash join. - auto aspOptimizer = ASPOptimizer(); - aspOptimizer.rewrite(plan); + auto accHashJoinOptimizer = AccHashJoinOptimizer(); + accHashJoinOptimizer.rewrite(plan); auto projectionPushDownOptimizer = ProjectionPushDownOptimizer(); projectionPushDownOptimizer.rewrite(plan); diff --git a/src/planner/join_order/cardinality_estimator.cpp b/src/planner/join_order/cardinality_estimator.cpp index 0580e068f38..b644b63941a 100644 --- a/src/planner/join_order/cardinality_estimator.cpp +++ b/src/planner/join_order/cardinality_estimator.cpp @@ -58,7 +58,7 @@ uint64_t CardinalityEstimator::estimateIntersect(const binder::expression_vector const LogicalPlan& probePlan, const std::vector>& buildPlans) { // Formula 1: treat intersect as a Filter on probe side. uint64_t estCardinality1 = - probePlan.estCardinality * common::EnumeratorKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY; + probePlan.estCardinality * common::PlannerKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY; // Formula 2: assume independence on join conditions. auto denominator = 1u; for (auto& joinNodeID : joinNodeIDs) { @@ -93,11 +93,11 @@ uint64_t CardinalityEstimator::estimateFilter( return 1; } else { return atLeastOne( - childPlan.estCardinality * common::EnumeratorKnobs::EQUALITY_PREDICATE_SELECTIVITY); + childPlan.estCardinality * common::PlannerKnobs::EQUALITY_PREDICATE_SELECTIVITY); } } else { return atLeastOne( - childPlan.estCardinality * common::EnumeratorKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY); + childPlan.estCardinality * common::PlannerKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY); } } diff --git a/src/planner/join_order/cost_model.cpp b/src/planner/join_order/cost_model.cpp index f3b1192600c..8cdc9e7551c 100644 --- a/src/planner/join_order/cost_model.cpp +++ b/src/planner/join_order/cost_model.cpp @@ -15,7 +15,7 @@ uint64_t CostModel::computeHashJoinCost(const binder::expression_vector& joinNod cost += probe.getCost(); cost += build.getCost(); cost += probe.getCardinality(); - cost += common::EnumeratorKnobs::BUILD_PENALTY * build.getCardinality(); + cost += common::PlannerKnobs::BUILD_PENALTY * build.getCardinality(); return cost; } diff --git a/src/planner/join_order_enumerator.cpp b/src/planner/join_order_enumerator.cpp index 25032fe4df4..26b5004c01e 100644 --- a/src/planner/join_order_enumerator.cpp +++ b/src/planner/join_order_enumerator.cpp @@ -394,21 +394,29 @@ void JoinOrderEnumerator::planInnerHashJoin(const SubqueryGraph& subgraph, bool flipPlan) { auto newSubgraph = subgraph; newSubgraph.addSubqueryGraph(otherSubgraph); + auto maxCost = context->subPlansTable->getMaxCost(newSubgraph); + binder::expression_vector joinNodeIDs; + for (auto& joinNode : joinNodes) { + joinNodeIDs.push_back(joinNode->getInternalIDProperty()); + } auto predicates = getNewlyMatchedExpressions(std::vector{subgraph, otherSubgraph}, newSubgraph, context->getWhereExpressions()); for (auto& leftPlan : context->getPlans(subgraph)) { for (auto& rightPlan : context->getPlans(otherSubgraph)) { - auto leftPlanProbeCopy = leftPlan->shallowCopy(); - auto rightPlanBuildCopy = rightPlan->shallowCopy(); - auto leftPlanBuildCopy = leftPlan->shallowCopy(); - auto rightPlanProbeCopy = rightPlan->shallowCopy(); - planInnerHashJoin(joinNodes, *leftPlanProbeCopy, *rightPlanBuildCopy); - planFiltersForHashJoin(predicates, *leftPlanProbeCopy); - context->addPlan(newSubgraph, std::move(leftPlanProbeCopy)); + if (CostModel::computeHashJoinCost(joinNodeIDs, *leftPlan, *rightPlan) < maxCost) { + auto leftPlanProbeCopy = leftPlan->shallowCopy(); + auto rightPlanBuildCopy = rightPlan->shallowCopy(); + planInnerHashJoin(joinNodeIDs, *leftPlanProbeCopy, *rightPlanBuildCopy); + planFiltersForHashJoin(predicates, *leftPlanProbeCopy); + context->addPlan(newSubgraph, std::move(leftPlanProbeCopy)); + } // flip build and probe side to get another HashJoin plan - if (flipPlan) { - planInnerHashJoin(joinNodes, *rightPlanProbeCopy, *leftPlanBuildCopy); + if (flipPlan && + CostModel::computeHashJoinCost(joinNodeIDs, *rightPlan, *leftPlan) < maxCost) { + auto leftPlanBuildCopy = leftPlan->shallowCopy(); + auto rightPlanProbeCopy = rightPlan->shallowCopy(); + planInnerHashJoin(joinNodeIDs, *rightPlanProbeCopy, *leftPlanBuildCopy); planFiltersForHashJoin(predicates, *rightPlanProbeCopy); context->addPlan(newSubgraph, std::move(rightPlanProbeCopy)); } @@ -493,6 +501,12 @@ void JoinOrderEnumerator::appendHashJoin(const expression_vector& joinNodeIDs, J queryPlanner->appendFlattens(hashJoin->getGroupsPosToFlattenOnBuildSide(), buildPlan); hashJoin->setChild(1, buildPlan.getLastOperator()); hashJoin->computeFactorizedSchema(); + auto ratio = probePlan.getCardinality() / buildPlan.getCardinality(); + if (ratio > common::PlannerKnobs::ACC_HASH_JOIN_RATIO) { + hashJoin->setSIP(SidewaysInfoPassing::PROHIBIT_LEFT_TO_RIGHT); + } else { + hashJoin->setSIP(SidewaysInfoPassing::PROHIBIT_RIGHT_TO_LEFT); + } // update cost probePlan.setCost(CostModel::computeHashJoinCost(joinNodeIDs, probePlan, buildPlan)); // update cardinality @@ -535,6 +549,10 @@ void JoinOrderEnumerator::appendIntersect(const std::shared_ptr& int queryPlanner->appendFlattens( intersect->getGroupsPosToFlattenOnBuildSide(i), *buildPlans[i]); intersect->setChild(i + 1, buildPlans[i]->getLastOperator()); + if (probePlan.getCardinality() / buildPlans[i]->getCardinality() > + common::PlannerKnobs::ACC_HASH_JOIN_RATIO) { + intersect->setSIP(SidewaysInfoPassing::PROHIBIT_LEFT_TO_RIGHT); + } } intersect->computeFactorizedSchema(); // update cost diff --git a/src/planner/subplans_table.cpp b/src/planner/subplans_table.cpp index 9fe555e92cc..07469b80c27 100644 --- a/src/planner/subplans_table.cpp +++ b/src/planner/subplans_table.cpp @@ -12,6 +12,7 @@ SubgraphPlans::SubgraphPlans(const kuzu::binder::SubqueryGraph& subqueryGraph) { subqueryGraph.queryGraph.getQueryNode(i)->getInternalIDProperty()); } } + maxCost = UINT64_MAX; } void SubgraphPlans::addPlan(std::unique_ptr plan) { @@ -21,10 +22,21 @@ void SubgraphPlans::addPlan(std::unique_ptr plan) { auto planCode = encodePlan(*plan); if (!encodedPlan2PlanIdx.contains(planCode)) { encodedPlan2PlanIdx.insert({planCode, plans.size()}); + if (maxCost == UINT64_MAX || plan->getCost() > maxCost) { // update max cost + maxCost = plan->getCost(); + } plans.push_back(std::move(plan)); } else { auto planIdx = encodedPlan2PlanIdx.at(planCode); if (plan->getCost() < plans[planIdx]->getCost()) { + if (plans[planIdx]->getCost() == maxCost) { // update max cost + maxCost = 0; + for (auto& plan_ : plans) { + if (plan_->getCost() > maxCost) { + maxCost = plan_->getCost(); + } + } + } plans[planIdx] = std::move(plan); } } @@ -67,6 +79,12 @@ void SubPlansTable::resize(uint32_t newSize) { } } +uint64_t SubPlansTable::getMaxCost(const SubqueryGraph& subqueryGraph) const { + return containSubgraphPlans(subqueryGraph) ? + getDPLevel(subqueryGraph)->getSubgraphPlans(subqueryGraph)->getMaxCost() : + UINT64_MAX; +} + bool SubPlansTable::containSubgraphPlans(const SubqueryGraph& subqueryGraph) const { return getDPLevel(subqueryGraph)->contains(subqueryGraph); } @@ -74,7 +92,7 @@ bool SubPlansTable::containSubgraphPlans(const SubqueryGraph& subqueryGraph) con std::vector>& SubPlansTable::getSubgraphPlans( const SubqueryGraph& subqueryGraph) { auto dpLevel = getDPLevel(subqueryGraph); - KU_ASSERT(dpLevel->contains(subqueryGraph)); + assert(dpLevel->contains(subqueryGraph)); return dpLevel->getSubgraphPlans(subqueryGraph)->getPlans(); } diff --git a/src/processor/mapper/CMakeLists.txt b/src/processor/mapper/CMakeLists.txt index c81007d8473..34435bd431a 100644 --- a/src/processor/mapper/CMakeLists.txt +++ b/src/processor/mapper/CMakeLists.txt @@ -3,7 +3,7 @@ add_library(kuzu_processor_mapper expression_mapper.cpp map_accumulate.cpp map_aggregate.cpp - map_asp.cpp + map_acc_hash_join.cpp map_create.cpp map_cross_product.cpp map_ddl.cpp diff --git a/src/processor/mapper/map_asp.cpp b/src/processor/mapper/map_acc_hash_join.cpp similarity index 96% rename from src/processor/mapper/map_asp.cpp rename to src/processor/mapper/map_acc_hash_join.cpp index 2e8feaa90b4..03487d0b6b8 100644 --- a/src/processor/mapper/map_asp.cpp +++ b/src/processor/mapper/map_acc_hash_join.cpp @@ -18,7 +18,7 @@ static FactorizedTableScan* getTableScanForAccHashJoin(PhysicalOperator* probe) return (FactorizedTableScan*)op; } -void PlanMapper::mapASP(kuzu::processor::PhysicalOperator* probe) { +void PlanMapper::mapAccHashJoin(kuzu::processor::PhysicalOperator* probe) { auto tableScan = getTableScanForAccHashJoin(probe); auto resultCollector = tableScan->moveUnaryChild(); probe->addChild(std::move(resultCollector)); diff --git a/src/processor/mapper/map_hash_join.cpp b/src/processor/mapper/map_hash_join.cpp index 92c4f6992b4..8e120b1ae63 100644 --- a/src/processor/mapper/map_hash_join.cpp +++ b/src/processor/mapper/map_hash_join.cpp @@ -40,8 +40,15 @@ std::unique_ptr PlanMapper::mapLogicalHashJoinToPhysical( auto hashJoin = (LogicalHashJoin*)logicalOperator; auto outSchema = hashJoin->getSchema(); auto buildSchema = hashJoin->getChild(1)->getSchema(); - auto buildSidePrevOperator = mapLogicalOperatorToPhysical(hashJoin->getChild(1)); - auto probeSidePrevOperator = mapLogicalOperatorToPhysical(hashJoin->getChild(0)); + std::unique_ptr probeSidePrevOperator; + std::unique_ptr buildSidePrevOperator; + if (hashJoin->getSIP() == planner::SidewaysInfoPassing::RIGHT_TO_LEFT) { + probeSidePrevOperator = mapLogicalOperatorToPhysical(hashJoin->getChild(0)); + buildSidePrevOperator = mapLogicalOperatorToPhysical(hashJoin->getChild(1)); + } else { + buildSidePrevOperator = mapLogicalOperatorToPhysical(hashJoin->getChild(1)); + probeSidePrevOperator = mapLogicalOperatorToPhysical(hashJoin->getChild(0)); + } // Populate build side and probe side std::vector positions auto paramsString = hashJoin->getExpressionsForPrinting(); auto buildDataInfo = generateBuildDataInfo( @@ -72,7 +79,7 @@ std::unique_ptr PlanMapper::mapLogicalHashJoinToPhysical( probeDataInfo, std::move(probeSidePrevOperator), std::move(hashJoinBuild), getOperatorID(), paramsString); if (hashJoin->getSIP() == planner::SidewaysInfoPassing::LEFT_TO_RIGHT) { - mapASP(hashJoinProbe.get()); + mapAccHashJoin(hashJoinProbe.get()); } return hashJoinProbe; } diff --git a/src/processor/mapper/map_intersect.cpp b/src/processor/mapper/map_intersect.cpp index 463e39619b8..488a850f3b8 100644 --- a/src/processor/mapper/map_intersect.cpp +++ b/src/processor/mapper/map_intersect.cpp @@ -52,7 +52,7 @@ std::unique_ptr PlanMapper::mapLogicalIntersectToPhysical( auto intersect = make_unique(outputDataPos, intersectDataInfos, sharedStates, std::move(children), getOperatorID(), logicalIntersect->getExpressionsForPrinting()); if (logicalIntersect->getSIP() == SidewaysInfoPassing::LEFT_TO_RIGHT) { - mapASP(intersect.get()); + mapAccHashJoin(intersect.get()); } return intersect; } diff --git a/test/runner/e2e_read_test.cpp b/test/runner/e2e_read_test.cpp index b699a9375f6..88a053df47c 100644 --- a/test/runner/e2e_read_test.cpp +++ b/test/runner/e2e_read_test.cpp @@ -45,8 +45,8 @@ TEST_F(TinySnbReadTest, Filter) { runTest(TestHelper::appendKuzuRootPath("test/test_files/tinysnb/filter/multi_label.test")); } -TEST_F(TinySnbReadTest, Asp) { - runTest(TestHelper::appendKuzuRootPath("test/test_files/tinysnb/asp/asp.test")); +TEST_F(TinySnbReadTest, AccHJ) { + runTest(TestHelper::appendKuzuRootPath("test/test_files/tinysnb/acc/acc_hj.test")); } TEST_F(TinySnbReadTest, Function) { diff --git a/test/test_files/tinysnb/asp/asp.test b/test/test_files/tinysnb/acc/acc_hj.test similarity index 100% rename from test/test_files/tinysnb/asp/asp.test rename to test/test_files/tinysnb/acc/acc_hj.test