Skip to content

Commit

Permalink
Merge pull request #1114 from kuzudb/unlabeled-query
Browse files Browse the repository at this point in the history
Unlabeled query
  • Loading branch information
andyfengHKU committed Dec 13, 2022
2 parents de3b589 + 4dab228 commit ff546cc
Show file tree
Hide file tree
Showing 22 changed files with 213 additions and 116 deletions.
37 changes: 25 additions & 12 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,12 @@ shared_ptr<NodeExpression> Binder::bindQueryNode(
auto prevVariable = variablesInScope.at(parsedName);
ExpressionBinder::validateExpectedDataType(*prevVariable, NODE);
queryNode = static_pointer_cast<NodeExpression>(prevVariable);
auto otherTableIDs = bindNodeTableIDs(nodePattern.getTableNames());
queryNode->addTableIDs(otherTableIDs);
// E.g. MATCH (a:person) MATCH (a:organisation)
// We bind to single node a with both labels
if (!nodePattern.getTableNames().empty()) {
auto otherTableIDs = bindNodeTableIDs(nodePattern.getTableNames());
queryNode->addTableIDs(otherTableIDs);
}
} else {
queryNode = createQueryNode(nodePattern);
}
Expand Down Expand Up @@ -147,29 +151,38 @@ shared_ptr<NodeExpression> Binder::createQueryNode(const NodePattern& nodePatter
return queryNode;
}

unordered_set<table_id_t> Binder::bindTableIDs(
vector<table_id_t> Binder::bindTableIDs(
const vector<string>& tableNames, DataTypeID nodeOrRelType) {
unordered_set<table_id_t> result;
unordered_set<table_id_t> tableIDs;
switch (nodeOrRelType) {
case NODE: {
for (auto& tableName : tableNames) {
result.insert(bindNodeTableID(tableName));
if (tableNames.empty()) {
for (auto tableID : catalog.getReadOnlyVersion()->getNodeTableIDs()) {
tableIDs.insert(tableID);
}
} else {
for (auto& tableName : tableNames) {
tableIDs.insert(bindNodeTableID(tableName));
}
}

} break;
case REL: {
if (tableNames.empty()) {
for (auto tableID : catalog.getReadOnlyVersion()->getRelTableIDs()) {
tableIDs.insert(tableID);
}
}
for (auto& tableName : tableNames) {
result.insert(bindRelTableID(tableName));
tableIDs.insert(bindRelTableID(tableName));
}
} break;
default:
throw NotImplementedException(
"bindTableIDs(" + Types::dataTypeToString(nodeOrRelType) + ").");
}
for (auto& tableID : result) {
if (tableID == ANY_TABLE_ID) {
throw BinderException("Any-table is not supported.");
}
}
auto result = vector<table_id_t>{tableIDs.begin(), tableIDs.end()};
std::sort(result.begin(), result.end());
return result;
}

Expand Down
52 changes: 36 additions & 16 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,23 @@ unique_ptr<BoundUpdatingClause> Binder::bindCreateClause(const UpdatingClause& u
}
}
}

return boundCreateClause;
}

