Skip to content

Commit

Permalink
Merge pull request #1761 from kuzudb/issue-1727
Browse files Browse the repository at this point in the history
Issue 1727
  • Loading branch information
acquamarin committed Jul 5, 2023
2 parents 7ecf3b4 + cb6838f commit f8f3a93
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 54 deletions.
38 changes: 10 additions & 28 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,9 @@ Binder::bindGraphPattern(const std::vector<std::unique_ptr<PatternElement>>& gra
return std::make_pair(std::move(queryGraphCollection), std::move(propertyCollection));
}

// For undirected pattern (A)-[R]-(B), we need to match R in both FWD and BWD direction.
// Since computation always starts from one node, we need to rewrite the node table names to be
// the union of both node table names, i.e. (A|B)-[R]-(A|B)
void static rewriteNodeTableNameForUndirectedRel(const PatternElement& patternElement) {
auto leftNode = patternElement.getFirstNodePattern();
for (auto i = 0u; i < patternElement.getNumPatternElementChains(); ++i) {
auto patternElementChain = patternElement.getPatternElementChain(i);
auto rightNode = patternElementChain->getNodePattern();
if (patternElementChain->getRelPattern()->getDirection() == ArrowDirection::BOTH) {
std::vector<std::string> tableNameUnion = {};
auto leftTableNames = leftNode->getTableNames();
auto rightTableNames = rightNode->getTableNames();
if (!leftTableNames.empty() && !rightTableNames.empty()) {
tableNameUnion.insert(
tableNameUnion.end(), leftTableNames.begin(), leftTableNames.end());
tableNameUnion.insert(
tableNameUnion.end(), rightTableNames.begin(), rightTableNames.end());
}
leftNode->setTableNames(tableNameUnion);
rightNode->setTableNames(tableNameUnion);
}
leftNode = rightNode;
}
}

