Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add binding for MERGE clause #1861

Merged
merged 1 commit into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ kU_QueryPart

oC_UpdatingClause
: oC_Create
| oC_Merge
| oC_Set
| oC_Delete
;
Expand Down Expand Up @@ -218,6 +219,19 @@ oC_Create

CREATE : ( 'C' | 'c' ) ( 'R' | 'r' ) ( 'E' | 'e' ) ( 'A' | 'a' ) ( 'T' | 't' ) ( 'E' | 'e' ) ;

// For unknown reason, openCypher use oC_PatternPart instead of oC_Pattern. There should be no difference in terms of planning.
// So we choose to be consistent with oC_Create and use oC_Pattern instead.
oC_Merge : MERGE SP? oC_Pattern ( SP oC_MergeAction )* ;

MERGE : ( 'M' | 'm' ) ( 'E' | 'e' ) ( 'R' | 'r' ) ( 'G' | 'g' ) ( 'E' | 'e' ) ;

oC_MergeAction
: ( ON SP MATCH SP oC_Set )
| ( ON SP CREATE SP oC_Set )
;

ON : ( 'O' | 'o' ) ( 'N' | 'n' ) ;

oC_Set
: SET SP? oC_SetItem ( SP? ',' SP? oC_SetItem )* ;

Expand Down
169 changes: 108 additions & 61 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include "binder/binder.h"
#include "binder/query/updating_clause/bound_create_clause.h"
#include "binder/query/updating_clause/bound_delete_clause.h"
#include "binder/query/updating_clause/bound_merge_clause.h"
#include "binder/query/updating_clause/bound_set_clause.h"
#include "parser/query/updating_clause/create_clause.h"
#include "parser/query/updating_clause/delete_clause.h"
#include "parser/query/updating_clause/merge_clause.h"
#include "parser/query/updating_clause/set_clause.h"

using namespace kuzu::common;
Expand All @@ -18,57 +20,95 @@ std::unique_ptr<BoundUpdatingClause> Binder::bindUpdatingClause(
case ClauseType::CREATE: {
return bindCreateClause(updatingClause);
}
case ClauseType::MERGE: {
return bindMergeClause(updatingClause);
}
case ClauseType::SET: {
return bindSetClause(updatingClause);
}
case ClauseType::DELETE_: {
return bindDeleteClause(updatingClause);
}
default:
throw NotImplementedException("bindUpdatingClause().");
throw NotImplementedException("Binder::bindUpdatingClause");
}
}

static expression_set populateNodeRelScope(const BinderScope& scope) {
expression_set result;
for (auto& expression : scope.getExpressions()) {
if (ExpressionUtil::isNodeVariable(*expression) ||
ExpressionUtil::isRelVariable(*expression)) {
result.insert(expression);
}
}
return result;
}

