Skip to content

Commit

Permalink
Merge pull request #1057 from kuzudb/multi-labeled-node-scan
Browse files Browse the repository at this point in the history
Multi labeled node scan
  • Loading branch information
andyfengHKU committed Nov 25, 2022
2 parents b67ad7f + 8a846c7 commit 814ce9a
Show file tree
Hide file tree
Showing 28 changed files with 2,292 additions and 2,074 deletions.
7 changes: 5 additions & 2 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ oC_PatternElement
;

oC_NodePattern
: '(' SP? ( oC_Variable SP? )? ( oC_NodeLabel SP? )? ( kU_Properties SP? )? ')'
| SP? ( oC_Variable SP? )? ( oC_NodeLabel SP? )? ( kU_Properties SP? )? { notifyNodePatternWithoutParentheses($oC_Variable.text, $oC_Variable.start); }
: '(' SP? ( oC_Variable SP? )? ( oC_NodeLabels SP? )? ( kU_Properties SP? )? ')'
| SP? ( oC_Variable SP? )? ( oC_NodeLabels SP? )? ( kU_Properties SP? )? { notifyNodePatternWithoutParentheses($oC_Variable.text, $oC_Variable.start); }
;

oC_PatternElementChain
Expand All @@ -267,6 +267,9 @@ oC_RelationshipDetail
kU_Properties
: '{' SP? ( oC_PropertyKeyName SP? ':' SP? oC_Expression SP? ( ',' SP? oC_PropertyKeyName SP? ':' SP? oC_Expression SP? )* )? '}';

oC_NodeLabels
: oC_NodeLabel ( SP? oC_NodeLabel )* ;

oC_NodeLabel
: ':' SP? oC_LabelName ;

