Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi-label recursive join bug #1823

Merged
merged 1 commit into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/include/binder/expression/node_rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class NodeOrRelExpression : public Expression {
}
}
inline bool isMultiLabeled() const { return tableIDs.size() > 1; }
inline uint32_t getNumTableIDs() const { return tableIDs.size(); }
inline std::vector<common::table_id_t> getTableIDs() const { return tableIDs; }
inline std::unordered_set<common::table_id_t> getTableIDsSet() const {
return {tableIDs.begin(), tableIDs.end()};
Expand Down
3 changes: 3 additions & 0 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class JoinOrderEnumerator {
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> recursiveRel,
ExtendDirection direction, LogicalPlan& plan);

void appendNodeLabelFilter(std::shared_ptr<Expression> nodeID,
std::unordered_set<common::table_id_t> tableIDSet, LogicalPlan& plan);

void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan);
void appendHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ enum class LogicalOperatorType : uint8_t {
ACCUMULATE,
ADD_PROPERTY,
AGGREGATE,
STANDALONE_CALL,
IN_QUERY_CALL,
COPY,
CREATE_NODE,
CREATE_REL,
Expand All @@ -29,9 +27,12 @@ enum class LogicalOperatorType : uint8_t {
FLATTEN,
FTABLE_SCAN,
HASH_JOIN,
IN_QUERY_CALL,
INDEX_SCAN_NODE,
INTERSECT,
LIMIT,
MULTIPLICITY_REDUCER,
NODE_LABEL_FILTER,
ORDER_BY,
PATH_PROPERTY_PROBE,
PROJECTION,
Expand All @@ -40,12 +41,12 @@ enum class LogicalOperatorType : uint8_t {
RENAME_PROPERTY,
SCAN_FRONTIER,
SCAN_NODE,
INDEX_SCAN_NODE,
SCAN_NODE_PROPERTY,
SEMI_MASKER,
SET_NODE_PROPERTY,
SET_REL_PROPERTY,
SKIP,
STANDALONE_CALL,
UNION_ALL,
UNWIND,
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "base_logical_operator.h"

namespace kuzu {
namespace planner {

class LogicalNodeLabelFilter : public LogicalOperator {
public:
LogicalNodeLabelFilter(std::shared_ptr<binder::Expression> nodeID,
std::unordered_set<common::table_id_t> tableIDSet, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::NODE_LABEL_FILTER, std::move(child)},
nodeID{std::move(nodeID)}, tableIDSet{std::move(tableIDSet)} {}

inline void computeFactorizedSchema() final { copyChildSchema(0); }
inline void computeFlatSchema() final { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const final { return nodeID->toString(); }

inline std::shared_ptr<binder::Expression> getNodeID() const { return nodeID; }
inline std::unordered_set<common::table_id_t> getTableIDSet() const { return tableIDSet; }

std::unique_ptr<LogicalOperator> copy() final {
return std::make_unique<LogicalNodeLabelFilter>(nodeID, tableIDSet, children[0]->copy());
}

private:
std::shared_ptr<binder::Expression> nodeID;
std::unordered_set<common::table_id_t> tableIDSet;
};

} // namespace planner
} // namespace kuzu
2 changes: 2 additions & 0 deletions src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class PlanMapper {
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalMultiplicityReducerToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalNodeLabelFilterToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalSkipToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalLimitToPhysical(
Expand Down
69 changes: 68 additions & 1 deletion src/planner/join_order/append_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "planner/join_order/cost_model.h"
#include "planner/join_order_enumerator.h"
#include "planner/logical_plan/logical_operator/logical_extend.h"
#include "planner/logical_plan/logical_operator/logical_node_label_filter.h"
#include "planner/logical_plan/logical_operator/logical_recursive_extend.h"
#include "planner/query_planner.h"

Expand All @@ -26,9 +27,59 @@ static bool extendHasAtMostOneNbrGuarantee(RelExpression& rel, NodeExpression& b
rel.getSingleTableID(), relDirection);
}

static std::unordered_set<common::table_id_t> getBoundNodeTableIDSet(
const RelExpression& rel, ExtendDirection extendDirection, const catalog::Catalog& catalog) {
std::unordered_set<common::table_id_t> result;
for (auto tableID : rel.getTableIDs()) {
auto tableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(tableID);
switch (extendDirection) {
case ExtendDirection::FWD: {
result.insert(tableSchema->getBoundTableID(RelDataDirection::FWD));
} break;
case ExtendDirection::BWD: {
result.insert(tableSchema->getBoundTableID(RelDataDirection::BWD));
} break;
case ExtendDirection::BOTH: {
result.insert(tableSchema->getBoundTableID(RelDataDirection::FWD));
result.insert(tableSchema->getBoundTableID(RelDataDirection::BWD));
} break;
default:
throw common::NotImplementedException("getBoundNodeTableIDSet");
}
}
return result;
}

static std::unordered_set<common::table_id_t> getNbrNodeTableIDSet(
const RelExpression& rel, ExtendDirection extendDirection, const catalog::Catalog& catalog) {
std::unordered_set<common::table_id_t> result;
for (auto tableID : rel.getTableIDs()) {
auto tableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(tableID);
switch (extendDirection) {
case ExtendDirection::FWD: {
result.insert(tableSchema->getNbrTableID(RelDataDirection::FWD));
} break;
case ExtendDirection::BWD: {
result.insert(tableSchema->getNbrTableID(RelDataDirection::BWD));
} break;
case ExtendDirection::BOTH: {
result.insert(tableSchema->getNbrTableID(RelDataDirection::FWD));
result.insert(tableSchema->getNbrTableID(RelDataDirection::BWD));
} break;
default:
throw common::NotImplementedException("getNbrNodeTableIDSet");
}
}
return result;
}

void JoinOrderEnumerator::appendNonRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
ExtendDirection direction, const expression_vector& properties, LogicalPlan& plan) {
auto boundNodeTableIDSet = getBoundNodeTableIDSet(*rel, direction, catalog);
if (boundNode->getNumTableIDs() > boundNodeTableIDSet.size()) {
appendNodeLabelFilter(boundNode->getInternalIDProperty(), boundNodeTableIDSet, plan);
}
auto hasAtMostOneNbr = extendHasAtMostOneNbrGuarantee(*rel, *boundNode, direction, catalog);
auto extend = make_shared<LogicalExtend>(
boundNode, nbrNode, rel, direction, properties, hasAtMostOneNbr, plan.getLastOperator());
Expand All @@ -44,6 +95,10 @@ void JoinOrderEnumerator::appendNonRecursiveExtend(std::shared_ptr<NodeExpressio
group->setMultiplier(extensionRate);
}
plan.setLastOperator(std::move(extend));
auto nbrNodeTableIDSet = getNbrNodeTableIDSet(*rel, direction, catalog);
if (nbrNodeTableIDSet.size() > nbrNode->getNumTableIDs()) {
appendNodeLabelFilter(nbrNode->getInternalIDProperty(), nbrNode->getTableIDsSet(), plan);
}
}

void JoinOrderEnumerator::appendRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
Expand All @@ -53,9 +108,13 @@ void JoinOrderEnumerator::appendRecursiveExtend(std::shared_ptr<NodeExpression>
queryPlanner->appendAccumulate(plan);
// Create recursive plan
auto recursivePlan = std::make_unique<LogicalPlan>();
createRecursivePlan(boundNode, recursiveInfo->node, recursiveInfo->rel, direction,
createRecursivePlan(recursiveInfo->node, recursiveInfo->nodeCopy, recursiveInfo->rel, direction,
recursiveInfo->predicates, *recursivePlan);
// Create recursive extend
if (boundNode->getNumTableIDs() > recursiveInfo->node->getNumTableIDs()) {
appendNodeLabelFilter(
boundNode->getInternalIDProperty(), recursiveInfo->node->getTableIDsSet(), plan);
}
auto extend = std::make_shared<LogicalRecursiveExtend>(boundNode, nbrNode, rel, direction,
RecursiveJoinType::TRACK_PATH, plan.getLastOperator(), recursivePlan->getLastOperator());
queryPlanner->appendFlattens(extend->getGroupsPosToFlatten(), plan);
Expand Down Expand Up @@ -139,5 +198,13 @@ void JoinOrderEnumerator::createPathRelPropertyScanPlan(
queryPlanner->projectionPlanner.appendProjection(expressionsToProject, plan);
}

void JoinOrderEnumerator::appendNodeLabelFilter(std::shared_ptr<Expression> nodeID,
std::unordered_set<common::table_id_t> tableIDSet, LogicalPlan& plan) {
auto filter = std::make_shared<LogicalNodeLabelFilter>(
std::move(nodeID), std::move(tableIDSet), plan.getLastOperator());
filter->computeFactorizedSchema();
plan.setLastOperator(std::move(filter));
}

} // namespace planner
} // namespace kuzu
21 changes: 12 additions & 9 deletions src/planner/operator/base_logical_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
case LogicalOperatorType::AGGREGATE: {
return "AGGREGATE";
}
case LogicalOperatorType::STANDALONE_CALL: {
return "STANDALONE_CALL";
}
case LogicalOperatorType::IN_QUERY_CALL: {
return "IN_QUERY_CALL";
}
case LogicalOperatorType::COPY: {
return "COPY";
}
Expand Down Expand Up @@ -76,6 +70,12 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
case LogicalOperatorType::HASH_JOIN: {
return "HASH_JOIN";
}
case LogicalOperatorType::IN_QUERY_CALL: {
return "IN_QUERY_CALL";
}
case LogicalOperatorType::INDEX_SCAN_NODE: {
return "INDEX_SCAN_NODE";
}
case LogicalOperatorType::INTERSECT: {
return "INTERSECT";
}
Expand All @@ -85,6 +85,9 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
case LogicalOperatorType::MULTIPLICITY_REDUCER: {
return "MULTIPLICITY_REDUCER";
}
case LogicalOperatorType::NODE_LABEL_FILTER: {
return "NODE_LABEL_FILTER";
}
case LogicalOperatorType::ORDER_BY: {
return "ORDER_BY";
}
Expand All @@ -109,9 +112,6 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
case LogicalOperatorType::SCAN_NODE: {
return "SCAN_NODE";
}
case LogicalOperatorType::INDEX_SCAN_NODE: {
return "INDEX_SCAN_NODE";
}
case LogicalOperatorType::SCAN_NODE_PROPERTY: {
return "SCAN_NODE_PROPERTY";
}
Expand All @@ -127,6 +127,9 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
case LogicalOperatorType::SKIP: {
return "SKIP";
}
case LogicalOperatorType::STANDALONE_CALL: {
return "STANDALONE_CALL";
}
case LogicalOperatorType::UNION_ALL: {
return "UNION_ALL";
}
Expand Down
1 change: 1 addition & 0 deletions src/processor/mapper/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_library(kuzu_processor_mapper
map_flatten.cpp
map_hash_join.cpp
map_intersect.cpp
map_label_filter.cpp
map_limit.cpp
map_multiplicity_reducer.cpp
map_order_by.cpp
Expand Down
43 changes: 1 addition & 42 deletions src/processor/mapper/map_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,35 +82,6 @@ static std::unique_ptr<RelTableCollectionScanner> populateRelTableCollectionScan
return std::make_unique<RelTableCollectionScanner>(std::move(scanInfos));
}

static std::unordered_set<common::table_id_t> getNodeIDFilterSet(const NodeExpression& node,
const RelExpression& rel, ExtendDirection extendDirection, const catalog::Catalog& catalog) {
std::unordered_set<common::table_id_t> nodeTableIDSet = node.getTableIDsSet();
std::unordered_set<common::table_id_t> extendedNodeTableIDSet;
for (auto tableID : rel.getTableIDs()) {
auto tableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(tableID);
switch (extendDirection) {
case ExtendDirection::FWD: {
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::FWD));
} break;
case ExtendDirection::BWD: {
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::BWD));
} break;
case ExtendDirection::BOTH: {
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::FWD));
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::BWD));
} break;
default:
throw common::NotImplementedException("getNbrTableIDFilterSet");
}
}
for (auto& tableID : extendedNodeTableIDSet) {
if (!nodeTableIDSet.contains(tableID)) {
return nodeTableIDSet; // Two sets are not equal. A post extend filter is needed.
}
}
return std::unordered_set<common::table_id_t>{};
}