std::unique_ptr<BoundUpdatingClause> Binder::bindCreateClause(
const UpdatingClause& updatingClause) {
auto& createClause = (CreateClause&)updatingClause;
auto prevScope = scope->copy();
expression_set nodesScope;
expression_set relsScope;
for (auto& expression : scope->getExpressions()) {
if (ExpressionUtil::isNodeVariable(*expression)) {
nodesScope.insert(expression);
} else if (ExpressionUtil::isRelVariable(*expression)) {
relsScope.insert(expression);
}
}
auto nodeRelScope = populateNodeRelScope(*scope);
// bindGraphPattern will update scope.
auto [queryGraphCollection, propertyCollection] =
bindGraphPattern(createClause.getPatternElements());
auto boundCreateClause = std::make_unique<BoundCreateClause>();
for (auto i = 0u; i < queryGraphCollection->getNumQueryGraphs(); ++i) {
auto queryGraph = queryGraphCollection->getQueryGraph(i);
bindGraphPattern(createClause.getPatternElementsRef());
auto createInfos = bindCreateInfos(*queryGraphCollection, *propertyCollection, nodeRelScope);
return std::make_unique<BoundCreateClause>(std::move(createInfos));
}

std::unique_ptr<BoundUpdatingClause> Binder::bindMergeClause(
const parser::UpdatingClause& updatingClause) {
auto& mergeClause = (MergeClause&)updatingClause;
auto nodeRelScope = populateNodeRelScope(*scope);
// bindGraphPattern will update scope.
auto [queryGraphCollection, propertyCollection] =
bindGraphPattern(mergeClause.getPatternElementsRef());
auto createInfos = bindCreateInfos(*queryGraphCollection, *propertyCollection, nodeRelScope);
auto boundMergeClause =
std::make_unique<BoundMergeClause>(std::move(queryGraphCollection), std::move(createInfos));
if (mergeClause.hasOnMatchSetItems()) {
for (auto i = 0u; i < mergeClause.getNumOnMatchSetItems(); ++i) {
auto setPropertyInfo = bindSetPropertyInfo(mergeClause.getOnMatchSetItem(i));
boundMergeClause->addOnMatchSetPropertyInfo(std::move(setPropertyInfo));
}
}
if (mergeClause.hasOnCreateSetItems()) {
for (auto i = 0u; i < mergeClause.getNumOnCreateSetItems(); ++i) {
auto setPropertyInfo = bindSetPropertyInfo(mergeClause.getOnCreateSetItem(i));
boundMergeClause->addOnCreateSetPropertyInfo(std::move(setPropertyInfo));
}
}
return boundMergeClause;
}

std::vector<std::unique_ptr<BoundCreateInfo>> Binder::bindCreateInfos(
const QueryGraphCollection& queryGraphCollection,
const PropertyKeyValCollection& keyValCollection, const expression_set& nodeRelScope_) {
auto nodeRelScope = nodeRelScope_;
std::vector<std::unique_ptr<BoundCreateInfo>> result;
for (auto i = 0u; i < queryGraphCollection.getNumQueryGraphs(); ++i) {
auto queryGraph = queryGraphCollection.getQueryGraph(i);
for (auto j = 0u; j < queryGraph->getNumQueryNodes(); ++j) {
auto node = queryGraph->getQueryNode(j);
if (nodesScope.contains(node)) {
if (nodeRelScope.contains(node)) {
continue;
}
nodesScope.insert(node);
boundCreateClause->addCreateNode(bindCreateNode(node, *propertyCollection));
nodeRelScope.insert(node);
result.push_back(bindCreateNodeInfo(node, keyValCollection));
}
for (auto j = 0u; j < queryGraph->getNumQueryRels(); ++j) {
auto rel = queryGraph->getQueryRel(j);
if (relsScope.contains(rel)) {
if (nodeRelScope.contains(rel)) {
continue;
}
relsScope.insert(rel);
boundCreateClause->addCreateRel(bindCreateRel(rel, *propertyCollection));
nodeRelScope.insert(rel);
result.push_back(bindCreateRelInfo(rel, keyValCollection));
}
}
return boundCreateClause;
return result;
}

std::unique_ptr<BoundCreateNode> Binder::bindCreateNode(
std::unique_ptr<BoundCreateInfo> Binder::bindCreateNodeInfo(
std::shared_ptr<NodeExpression> node, const PropertyKeyValCollection& collection) {
if (node->isMultiLabeled()) {
throw BinderException(
Expand Down Expand Up @@ -102,11 +142,12 @@ std::unique_ptr<BoundCreateNode> Binder::bindCreateNode(
throw BinderException("Create node " + node->toString() + " expects primary key " +
primaryKey.name + " as input.");
}
return std::make_unique<BoundCreateNode>(
std::move(node), std::move(primaryKeyExpression), std::move(setItems));
auto extraInfo = std::make_unique<ExtraCreateNodeInfo>(std::move(primaryKeyExpression));
return std::make_unique<BoundCreateInfo>(
UpdateTableType::NODE, std::move(node), std::move(setItems), std::move(extraInfo));
}

std::unique_ptr<BoundCreateRel> Binder::bindCreateRel(
std::unique_ptr<BoundCreateInfo> Binder::bindCreateRelInfo(
std::shared_ptr<RelExpression> rel, const PropertyKeyValCollection& collection) {
if (rel->isMultiLabeled() || rel->isBoundByMultiLabeledNode()) {
throw BinderException(
Expand All @@ -130,49 +171,54 @@ std::unique_ptr<BoundCreateRel> Binder::bindCreateRel(
setItems.emplace_back(std::move(propertyExpression), std::move(nullExpression));
}
}
return std::make_unique<BoundCreateRel>(std::move(rel), std::move(setItems));
return std::make_unique<BoundCreateInfo>(
UpdateTableType::REL, std::move(rel), std::move(setItems), nullptr /* extraInfo */);
}

std::unique_ptr<BoundUpdatingClause> Binder::bindSetClause(const UpdatingClause& updatingClause) {
auto& setClause = (SetClause&)updatingClause;
auto boundSetClause = std::make_unique<BoundSetClause>();
for (auto i = 0u; i < setClause.getNumSetItems(); ++i) {
auto setItem = setClause.getSetItem(i);
auto nodeOrRel = expressionBinder.bindExpression(*setItem.first->getChild(0));
switch (nodeOrRel->dataType.getLogicalTypeID()) {
case LogicalTypeID::NODE: {
auto node = static_pointer_cast<NodeExpression>(nodeOrRel);
boundSetClause->addSetNodeProperty(bindSetNodeProperty(node, setItem));
} break;
case LogicalTypeID::REL: {
auto rel = static_pointer_cast<RelExpression>(nodeOrRel);
boundSetClause->addSetRelProperty(bindSetRelProperty(rel, setItem));
} break;
default:
throw BinderException("Set " + expressionTypeToString(nodeOrRel->expressionType) +
" property is supported.");
}
boundSetClause->addInfo(bindSetPropertyInfo(setClause.getSetItem(i)));
}
return boundSetClause;
}

std::unique_ptr<BoundSetNodeProperty> Binder::bindSetNodeProperty(
std::shared_ptr<NodeExpression> node, std::pair<ParsedExpression*, ParsedExpression*> setItem) {
if (node->isMultiLabeled()) {
throw BinderException("Set property of node " + node->toString() +
static void validateSetNodeProperty(const Expression& expression) {
auto& node = (const NodeExpression&)expression;
if (node.isMultiLabeled()) {
throw BinderException("Set property of node " + node.toString() +
" with multiple node labels is not supported.");
}
return std::make_unique<BoundSetNodeProperty>(std::move(node), bindSetItem(setItem));
}

std::unique_ptr<BoundSetRelProperty> Binder::bindSetRelProperty(
std::shared_ptr<RelExpression> rel, std::pair<ParsedExpression*, ParsedExpression*> setItem) {
if (rel->isMultiLabeled() || rel->isBoundByMultiLabeledNode()) {
throw BinderException("Set property of rel " + rel->toString() +
static void validateSetRelProperty(const Expression& expression) {
auto& rel = (const RelExpression&)expression;
if (rel.isMultiLabeled() || rel.isBoundByMultiLabeledNode()) {
throw BinderException("Set property of rel " + rel.toString() +
" with multiple rel labels or bound by multiple node labels "
"is not supported.");
}
return std::make_unique<BoundSetRelProperty>(std::move(rel), bindSetItem(setItem));
}

std::unique_ptr<BoundSetPropertyInfo> Binder::bindSetPropertyInfo(
std::pair<parser::ParsedExpression*, parser::ParsedExpression*> setItem) {
auto left = expressionBinder.bindExpression(*setItem.first->getChild(0));
switch (left->dataType.getLogicalTypeID()) {
case LogicalTypeID::NODE: {
validateSetNodeProperty(*left);
return std::make_unique<BoundSetPropertyInfo>(
UpdateTableType::NODE, left, bindSetItem(setItem));
}
case LogicalTypeID::REL: {
validateSetRelProperty(*left);
return std::make_unique<BoundSetPropertyInfo>(
UpdateTableType::REL, left, bindSetItem(setItem));
}
default:
throw BinderException(
"Set " + expressionTypeToString(left->expressionType) + " property is supported.");
}
}

expression_pair Binder::bindSetItem(std::pair<ParsedExpression*, ParsedExpression*> setItem) {
Expand All @@ -190,12 +236,13 @@ std::unique_ptr<BoundUpdatingClause> Binder::bindDeleteClause(
auto nodeOrRel = expressionBinder.bindExpression(*deleteClause.getExpression(i));
switch (nodeOrRel->dataType.getLogicalTypeID()) {
case LogicalTypeID::NODE: {
auto deleteNode = bindDeleteNode(static_pointer_cast<NodeExpression>(nodeOrRel));
boundDeleteClause->addDeleteNode(std::move(deleteNode));
auto deleteNodeInfo =
bindDeleteNodeInfo(static_pointer_cast<NodeExpression>(nodeOrRel));
boundDeleteClause->addInfo(std::move(deleteNodeInfo));
} break;
case LogicalTypeID::REL: {
auto deleteRel = bindDeleteRel(static_pointer_cast<RelExpression>(nodeOrRel));
boundDeleteClause->addDeleteRel(std::move(deleteRel));
auto deleteRel = bindDeleteRelInfo(static_pointer_cast<RelExpression>(nodeOrRel));
boundDeleteClause->addInfo(std::move(deleteRel));
} break;
default:
throw BinderException("Delete " + expressionTypeToString(nodeOrRel->expressionType) +
Expand All @@ -205,8 +252,7 @@ std::unique_ptr<BoundUpdatingClause> Binder::bindDeleteClause(
return boundDeleteClause;
}

std::unique_ptr<BoundDeleteNode> Binder::bindDeleteNode(
const std::shared_ptr<NodeExpression>& node) {
std::unique_ptr<BoundDeleteInfo> Binder::bindDeleteNodeInfo(std::shared_ptr<NodeExpression> node) {
if (node->isMultiLabeled()) {
throw BinderException(
"Delete node " + node->toString() + " with multiple node labels is not supported.");
Expand All @@ -215,16 +261,17 @@ std::unique_ptr<BoundDeleteNode> Binder::bindDeleteNode(
auto nodeTableSchema = catalog.getReadOnlyVersion()->getNodeTableSchema(nodeTableID);
auto primaryKeyExpression =
expressionBinder.bindNodePropertyExpression(*node, nodeTableSchema->getPrimaryKey().name);
return std::make_unique<BoundDeleteNode>(node, primaryKeyExpression);
auto extraInfo = std::make_unique<ExtraDeleteNodeInfo>(primaryKeyExpression);
return std::make_unique<BoundDeleteInfo>(UpdateTableType::NODE, node, std::move(extraInfo));
}

std::shared_ptr<RelExpression> Binder::bindDeleteRel(std::shared_ptr<RelExpression> rel) {
std::unique_ptr<BoundDeleteInfo> Binder::bindDeleteRelInfo(std::shared_ptr<RelExpression> rel) {
if (rel->isMultiLabeled() || rel->isBoundByMultiLabeledNode()) {
throw BinderException(
"Delete rel " + rel->toString() +
" with multiple rel labels or bound by multiple node labels is not supported.");
}
return rel;
return std::make_unique<BoundDeleteInfo>(UpdateTableType::REL, rel, nullptr /* extraInfo */);
}

} // namespace binder
Expand Down
1 change: 1 addition & 0 deletions src/binder/query/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_library(
OBJECT
bound_create_clause.cpp
bound_delete_clause.cpp
bound_merge_clause.cpp
bound_set_clause.cpp
query_graph.cpp)

Expand Down
37 changes: 24 additions & 13 deletions src/binder/query/bound_create_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,39 @@
namespace kuzu {
namespace binder {

BoundCreateClause::BoundCreateClause(const BoundCreateClause& other)
: BoundUpdatingClause{common::ClauseType::CREATE} {
for (auto& info : other.infos) {
infos.push_back(info->copy());
}
}

std::vector<expression_pair> BoundCreateClause::getAllSetItems() const {
std::vector<expression_pair> result;
for (auto& createNode : createNodes) {
for (auto& setItem : createNode->getSetItems()) {
result.push_back(setItem);
}
}
for (auto& createRel : createRels) {
for (auto& setItem : createRel->getSetItems()) {
for (auto& info : infos) {
for (auto& setItem : info->setItems) {
result.push_back(setItem);
}
}
return result;
}

std::unique_ptr<BoundUpdatingClause> BoundCreateClause::copy() {
auto result = std::make_unique<BoundCreateClause>();
for (auto& createNode : createNodes) {
result->addCreateNode(createNode->copy());
bool BoundCreateClause::hasInfo(const std::function<bool(const BoundCreateInfo&)>& check) const {
for (auto& info : infos) {
if (check(*info)) {
return true;
}
}
for (auto& createRel : createRels) {
result->addCreateRel(createRel->copy());
return false;
}

std::vector<BoundCreateInfo*> BoundCreateClause::getInfos(
const std::function<bool(const BoundCreateInfo&)>& check) const {
std::vector<BoundCreateInfo*> result;
for (auto& info : infos) {
if (check(*info)) {
result.push_back(info.get());
}
}
return result;
}
Expand Down
Loading