Skip to content

Commit

Permalink
Fix multi-label recursive join bug
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jul 17, 2023
1 parent f6159a4 commit ace38f7
Show file tree
Hide file tree
Showing 13 changed files with 157 additions and 56 deletions.
1 change: 1 addition & 0 deletions src/include/binder/expression/node_rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class NodeOrRelExpression : public Expression {
}
}
inline bool isMultiLabeled() const { return tableIDs.size() > 1; }
inline uint32_t getNumTableIDs() const { return tableIDs.size(); }
inline std::vector<common::table_id_t> getTableIDs() const { return tableIDs; }
inline std::unordered_set<common::table_id_t> getTableIDsSet() const {
return {tableIDs.begin(), tableIDs.end()};
Expand Down
3 changes: 3 additions & 0 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class JoinOrderEnumerator {
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> recursiveRel,
ExtendDirection direction, LogicalPlan& plan);

void appendNodeLabelFilter(std::shared_ptr<Expression> nodeID,
std::unordered_set<common::table_id_t> tableIDSet, LogicalPlan& plan);

void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan);
void appendHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ enum class LogicalOperatorType : uint8_t {
ACCUMULATE,
ADD_PROPERTY,
AGGREGATE,
STANDALONE_CALL,
IN_QUERY_CALL,
COPY,
CREATE_NODE,
CREATE_REL,
Expand All @@ -29,9 +27,12 @@ enum class LogicalOperatorType : uint8_t {
FLATTEN,
FTABLE_SCAN,
HASH_JOIN,
IN_QUERY_CALL,
INDEX_SCAN_NODE,
INTERSECT,
LIMIT,
MULTIPLICITY_REDUCER,
NODE_LABEL_FILTER,
ORDER_BY,
PATH_PROPERTY_PROBE,
PROJECTION,
Expand All @@ -40,12 +41,12 @@ enum class LogicalOperatorType : uint8_t {
RENAME_PROPERTY,
SCAN_FRONTIER,
SCAN_NODE,
INDEX_SCAN_NODE,
SCAN_NODE_PROPERTY,
SEMI_MASKER,
SET_NODE_PROPERTY,
SET_REL_PROPERTY,
SKIP,
STANDALONE_CALL,
UNION_ALL,
UNWIND,
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "base_logical_operator.h"

namespace kuzu {
namespace planner {

class LogicalNodeLabelFilter : public LogicalOperator {
public:
LogicalNodeLabelFilter(std::shared_ptr<binder::Expression> nodeID,
std::unordered_set<common::table_id_t> tableIDSet, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::NODE_LABEL_FILTER, std::move(child)},
nodeID{std::move(nodeID)}, tableIDSet{std::move(tableIDSet)} {}

inline void computeFactorizedSchema() final { copyChildSchema(0); }
inline void computeFlatSchema() final { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const final { return nodeID->toString(); }

inline std::shared_ptr<binder::Expression> getNodeID() const { return nodeID; }
inline std::unordered_set<common::table_id_t> getTableIDSet() const { return tableIDSet; }

std::unique_ptr<LogicalOperator> copy() final {
return std::make_unique<LogicalNodeLabelFilter>(nodeID, tableIDSet, children[0]->copy());
}

private:
std::shared_ptr<binder::Expression> nodeID;
std::unordered_set<common::table_id_t> tableIDSet;
};

} // namespace planner
} // namespace kuzu
2 changes: 2 additions & 0 deletions src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class PlanMapper {
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalMultiplicityReducerToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalNodeLabelFilterToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalSkipToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalLimitToPhysical(
Expand Down
69 changes: 68 additions & 1 deletion src/planner/join_order/append_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "planner/join_order/cost_model.h"
#include "planner/join_order_enumerator.h"
#include "planner/logical_plan/logical_operator/logical_extend.h"
#include "planner/logical_plan/logical_operator/logical_node_label_filter.h"
#include "planner/logical_plan/logical_operator/logical_recursive_extend.h"
#include "planner/query_planner.h"

Expand All @@ -26,9 +27,59 @@ static bool extendHasAtMostOneNbrGuarantee(RelExpression& rel, NodeExpression& b
rel.getSingleTableID(), relDirection);
}

static std::unordered_set<common::table_id_t> getBoundNodeTableIDSet(
const RelExpression& rel, ExtendDirection extendDirection, const catalog::Catalog& catalog) {
std::unordered_set<common::table_id_t> result;
for (auto tableID : rel.getTableIDs()) {
auto tableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(tableID);
switch (extendDirection) {
case ExtendDirection::FWD: {
result.insert(tableSchema->getBoundTableID(RelDataDirection::FWD));
} break;
case ExtendDirection::BWD: {
result.insert(tableSchema->getBoundTableID(RelDataDirection::BWD));
} break;
case ExtendDirection::BOTH: {
result.insert(tableSchema->getBoundTableID(RelDataDirection::FWD));
result.insert(tableSchema->getBoundTableID(RelDataDirection::BWD));
} break;
default:
throw common::NotImplementedException("getBoundNodeTableIDSet");
}
}
return result;
}

static std::unordered_set<common::table_id_t> getNbrNodeTableIDSet(
const RelExpression& rel, ExtendDirection extendDirection, const catalog::Catalog& catalog) {
std::unordered_set<common::table_id_t> result;
for (auto tableID : rel.getTableIDs()) {
auto tableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(tableID);
switch (extendDirection) {
case ExtendDirection::FWD: {
result.insert(tableSchema->getNbrTableID(RelDataDirection::FWD));
} break;
case ExtendDirection::BWD: {
result.insert(tableSchema->getNbrTableID(RelDataDirection::BWD));
} break;
case ExtendDirection::BOTH: {
result.insert(tableSchema->getNbrTableID(RelDataDirection::FWD));
result.insert(tableSchema->getNbrTableID(RelDataDirection::BWD));
} break;
default:
throw common::NotImplementedException("getNbrNodeTableIDSet");
}
}
return result;
}

void JoinOrderEnumerator::appendNonRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
ExtendDirection direction, const expression_vector& properties, LogicalPlan& plan) {
auto boundNodeTableIDSet = getBoundNodeTableIDSet(*rel, direction, catalog);
if (boundNode->getNumTableIDs() > boundNodeTableIDSet.size()) {
appendNodeLabelFilter(boundNode->getInternalIDProperty(), boundNodeTableIDSet, plan);
}
auto hasAtMostOneNbr = extendHasAtMostOneNbrGuarantee(*rel, *boundNode, direction, catalog);
auto extend = make_shared<LogicalExtend>(
boundNode, nbrNode, rel, direction, properties, hasAtMostOneNbr, plan.getLastOperator());
Expand All @@ -44,6 +95,10 @@ void JoinOrderEnumerator::appendNonRecursiveExtend(std::shared_ptr<NodeExpressio
group->setMultiplier(extensionRate);
}
plan.setLastOperator(std::move(extend));
auto nbrNodeTableIDSet = getNbrNodeTableIDSet(*rel, direction, catalog);
if (nbrNodeTableIDSet.size() > nbrNode->getNumTableIDs()) {
appendNodeLabelFilter(nbrNode->getInternalIDProperty(), nbrNode->getTableIDsSet(), plan);
}
}

void JoinOrderEnumerator::appendRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
Expand All @@ -53,9 +108,13 @@ void JoinOrderEnumerator::appendRecursiveExtend(std::shared_ptr<NodeExpression>
queryPlanner->appendAccumulate(plan);
// Create recursive plan
auto recursivePlan = std::make_unique<LogicalPlan>();
createRecursivePlan(boundNode, recursiveInfo->node, recursiveInfo->rel, direction,
createRecursivePlan(recursiveInfo->node, recursiveInfo->nodeCopy, recursiveInfo->rel, direction,
recursiveInfo->predicates, *recursivePlan);
// Create recursive extend
if (boundNode->getNumTableIDs() > recursiveInfo->node->getNumTableIDs()) {
appendNodeLabelFilter(
boundNode->getInternalIDProperty(), recursiveInfo->node->getTableIDsSet(), plan);
}
auto extend = std::make_shared<LogicalRecursiveExtend>(boundNode, nbrNode, rel, direction,
RecursiveJoinType::TRACK_PATH, plan.getLastOperator(), recursivePlan->getLastOperator());
queryPlanner->appendFlattens(extend->getGroupsPosToFlatten(), plan);
Expand Down Expand Up @@ -139,5 +198,13 @@ void JoinOrderEnumerator::createPathRelPropertyScanPlan(
queryPlanner->projectionPlanner.appendProjection(expressionsToProject, plan);
}

void JoinOrderEnumerator::appendNodeLabelFilter(std::shared_ptr<Expression> nodeID,
std::unordered_set<common::table_id_t> tableIDSet, LogicalPlan& plan) {
auto filter = std::make_shared<LogicalNodeLabelFilter>(
std::move(nodeID), std::move(tableIDSet), plan.getLastOperator());
filter->computeFactorizedSchema();
plan.setLastOperator(std::move(filter));
}

} // namespace planner
} // namespace kuzu
21 changes: 12 additions & 9 deletions src/planner/operator/base_logical_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
case LogicalOperatorType::AGGREGATE: {
return "AGGREGATE";
}
case LogicalOperatorType::STANDALONE_CALL: {
return "STANDALONE_CALL";
}
case LogicalOperatorType::IN_QUERY_CALL: {
return "IN_QUERY_CALL";
}
case LogicalOperatorType::COPY: {
return "COPY";
}
Expand Down Expand Up @@ -76,6 +70,12 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
case LogicalOperatorType::HASH_JOIN: {
return "HASH_JOIN";
}
case LogicalOperatorType::IN_QUERY_CALL: {
return "IN_QUERY_CALL";
}
case LogicalOperatorType::INDEX_SCAN_NODE: {
return "INDEX_SCAN_NODE";
}
case LogicalOperatorType::INTERSECT: {
return "INTERSECT";
}
Expand All @@ -85,6 +85,9 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
case LogicalOperatorType::MULTIPLICITY_REDUCER: {
return "MULTIPLICITY_REDUCER";
}
case LogicalOperatorType::NODE_LABEL_FILTER: {
return "NODE_LABEL_FILTER";
}
case LogicalOperatorType::ORDER_BY: {
return "ORDER_BY";
}
Expand All @@ -109,9 +112,6 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
case LogicalOperatorType::SCAN_NODE: {
return "SCAN_NODE";
}
case LogicalOperatorType::INDEX_SCAN_NODE: {
return "INDEX_SCAN_NODE";
}
case LogicalOperatorType::SCAN_NODE_PROPERTY: {
return "SCAN_NODE_PROPERTY";
}
Expand All @@ -127,6 +127,9 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
case LogicalOperatorType::SKIP: {
return "SKIP";
}
case LogicalOperatorType::STANDALONE_CALL: {
return "STANDALONE_CALL";
}
case LogicalOperatorType::UNION_ALL: {
return "UNION_ALL";
}
Expand Down
1 change: 1 addition & 0 deletions src/processor/mapper/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_library(kuzu_processor_mapper
map_flatten.cpp
map_hash_join.cpp
map_intersect.cpp
map_label_filter.cpp
map_limit.cpp
map_multiplicity_reducer.cpp
map_order_by.cpp
Expand Down
43 changes: 1 addition & 42 deletions src/processor/mapper/map_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,35 +82,6 @@ static std::unique_ptr<RelTableCollectionScanner> populateRelTableCollectionScan
return std::make_unique<RelTableCollectionScanner>(std::move(scanInfos));
}

static std::unordered_set<common::table_id_t> getNodeIDFilterSet(const NodeExpression& node,
const RelExpression& rel, ExtendDirection extendDirection, const catalog::Catalog& catalog) {
std::unordered_set<common::table_id_t> nodeTableIDSet = node.getTableIDsSet();
std::unordered_set<common::table_id_t> extendedNodeTableIDSet;
for (auto tableID : rel.getTableIDs()) {
auto tableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(tableID);
switch (extendDirection) {
case ExtendDirection::FWD: {
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::FWD));
} break;
case ExtendDirection::BWD: {
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::BWD));
} break;
case ExtendDirection::BOTH: {
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::FWD));
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::BWD));
} break;
default:
throw common::NotImplementedException("getNbrTableIDFilterSet");
}
}
for (auto& tableID : extendedNodeTableIDSet) {
if (!nodeTableIDSet.contains(tableID)) {
return nodeTableIDSet; // Two sets are not equal. A post extend filter is needed.
}
}
return std::unordered_set<common::table_id_t>{};
}