std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalExtendToPhysical(
LogicalOperator* logicalOperator) {
auto extend = (LogicalExtend*)logicalOperator;
Expand Down Expand Up @@ -153,20 +124,8 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalExtendToPhysical(
scanners.insert({boundNodeTableID, std::move(scanner)});
}
}
auto scanRel = std::make_unique<ScanMultiRelTable>(std::move(posInfo), std::move(scanners),
return std::make_unique<ScanMultiRelTable>(std::move(posInfo), std::move(scanners),
std::move(prevOperator), getOperatorID(), extend->getExpressionsForPrinting());
auto nbrNodeIDFilterSet = getNodeIDFilterSet(*nbrNode, *rel, extendDirection, *catalog);
if (!nbrNodeIDFilterSet.empty()) {
auto nbrNodeVectorPos =
DataPos(outSchema->getExpressionPos(*nbrNode->getInternalIDProperty()));
auto filterInfo =
std::make_unique<NodeLabelFilterInfo>(nbrNodeVectorPos, nbrNodeIDFilterSet);
auto filter = std::make_unique<NodeLabelFiler>(
std::move(filterInfo), std::move(scanRel), getOperatorID(), "");
return filter;
} else {
return scanRel;
}
}
}

Expand Down
23 changes: 23 additions & 0 deletions src/processor/mapper/map_label_filter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "planner/logical_plan/logical_operator/logical_node_label_filter.h"
#include "processor/mapper/plan_mapper.h"
#include "processor/operator/filter.h"