Expand Down
2 changes: 1 addition & 1 deletion src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ unique_ptr<BoundStatement> Binder::bindCreateRelClause(const Statement& statemen
auto relConnections = createRelClause.getRelConnections();
vector<pair<table_id_t, table_id_t>> srcDstTableIDs;
for (auto& [srcTableName, dstTableName] : relConnections) {
srcDstTableIDs.emplace_back(bindNodeTable(srcTableName), bindNodeTable(dstTableName));
srcDstTableIDs.emplace_back(bindNodeTableID(srcTableName), bindNodeTableID(dstTableName));
}
return make_unique<BoundCreateRelClause>(
tableName, move(propertyNameDataTypes), relMultiplicity, srcDstTableIDs);
Expand Down
56 changes: 42 additions & 14 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@ unique_ptr<QueryGraph> Binder::bindPatternElement(
return queryGraph;
}

// TODO(Xiyang): remove this validation when we support full multi-labeled query
static void validateNodeRelConnectivity(table_id_t srcTableID, table_id_t dstTableID,
table_id_t relTableID, const CatalogContent& catalogContent) {
for (auto& [srcTableID_, dstTableID_] :
catalogContent.getRelTableSchema(relTableID)->srcDstTableIDs) {
if (srcTableID_ == srcTableID && dstTableID_ == dstTableID) {
return;
}
}
throw BinderException("Node table " + catalogContent.getNodeTableName(srcTableID) +
" doesn't connect to " + catalogContent.getNodeTableName(dstTableID) +
" through rel table " + catalogContent.getRelTableName(relTableID) + ".");
}

// E.g. MATCH (:person)-[:studyAt]->(:person) ...
static void validateNodeRelConnectivity(const Catalog& catalog_, const RelExpression& rel,
const NodeExpression& srcNode, const NodeExpression& dstNode) {
for (auto srcTableID : srcNode.getTableIDs()) {
for (auto dstTableID : dstNode.getTableIDs()) {
validateNodeRelConnectivity(
srcTableID, dstTableID, rel.getTableID(), *catalog_.getReadOnlyVersion());
}
}
}

void Binder::bindQueryRel(const RelPattern& relPattern, const shared_ptr<NodeExpression>& leftNode,
const shared_ptr<NodeExpression>& rightNode, QueryGraph& queryGraph,
PropertyKeyValCollection& collection) {
Expand All @@ -56,8 +81,6 @@ void Binder::bindQueryRel(const RelPattern& relPattern, const shared_ptr<NodeExp
auto isLeftNodeSrc = RIGHT == relPattern.getDirection();
auto srcNode = isLeftNodeSrc ? leftNode : rightNode;
auto dstNode = isLeftNodeSrc ? rightNode : leftNode;
validateNodeAndRelTableIsConnected(
catalog, tableID, srcNode->getTableID(), dstNode->getTableID());
if (srcNode->getUniqueName() == dstNode->getUniqueName()) {
throw BinderException("Self-loop rel " + parsedName + " is not supported.");
}
Expand All @@ -79,6 +102,7 @@ void Binder::bindQueryRel(const RelPattern& relPattern, const shared_ptr<NodeExp
if (!parsedName.empty()) {
variablesInScope.insert({parsedName, queryRel});
}
validateNodeRelConnectivity(catalog, *queryRel, *srcNode, *dstNode);
for (auto i = 0u; i < relPattern.getNumPropertyKeyValPairs(); ++i) {
auto [propertyName, rhs] = relPattern.getProperty(i);
auto boundLhs = expressionBinder.bindRelPropertyExpression(queryRel, propertyName);
Expand All @@ -97,12 +121,8 @@ shared_ptr<NodeExpression> Binder::bindQueryNode(
auto prevVariable = variablesInScope.at(parsedName);
ExpressionBinder::validateExpectedDataType(*prevVariable, NODE);
queryNode = static_pointer_cast<NodeExpression>(prevVariable);
auto otherTableID = bindNodeTable(nodePattern.getTableName());
KU_ASSERT(queryNode->getTableID() != ANY_TABLE_ID);
if (otherTableID != ANY_TABLE_ID && queryNode->getTableID() != otherTableID) {
throw BinderException(
"Multi-table is not supported. Node " + parsedName + " is given multiple tables.");
}
auto otherTableIDs = bindNodeTableIDs(nodePattern.getTableNames());
queryNode->addTableIDs(otherTableIDs);
} else {
queryNode = createQueryNode(nodePattern);
}
Expand All @@ -119,19 +139,27 @@ shared_ptr<NodeExpression> Binder::bindQueryNode(

shared_ptr<NodeExpression> Binder::createQueryNode(const NodePattern& nodePattern) {
auto parsedName = nodePattern.getVariableName();
auto tableID = bindNodeTable(nodePattern.getTableName());
auto queryNode = make_shared<NodeExpression>(getUniqueExpressionName(parsedName), tableID);
auto tableIDs = bindNodeTableIDs(nodePattern.getTableNames());
auto queryNode = make_shared<NodeExpression>(getUniqueExpressionName(parsedName), tableIDs);
queryNode->setAlias(parsedName);
queryNode->setRawName(parsedName);
if (ANY_TABLE_ID == tableID) {
throw BinderException(
"Any-table is not supported. " + parsedName + " does not have a table.");
}
if (!parsedName.empty()) {
variablesInScope.insert({parsedName, queryNode});
}
return queryNode;
}

unordered_set<table_id_t> Binder::bindNodeTableIDs(const vector<string>& nodeTableNames) {
unordered_set<table_id_t> result;
for (auto& nodeTableName : nodeTableNames) {
auto nodeTableID = bindNodeTableID(nodeTableName);
if (nodeTableID == ANY_TABLE_ID) {
throw BinderException("Any-table is not supported");
}
result.insert(nodeTableID);
}
return result;
}

} // namespace binder
} // namespace kuzu
4 changes: 4 additions & 0 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ expression_vector Binder::rewriteProjectionExpressions(const expression_vector&

expression_vector Binder::rewriteNodeAsAllProperties(const shared_ptr<Expression>& expression) {
auto& node = (NodeExpression&)*expression;
if (node.getNumTableIDs() > 1) {
throw BinderException(
"Cannot rewrite multi-labeled node " + node.getRawName() + " as all properties.");
}
expression_vector result;
for (auto& property : catalog.getReadOnlyVersion()->getAllNodeProperties(node.getTableID())) {
auto propertyExpression = expressionBinder.bindNodePropertyExpression(expression, property);
Expand Down
8 changes: 8 additions & 0 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ unique_ptr<BoundUpdatingClause> Binder::bindCreateClause(const UpdatingClause& u

unique_ptr<BoundCreateNode> Binder::bindCreateNode(
shared_ptr<NodeExpression> node, const PropertyKeyValCollection& collection) {
if (node->getNumTableIDs() > 1) {
throw BinderException(
"Create multi-labeled node " + node->getRawName() + "is not supported.");
}
auto nodeTableSchema = catalog.getReadOnlyVersion()->getNodeTableSchema(node->getTableID());
auto primaryKey = nodeTableSchema->getPrimaryKey();
shared_ptr<Expression> primaryKeyExpression;
Expand Down Expand Up @@ -130,6 +134,10 @@ unique_ptr<BoundUpdatingClause> Binder::bindDeleteClause(const UpdatingClause& u
}

unique_ptr<BoundDeleteNode> Binder::bindDeleteNode(shared_ptr<NodeExpression> node) {
if (node->getNumTableIDs() > 1) {
throw BinderException(
"Delete multi-labeled node " + node->getRawName() + "is not supported.");
}
auto nodeTableSchema = catalog.getReadOnlyVersion()->getNodeTableSchema(node->getTableID());
auto primaryKey = nodeTableSchema->getPrimaryKey();
auto primaryKeyExpression = expressionBinder.bindNodePropertyExpression(node, primaryKey);
Expand Down
17 changes: 1 addition & 16 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ table_id_t Binder::bindRelTable(const string& tableName) const {
return catalog.getReadOnlyVersion()->getRelTableIDFromName(tableName);
}

table_id_t Binder::bindNodeTable(const string& tableName) const {
table_id_t Binder::bindNodeTableID(const string& tableName) const {
if (tableName.empty()) {
return ANY_TABLE_ID;
}
Expand Down Expand Up @@ -69,21 +69,6 @@ void Binder::validateFirstMatchIsNotOptional(const SingleQuery& singleQuery) {
}
}

void Binder::validateNodeAndRelTableIsConnected(
const Catalog& catalog_, table_id_t relTableID, table_id_t srcTableID, table_id_t dstTableID) {
assert(relTableID != ANY_TABLE_ID && srcTableID != ANY_TABLE_ID && dstTableID != ANY_TABLE_ID);
for (auto& [srcTableID_, dstTableID_] :
catalog_.getReadOnlyVersion()->getRelTableSchema(relTableID)->srcDstTableIDs) {
if (srcTableID_ == srcTableID && dstTableID_ == dstTableID) {
return;
}
}
throw BinderException(
"Node table " + catalog_.getReadOnlyVersion()->getNodeTableName(srcTableID) +
" doesn't connect to " + catalog_.getReadOnlyVersion()->getNodeTableName(dstTableID) +
" through rel table " + catalog_.getReadOnlyVersion()->getRelTableName(relTableID) + ".");
}

void Binder::validateProjectionColumnNamesAreUnique(const expression_vector& expressions) {
auto existColumnNames = unordered_set<string>();
for (auto& expression : expressions) {
Expand Down
3 changes: 3 additions & 0 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ shared_ptr<Expression> ExpressionBinder::bindNodePropertyExpression(
shared_ptr<Expression> node, const string& propertyName) {
auto catalogContent = binder->catalog.getReadOnlyVersion();
auto nodeExpression = static_pointer_cast<NodeExpression>(node);
if (nodeExpression->getNumTableIDs() > 1) {
throw BinderException("Cannot bind property for multi-labeled node " + node->getRawName());
}
if (catalogContent->containNodeProperty(nodeExpression->getTableID(), propertyName)) {
auto& property =
catalogContent->getNodeProperty(nodeExpression->getTableID(), propertyName);
Expand Down
5 changes: 0 additions & 5 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,6 @@ const Property& CatalogContent::getRelProperty(
assert(false);
}

const Property& CatalogContent::getNodePrimaryKeyProperty(table_id_t tableID) const {
auto primaryKeyId = nodeTableSchemas.at(tableID)->primaryKeyPropertyIdx;
return nodeTableSchemas.at(tableID)->structuredProperties[primaryKeyId];
}

vector<Property> CatalogContent::getAllNodeProperties(table_id_t tableID) const {
return nodeTableSchemas.at(tableID)->getAllNodeProperties();
}
Expand Down
7 changes: 2 additions & 5 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Binder {
shared_ptr<Expression> bindWhereExpression(const ParsedExpression& parsedExpression);

table_id_t bindRelTable(const string& tableName) const;
table_id_t bindNodeTable(const string& tableName) const;
table_id_t bindNodeTableID(const string& tableName) const;

shared_ptr<Expression> createVariable(const string& name, const DataType& dataType);

Expand Down Expand Up @@ -112,17 +112,14 @@ class Binder {
shared_ptr<NodeExpression> bindQueryNode(const NodePattern& nodePattern, QueryGraph& queryGraph,
PropertyKeyValCollection& collection);
shared_ptr<NodeExpression> createQueryNode(const NodePattern& nodePattern);
unordered_set<table_id_t> bindNodeTableIDs(const vector<string>& nodeTableNames);

/*** validations ***/
// E.g. Optional MATCH (a) RETURN a.age
// Although this is doable in Neo4j, I don't think the semantic make a lot of sense because
// there is nothing to left join on.
static void validateFirstMatchIsNotOptional(const SingleQuery& singleQuery);

// E.g. MATCH (:person)-[:studyAt]->(:person) ...
static void validateNodeAndRelTableIsConnected(const Catalog& catalog_, table_id_t relTableID,
table_id_t srcTableID, table_id_t dstTableID);

// E.g. ... RETURN a, b AS a
static void validateProjectionColumnNamesAreUnique(const expression_vector& expressions);

Expand Down
17 changes: 12 additions & 5 deletions src/include/binder/expression/node_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@ namespace kuzu {
namespace binder {

class NodeExpression : public Expression {

public:
NodeExpression(const string& uniqueName, table_id_t tableID)
: Expression{VARIABLE, NODE, uniqueName}, tableID{tableID} {}
NodeExpression(const string& uniqueName, unordered_set<table_id_t> tableIDs)
: Expression{VARIABLE, NODE, uniqueName}, tableIDs{std::move(tableIDs)} {}

inline table_id_t getTableID() const { return tableID; }
inline void addTableIDs(const unordered_set<table_id_t>& tableIDsToAdd) {
tableIDs.insert(tableIDsToAdd.begin(), tableIDsToAdd.end());
}
inline uint32_t getNumTableIDs() const { return tableIDs.size(); }
inline unordered_set<table_id_t> getTableIDs() const { return tableIDs; }
inline table_id_t getTableID() const {
assert(tableIDs.size() == 1);
return *tableIDs.begin();
}

inline string getIDProperty() const { return uniqueName + "." + INTERNAL_ID_SUFFIX; }

Expand All @@ -21,7 +28,7 @@ class NodeExpression : public Expression {
}

private:
table_id_t tableID;
unordered_set<table_id_t> tableIDs;
};

} // namespace binder
Expand Down
4 changes: 1 addition & 3 deletions src/include/catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ class CatalogContent {
virtual const Property& getNodeProperty(table_id_t tableID, const string& propertyName) const;
virtual const Property& getRelProperty(table_id_t tableID, const string& propertyName) const;

const Property& getNodePrimaryKeyProperty(table_id_t tableID) const;

vector<Property> getAllNodeProperties(table_id_t tableID) const;
inline const vector<Property>& getRelProperties(table_id_t tableID) const {
return relTableSchemas.at(tableID)->properties;
Expand All @@ -116,7 +114,7 @@ class CatalogContent {
/**
* Graph topology functions.
*/

// TODO(Xiyang): remove
virtual const unordered_set<table_id_t>& getRelTableIDsForNodeTableDirection(
table_id_t tableID, RelDirection direction) const;

Expand Down
1 change: 1 addition & 0 deletions src/include/common/types/node_id_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ namespace common {

typedef uint64_t table_id_t;
typedef uint64_t node_offset_t;
constexpr node_offset_t INVALID_NODE_OFFSET = UINT64_MAX;

// System representation for nodeID.
struct nodeID_t {
Expand Down
12 changes: 6 additions & 6 deletions src/include/parser/query/graph_pattern/node_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,29 @@ namespace kuzu {
namespace parser {

/**
* NodePattern represents "(nodeName:NodeTable {p1:v1, p2:v2, ...})"
* NodePattern represents "(nodeName:NodeTable+ {p1:v1, p2:v2, ...})"
*/
class NodePattern {
public:
NodePattern(string name, string tableName,
NodePattern(string name, vector<string> tableNames,
vector<pair<string, unique_ptr<ParsedExpression>>> propertyKeyValPairs)
: variableName{std::move(name)}, tableName{std::move(tableName)},
: variableName{std::move(name)}, tableNames{std::move(tableNames)},
propertyKeyValPairs{std::move(propertyKeyValPairs)} {}

virtual ~NodePattern() = default;

inline string getVariableName() const { return variableName; }

inline string getTableName() const { return tableName; }
inline vector<string> getTableNames() const { return tableNames; }

inline uint32_t getNumPropertyKeyValPairs() const { return propertyKeyValPairs.size(); }
inline pair<string, ParsedExpression*> getProperty(uint32_t idx) const {
return make_pair(propertyKeyValPairs[idx].first, propertyKeyValPairs[idx].second.get());
}

private:
protected:
string variableName;
string tableName;
vector<string> tableNames;
vector<pair<string, unique_ptr<ParsedExpression>>> propertyKeyValPairs;
};

Expand Down
8 changes: 7 additions & 1 deletion src/include/parser/query/graph_pattern/rel_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@ class RelPattern : public NodePattern {
RelPattern(string name, string tableName, string lowerBound, string upperBound,
ArrowDirection arrowDirection,
vector<pair<string, unique_ptr<ParsedExpression>>> propertyKeyValPairs)
: NodePattern{std::move(name), std::move(tableName), std::move(propertyKeyValPairs)},
: NodePattern{std::move(name), vector<string>{std::move(tableName)},
std::move(propertyKeyValPairs)},
lowerBound{std::move(lowerBound)}, upperBound{std::move(upperBound)},
arrowDirection{arrowDirection} {}

~RelPattern() = default;

inline string getTableName() const {
assert(tableNames.size() == 1);
return tableNames[0];
}

inline string getLowerBound() const { return lowerBound; }

inline string getUpperBound() const { return upperBound; }
Expand Down
Loading

0 comments on commit 814ce9a

Please sign in to comment.