// Grammar ensures pattern element is always connected and thus can be bound as a query graph.
std::unique_ptr<QueryGraph> Binder::bindPatternElement(
const PatternElement& patternElement, PropertyKeyValCollection& collection) {
rewriteNodeTableNameForUndirectedRel(patternElement);
auto queryGraph = std::make_unique<QueryGraph>();
auto leftNode = bindQueryNode(*patternElement.getFirstNodePattern(), *queryGraph, collection);
for (auto i = 0u; i < patternElement.getNumPatternElementChains(); ++i) {
Expand Down Expand Up @@ -208,7 +182,11 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
if (QueryRelTypeUtils::isRecursive(relPattern.getRelType())) {
queryRel = createRecursiveQueryRel(relPattern, tableIDs, srcNode, dstNode, directionType);
} else {
tableIDs = pruneRelTableIDs(catalog, tableIDs, *srcNode, *dstNode);
if (directionType == RelDirectionType::SINGLE) {
// We perform table ID pruning as an optimization. BOTH direction type requires a more
// advanced pruning logic because it does not have notion of src & dst by nature.
tableIDs = pruneRelTableIDs(catalog, tableIDs, *srcNode, *dstNode);
}
if (tableIDs.empty()) {
throw BinderException("Nodes " + srcNode->toString() + " and " + dstNode->toString() +
" are not connected through rel " + parsedName + ".");
Expand Down Expand Up @@ -252,6 +230,9 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
auto tmpNode = createQueryNode(
InternalKeyword::ANONYMOUS, std::vector<common::table_id_t>{recursiveNodeTableIDs.begin(),
recursiveNodeTableIDs.end()});
auto tmpNodeCopy = createQueryNode(
InternalKeyword::ANONYMOUS, std::vector<common::table_id_t>{recursiveNodeTableIDs.begin(),
recursiveNodeTableIDs.end()});
auto prevScope = saveScope();
variableScope->clear();
auto tmpRel = createNonRecursiveQueryRel(
Expand All @@ -277,7 +258,8 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
auto lengthExpression = expressionBinder.createInternalLengthExpression(*queryRel);
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
auto recursiveInfo = std::make_unique<RecursiveInfo>(lowerBound, upperBound, std::move(tmpNode),
std::move(tmpRel), std::move(lengthExpression), std::move(predicates));
std::move(tmpNodeCopy), std::move(tmpRel), std::move(lengthExpression),
std::move(predicates));
queryRel->setRecursiveInfo(std::move(recursiveInfo));
bindQueryRelProperties(*queryRel);
return queryRel;
Expand Down
4 changes: 4 additions & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ void ListVector::copyFromRowData(ValueVector* vector, uint32_t pos, const uint8_
if (NullBuffer::isNull(srcNullBytes, i)) {
resultDataVector->setNull(dstListValuePos, true);
} else {
resultDataVector->setNull(dstListValuePos, false);
resultDataVector->copyFromRowData(dstListValuePos, srcListValues);
}
srcListValues += rowLayoutSize;
Expand Down Expand Up @@ -308,6 +309,7 @@ void ListVector::copyFromVectorData(ValueVector* dstVector, uint8_t* dstData,
if (srcDataVector->isNull(srcListEntry.offset + i)) {
dstDataVector->setNull(dstListEntry.offset + i, true);
} else {
dstDataVector->setNull(dstListEntry.offset + i, false);
dstDataVector->copyFromVectorData(dstListData, srcDataVector, srcListData);
}
srcListData += numBytesPerValue;
Expand All @@ -325,6 +327,7 @@ void StructVector::copyFromRowData(ValueVector* vector, uint32_t pos, const uint
if (NullBuffer::isNull(structNullBytes, i)) {
structField->setNull(pos, true /* isNull */);
} else {
structField->setNull(pos, false /* isNull */);
structField->copyFromRowData(pos, structValues);
}
structValues += LogicalTypeUtils::getRowLayoutSize(structField->dataType);
Expand Down Expand Up @@ -362,6 +365,7 @@ void StructVector::copyFromVectorData(ValueVector* dstVector, const uint8_t* dst
if (srcFieldVector->isNull(srcPos)) {
dstFieldVector->setNull(dstPos, true /* isNull */);
} else {
dstFieldVector->setNull(dstPos, false /* isNull */);
auto srcFieldVectorData =
srcFieldVector->getData() + srcFieldVector->getNumBytesPerValue() * srcPos;
auto dstFieldVectorData =
Expand Down
11 changes: 7 additions & 4 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ struct RecursiveInfo {
uint64_t lowerBound;
uint64_t upperBound;
std::shared_ptr<NodeExpression> node;
// NodeCopy has the same fields as node but a different unique name.
// We use nodeCopy to plan recursive plan because boundNode&nbrNode cannot be the same.
std::shared_ptr<NodeExpression> nodeCopy;
std::shared_ptr<RelExpression> rel;
std::shared_ptr<Expression> lengthExpression;
expression_vector predicates;

RecursiveInfo(uint64_t lowerBound, uint64_t upperBound, std::shared_ptr<NodeExpression> node,
std::shared_ptr<RelExpression> rel, std::shared_ptr<Expression> lengthExpression,
expression_vector predicates)
: lowerBound{lowerBound}, upperBound{upperBound}, node{std::move(node)}, rel{std::move(
rel)},
std::shared_ptr<NodeExpression> nodeCopy, std::shared_ptr<RelExpression> rel,
std::shared_ptr<Expression> lengthExpression, expression_vector predicates)
: lowerBound{lowerBound}, upperBound{upperBound}, node{std::move(node)},
nodeCopy{std::move(nodeCopy)}, rel{std::move(rel)},
lengthExpression{std::move(lengthExpression)}, predicates{std::move(predicates)} {}
};

Expand Down
36 changes: 36 additions & 0 deletions src/include/processor/operator/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,41 @@ class Filter : public PhysicalOperator, public SelVectorOverWriter {
std::shared_ptr<common::DataChunk> dataChunkToSelect;
};

struct NodeLabelFilterInfo {
DataPos nodeVectorPos;
std::unordered_set<common::table_id_t> nodeLabelSet;

NodeLabelFilterInfo(
const DataPos& nodeVectorPos, std::unordered_set<common::table_id_t> nodeLabelSet)
: nodeVectorPos{nodeVectorPos}, nodeLabelSet{std::move(nodeLabelSet)} {}
NodeLabelFilterInfo(const NodeLabelFilterInfo& other)
: nodeVectorPos{other.nodeVectorPos}, nodeLabelSet{other.nodeLabelSet} {}

inline std::unique_ptr<NodeLabelFilterInfo> copy() const {
return std::make_unique<NodeLabelFilterInfo>(*this);
}
};

class NodeLabelFiler : public PhysicalOperator, public SelVectorOverWriter {
public:
NodeLabelFiler(std::unique_ptr<NodeLabelFilterInfo> info,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::FILTER, std::move(child), id, paramsString},
info{std::move(info)} {}

void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) override;

bool getNextTuplesInternal(ExecutionContext* context) override;

inline std::unique_ptr<PhysicalOperator> clone() final {
return std::make_unique<NodeLabelFiler>(
info->copy(), children[0]->clone(), id, paramsString);
}

private:
std::unique_ptr<NodeLabelFilterInfo> info;
common::ValueVector* nodeIDVector;
};

} // namespace processor
} // namespace kuzu
7 changes: 0 additions & 7 deletions src/optimizer/projection_push_down_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ void ProjectionPushDownOptimizer::visitPathPropertyProbe(planner::LogicalOperato
// TODO(Xiyang): we should remove pathPropertyProbe if we don't need to track path
pathPropertyProbe->setChildren(
std::vector<std::shared_ptr<LogicalOperator>>{pathPropertyProbe->getChild(0)});
} else {
// Pre-append projection to rel property build.
expression_vector properties;
for (auto& expression : recursiveInfo->rel->getPropertyExpressions()) {
properties.push_back(expression->copy());
}
preAppendProjection(op, 2, properties);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/planner/join_order/append_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ void JoinOrderEnumerator::appendRecursiveExtend(std::shared_ptr<NodeExpression>
createPathNodePropertyScanPlan(recursiveInfo->node, *pathNodePropertyScanPlan);
// Create path rel property scan plan
auto pathRelPropertyScanPlan = std::make_unique<LogicalPlan>();
createPathRelPropertyScanPlan(
recursiveInfo->node, nbrNode, recursiveInfo->rel, direction, *pathRelPropertyScanPlan);
createPathRelPropertyScanPlan(recursiveInfo->node, recursiveInfo->nodeCopy, recursiveInfo->rel,
direction, *pathRelPropertyScanPlan);
// Create path property probe
auto pathPropertyProbe = std::make_shared<LogicalPathPropertyProbe>(rel, extend,
pathNodePropertyScanPlan->getLastOperator(), pathRelPropertyScanPlan->getLastOperator());
Expand Down
44 changes: 43 additions & 1 deletion src/processor/mapper/map_extend.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "planner/logical_plan/logical_operator/logical_extend.h"
#include "processor/mapper/plan_mapper.h"
#include "processor/operator/filter.h"
#include "processor/operator/scan/generic_scan_rel_tables.h"
#include "processor/operator/scan/scan_rel_table_columns.h"
#include "processor/operator/scan/scan_rel_table_lists.h"
Expand Down Expand Up @@ -81,6 +82,35 @@ static std::unique_ptr<RelTableCollectionScanner> populateRelTableCollectionScan
return std::make_unique<RelTableCollectionScanner>(std::move(scanInfos));
}

static std::unordered_set<common::table_id_t> getNodeIDFilterSet(const NodeExpression& node,
const RelExpression& rel, ExtendDirection extendDirection, const catalog::Catalog& catalog) {
std::unordered_set<common::table_id_t> nodeTableIDSet = node.getTableIDsSet();
std::unordered_set<common::table_id_t> extendedNodeTableIDSet;
for (auto tableID : rel.getTableIDs()) {
auto tableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(tableID);
switch (extendDirection) {
case ExtendDirection::FWD: {
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::FWD));
} break;
case ExtendDirection::BWD: {
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::BWD));
} break;
case ExtendDirection::BOTH: {
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::FWD));
extendedNodeTableIDSet.insert(tableSchema->getNbrTableID(RelDataDirection::BWD));
} break;
default:
throw common::NotImplementedException("getNbrTableIDFilterSet");
}
}
for (auto& tableID : extendedNodeTableIDSet) {
if (!nodeTableIDSet.contains(tableID)) {
return nodeTableIDSet; // Two sets are not equal. A post extend filter is needed.
}
}
return std::unordered_set<common::table_id_t>{};
}