std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalExtendToPhysical(
LogicalOperator* logicalOperator) {
auto extend = (LogicalExtend*)logicalOperator;
Expand Down Expand Up @@ -153,20 +124,8 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalExtendToPhysical(
scanners.insert({boundNodeTableID, std::move(scanner)});
}
}
auto scanRel = std::make_unique<ScanMultiRelTable>(std::move(posInfo), std::move(scanners),
return std::make_unique<ScanMultiRelTable>(std::move(posInfo), std::move(scanners),
std::move(prevOperator), getOperatorID(), extend->getExpressionsForPrinting());
auto nbrNodeIDFilterSet = getNodeIDFilterSet(*nbrNode, *rel, extendDirection, *catalog);
if (!nbrNodeIDFilterSet.empty()) {
auto nbrNodeVectorPos =
DataPos(outSchema->getExpressionPos(*nbrNode->getInternalIDProperty()));
auto filterInfo =
std::make_unique<NodeLabelFilterInfo>(nbrNodeVectorPos, nbrNodeIDFilterSet);
auto filter = std::make_unique<NodeLabelFiler>(
std::move(filterInfo), std::move(scanRel), getOperatorID(), "");
return filter;
} else {
return scanRel;
}
}
}

Expand Down
23 changes: 23 additions & 0 deletions src/processor/mapper/map_label_filter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "planner/logical_plan/logical_operator/logical_node_label_filter.h"
#include "processor/mapper/plan_mapper.h"
#include "processor/operator/filter.h"

