Skip to content

Commit

Permalink
Merge pull request #1861 from kuzudb/merge
Browse files Browse the repository at this point in the history
Add binding for MERGE clause
  • Loading branch information
andyfengHKU committed Jul 26, 2023
2 parents c1896a3 + 598798c commit bbe9011
Show file tree
Hide file tree
Showing 34 changed files with 4,712 additions and 4,134 deletions.
14 changes: 14 additions & 0 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ kU_QueryPart

oC_UpdatingClause
: oC_Create
| oC_Merge
| oC_Set
| oC_Delete
;
Expand Down Expand Up @@ -218,6 +219,19 @@ oC_Create

CREATE : ( 'C' | 'c' ) ( 'R' | 'r' ) ( 'E' | 'e' ) ( 'A' | 'a' ) ( 'T' | 't' ) ( 'E' | 'e' ) ;

// For unknown reason, openCypher use oC_PatternPart instead of oC_Pattern. There should be no difference in terms of planning.
// So we choose to be consistent with oC_Create and use oC_Pattern instead.
oC_Merge : MERGE SP? oC_Pattern ( SP oC_MergeAction )* ;

MERGE : ( 'M' | 'm' ) ( 'E' | 'e' ) ( 'R' | 'r' ) ( 'G' | 'g' ) ( 'E' | 'e' ) ;

oC_MergeAction
: ( ON SP MATCH SP oC_Set )
| ( ON SP CREATE SP oC_Set )
;

ON : ( 'O' | 'o' ) ( 'N' | 'n' ) ;

oC_Set
: SET SP? oC_SetItem ( SP? ',' SP? oC_SetItem )* ;

Expand Down
169 changes: 108 additions & 61 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include "binder/binder.h"
#include "binder/query/updating_clause/bound_create_clause.h"
#include "binder/query/updating_clause/bound_delete_clause.h"
#include "binder/query/updating_clause/bound_merge_clause.h"
#include "binder/query/updating_clause/bound_set_clause.h"
#include "parser/query/updating_clause/create_clause.h"
#include "parser/query/updating_clause/delete_clause.h"
#include "parser/query/updating_clause/merge_clause.h"
#include "parser/query/updating_clause/set_clause.h"

using namespace kuzu::common;
Expand All @@ -18,57 +20,95 @@ std::unique_ptr<BoundUpdatingClause> Binder::bindUpdatingClause(
case ClauseType::CREATE: {
return bindCreateClause(updatingClause);
}
case ClauseType::MERGE: {
return bindMergeClause(updatingClause);
}
case ClauseType::SET: {
return bindSetClause(updatingClause);
}
case ClauseType::DELETE_: {
return bindDeleteClause(updatingClause);
}
default:
throw NotImplementedException("bindUpdatingClause().");
throw NotImplementedException("Binder::bindUpdatingClause");
}
}

static expression_set populateNodeRelScope(const BinderScope& scope) {
expression_set result;
for (auto& expression : scope.getExpressions()) {
if (ExpressionUtil::isNodeVariable(*expression) ||
ExpressionUtil::isRelVariable(*expression)) {
result.insert(expression);
}
}
return result;
}