std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalExtendToPhysical(
LogicalOperator* logicalOperator) {
auto extend = (LogicalExtend*)logicalOperator;
Expand Down Expand Up @@ -123,8 +153,20 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalExtendToPhysical(
scanners.insert({boundNodeTableID, std::move(scanner)});
}
}
return std::make_unique<ScanMultiRelTable>(std::move(posInfo), std::move(scanners),
auto scanRel = std::make_unique<ScanMultiRelTable>(std::move(posInfo), std::move(scanners),
std::move(prevOperator), getOperatorID(), extend->getExpressionsForPrinting());
auto nbrNodeIDFilterSet = getNodeIDFilterSet(*nbrNode, *rel, extendDirection, *catalog);
if (!nbrNodeIDFilterSet.empty()) {
auto nbrNodeVectorPos =
DataPos(outSchema->getExpressionPos(*nbrNode->getInternalIDProperty()));
auto filterInfo =
std::make_unique<NodeLabelFilterInfo>(nbrNodeVectorPos, nbrNodeIDFilterSet);
auto filter = std::make_unique<NodeLabelFiler>(
std::move(filterInfo), std::move(scanRel), getOperatorID(), "");
return filter;
} else {
return scanRel;
}
}
}

Expand Down
27 changes: 27 additions & 0 deletions src/processor/operator/filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,32 @@ bool Filter::getNextTuplesInternal(ExecutionContext* context) {
return true;
}