using namespace kuzu::planner;

namespace kuzu {
namespace processor {

std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalNodeLabelFilterToPhysical(
LogicalOperator* logicalOperator) {
auto logicalLabelFilter = (LogicalNodeLabelFilter*)logicalOperator;
auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0));
auto schema = logicalOperator->getSchema();
auto nbrNodeVectorPos = DataPos(schema->getExpressionPos(*logicalLabelFilter->getNodeID()));
auto filterInfo = std::make_unique<NodeLabelFilterInfo>(
nbrNodeVectorPos, logicalLabelFilter->getTableIDSet());
return std::make_unique<NodeLabelFiler>(std::move(filterInfo), std::move(prevOperator),
getOperatorID(), logicalLabelFilter->getExpressionsForPrinting());
}

} // namespace processor
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/processor/mapper/map_recursive_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalRecursiveExtendToPhysica
auto recursivePlanResultSetDescriptor =
std::make_unique<ResultSetDescriptor>(recursivePlanSchema);
auto recursiveDstNodeIDPos = DataPos(
recursivePlanSchema->getExpressionPos(*recursiveInfo->node->getInternalIDProperty()));
recursivePlanSchema->getExpressionPos(*recursiveInfo->nodeCopy->getInternalIDProperty()));
auto recursiveEdgeIDPos = DataPos(
recursivePlanSchema->getExpressionPos(*recursiveInfo->rel->getInternalIDProperty()));
// Generate RecursiveJoin
Expand Down
3 changes: 3 additions & 0 deletions src/processor/mapper/plan_mapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalOperatorToPhysical(
case LogicalOperatorType::MULTIPLICITY_REDUCER: {
physicalOperator = mapLogicalMultiplicityReducerToPhysical(logicalOperator.get());
} break;
case LogicalOperatorType::NODE_LABEL_FILTER: {
physicalOperator = mapLogicalNodeLabelFilterToPhysical(logicalOperator.get());
} break;
case LogicalOperatorType::SKIP: {
physicalOperator = mapLogicalSkipToPhysical(logicalOperator.get());
} break;
Expand Down
5 changes: 5 additions & 0 deletions test/test_files/tinysnb/var_length_extend/multi_label.test
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,8 @@
-STATEMENT MATCH (a)-[e:studyAt|:knows*2..3]-(a) WHERE a.ID = 1 RETURN COUNT(*)
---- 1
7

-LOG MultiLabelSelfLoopTest2
-STATEMENT MATCH (a:organisation)-[e*2..2]-(a) RETURN COUNT(*)
---- 1
6

0 comments on commit ace38f7

Please sign in to comment.