std::unique_ptr<BoundUpdatingClause> Binder::bindCreateClause(
const UpdatingClause& updatingClause) {
auto& createClause = (CreateClause&)updatingClause;
auto prevScope = scope->copy();
expression_set nodesScope;
expression_set relsScope;
for (auto& expression : scope->getExpressions()) {
if (ExpressionUtil::isNodeVariable(*expression)) {
nodesScope.insert(expression);
} else if (ExpressionUtil::isRelVariable(*expression)) {
relsScope.insert(expression);
}
}
auto nodeRelScope = populateNodeRelScope(*scope);
// bindGraphPattern will update scope.
auto [queryGraphCollection, propertyCollection] =
bindGraphPattern(createClause.getPatternElements());
auto boundCreateClause = std::make_unique<BoundCreateClause>();
for (auto i = 0u; i < queryGraphCollection->getNumQueryGraphs(); ++i) {
auto queryGraph = queryGraphCollection->getQueryGraph(i);
bindGraphPattern(createClause.getPatternElementsRef());
auto createInfos = bindCreateInfos(*queryGraphCollection, *propertyCollection, nodeRelScope);
return std::make_unique<BoundCreateClause>(std::move(createInfos));
}

std::unique_ptr<BoundUpdatingClause> Binder::bindMergeClause(
const parser::UpdatingClause& updatingClause) {
auto& mergeClause = (MergeClause&)updatingClause;
auto nodeRelScope = populateNodeRelScope(*scope);
// bindGraphPattern will update scope.
auto [queryGraphCollection, propertyCollection] =
bindGraphPattern(mergeClause.getPatternElementsRef());
auto createInfos = bindCreateInfos(*queryGraphCollection, *propertyCollection, nodeRelScope);
auto boundMergeClause =
std::make_unique<BoundMergeClause>(std::move(queryGraphCollection), std::move(createInfos));
if (mergeClause.hasOnMatchSetItems()) {
for (auto i = 0u; i < mergeClause.getNumOnMatchSetItems(); ++i) {
auto setPropertyInfo = bindSetPropertyInfo(mergeClause.getOnMatchSetItem(i));
boundMergeClause->addOnMatchSetPropertyInfo(std::move(setPropertyInfo));
}
}
if (mergeClause.hasOnCreateSetItems()) {
for (auto i = 0u; i < mergeClause.getNumOnCreateSetItems(); ++i) {
auto setPropertyInfo = bindSetPropertyInfo(mergeClause.getOnCreateSetItem(i));
boundMergeClause->addOnCreateSetPropertyInfo(std::move(setPropertyInfo));
}
}
return boundMergeClause;
}

std::vector<std::unique_ptr<BoundCreateInfo>> Binder::bindCreateInfos(
const QueryGraphCollection& queryGraphCollection,
const PropertyKeyValCollection& keyValCollection, const expression_set& nodeRelScope_) {
auto nodeRelScope = nodeRelScope_;
std::vector<std::unique_ptr<BoundCreateInfo>> result;
for (auto i = 0u; i < queryGraphCollection.getNumQueryGraphs(); ++i) {
auto queryGraph = queryGraphCollection.getQueryGraph(i);
for (auto j = 0u; j < queryGraph->getNumQueryNodes(); ++j) {
auto node = queryGraph->getQueryNode(j);
if (nodesScope.contains(node)) {
if (nodeRelScope.contains(node)) {
continue;
}
nodesScope.insert(node);
boundCreateClause->addCreateNode(bindCreateNode(node, *propertyCollection));
nodeRelScope.insert(node);
result.push_back(bindCreateNodeInfo(node, keyValCollection));
}
for (auto j = 0u; j < queryGraph->getNumQueryRels(); ++j) {
auto rel = queryGraph->getQueryRel(j);
if (relsScope.contains(rel)) {
if (nodeRelScope.contains(rel)) {
continue;
}
relsScope.insert(rel);
boundCreateClause->addCreateRel(bindCreateRel(rel, *propertyCollection));
nodeRelScope.insert(rel);
result.push_back(bindCreateRelInfo(rel, keyValCollection));
}
}
return boundCreateClause;
return result;
}

std::unique_ptr<BoundCreateNode> Binder::bindCreateNode(
std::unique_ptr<BoundCreateInfo> Binder::bindCreateNodeInfo(
std::shared_ptr<NodeExpression> node, const PropertyKeyValCollection& collection) {
if (node->isMultiLabeled()) {
throw BinderException(
Expand Down Expand Up @@ -102,11 +142,12 @@ std::unique_ptr<BoundCreateNode> Binder::bindCreateNode(
throw BinderException("Create node " + node->toString() + " expects primary key " +
primaryKey.name + " as input.");
}
return std::make_unique<BoundCreateNode>(
std::move(node), std::move(primaryKeyExpression), std::move(setItems));
auto extraInfo = std::make_unique<ExtraCreateNodeInfo>(std::move(primaryKeyExpression));
return std::make_unique<BoundCreateInfo>(
UpdateTableType::NODE, std::move(node), std::move(setItems), std::move(extraInfo));
}

std::unique_ptr<BoundCreateRel> Binder::bindCreateRel(
std::unique_ptr<BoundCreateInfo> Binder::bindCreateRelInfo(
std::shared_ptr<RelExpression> rel, const PropertyKeyValCollection& collection) {
if (rel->isMultiLabeled() || rel->isBoundByMultiLabeledNode()) {
throw BinderException(
Expand All @@ -130,49 +171,54 @@ std::unique_ptr<BoundCreateRel> Binder::bindCreateRel(
setItems.emplace_back(std::move(propertyExpression), std::move(nullExpression));
}
}
return std::make_unique<BoundCreateRel>(std::move(rel), std::move(setItems));
return std::make_unique<BoundCreateInfo>(
UpdateTableType::REL, std::move(rel), std::move(setItems), nullptr /* extraInfo */);
}

std::unique_ptr<BoundUpdatingClause> Binder::bindSetClause(const UpdatingClause& updatingClause) {
auto& setClause = (SetClause&)updatingClause;
auto boundSetClause = std::make_unique<BoundSetClause>();
for (auto i = 0u; i < setClause.getNumSetItems(); ++i) {
auto setItem = setClause.getSetItem(i);
auto nodeOrRel = expressionBinder.bindExpression(*setItem.first->getChild(0));
switch (nodeOrRel->dataType.getLogicalTypeID()) {
case LogicalTypeID::NODE: {
auto node = static_pointer_cast<NodeExpression>(nodeOrRel);
boundSetClause->addSetNodeProperty(bindSetNodeProperty(node, setItem));
} break;
case LogicalTypeID::REL: {
auto rel = static_pointer_cast<RelExpression>(nodeOrRel);
boundSetClause->addSetRelProperty(bindSetRelProperty(rel, setItem));
} break;
default:
throw BinderException("Set " + expressionTypeToString(nodeOrRel->expressionType) +
" property is supported.");
}
boundSetClause->addInfo(bindSetPropertyInfo(setClause.getSetItem(i)));
}
return boundSetClause;
}

std::unique_ptr<BoundSetNodeProperty> Binder::bindSetNodeProperty(
std::shared_ptr<NodeExpression> node, std::pair<ParsedExpression*, ParsedExpression*> setItem) {
if (node->isMultiLabeled()) {
throw BinderException("Set property of node " + node->toString() +
static void validateSetNodeProperty(const Expression& expression) {
auto& node = (const NodeExpression&)expression;
if (node.isMultiLabeled()) {
throw BinderException("Set property of node " + node.toString() +
" with multiple node labels is not supported.");
}
return std::make_unique<BoundSetNodeProperty>(std::move(node), bindSetItem(setItem));
}

std::unique_ptr<BoundSetRelProperty> Binder::bindSetRelProperty(
std::shared_ptr<RelExpression> rel, std::pair<ParsedExpression*, ParsedExpression*> setItem) {
if (rel->isMultiLabeled() || rel->isBoundByMultiLabeledNode()) {
throw BinderException("Set property of rel " + rel->toString() +
static void validateSetRelProperty(const Expression& expression) {
auto& rel = (const RelExpression&)expression;
if (rel.isMultiLabeled() || rel.isBoundByMultiLabeledNode()) {
throw BinderException("Set property of rel " + rel.toString() +
" with multiple rel labels or bound by multiple node labels "
"is not supported.");
}
return std::make_unique<BoundSetRelProperty>(std::move(rel), bindSetItem(setItem));
}

std::unique_ptr<BoundSetPropertyInfo> Binder::bindSetPropertyInfo(
std::pair<parser::ParsedExpression*, parser::ParsedExpression*> setItem) {
auto left = expressionBinder.bindExpression(*setItem.first->getChild(0));
switch (left->dataType.getLogicalTypeID()) {
case LogicalTypeID::NODE: {
validateSetNodeProperty(*left);
return std::make_unique<BoundSetPropertyInfo>(
UpdateTableType::NODE, left, bindSetItem(setItem));
}
case LogicalTypeID::REL: {
validateSetRelProperty(*left);
return std::make_unique<BoundSetPropertyInfo>(
UpdateTableType::REL, left, bindSetItem(setItem));
}
default:
throw BinderException(
"Set " + expressionTypeToString(left->expressionType) + " property is supported.");
}
}

expression_pair Binder::bindSetItem(std::pair<ParsedExpression*, ParsedExpression*> setItem) {
Expand All @@ -190,12 +236,13 @@ std::unique_ptr<BoundUpdatingClause> Binder::bindDeleteClause(
auto nodeOrRel = expressionBinder.bindExpression(*deleteClause.getExpression(i));
switch (nodeOrRel->dataType.getLogicalTypeID()) {
case LogicalTypeID::NODE: {
auto deleteNode = bindDeleteNode(static_pointer_cast<NodeExpression>(nodeOrRel));
boundDeleteClause->addDeleteNode(std::move(deleteNode));
auto deleteNodeInfo =
bindDeleteNodeInfo(static_pointer_cast<NodeExpression>(nodeOrRel));
boundDeleteClause->addInfo(std::move(deleteNodeInfo));
} break;
case LogicalTypeID::REL: {
auto deleteRel = bindDeleteRel(static_pointer_cast<RelExpression>(nodeOrRel));
boundDeleteClause->addDeleteRel(std::move(deleteRel));
auto deleteRel = bindDeleteRelInfo(static_pointer_cast<RelExpression>(nodeOrRel));
boundDeleteClause->addInfo(std::move(deleteRel));
} break;
default:
throw BinderException("Delete " + expressionTypeToString(nodeOrRel->expressionType) +
Expand All @@ -205,8 +252,7 @@ std::unique_ptr<BoundUpdatingClause> Binder::bindDeleteClause(
return boundDeleteClause;
}

std::unique_ptr<BoundDeleteNode> Binder::bindDeleteNode(
const std::shared_ptr<NodeExpression>& node) {
std::unique_ptr<BoundDeleteInfo> Binder::bindDeleteNodeInfo(std::shared_ptr<NodeExpression> node) {
if (node->isMultiLabeled()) {
throw BinderException(
"Delete node " + node->toString() + " with multiple node labels is not supported.");
Expand All @@ -215,16 +261,17 @@ std::unique_ptr<BoundDeleteNode> Binder::bindDeleteNode(
auto nodeTableSchema = catalog.getReadOnlyVersion()->getNodeTableSchema(nodeTableID);
auto primaryKeyExpression =
expressionBinder.bindNodePropertyExpression(*node, nodeTableSchema->getPrimaryKey().name);
return std::make_unique<BoundDeleteNode>(node, primaryKeyExpression);
auto extraInfo = std::make_unique<ExtraDeleteNodeInfo>(primaryKeyExpression);
return std::make_unique<BoundDeleteInfo>(UpdateTableType::NODE, node, std::move(extraInfo));
}

std::shared_ptr<RelExpression> Binder::bindDeleteRel(std::shared_ptr<RelExpression> rel) {
std::unique_ptr<BoundDeleteInfo> Binder::bindDeleteRelInfo(std::shared_ptr<RelExpression> rel) {
if (rel->isMultiLabeled() || rel->isBoundByMultiLabeledNode()) {
throw BinderException(
"Delete rel " + rel->toString() +
" with multiple rel labels or bound by multiple node labels is not supported.");
}
return rel;
return std::make_unique<BoundDeleteInfo>(UpdateTableType::REL, rel, nullptr /* extraInfo */);
}

} // namespace binder
Expand Down
1 change: 1 addition & 0 deletions src/binder/query/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_library(
OBJECT
bound_create_clause.cpp
bound_delete_clause.cpp
bound_merge_clause.cpp
bound_set_clause.cpp
query_graph.cpp)

Expand Down
37 changes: 24 additions & 13 deletions src/binder/query/bound_create_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,39 @@
namespace kuzu {
namespace binder {

BoundCreateClause::BoundCreateClause(const BoundCreateClause& other)
: BoundUpdatingClause{common::ClauseType::CREATE} {
for (auto& info : other.infos) {
infos.push_back(info->copy());
}
}

std::vector<expression_pair> BoundCreateClause::getAllSetItems() const {
std::vector<expression_pair> result;
for (auto& createNode : createNodes) {
for (auto& setItem : createNode->getSetItems()) {
result.push_back(setItem);
}
}
for (auto& createRel : createRels) {
for (auto& setItem : createRel->getSetItems()) {
for (auto& info : infos) {
for (auto& setItem : info->setItems) {
result.push_back(setItem);
}
}
return result;
}

std::unique_ptr<BoundUpdatingClause> BoundCreateClause::copy() {
auto result = std::make_unique<BoundCreateClause>();
for (auto& createNode : createNodes) {
result->addCreateNode(createNode->copy());
bool BoundCreateClause::hasInfo(const std::function<bool(const BoundCreateInfo&)>& check) const {
for (auto& info : infos) {
if (check(*info)) {
return true;
}
}
for (auto& createRel : createRels) {
result->addCreateRel(createRel->copy());
return false;
}

std::vector<BoundCreateInfo*> BoundCreateClause::getInfos(
const std::function<bool(const BoundCreateInfo&)>& check) const {
std::vector<BoundCreateInfo*> result;
for (auto& info : infos) {
if (check(*info)) {
result.push_back(info.get());
}
}
return result;
}
Expand Down
Loading

0 comments on commit bbe9011

Please sign in to comment.