Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive join rel filter #1726

Merged
merged 1 commit into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ oC_RelationshipPattern
;

oC_RelationshipDetail
: '[' SP? ( oC_Variable SP? )? ( oC_RelationshipTypes SP? )? ( oC_RangeLiteral SP? ) ? ( kU_Properties SP? ) ? ']' ;
: '[' SP? ( oC_Variable SP? )? ( oC_RelationshipTypes SP? )? ( oC_RangeLiteral SP? )? ( kU_Properties SP? )? ']' ;

// The original oC_Properties definition is oC_MapLiteral | oC_Parameter.
// We choose to not support parameter as properties which will be the decision for a long time.
Expand All @@ -322,7 +322,7 @@ oC_NodeLabel
: ':' SP? oC_LabelName ;

oC_RangeLiteral
: '*' SP? ( SHORTEST | ALL SP SHORTEST )? SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral ;
: '*' SP? ( SHORTEST | ALL SP SHORTEST )? SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral (SP? '(' SP? oC_Variable SP? ',' SP? '_' SP? '|' SP? oC_Where SP? ')')? ;

SHORTEST : ( 'S' | 's' ) ( 'H' | 'h' ) ( 'O' | 'o' ) ( 'R' | 'r' ) ( 'T' | 't' ) ( 'E' | 'e' ) ( 'S' | 's' ) ( 'T' | 't' ) ;

