Skip to content

Commit

Permalink
Merge pull request #1894 from kuzudb/merge-operator
Browse files Browse the repository at this point in the history
Merge operator
  • Loading branch information
andyfengHKU committed Aug 6, 2023
2 parents db8866f + 6b008f0 commit 68e86ef
Show file tree
Hide file tree
Showing 51 changed files with 981 additions and 166 deletions.
9 changes: 7 additions & 2 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,14 @@ std::unique_ptr<BoundUpdatingClause> Binder::bindMergeClause(
// bindGraphPattern will update scope.
auto [queryGraphCollection, propertyCollection] =
bindGraphPattern(mergeClause.getPatternElementsRef());
std::shared_ptr<Expression> predicate = nullptr;
for (auto& [key, val] : propertyCollection->getKeyVals()) {
predicate = expressionBinder.combineConjunctiveExpressions(
expressionBinder.createEqualityComparisonExpression(key, val), predicate);
}
auto createInfos = bindCreateInfos(*queryGraphCollection, *propertyCollection, nodeRelScope);
auto boundMergeClause =
std::make_unique<BoundMergeClause>(std::move(queryGraphCollection), std::move(createInfos));
auto boundMergeClause = std::make_unique<BoundMergeClause>(
std::move(queryGraphCollection), std::move(predicate), std::move(createInfos));
if (mergeClause.hasOnMatchSetItems()) {
for (auto i = 0u; i < mergeClause.getNumOnMatchSetItems(); ++i) {
auto setPropertyInfo = bindSetPropertyInfo(mergeClause.getOnMatchSetItem(i));
Expand Down
15 changes: 9 additions & 6 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ void BoundStatementVisitor::visitExplain(const BoundStatement& statement) {

void BoundStatementVisitor::visitReadingClause(const BoundReadingClause& readingClause) {
switch (readingClause.getClauseType()) {
case common::ClauseType::MATCH: {
case ClauseType::MATCH: {
visitMatch(readingClause);
} break;
case common::ClauseType::UNWIND: {
case ClauseType::UNWIND: {
visitUnwind(readingClause);
} break;
case common::ClauseType::InQueryCall: {
case ClauseType::InQueryCall: {
visitInQueryCall(readingClause);
} break;
default:
Expand All @@ -100,15 +100,18 @@ void BoundStatementVisitor::visitReadingClause(const BoundReadingClause& reading

void BoundStatementVisitor::visitUpdatingClause(const BoundUpdatingClause& updatingClause) {
switch (updatingClause.getClauseType()) {
case common::ClauseType::SET: {
case ClauseType::SET: {
visitSet(updatingClause);
} break;
case common::ClauseType::DELETE_: {
case ClauseType::DELETE_: {
visitDelete(updatingClause);
} break;
case common::ClauseType::CREATE: {
case ClauseType::CREATE: {
visitCreate(updatingClause);
} break;
case ClauseType::MERGE: {
visitMerge(updatingClause);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visitUpdatingClause");
}
Expand Down
64 changes: 64 additions & 0 deletions src/binder/query/bound_merge_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace binder {
BoundMergeClause::BoundMergeClause(const BoundMergeClause& other)
: BoundUpdatingClause{common::ClauseType::MERGE} {
queryGraphCollection = other.queryGraphCollection->copy();
predicate = other.predicate;
for (auto& createInfo : other.createInfos) {
createInfos.push_back(createInfo->copy());
}
Expand All @@ -17,5 +18,68 @@ BoundMergeClause::BoundMergeClause(const BoundMergeClause& other)
}
}

bool BoundMergeClause::hasCreateInfo(
const std::function<bool(const BoundCreateInfo&)>& check) const {
for (auto& info : createInfos) {
if (check(*info)) {
return true;
}
}
return false;
}

std::vector<BoundCreateInfo*> BoundMergeClause::getCreateInfos(
const std::function<bool(const BoundCreateInfo&)>& check) const {
std::vector<BoundCreateInfo*> result;
for (auto& info : createInfos) {
if (check(*info)) {
result.push_back(info.get());
}
}
return result;
}

bool BoundMergeClause::hasOnMatchSetInfo(
const std::function<bool(const BoundSetPropertyInfo&)>& check) const {
for (auto& info : onMatchSetPropertyInfos) {
if (check(*info)) {
return true;
}
}
return false;
}

std::vector<BoundSetPropertyInfo*> BoundMergeClause::getOnMatchSetInfos(
const std::function<bool(const BoundSetPropertyInfo&)>& check) const {
std::vector<BoundSetPropertyInfo*> result;
for (auto& info : onMatchSetPropertyInfos) {
if (check(*info)) {
result.push_back(info.get());
}
}
return result;
}

bool BoundMergeClause::hasOnCreateSetInfo(
const std::function<bool(const BoundSetPropertyInfo&)>& check) const {
for (auto& info : onCreateSetPropertyInfos) {
if (check(*info)) {
return true;
}
}
return false;
}

std::vector<BoundSetPropertyInfo*> BoundMergeClause::getOnCreateSetInfos(
const std::function<bool(const BoundSetPropertyInfo&)>& check) const {
std::vector<BoundSetPropertyInfo*> result;
for (auto& info : onCreateSetPropertyInfos) {
if (check(*info)) {
result.push_back(info.get());
}
}
return result;
}

} // namespace binder
} // namespace kuzu
28 changes: 27 additions & 1 deletion src/binder/visitor/property_collector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
#include "binder/query/reading_clause/bound_unwind_clause.h"
#include "binder/query/updating_clause/bound_create_clause.h"
#include "binder/query/updating_clause/bound_delete_clause.h"
#include "binder/query/updating_clause/bound_merge_clause.h"
#include "binder/query/updating_clause/bound_set_clause.h"

using namespace kuzu::common;

namespace kuzu {
namespace binder {

Expand All @@ -21,7 +24,7 @@ expression_vector PropertyCollector::getProperties() {
void PropertyCollector::visitMatch(const BoundReadingClause& readingClause) {
auto& matchClause = (BoundMatchClause&)readingClause;
for (auto& rel : matchClause.getQueryGraphCollection()->getQueryRels()) {
if (rel->getRelType() == common::QueryRelType::NON_RECURSIVE) {
if (rel->getRelType() == QueryRelType::NON_RECURSIVE) {
properties.insert(rel->getInternalIDProperty());
}
}
Expand Down Expand Up @@ -63,6 +66,29 @@ void PropertyCollector::visitCreate(const BoundUpdatingClause& updatingClause) {
}
}

void PropertyCollector::visitMerge(const BoundUpdatingClause& updatingClause) {
auto& boundMergeClause = (BoundMergeClause&)updatingClause;
for (auto& rel : boundMergeClause.getQueryGraphCollection()->getQueryRels()) {
if (rel->getRelType() == QueryRelType::NON_RECURSIVE) {
properties.insert(rel->getInternalIDProperty());
}
}
if (boundMergeClause.hasPredicate()) {
collectPropertyExpressions(boundMergeClause.getPredicate());
}
for (auto& info : boundMergeClause.getCreateInfosRef()) {
for (auto& setItem : info->setItems) {
collectPropertyExpressions(setItem.second);
}
}
for (auto& info : boundMergeClause.getOnMatchSetInfosRef()) {
collectPropertyExpressions(info->setItem.second);
}
for (auto& info : boundMergeClause.getOnCreateSetInfosRef()) {
collectPropertyExpressions(info->setItem.second);
}
}

void PropertyCollector::visitProjectionBody(const BoundProjectionBody& projectionBody) {
for (auto& expression : projectionBody.getProjectionExpressions()) {
collectPropertyExpressions(expression);
Expand Down
1 change: 1 addition & 0 deletions src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class BoundStatementVisitor {
virtual void visitSet(const BoundUpdatingClause& updatingClause) {}
virtual void visitDelete(const BoundUpdatingClause& updatingClause) {}
virtual void visitCreate(const BoundUpdatingClause& updatingClause) {}
virtual void visitMerge(const BoundUpdatingClause& updatingClause) {}

virtual void visitProjectionBody(const BoundProjectionBody& projectionBody) {}
virtual void visitProjectionBodyPredicate(const std::shared_ptr<Expression>& predicate) {}
Expand Down
103 changes: 100 additions & 3 deletions src/include/binder/query/updating_clause/bound_merge_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,93 @@ namespace binder {
class BoundMergeClause : public BoundUpdatingClause {
public:
BoundMergeClause(std::unique_ptr<QueryGraphCollection> queryGraphCollection,
std::shared_ptr<Expression> predicate,
std::vector<std::unique_ptr<BoundCreateInfo>> createInfos)
: BoundUpdatingClause{common::ClauseType::MERGE},
queryGraphCollection{std::move(queryGraphCollection)}, createInfos{
std::move(createInfos)} {}
: BoundUpdatingClause{common::ClauseType::MERGE}, queryGraphCollection{std::move(
queryGraphCollection)},
predicate{std::move(predicate)}, createInfos{std::move(createInfos)} {}
BoundMergeClause(const BoundMergeClause& other);

inline QueryGraphCollection* getQueryGraphCollection() const {
return queryGraphCollection.get();
}
inline bool hasPredicate() const { return predicate != nullptr; }
inline std::shared_ptr<Expression> getPredicate() const { return predicate; }

inline const std::vector<std::unique_ptr<BoundCreateInfo>>& getCreateInfosRef() const {
return createInfos;
}
inline const std::vector<std::unique_ptr<BoundSetPropertyInfo>>& getOnMatchSetInfosRef() const {
return onMatchSetPropertyInfos;
}
inline const std::vector<std::unique_ptr<BoundSetPropertyInfo>>&
getOnCreateSetInfosRef() const {
return onCreateSetPropertyInfos;
}

inline bool hasCreateNodeInfo() const {
return hasCreateInfo([](const BoundCreateInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline std::vector<BoundCreateInfo*> getCreateNodeInfos() const {
return getCreateInfos([](const BoundCreateInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline bool hasCreateRelInfo() const {
return hasCreateInfo([](const BoundCreateInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}
inline std::vector<BoundCreateInfo*> getCreateRelInfos() const {
return getCreateInfos([](const BoundCreateInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}

inline bool hasOnMatchSetNodeInfo() const {
return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline std::vector<BoundSetPropertyInfo*> getOnMatchSetNodeInfos() const {
return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline bool hasOnMatchSetRelInfo() const {
return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}
inline std::vector<BoundSetPropertyInfo*> getOnMatchSetRelInfos() const {
return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}

inline bool hasOnCreateSetNodeInfo() const {
return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline std::vector<BoundSetPropertyInfo*> getOnCreateSetNodeInfos() const {
return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline bool hasOnCreateSetRelInfo() const {
return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}
inline std::vector<BoundSetPropertyInfo*> getOnCreateSetRelInfos() const {
return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}

inline void addOnMatchSetPropertyInfo(std::unique_ptr<BoundSetPropertyInfo> setPropertyInfo) {
onMatchSetPropertyInfos.push_back(std::move(setPropertyInfo));
}
Expand All @@ -28,9 +109,25 @@ class BoundMergeClause : public BoundUpdatingClause {
return std::make_unique<BoundMergeClause>(*this);
}

private:
bool hasCreateInfo(const std::function<bool(const BoundCreateInfo& info)>& check) const;
std::vector<BoundCreateInfo*> getCreateInfos(
const std::function<bool(const BoundCreateInfo& info)>& check) const;

bool hasOnMatchSetInfo(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;
std::vector<BoundSetPropertyInfo*> getOnMatchSetInfos(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;

bool hasOnCreateSetInfo(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;
std::vector<BoundSetPropertyInfo*> getOnCreateSetInfos(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;

private:
// Pattern to match.
std::unique_ptr<QueryGraphCollection> queryGraphCollection;
std::shared_ptr<Expression> predicate;
// Pattern to create on match failure.
std::vector<std::unique_ptr<BoundCreateInfo>> createInfos;
// Update on match
Expand Down
1 change: 1 addition & 0 deletions src/include/binder/visitor/property_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class PropertyCollector : public BoundStatementVisitor {
void visitSet(const BoundUpdatingClause& updatingClause) final;
void visitDelete(const BoundUpdatingClause& updatingClause) final;
void visitCreate(const BoundUpdatingClause& updatingClause) final;
void visitMerge(const BoundUpdatingClause& updatingClause) final;

void visitProjectionBody(const BoundProjectionBody& projectionBody) final;
void visitProjectionBodyPredicate(const std::shared_ptr<Expression>& predicate) final;
Expand Down
1 change: 1 addition & 0 deletions src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class FactorizationRewriter : public LogicalOperatorVisitor {
void visitDeleteRel(planner::LogicalOperator* op) override;
void visitCreateNode(planner::LogicalOperator* op) override;
void visitCreateRel(planner::LogicalOperator* op) override;
void visitMerge(planner::LogicalOperator* op) override;
void visitCopyTo(planner::LogicalOperator* op) override;

std::shared_ptr<planner::LogicalOperator> appendFlattens(
Expand Down
8 changes: 7 additions & 1 deletion src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,14 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitMerge(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitMergeReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitCopyTo(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitCopyTo(
virtual std::shared_ptr<planner::LogicalOperator> visitCopyToReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}
Expand Down
1 change: 1 addition & 0 deletions src/include/optimizer/projection_push_down_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ProjectionPushDownOptimizer : public LogicalOperatorVisitor {
void visitCreateRel(planner::LogicalOperator* op) override;
void visitDeleteNode(planner::LogicalOperator* op) override;
void visitDeleteRel(planner::LogicalOperator* op) override;
void visitMerge(planner::LogicalOperator* op) override;

void collectExpressionsInUse(std::shared_ptr<binder::Expression> expression);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum class LogicalOperatorType : uint8_t {
INDEX_SCAN_NODE,
INTERSECT,
LIMIT,
MERGE,
MULTIPLICITY_REDUCER,
NODE_LABEL_FILTER,
ORDER_BY,
Expand Down
Loading

0 comments on commit 68e86ef

Please sign in to comment.