unique_ptr<BoundCreateNode> Binder::bindCreateNode(
shared_ptr<NodeExpression> node, const PropertyKeyValCollection& collection) {
if (node->getNumTableIDs() > 1) {
if (node->isMultiLabeled()) {
throw BinderException(
"Create multi-labeled node " + node->getRawName() + "is not supported.");
"Create node " + node->getRawName() + " with multiple node labels is not supported.");
}
auto nodeTableSchema = catalog.getReadOnlyVersion()->getNodeTableSchema(node->getTableID());
auto nodeTableID = node->getSingleTableID();
auto nodeTableSchema = catalog.getReadOnlyVersion()->getNodeTableSchema(nodeTableID);
auto primaryKey = nodeTableSchema->getPrimaryKey();
shared_ptr<Expression> primaryKeyExpression;
vector<expression_pair> setItems;
for (auto& [key, val] : collection.getPropertyKeyValPairs(*node)) {
auto propertyExpression = static_pointer_cast<PropertyExpression>(key);
if (propertyExpression->getPropertyID(node->getTableID()) == primaryKey.propertyID) {
if (propertyExpression->getPropertyID(nodeTableID) == primaryKey.propertyID) {
primaryKeyExpression = val;
}
setItems.emplace_back(key, val);
Expand All @@ -78,15 +78,17 @@ unique_ptr<BoundCreateNode> Binder::bindCreateNode(

unique_ptr<BoundCreateRel> Binder::bindCreateRel(
shared_ptr<RelExpression> rel, const PropertyKeyValCollection& collection) {
if (rel->getNumTableIDs() > 1) {
if (rel->isMultiLabeled() || rel->isBoundByMultiLabeledNode()) {
throw BinderException(
"Create multi-labeled rel " + rel->getRawName() + "is not supported.");
"Create rel " + rel->getRawName() +
" with multiple rel labels or bound by multiple node labels is not supported.");
}
auto relTableID = rel->getSingleTableID();
auto catalogContent = catalog.getReadOnlyVersion();
// CreateRel requires all properties in schema as input. So we rewrite set property to
// null if user does not specify a property in the query.
vector<expression_pair> setItems;
for (auto& property : catalogContent->getRelProperties(rel->getTableID())) {
for (auto& property : catalogContent->getRelProperties(relTableID)) {
if (collection.hasPropertyKeyValPair(*rel, property.name)) {
setItems.push_back(collection.getPropertyKeyValPair(*rel, property.name));
} else {
Expand All @@ -108,8 +110,15 @@ unique_ptr<BoundUpdatingClause> Binder::bindSetClause(const UpdatingClause& upda
for (auto i = 0u; i < setClause.getNumSetItems(); ++i) {
auto setItem = setClause.getSetItem(i);
auto boundLhs = expressionBinder.bindExpression(*setItem->origin);
if (boundLhs->getChild(0)->dataType.typeID != NODE) {
throw BinderException("Only updating node properties is supported.");
auto boundNodeOrRel = boundLhs->getChild(0);
if (boundNodeOrRel->dataType.typeID != NODE) {
throw BinderException("Set " + Types::dataTypeToString(boundNodeOrRel->dataType) +
" property is supported.");
}
auto boundNode = static_pointer_cast<NodeExpression>(boundNodeOrRel);
if (boundNode->isMultiLabeled()) {
throw BinderException("Set property of node " + boundNode->getRawName() +
" with multiple node labels is not supported.");
}
auto boundRhs = expressionBinder.bindExpression(*setItem->target);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
Expand All @@ -124,10 +133,11 @@ unique_ptr<BoundUpdatingClause> Binder::bindDeleteClause(const UpdatingClause& u
for (auto i = 0u; i < deleteClause.getNumExpressions(); ++i) {
auto boundExpression = expressionBinder.bindExpression(*deleteClause.getExpression(i));
if (boundExpression->dataType.typeID == NODE) {
boundDeleteClause->addDeleteNode(
bindDeleteNode(static_pointer_cast<NodeExpression>(boundExpression)));
auto deleteNode = bindDeleteNode(static_pointer_cast<NodeExpression>(boundExpression));
boundDeleteClause->addDeleteNode(std::move(deleteNode));
} else if (boundExpression->dataType.typeID == REL) {
boundDeleteClause->addDeleteRel(static_pointer_cast<RelExpression>(boundExpression));
auto deleteRel = bindDeleteRel(static_pointer_cast<RelExpression>(boundExpression));
boundDeleteClause->addDeleteRel(std::move(deleteRel));
} else {
throw BinderException("Delete " +
expressionTypeToString(boundExpression->expressionType) +
Expand All @@ -138,16 +148,26 @@ unique_ptr<BoundUpdatingClause> Binder::bindDeleteClause(const UpdatingClause& u
}

unique_ptr<BoundDeleteNode> Binder::bindDeleteNode(shared_ptr<NodeExpression> node) {
if (node->getNumTableIDs() > 1) {
if (node->isMultiLabeled()) {
throw BinderException(
"Delete multi-labeled node " + node->getRawName() + "is not supported.");
"Delete node " + node->getRawName() + " with multiple node labels is not supported.");
}
auto nodeTableSchema = catalog.getReadOnlyVersion()->getNodeTableSchema(node->getTableID());
auto nodeTableID = node->getSingleTableID();
auto nodeTableSchema = catalog.getReadOnlyVersion()->getNodeTableSchema(nodeTableID);
auto primaryKey = nodeTableSchema->getPrimaryKey();
auto primaryKeyExpression =
expressionBinder.bindNodePropertyExpression(node, vector<Property>{primaryKey});
return make_unique<BoundDeleteNode>(node, primaryKeyExpression);
}

shared_ptr<RelExpression> Binder::bindDeleteRel(shared_ptr<RelExpression> rel) {
if (rel->isMultiLabeled() || rel->isBoundByMultiLabeledNode()) {
throw BinderException(
"Delete rel " + rel->getRawName() +
" with multiple rel labels or bound by multiple node labels is not supported.");
}
return rel;
}

} // namespace binder
} // namespace kuzu
6 changes: 0 additions & 6 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,13 @@ shared_ptr<Expression> Binder::bindWhereExpression(const ParsedExpression& parse
}

table_id_t Binder::bindRelTableID(const string& tableName) const {
if (tableName.empty()) {
return ANY_TABLE_ID;
}
if (!catalog.getReadOnlyVersion()->containRelTable(tableName)) {
throw BinderException("Rel table " + tableName + " does not exist.");
}
return catalog.getReadOnlyVersion()->getRelTableIDFromName(tableName);
}

table_id_t Binder::bindNodeTableID(const string& tableName) const {
if (tableName.empty()) {
return ANY_TABLE_ID;
}
if (!catalog.getReadOnlyVersion()->containNodeTable(tableName)) {
throw BinderException("Node table " + tableName + " does not exist.");
}
Expand Down
8 changes: 4 additions & 4 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class Binder {
unique_ptr<BoundCreateRel> bindCreateRel(
shared_ptr<RelExpression> rel, const PropertyKeyValCollection& collection);
unique_ptr<BoundDeleteNode> bindDeleteNode(shared_ptr<NodeExpression> node);
shared_ptr<RelExpression> bindDeleteRel(shared_ptr<RelExpression> rel);

/*** bind projection clause ***/
unique_ptr<BoundWithClause> bindWithClause(const WithClause& withClause);
Expand Down Expand Up @@ -112,14 +113,13 @@ class Binder {
shared_ptr<NodeExpression> bindQueryNode(const NodePattern& nodePattern, QueryGraph& queryGraph,
PropertyKeyValCollection& collection);
shared_ptr<NodeExpression> createQueryNode(const NodePattern& nodePattern);
inline unordered_set<table_id_t> bindNodeTableIDs(const vector<string>& tableNames) {
inline vector<table_id_t> bindNodeTableIDs(const vector<string>& tableNames) {
return bindTableIDs(tableNames, NODE);
}
inline unordered_set<table_id_t> bindRelTableIDs(const vector<string>& tableNames) {
inline vector<table_id_t> bindRelTableIDs(const vector<string>& tableNames) {
return bindTableIDs(tableNames, REL);
}
unordered_set<table_id_t> bindTableIDs(
const vector<string>& tableNames, DataTypeID nodeOrRelType);
vector<table_id_t> bindTableIDs(const vector<string>& tableNames, DataTypeID nodeOrRelType);

/*** validations ***/
// E.g. Optional MATCH (a) RETURN a.age
Expand Down
18 changes: 4 additions & 14 deletions src/include/binder/expression/node_expression.h
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
#pragma once

#include "node_rel_expression.h"
#include "property_expression.h"

namespace kuzu {
namespace binder {

class NodeExpression : public Expression {
class NodeExpression : public NodeOrRelExpression {
public:
NodeExpression(const string& uniqueName, unordered_set<table_id_t> tableIDs)
: Expression{VARIABLE, NODE, uniqueName}, tableIDs{std::move(tableIDs)} {}

inline void addTableIDs(const unordered_set<table_id_t>& tableIDsToAdd) {
tableIDs.insert(tableIDsToAdd.begin(), tableIDsToAdd.end());
}
inline uint32_t getNumTableIDs() const { return tableIDs.size(); }
inline unordered_set<table_id_t> getTableIDs() const { return tableIDs; }
inline table_id_t getTableID() const {
assert(tableIDs.size() == 1);
return *tableIDs.begin();
}
NodeExpression(const string& uniqueName, vector<table_id_t> tableIDs)
: NodeOrRelExpression{NODE, uniqueName, std::move(tableIDs)} {}

inline void setInternalIDProperty(shared_ptr<Expression> expression) {
internalIDExpression = std::move(expression);
Expand All @@ -32,7 +23,6 @@ class NodeExpression : public Expression {
}

private:
unordered_set<table_id_t> tableIDs;
shared_ptr<Expression> internalIDExpression;
};

Expand Down
34 changes: 34 additions & 0 deletions src/include/binder/expression/node_rel_expression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include "expression.h"

namespace kuzu {
namespace binder {

class NodeOrRelExpression : public Expression {
public:
NodeOrRelExpression(
DataTypeID dataTypeID, const string& uniqueName, vector<table_id_t> tableIDs)
: Expression{VARIABLE, dataTypeID, uniqueName}, tableIDs{std::move(tableIDs)} {}

inline void addTableIDs(const vector<table_id_t>& tableIDsToAdd) {
auto tableIDsMap = unordered_set<table_id_t>(tableIDs.begin(), tableIDs.end());
for (auto tableID : tableIDsToAdd) {
if (!tableIDsMap.contains(tableID)) {
tableIDs.push_back(tableID);
}
}
}
inline bool isMultiLabeled() const { return tableIDs.size() > 1; }
inline vector<table_id_t> getTableIDs() const { return tableIDs; }
inline table_id_t getSingleTableID() const {
assert(tableIDs.size() == 1);
return tableIDs[0];
}

protected:
vector<table_id_t> tableIDs;
};

} // namespace binder
} // namespace kuzu
19 changes: 6 additions & 13 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,25 @@
namespace kuzu {
namespace binder {

class RelExpression : public Expression {
class RelExpression : public NodeOrRelExpression {
public:
RelExpression(const string& uniqueName, unordered_set<table_id_t> tableIDs,
RelExpression(const string& uniqueName, vector<table_id_t> tableIDs,
shared_ptr<NodeExpression> srcNode, shared_ptr<NodeExpression> dstNode, uint64_t lowerBound,
uint64_t upperBound)
: Expression{VARIABLE, REL, uniqueName}, tableIDs{std::move(tableIDs)}, srcNode{std::move(
srcNode)},
: NodeOrRelExpression{REL, uniqueName, std::move(tableIDs)}, srcNode{std::move(srcNode)},
dstNode{std::move(dstNode)}, lowerBound{lowerBound}, upperBound{upperBound} {}

inline table_id_t getTableID() const { return *tableIDs.begin(); }
inline uint32_t getNumTableIDs() const { return tableIDs.size(); }
inline unordered_set<table_id_t> getTableIDs() const { return tableIDs; }
inline bool isBoundByMultiLabeledNode() const {
return srcNode->isMultiLabeled() || dstNode->isMultiLabeled();
}

inline shared_ptr<NodeExpression> getSrcNode() const { return srcNode; }

inline string getSrcNodeName() const { return srcNode->getUniqueName(); }

inline shared_ptr<NodeExpression> getDstNode() const { return dstNode; }

inline string getDstNodeName() const { return dstNode->getUniqueName(); }

inline uint64_t getLowerBound() const { return lowerBound; }

inline uint64_t getUpperBound() const { return upperBound; }

inline bool isVariableLength() const { return !(lowerBound == 1 && upperBound == 1); }

inline void setInternalIDProperty(shared_ptr<Expression> expression) {
Expand All @@ -46,7 +40,6 @@ class RelExpression : public Expression {
}

private:
unordered_set<table_id_t> tableIDs;
shared_ptr<NodeExpression> srcNode;
shared_ptr<NodeExpression> dstNode;
uint64_t lowerBound;
Expand Down
1 change: 0 additions & 1 deletion src/include/binder/query/reading_clause/query_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
namespace kuzu {
namespace binder {

const table_id_t ANY_TABLE_ID = numeric_limits<uint32_t>::max();
const uint8_t MAX_NUM_VARIABLES = 64;

class QueryGraph;
Expand Down
14 changes: 14 additions & 0 deletions src/include/catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@ class CatalogContent {
inline const vector<Property>& getRelProperties(table_id_t tableID) const {
return relTableSchemas.at(tableID)->properties;
}
inline vector<table_id_t> getNodeTableIDs() const {
vector<table_id_t> nodeTableIDs;
for (auto& [tableID, _] : nodeTableSchemas) {
nodeTableIDs.push_back(tableID);
}
return nodeTableIDs;
}
inline vector<table_id_t> getRelTableIDs() const {
vector<table_id_t> relTableIDs;
for (auto& [tableID, _] : relTableSchemas) {
relTableIDs.push_back(tableID);
}
return relTableIDs;
}
inline unordered_map<table_id_t, unique_ptr<NodeTableSchema>>& getNodeTableSchemas() {
return nodeTableSchemas;
}
Expand Down
2 changes: 1 addition & 1 deletion src/planner/asp_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ bool ASPOptimizer::canApplyASP(const vector<shared_ptr<NodeExpression>>& joinNod
}
auto rightScanNodeID = (LogicalScanNode*)rightScanNodeIDs[0];
// Semi mask cannot be applied to a ScanNodeID on multiple node tables.
if (rightScanNodeID->getNode()->getNumTableIDs() > 1) {
if (rightScanNodeID->getNode()->isMultiLabeled()) {
return false;
}
// Semi mask can only be pushed to ScanNodeIDs.
Expand Down
Loading

0 comments on commit ff546cc

Please sign in to comment.