Expand Down
20 changes: 15 additions & 5 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,15 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
recursiveNodeTableIDs.insert(relTableSchema->srcTableID);
recursiveNodeTableIDs.insert(relTableSchema->dstTableID);
}
auto recursiveRelPatternInfo = relPattern.getRecursiveInfo();
auto tmpNode = createQueryNode(
InternalKeyword::ANONYMOUS, std::vector<common::table_id_t>{recursiveNodeTableIDs.begin(),
recursiveNodeTableIDs.end()});
auto prevScope = saveScope();
variableScope->clear();
auto tmpRel = createNonRecursiveQueryRel(
InternalKeyword::ANONYMOUS, tableIDs, nullptr, nullptr, directionType);
recursiveRelPatternInfo->relName, tableIDs, nullptr, nullptr, directionType);
variableScope->addExpression(tmpRel->toString(), tmpRel);
expression_vector predicates;
for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindRelPropertyExpression(*tmpRel, propertyName);
Expand All @@ -261,6 +265,11 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
predicates.push_back(
expressionBinder.createEqualityComparisonExpression(boundLhs, boundRhs));
}
if (recursiveRelPatternInfo->whereExpression != nullptr) {
predicates.push_back(
expressionBinder.bindExpression(*recursiveRelPatternInfo->whereExpression));
}
restoreScope(std::move(prevScope));
auto parsedName = relPattern.getVariableName();
auto queryRel = make_shared<RelExpression>(*getRecursiveRelLogicalType(*tmpNode, *tmpRel),
getUniqueExpressionName(parsedName), parsedName, tableIDs, std::move(srcNode),
Expand All @@ -276,10 +285,11 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel

std::pair<uint64_t, uint64_t> Binder::bindVariableLengthRelBound(
const kuzu::parser::RelPattern& relPattern) {
auto lowerBound = std::min(TypeUtils::convertToUint32(relPattern.getLowerBound().c_str()),
VAR_LENGTH_EXTEND_MAX_DEPTH);
auto upperBound = std::min(TypeUtils::convertToUint32(relPattern.getUpperBound().c_str()),
VAR_LENGTH_EXTEND_MAX_DEPTH);
auto recursiveInfo = relPattern.getRecursiveInfo();
auto lowerBound = std::min(
TypeUtils::convertToUint32(recursiveInfo->lowerBound.c_str()), VAR_LENGTH_EXTEND_MAX_DEPTH);
auto upperBound = std::min(
TypeUtils::convertToUint32(recursiveInfo->upperBound.c_str()), VAR_LENGTH_EXTEND_MAX_DEPTH);
if (lowerBound == 0 || upperBound == 0) {
throw BinderException("Lower and upper bound of a rel must be greater than 0.");
}
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind_expression/bind_subquery_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindExistentialSubqueryExpression(
const ParsedExpression& parsedExpression) {
auto& subqueryExpression = (ParsedSubqueryExpression&)parsedExpression;
auto prevVariableScope = binder->enterSubquery();
auto prevVariableScope = binder->saveScope();
auto [queryGraph, _] = binder->bindGraphPattern(subqueryExpression.getPatternElements());
auto rawName = parsedExpression.getRawName();
auto uniqueName = binder->getUniqueExpressionName(rawName);
Expand All @@ -21,7 +21,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindExistentialSubqueryExpression(
boundSubqueryExpression->setWhereExpression(
binder->bindWhereExpression(*subqueryExpression.getWhereClause()));
}
binder->exitSubquery(std::move(prevVariableScope));
binder->restoreScope(std::move(prevVariableScope));
return boundSubqueryExpression;
}

Expand Down
4 changes: 2 additions & 2 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ std::string Binder::getUniqueExpressionName(const std::string& name) {
return "_" + std::to_string(lastExpressionId++) + "_" + name;
}

std::unique_ptr<VariableScope> Binder::enterSubquery() {
std::unique_ptr<VariableScope> Binder::saveScope() {
return variableScope->copy();
}

void Binder::exitSubquery(std::unique_ptr<VariableScope> prevVariableScope) {
void Binder::restoreScope(std::unique_ptr<VariableScope> prevVariableScope) {
variableScope = std::move(prevVariableScope);
}

Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ class Binder {
/*** helpers ***/
std::string getUniqueExpressionName(const std::string& name);

std::unique_ptr<VariableScope> enterSubquery();
void exitSubquery(std::unique_ptr<VariableScope> prevVariableScope);
std::unique_ptr<VariableScope> saveScope();
void restoreScope(std::unique_ptr<VariableScope> prevVariableScope);

private:
const catalog::Catalog& catalog;
Expand Down
31 changes: 21 additions & 10 deletions src/include/parser/query/graph_pattern/rel_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,44 @@ namespace kuzu {
namespace parser {

enum class ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1, BOTH = 2 };

struct RecursiveRelPatternInfo {
std::string lowerBound;
std::string upperBound;
std::string relName;
std::unique_ptr<ParsedExpression> whereExpression;

RecursiveRelPatternInfo(std::string lowerBound, std::string upperBound, std::string relName,
std::unique_ptr<ParsedExpression> whereExpression)
: lowerBound{std::move(lowerBound)}, upperBound{std::move(upperBound)},
relName{std::move(relName)}, whereExpression{std::move(whereExpression)} {}
};

/**
* RelationshipPattern represents "-[relName:RelTableName+]-"
*/
class RelPattern : public NodePattern {
public:
RelPattern(std::string name, std::vector<std::string> tableNames, common::QueryRelType relType,
std::string lowerBound, std::string upperBound, ArrowDirection arrowDirection,
std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>> propertyKeyValPairs)
ArrowDirection arrowDirection,
std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>> propertyKeyValPairs,
std::unique_ptr<RecursiveRelPatternInfo> recursiveInfo)
: NodePattern{std::move(name), std::move(tableNames), std::move(propertyKeyValPairs)},
relType{relType}, lowerBound{std::move(lowerBound)}, upperBound{std::move(upperBound)},
arrowDirection{arrowDirection} {}
relType{relType}, arrowDirection{arrowDirection}, recursiveInfo{
std::move(recursiveInfo)} {}

~RelPattern() override = default;

inline common::QueryRelType getRelType() const { return relType; }

inline std::string getLowerBound() const { return lowerBound; }

inline std::string getUpperBound() const { return upperBound; }

inline ArrowDirection getDirection() const { return arrowDirection; }

inline RecursiveRelPatternInfo* getRecursiveInfo() const { return recursiveInfo.get(); }

private:
common::QueryRelType relType;
std::string lowerBound;
std::string upperBound;
ArrowDirection arrowDirection;
std::unique_ptr<RecursiveRelPatternInfo> recursiveInfo;
};

} // namespace parser
Expand Down
74 changes: 46 additions & 28 deletions src/parser/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,18 @@ std::unique_ptr<PatternElement> Transformer::transformPatternElement(

std::unique_ptr<NodePattern> Transformer::transformNodePattern(
CypherParser::OC_NodePatternContext& ctx) {
auto variable = ctx.oC_Variable() ? transformVariable(*ctx.oC_Variable()) : std::string();
auto nodeLabels = ctx.oC_NodeLabels() ? transformNodeLabels(*ctx.oC_NodeLabels()) :
std::vector<std::string>{};
auto properties = ctx.kU_Properties() ?
transformProperties(*ctx.kU_Properties()) :
std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>>{};
auto variable = std::string();
if (ctx.oC_Variable()) {
variable = transformVariable(*ctx.oC_Variable());
}
auto nodeLabels = std::vector<std::string>{};
if (ctx.oC_NodeLabels()) {
nodeLabels = transformNodeLabels(*ctx.oC_NodeLabels());
}
auto properties = std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>>{};
if (ctx.kU_Properties()) {
properties = transformProperties(*ctx.kU_Properties());
}
return std::make_unique<NodePattern>(
std::move(variable), std::move(nodeLabels), std::move(properties));
}
Expand All @@ -288,38 +294,50 @@ std::unique_ptr<PatternElementChain> Transformer::transformPatternElementChain(
std::unique_ptr<RelPattern> Transformer::transformRelationshipPattern(
CypherParser::OC_RelationshipPatternContext& ctx) {
auto relDetail = ctx.oC_RelationshipDetail();
auto variable =
relDetail->oC_Variable() ? transformVariable(*relDetail->oC_Variable()) : std::string();
auto relTypes = relDetail->oC_RelationshipTypes() ?
transformRelTypes(*relDetail->oC_RelationshipTypes()) :
std::vector<std::string>{};
auto variable = std::string();
if (relDetail->oC_Variable()) {
variable = transformVariable(*relDetail->oC_Variable());
}
auto relTypes = std::vector<std::string>{};
if (relDetail->oC_RelationshipTypes()) {
relTypes = transformRelTypes(*relDetail->oC_RelationshipTypes());
}
auto properties = std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>>{};
if (relDetail->kU_Properties()) {
properties = transformProperties(*relDetail->kU_Properties());
}
ArrowDirection arrowDirection;
if (ctx.oC_LeftArrowHead()) {
arrowDirection = ArrowDirection::LEFT;
} else if (ctx.oC_RightArrowHead()) {
arrowDirection = ArrowDirection::RIGHT;
} else {
arrowDirection = ArrowDirection::BOTH;
}
auto relType = common::QueryRelType::NON_RECURSIVE;
std::string lowerBound = "1";
std::string upperBound = "1";
std::unique_ptr<RecursiveRelPatternInfo> recursiveInfo;
if (relDetail->oC_RangeLiteral()) {
lowerBound = relDetail->oC_RangeLiteral()->oC_IntegerLiteral()[0]->getText();
upperBound = relDetail->oC_RangeLiteral()->oC_IntegerLiteral()[1]->getText();
if (relDetail->oC_RangeLiteral()->ALL()) {
relType = common::QueryRelType::ALL_SHORTEST;
} else if (relDetail->oC_RangeLiteral()->SHORTEST()) {
relType = common::QueryRelType::SHORTEST;
} else {
relType = common::QueryRelType::VARIABLE_LENGTH;
}
auto range = relDetail->oC_RangeLiteral();
auto lowerBound = range->oC_IntegerLiteral()[0]->getText();
auto upperBound = range->oC_IntegerLiteral()[1]->getText();
auto recursiveRelName = std::string();
std::unique_ptr<ParsedExpression> whereExpression = nullptr;
if (range->oC_Where()) {
recursiveRelName = transformVariable(*range->oC_Variable());
whereExpression = transformWhere(*range->oC_Where());
}
recursiveInfo = std::make_unique<RecursiveRelPatternInfo>(
lowerBound, upperBound, recursiveRelName, std::move(whereExpression));
}
ArrowDirection arrowDirection;
if (ctx.oC_LeftArrowHead()) {
arrowDirection = ArrowDirection::LEFT;
} else if (ctx.oC_RightArrowHead()) {
arrowDirection = ArrowDirection::RIGHT;
} else {
arrowDirection = ArrowDirection::BOTH;
}
auto properties = relDetail->kU_Properties() ?
transformProperties(*relDetail->kU_Properties()) :
std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>>{};
return std::make_unique<RelPattern>(
variable, relTypes, relType, lowerBound, upperBound, arrowDirection, std::move(properties));
return std::make_unique<RelPattern>(variable, relTypes, relType, arrowDirection,
std::move(properties), std::move(recursiveInfo));
}

std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>>
Expand Down
5 changes: 5 additions & 0 deletions test/test_files/tinysnb/var_length_extend/multi_label.test
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,8 @@
-STATEMENT MATCH (a:person)-[e*2..2]->(b:organisation) WHERE a.fName = 'Alice' RETURN COUNT(*)
---- 1
5


-STATEMENT MATCH (a:person)-[e*2..2 (r, _ | WHERE offset(id(r)) > 0)]->(b:organisation) WHERE a.fName = 'Alice' RETURN rels(e)
---- 1
[{_ID: 3:2, DATE: 2021-06-30, MEETTIME: 2012-12-11 20:07:22, VALIDINTERVAL: 10 days, COMMENTS: [ioji232,jifhe8w99u43434], YEAR: , PLACES: , LENGTH: , GRADING: , RATING: , LOCATION: , TIMES: , DATA: , USEDADDRESS: , ADDRESS: , NOTE: },{_ID: 5:1, DATE: , MEETTIME: , VALIDINTERVAL: , COMMENTS: , YEAR: 2010, PLACES: , LENGTH: , GRADING: [2.100000,4.400000], RATING: 7.600000, LOCATION: , TIMES: , DATA: , USEDADDRESS: , ADDRESS: , NOTE: }]
5 changes: 5 additions & 0 deletions test/test_files/tinysnb/var_length_extend/n_n.test
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@ Greg
3:1|0:2|2
3:2|0:3|1
3:2|0:3|2

-LOG KnowsOneToTwoHopFilterTest2
-STATEMENT MATCH (a:person)-[e:knows*1..2 (r,_ | WHERE list_contains(r.comments, 'rnme'))]->(b:person) WHERE a.fName='Alice' RETURN COUNT(*)
---- 1
1
Loading