using namespace kuzu::planner;

namespace kuzu {
namespace processor {

std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalNodeLabelFilterToPhysical(
LogicalOperator* logicalOperator) {
auto logicalLabelFilter = (LogicalNodeLabelFilter*)logicalOperator;
auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0));
auto schema = logicalOperator->getSchema();
auto nbrNodeVectorPos = DataPos(schema->getExpressionPos(*logicalLabelFilter->getNodeID()));
auto filterInfo = std::make_unique<NodeLabelFilterInfo>(
nbrNodeVectorPos, logicalLabelFilter->getTableIDSet());
return std::make_unique<NodeLabelFiler>(std::move(filterInfo), std::move(prevOperator),
getOperatorID(), logicalLabelFilter->getExpressionsForPrinting());
}

} // namespace processor
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/processor/mapper/map_recursive_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalRecursiveExtendToPhysica
auto recursivePlanResultSetDescriptor =
std::make_unique<ResultSetDescriptor>(recursivePlanSchema);
auto recursiveDstNodeIDPos = DataPos(
recursivePlanSchema->getExpressionPos(*recursiveInfo->node->getInternalIDProperty()));
recursivePlanSchema->getExpressionPos(*recursiveInfo->nodeCopy->getInternalIDProperty()));
auto recursiveEdgeIDPos = DataPos(
recursivePlanSchema->getExpressionPos(*recursiveInfo->rel->getInternalIDProperty()));
// Generate RecursiveJoin
Expand Down
3 changes: 3 additions & 0 deletions src/processor/mapper/plan_mapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalOperatorToPhysical(
case LogicalOperatorType::MULTIPLICITY_REDUCER: {
physicalOperator = mapLogicalMultiplicityReducerToPhysical(logicalOperator.get());
} break;
case LogicalOperatorType::NODE_LABEL_FILTER: {
physicalOperator = mapLogicalNodeLabelFilterToPhysical(logicalOperator.get());
} break;
case LogicalOperatorType::SKIP: {
physicalOperator = mapLogicalSkipToPhysical(logicalOperator.get());
} break;
Expand Down
5 changes: 5 additions & 0 deletions test/test_files/tinysnb/var_length_extend/multi_label.test
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,8 @@
-STATEMENT MATCH (a)-[e:studyAt|:knows*2..3]-(a) WHERE a.ID = 1 RETURN COUNT(*)
---- 1
7

-LOG MultiLabelSelfLoopTest2
-STATEMENT MATCH (a:organisation)-[e*2..2]-(a) RETURN COUNT(*)
---- 1
6
Loading