void NodeLabelFiler::initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) {
nodeIDVector = resultSet->getValueVector(info->nodeVectorPos).get();
}

bool NodeLabelFiler::getNextTuplesInternal(ExecutionContext* context) {
common::sel_t numSelectValue;
do {
restoreSelVector(nodeIDVector->state->selVector);
if (!children[0]->getNextTuple(context)) {
return false;
}
saveSelVector(nodeIDVector->state->selVector);
numSelectValue = 0;
auto buffer = nodeIDVector->state->selVector->getSelectedPositionsBuffer();
for (auto i = 0; i < nodeIDVector->state->selVector->selectedSize; ++i) {
auto pos = nodeIDVector->state->selVector->selectedPositions[i];
buffer[numSelectValue] = pos;
numSelectValue +=
info->nodeLabelSet.contains(nodeIDVector->getValue<common::nodeID_t>(pos).tableID);
}
nodeIDVector->state->selVector->resetSelectorToValuePosBuffer();
} while (numSelectValue == 0);
nodeIDVector->state->selVector->selectedSize = numSelectValue;
metrics->numOutputTuple.increase(nodeIDVector->state->selVector->selectedSize);
return true;
}

} // namespace processor
} // namespace kuzu
6 changes: 1 addition & 5 deletions test/test_files/demo_db/demo_db.test
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,11 @@ Noura|50

-LOG Undir2
-STATEMENT MATCH (a:User)-[:LivesIn]-(c:City) RETURN a.name, c.name;
---- 8
---- 4
Adam|Waterloo
Karissa|Waterloo
Zhang|Kitchener
Noura|Guelph
Waterloo|Karissa
Waterloo|Adam
Kitchener|Zhang
Guelph|Noura

-LOG Undir3
-STATEMENT MATCH ()-[]-() RETURN COUNT(*);
Expand Down
6 changes: 1 addition & 5 deletions test/test_files/demo_db/demo_db_parquet.test
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,11 @@ Noura|50

-LOG Undir2
-STATEMENT MATCH (a:User)-[:LivesIn]-(c:City) RETURN a.name, c.name;
---- 8
---- 4
Adam|Waterloo
Karissa|Waterloo
Zhang|Kitchener
Noura|Guelph
Waterloo|Karissa
Waterloo|Adam
Kitchener|Zhang
Guelph|Noura

-LOG Undir3
-STATEMENT MATCH ()-[]-() RETURN COUNT(*);
Expand Down
11 changes: 9 additions & 2 deletions test/test_files/tinysnb/match/undirected.test
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Carol|Dan
-STATEMENT MATCH (a:person)-[:studyAt|:meets]-(b:person:organisation) RETURN COUNT(*);
-ENUMERATE
---- 1
20
17

-LOG UndirMultiLabel3
-STATEMENT MATCH (a:person)-[:meets|:marries|:knows]-(b:person)-[:knows|:meets]-(c:person) WHERE c.fName = "Farooq" AND a.fName <> "Farooq" RETURN a.fName, b.fName;
Expand All @@ -63,11 +63,18 @@ Dan|Carol
-STATEMENT MATCH (a:person)-[]-() RETURN COUNT(*);
-ENUMERATE
---- 1
60
54

-LOG UndirPattern
-STATEMENT MATCH ()-[:studyAt]-(a)-[:meets]-()-[:workAt]-() RETURN a.fName;
-ENUMERATE
---- 2
Farooq
Bob

-STATEMENT MATCH (a:person)-[:workAt]-(b:organisation) RETURN a.ID, b.ID;
-ENUMERATE
---- 3
3|4
5|6
7|6

0 comments on commit f8f3a93

Please sign in to comment.