Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge operator #1894

Merged
merged 1 commit into from
Aug 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,14 @@ std::unique_ptr<BoundUpdatingClause> Binder::bindMergeClause(
// bindGraphPattern will update scope.
auto [queryGraphCollection, propertyCollection] =
bindGraphPattern(mergeClause.getPatternElementsRef());
std::shared_ptr<Expression> predicate = nullptr;
for (auto& [key, val] : propertyCollection->getKeyVals()) {
predicate = expressionBinder.combineConjunctiveExpressions(
expressionBinder.createEqualityComparisonExpression(key, val), predicate);
}
auto createInfos = bindCreateInfos(*queryGraphCollection, *propertyCollection, nodeRelScope);
auto boundMergeClause =
std::make_unique<BoundMergeClause>(std::move(queryGraphCollection), std::move(createInfos));
auto boundMergeClause = std::make_unique<BoundMergeClause>(
std::move(queryGraphCollection), std::move(predicate), std::move(createInfos));
if (mergeClause.hasOnMatchSetItems()) {
for (auto i = 0u; i < mergeClause.getNumOnMatchSetItems(); ++i) {
auto setPropertyInfo = bindSetPropertyInfo(mergeClause.getOnMatchSetItem(i));
Expand Down
15 changes: 9 additions & 6 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ void BoundStatementVisitor::visitExplain(const BoundStatement& statement) {

void BoundStatementVisitor::visitReadingClause(const BoundReadingClause& readingClause) {
switch (readingClause.getClauseType()) {
case common::ClauseType::MATCH: {
case ClauseType::MATCH: {
visitMatch(readingClause);
} break;
case common::ClauseType::UNWIND: {
case ClauseType::UNWIND: {
visitUnwind(readingClause);
} break;
case common::ClauseType::InQueryCall: {
case ClauseType::InQueryCall: {
visitInQueryCall(readingClause);
} break;
default:
Expand All @@ -100,15 +100,18 @@ void BoundStatementVisitor::visitReadingClause(const BoundReadingClause& reading

void BoundStatementVisitor::visitUpdatingClause(const BoundUpdatingClause& updatingClause) {
switch (updatingClause.getClauseType()) {
case common::ClauseType::SET: {
case ClauseType::SET: {
visitSet(updatingClause);
} break;
case common::ClauseType::DELETE_: {
case ClauseType::DELETE_: {
visitDelete(updatingClause);
} break;
case common::ClauseType::CREATE: {
case ClauseType::CREATE: {
visitCreate(updatingClause);
} break;
case ClauseType::MERGE: {
visitMerge(updatingClause);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visitUpdatingClause");
}
Expand Down
64 changes: 64 additions & 0 deletions src/binder/query/bound_merge_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace binder {
BoundMergeClause::BoundMergeClause(const BoundMergeClause& other)
: BoundUpdatingClause{common::ClauseType::MERGE} {
queryGraphCollection = other.queryGraphCollection->copy();
predicate = other.predicate;
for (auto& createInfo : other.createInfos) {
createInfos.push_back(createInfo->copy());
}
Expand All @@ -17,5 +18,68 @@ BoundMergeClause::BoundMergeClause(const BoundMergeClause& other)
}
}

bool BoundMergeClause::hasCreateInfo(
const std::function<bool(const BoundCreateInfo&)>& check) const {
for (auto& info : createInfos) {
if (check(*info)) {
return true;
}
}
return false;
}

std::vector<BoundCreateInfo*> BoundMergeClause::getCreateInfos(
const std::function<bool(const BoundCreateInfo&)>& check) const {
std::vector<BoundCreateInfo*> result;
for (auto& info : createInfos) {
if (check(*info)) {
result.push_back(info.get());
}
}
return result;
}

bool BoundMergeClause::hasOnMatchSetInfo(
const std::function<bool(const BoundSetPropertyInfo&)>& check) const {
for (auto& info : onMatchSetPropertyInfos) {
if (check(*info)) {
return true;
}
}
return false;
}

std::vector<BoundSetPropertyInfo*> BoundMergeClause::getOnMatchSetInfos(
const std::function<bool(const BoundSetPropertyInfo&)>& check) const {
std::vector<BoundSetPropertyInfo*> result;
for (auto& info : onMatchSetPropertyInfos) {
if (check(*info)) {
result.push_back(info.get());
}
}
return result;
}

bool BoundMergeClause::hasOnCreateSetInfo(
const std::function<bool(const BoundSetPropertyInfo&)>& check) const {
for (auto& info : onCreateSetPropertyInfos) {
if (check(*info)) {
return true;
}
}
return false;
}

std::vector<BoundSetPropertyInfo*> BoundMergeClause::getOnCreateSetInfos(
const std::function<bool(const BoundSetPropertyInfo&)>& check) const {
std::vector<BoundSetPropertyInfo*> result;
for (auto& info : onCreateSetPropertyInfos) {
if (check(*info)) {
result.push_back(info.get());
}
}
return result;
}

} // namespace binder
} // namespace kuzu
28 changes: 27 additions & 1 deletion src/binder/visitor/property_collector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
#include "binder/query/reading_clause/bound_unwind_clause.h"
#include "binder/query/updating_clause/bound_create_clause.h"
#include "binder/query/updating_clause/bound_delete_clause.h"
#include "binder/query/updating_clause/bound_merge_clause.h"
#include "binder/query/updating_clause/bound_set_clause.h"

using namespace kuzu::common;

namespace kuzu {
namespace binder {

Expand All @@ -21,7 +24,7 @@ expression_vector PropertyCollector::getProperties() {
void PropertyCollector::visitMatch(const BoundReadingClause& readingClause) {
auto& matchClause = (BoundMatchClause&)readingClause;
for (auto& rel : matchClause.getQueryGraphCollection()->getQueryRels()) {
if (rel->getRelType() == common::QueryRelType::NON_RECURSIVE) {
if (rel->getRelType() == QueryRelType::NON_RECURSIVE) {
properties.insert(rel->getInternalIDProperty());
}
}
Expand Down Expand Up @@ -63,6 +66,29 @@ void PropertyCollector::visitCreate(const BoundUpdatingClause& updatingClause) {
}
}

void PropertyCollector::visitMerge(const BoundUpdatingClause& updatingClause) {
auto& boundMergeClause = (BoundMergeClause&)updatingClause;
for (auto& rel : boundMergeClause.getQueryGraphCollection()->getQueryRels()) {
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
if (rel->getRelType() == QueryRelType::NON_RECURSIVE) {
properties.insert(rel->getInternalIDProperty());
}
}
if (boundMergeClause.hasPredicate()) {
collectPropertyExpressions(boundMergeClause.getPredicate());
}
for (auto& info : boundMergeClause.getCreateInfosRef()) {
for (auto& setItem : info->setItems) {
collectPropertyExpressions(setItem.second);
}
}
for (auto& info : boundMergeClause.getOnMatchSetInfosRef()) {
collectPropertyExpressions(info->setItem.second);
}
for (auto& info : boundMergeClause.getOnCreateSetInfosRef()) {
collectPropertyExpressions(info->setItem.second);
}
}

void PropertyCollector::visitProjectionBody(const BoundProjectionBody& projectionBody) {
for (auto& expression : projectionBody.getProjectionExpressions()) {
collectPropertyExpressions(expression);
Expand Down
1 change: 1 addition & 0 deletions src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
virtual void visitSet(const BoundUpdatingClause& updatingClause) {}
virtual void visitDelete(const BoundUpdatingClause& updatingClause) {}
virtual void visitCreate(const BoundUpdatingClause& updatingClause) {}
virtual void visitMerge(const BoundUpdatingClause& updatingClause) {}

Check warning on line 41 in src/include/binder/bound_statement_visitor.h

View check run for this annotation

Codecov / codecov/patch

src/include/binder/bound_statement_visitor.h#L41

Added line #L41 was not covered by tests

virtual void visitProjectionBody(const BoundProjectionBody& projectionBody) {}
virtual void visitProjectionBodyPredicate(const std::shared_ptr<Expression>& predicate) {}
Expand Down
103 changes: 100 additions & 3 deletions src/include/binder/query/updating_clause/bound_merge_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,93 @@
class BoundMergeClause : public BoundUpdatingClause {
public:
BoundMergeClause(std::unique_ptr<QueryGraphCollection> queryGraphCollection,
std::shared_ptr<Expression> predicate,
std::vector<std::unique_ptr<BoundCreateInfo>> createInfos)
: BoundUpdatingClause{common::ClauseType::MERGE},
queryGraphCollection{std::move(queryGraphCollection)}, createInfos{
std::move(createInfos)} {}
: BoundUpdatingClause{common::ClauseType::MERGE}, queryGraphCollection{std::move(
queryGraphCollection)},
predicate{std::move(predicate)}, createInfos{std::move(createInfos)} {}
BoundMergeClause(const BoundMergeClause& other);

inline QueryGraphCollection* getQueryGraphCollection() const {
return queryGraphCollection.get();
}
inline bool hasPredicate() const { return predicate != nullptr; }
inline std::shared_ptr<Expression> getPredicate() const { return predicate; }

inline const std::vector<std::unique_ptr<BoundCreateInfo>>& getCreateInfosRef() const {
return createInfos;
}
inline const std::vector<std::unique_ptr<BoundSetPropertyInfo>>& getOnMatchSetInfosRef() const {
return onMatchSetPropertyInfos;
}
inline const std::vector<std::unique_ptr<BoundSetPropertyInfo>>&
getOnCreateSetInfosRef() const {
return onCreateSetPropertyInfos;
}

inline bool hasCreateNodeInfo() const {
return hasCreateInfo([](const BoundCreateInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline std::vector<BoundCreateInfo*> getCreateNodeInfos() const {
return getCreateInfos([](const BoundCreateInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline bool hasCreateRelInfo() const {
return hasCreateInfo([](const BoundCreateInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}
inline std::vector<BoundCreateInfo*> getCreateRelInfos() const {
return getCreateInfos([](const BoundCreateInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}

inline bool hasOnMatchSetNodeInfo() const {
return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline std::vector<BoundSetPropertyInfo*> getOnMatchSetNodeInfos() const {
return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline bool hasOnMatchSetRelInfo() const {
return hasOnMatchSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}
inline std::vector<BoundSetPropertyInfo*> getOnMatchSetRelInfos() const {
return getOnMatchSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}

inline bool hasOnCreateSetNodeInfo() const {
return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline std::vector<BoundSetPropertyInfo*> getOnCreateSetNodeInfos() const {
return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::NODE;
});
}
inline bool hasOnCreateSetRelInfo() const {
return hasOnCreateSetInfo([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});
}
inline std::vector<BoundSetPropertyInfo*> getOnCreateSetRelInfos() const {

Check warning on line 95 in src/include/binder/query/updating_clause/bound_merge_clause.h

View check run for this annotation

Codecov / codecov/patch

src/include/binder/query/updating_clause/bound_merge_clause.h#L95

Added line #L95 was not covered by tests
return getOnCreateSetInfos([](const BoundSetPropertyInfo& info) {
return info.updateTableType == UpdateTableType::REL;
});

Check warning on line 98 in src/include/binder/query/updating_clause/bound_merge_clause.h

View check run for this annotation

Codecov / codecov/patch

src/include/binder/query/updating_clause/bound_merge_clause.h#L97-L98

Added lines #L97 - L98 were not covered by tests
}

inline void addOnMatchSetPropertyInfo(std::unique_ptr<BoundSetPropertyInfo> setPropertyInfo) {
onMatchSetPropertyInfos.push_back(std::move(setPropertyInfo));
}
Expand All @@ -28,9 +109,25 @@
return std::make_unique<BoundMergeClause>(*this);
}

private:
bool hasCreateInfo(const std::function<bool(const BoundCreateInfo& info)>& check) const;
std::vector<BoundCreateInfo*> getCreateInfos(
const std::function<bool(const BoundCreateInfo& info)>& check) const;

bool hasOnMatchSetInfo(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;
std::vector<BoundSetPropertyInfo*> getOnMatchSetInfos(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;

bool hasOnCreateSetInfo(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;
std::vector<BoundSetPropertyInfo*> getOnCreateSetInfos(
const std::function<bool(const BoundSetPropertyInfo& info)>& check) const;

private:
// Pattern to match.
std::unique_ptr<QueryGraphCollection> queryGraphCollection;
std::shared_ptr<Expression> predicate;
// Pattern to create on match failure.
std::vector<std::unique_ptr<BoundCreateInfo>> createInfos;
// Update on match
Expand Down
1 change: 1 addition & 0 deletions src/include/binder/visitor/property_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class PropertyCollector : public BoundStatementVisitor {
void visitSet(const BoundUpdatingClause& updatingClause) final;
void visitDelete(const BoundUpdatingClause& updatingClause) final;
void visitCreate(const BoundUpdatingClause& updatingClause) final;
void visitMerge(const BoundUpdatingClause& updatingClause) final;

void visitProjectionBody(const BoundProjectionBody& projectionBody) final;
void visitProjectionBodyPredicate(const std::shared_ptr<Expression>& predicate) final;
Expand Down
1 change: 1 addition & 0 deletions src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class FactorizationRewriter : public LogicalOperatorVisitor {
void visitDeleteRel(planner::LogicalOperator* op) override;
void visitCreateNode(planner::LogicalOperator* op) override;
void visitCreateRel(planner::LogicalOperator* op) override;
void visitMerge(planner::LogicalOperator* op) override;
void visitCopyTo(planner::LogicalOperator* op) override;

std::shared_ptr<planner::LogicalOperator> appendFlattens(
Expand Down
8 changes: 7 additions & 1 deletion src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,14 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitMerge(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitMergeReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitCopyTo(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitCopyTo(
virtual std::shared_ptr<planner::LogicalOperator> visitCopyToReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}
Expand Down
1 change: 1 addition & 0 deletions src/include/optimizer/projection_push_down_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ProjectionPushDownOptimizer : public LogicalOperatorVisitor {
void visitCreateRel(planner::LogicalOperator* op) override;
void visitDeleteNode(planner::LogicalOperator* op) override;
void visitDeleteRel(planner::LogicalOperator* op) override;
void visitMerge(planner::LogicalOperator* op) override;

void collectExpressionsInUse(std::shared_ptr<binder::Expression> expression);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum class LogicalOperatorType : uint8_t {
INDEX_SCAN_NODE,
INTERSECT,
LIMIT,
MERGE,
MULTIPLICITY_REDUCER,
NODE_LABEL_FILTER,
ORDER_BY,
Expand Down
Loading
Loading