Skip to content

Commit

Permalink
add rel filter
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jun 28, 2023
1 parent 429d37f commit 687cfbd
Show file tree
Hide file tree
Showing 13 changed files with 2,937 additions and 2,767 deletions.
4 changes: 2 additions & 2 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ oC_RelationshipPattern
;

oC_RelationshipDetail
: '[' SP? ( oC_Variable SP? )? ( oC_RelationshipTypes SP? )? ( oC_RangeLiteral SP? ) ? ( kU_Properties SP? ) ? ']' ;
: '[' SP? ( oC_Variable SP? )? ( oC_RelationshipTypes SP? )? ( oC_RangeLiteral SP? )? ( kU_Properties SP? )? ']' ;

// The original oC_Properties definition is oC_MapLiteral | oC_Parameter.
// We choose to not support parameter as properties which will be the decision for a long time.
Expand All @@ -322,7 +322,7 @@ oC_NodeLabel
: ':' SP? oC_LabelName ;

oC_RangeLiteral
: '*' SP? ( SHORTEST | ALL SP SHORTEST )? SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral ;
: '*' SP? ( SHORTEST | ALL SP SHORTEST )? SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral (SP? '(' SP? oC_Variable SP? ',' SP? '_' SP? '|' SP? oC_Where SP? ')')? ;

SHORTEST : ( 'S' | 's' ) ( 'H' | 'h' ) ( 'O' | 'o' ) ( 'R' | 'r' ) ( 'T' | 't' ) ( 'E' | 'e' ) ( 'S' | 's' ) ( 'T' | 't' ) ;

Expand Down
20 changes: 15 additions & 5 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,15 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
recursiveNodeTableIDs.insert(relTableSchema->srcTableID);
recursiveNodeTableIDs.insert(relTableSchema->dstTableID);
}
auto recursiveRelPatternInfo = relPattern.getRecursiveInfo();
auto tmpNode = createQueryNode(
InternalKeyword::ANONYMOUS, std::vector<common::table_id_t>{recursiveNodeTableIDs.begin(),
recursiveNodeTableIDs.end()});
auto prevScope = saveScope();
variableScope->clear();
auto tmpRel = createNonRecursiveQueryRel(
InternalKeyword::ANONYMOUS, tableIDs, nullptr, nullptr, directionType);
recursiveRelPatternInfo->relName, tableIDs, nullptr, nullptr, directionType);
variableScope->addExpression(tmpRel->toString(), tmpRel);
expression_vector predicates;
for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindRelPropertyExpression(*tmpRel, propertyName);
Expand All @@ -261,6 +265,11 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
predicates.push_back(
expressionBinder.createEqualityComparisonExpression(boundLhs, boundRhs));
}
if (recursiveRelPatternInfo->whereExpression != nullptr) {
predicates.push_back(
expressionBinder.bindExpression(*recursiveRelPatternInfo->whereExpression));
}
restoreScope(std::move(prevScope));
auto parsedName = relPattern.getVariableName();
auto queryRel = make_shared<RelExpression>(*getRecursiveRelLogicalType(*tmpNode, *tmpRel),
getUniqueExpressionName(parsedName), parsedName, tableIDs, std::move(srcNode),
Expand All @@ -276,10 +285,11 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel

std::pair<uint64_t, uint64_t> Binder::bindVariableLengthRelBound(
const kuzu::parser::RelPattern& relPattern) {
auto lowerBound = std::min(TypeUtils::convertToUint32(relPattern.getLowerBound().c_str()),
VAR_LENGTH_EXTEND_MAX_DEPTH);
auto upperBound = std::min(TypeUtils::convertToUint32(relPattern.getUpperBound().c_str()),
VAR_LENGTH_EXTEND_MAX_DEPTH);
auto recursiveInfo = relPattern.getRecursiveInfo();
auto lowerBound = std::min(
TypeUtils::convertToUint32(recursiveInfo->lowerBound.c_str()), VAR_LENGTH_EXTEND_MAX_DEPTH);
auto upperBound = std::min(
TypeUtils::convertToUint32(recursiveInfo->upperBound.c_str()), VAR_LENGTH_EXTEND_MAX_DEPTH);
if (lowerBound == 0 || upperBound == 0) {
throw BinderException("Lower and upper bound of a rel must be greater than 0.");
}
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind_expression/bind_subquery_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace binder {
std::shared_ptr<Expression> ExpressionBinder::bindExistentialSubqueryExpression(
const ParsedExpression& parsedExpression) {
auto& subqueryExpression = (ParsedSubqueryExpression&)parsedExpression;
auto prevVariableScope = binder->enterSubquery();
auto prevVariableScope = binder->saveScope();
auto [queryGraph, _] = binder->bindGraphPattern(subqueryExpression.getPatternElements());
auto rawName = parsedExpression.getRawName();
auto uniqueName = binder->getUniqueExpressionName(rawName);
Expand All @@ -21,7 +21,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindExistentialSubqueryExpression(
boundSubqueryExpression->setWhereExpression(
binder->bindWhereExpression(*subqueryExpression.getWhereClause()));
}
binder->exitSubquery(std::move(prevVariableScope));
binder->restoreScope(std::move(prevVariableScope));
return boundSubqueryExpression;
}

Expand Down
4 changes: 2 additions & 2 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ std::string Binder::getUniqueExpressionName(const std::string& name) {
return "_" + std::to_string(lastExpressionId++) + "_" + name;
}

std::unique_ptr<VariableScope> Binder::enterSubquery() {
std::unique_ptr<VariableScope> Binder::saveScope() {
return variableScope->copy();
}

void Binder::exitSubquery(std::unique_ptr<VariableScope> prevVariableScope) {
void Binder::restoreScope(std::unique_ptr<VariableScope> prevVariableScope) {
variableScope = std::move(prevVariableScope);
}

Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ class Binder {
/*** helpers ***/
std::string getUniqueExpressionName(const std::string& name);

std::unique_ptr<VariableScope> enterSubquery();
void exitSubquery(std::unique_ptr<VariableScope> prevVariableScope);
std::unique_ptr<VariableScope> saveScope();
void restoreScope(std::unique_ptr<VariableScope> prevVariableScope);

private:
const catalog::Catalog& catalog;
Expand Down
31 changes: 21 additions & 10 deletions src/include/parser/query/graph_pattern/rel_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,44 @@ namespace kuzu {
namespace parser {

enum class ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1, BOTH = 2 };

struct RecursiveRelPatternInfo {
std::string lowerBound;
std::string upperBound;
std::string relName;
std::unique_ptr<ParsedExpression> whereExpression;

RecursiveRelPatternInfo(std::string lowerBound, std::string upperBound, std::string relName,
std::unique_ptr<ParsedExpression> whereExpression)
: lowerBound{std::move(lowerBound)}, upperBound{std::move(upperBound)},
relName{std::move(relName)}, whereExpression{std::move(whereExpression)} {}
};

/**
* RelationshipPattern represents "-[relName:RelTableName+]-"
*/
class RelPattern : public NodePattern {
public:
RelPattern(std::string name, std::vector<std::string> tableNames, common::QueryRelType relType,
std::string lowerBound, std::string upperBound, ArrowDirection arrowDirection,
std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>> propertyKeyValPairs)
ArrowDirection arrowDirection,
std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>> propertyKeyValPairs,
std::unique_ptr<RecursiveRelPatternInfo> recursiveInfo)
: NodePattern{std::move(name), std::move(tableNames), std::move(propertyKeyValPairs)},
relType{relType}, lowerBound{std::move(lowerBound)}, upperBound{std::move(upperBound)},
arrowDirection{arrowDirection} {}
relType{relType}, arrowDirection{arrowDirection}, recursiveInfo{
std::move(recursiveInfo)} {}

~RelPattern() override = default;

inline common::QueryRelType getRelType() const { return relType; }

inline std::string getLowerBound() const { return lowerBound; }

inline std::string getUpperBound() const { return upperBound; }

inline ArrowDirection getDirection() const { return arrowDirection; }

inline RecursiveRelPatternInfo* getRecursiveInfo() const { return recursiveInfo.get(); }

private:
common::QueryRelType relType;
std::string lowerBound;
std::string upperBound;
ArrowDirection arrowDirection;
std::unique_ptr<RecursiveRelPatternInfo> recursiveInfo;
};

} // namespace parser
Expand Down
74 changes: 46 additions & 28 deletions src/parser/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,18 @@ std::unique_ptr<PatternElement> Transformer::transformPatternElement(

std::unique_ptr<NodePattern> Transformer::transformNodePattern(
CypherParser::OC_NodePatternContext& ctx) {
auto variable = ctx.oC_Variable() ? transformVariable(*ctx.oC_Variable()) : std::string();
auto nodeLabels = ctx.oC_NodeLabels() ? transformNodeLabels(*ctx.oC_NodeLabels()) :
std::vector<std::string>{};
auto properties = ctx.kU_Properties() ?
transformProperties(*ctx.kU_Properties()) :
std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>>{};
auto variable = std::string();
if (ctx.oC_Variable()) {
variable = transformVariable(*ctx.oC_Variable());
}
auto nodeLabels = std::vector<std::string>{};
if (ctx.oC_NodeLabels()) {
nodeLabels = transformNodeLabels(*ctx.oC_NodeLabels());
}
auto properties = std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>>{};
if (ctx.kU_Properties()) {
properties = transformProperties(*ctx.kU_Properties());
}
return std::make_unique<NodePattern>(
std::move(variable), std::move(nodeLabels), std::move(properties));
}
Expand All @@ -288,38 +294,50 @@ std::unique_ptr<PatternElementChain> Transformer::transformPatternElementChain(
std::unique_ptr<RelPattern> Transformer::transformRelationshipPattern(
CypherParser::OC_RelationshipPatternContext& ctx) {
auto relDetail = ctx.oC_RelationshipDetail();
auto variable =
relDetail->oC_Variable() ? transformVariable(*relDetail->oC_Variable()) : std::string();
auto relTypes = relDetail->oC_RelationshipTypes() ?
transformRelTypes(*relDetail->oC_RelationshipTypes()) :
std::vector<std::string>{};
auto variable = std::string();
if (relDetail->oC_Variable()) {
variable = transformVariable(*relDetail->oC_Variable());
}
auto relTypes = std::vector<std::string>{};
if (relDetail->oC_RelationshipTypes()) {
relTypes = transformRelTypes(*relDetail->oC_RelationshipTypes());
}
auto properties = std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>>{};
if (relDetail->kU_Properties()) {
properties = transformProperties(*relDetail->kU_Properties());
}
ArrowDirection arrowDirection;
if (ctx.oC_LeftArrowHead()) {
arrowDirection = ArrowDirection::LEFT;
} else if (ctx.oC_RightArrowHead()) {
arrowDirection = ArrowDirection::RIGHT;
} else {
arrowDirection = ArrowDirection::BOTH;
}
auto relType = common::QueryRelType::NON_RECURSIVE;
std::string lowerBound = "1";
std::string upperBound = "1";
std::unique_ptr<RecursiveRelPatternInfo> recursiveInfo;
if (relDetail->oC_RangeLiteral()) {
lowerBound = relDetail->oC_RangeLiteral()->oC_IntegerLiteral()[0]->getText();
upperBound = relDetail->oC_RangeLiteral()->oC_IntegerLiteral()[1]->getText();
if (relDetail->oC_RangeLiteral()->ALL()) {
relType = common::QueryRelType::ALL_SHORTEST;
} else if (relDetail->oC_RangeLiteral()->SHORTEST()) {
relType = common::QueryRelType::SHORTEST;
} else {
relType = common::QueryRelType::VARIABLE_LENGTH;
}
auto range = relDetail->oC_RangeLiteral();
auto lowerBound = range->oC_IntegerLiteral()[0]->getText();
auto upperBound = range->oC_IntegerLiteral()[1]->getText();
auto recursiveRelName = std::string();
std::unique_ptr<ParsedExpression> whereExpression = nullptr;
if (range->oC_Where()) {
recursiveRelName = transformVariable(*range->oC_Variable());
whereExpression = transformWhere(*range->oC_Where());
}
recursiveInfo = std::make_unique<RecursiveRelPatternInfo>(
lowerBound, upperBound, recursiveRelName, std::move(whereExpression));
}
ArrowDirection arrowDirection;
if (ctx.oC_LeftArrowHead()) {
arrowDirection = ArrowDirection::LEFT;
} else if (ctx.oC_RightArrowHead()) {
arrowDirection = ArrowDirection::RIGHT;
} else {
arrowDirection = ArrowDirection::BOTH;
}
auto properties = relDetail->kU_Properties() ?
transformProperties(*relDetail->kU_Properties()) :
std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>>{};
return std::make_unique<RelPattern>(
variable, relTypes, relType, lowerBound, upperBound, arrowDirection, std::move(properties));
return std::make_unique<RelPattern>(variable, relTypes, relType, arrowDirection,
std::move(properties), std::move(recursiveInfo));
}

std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>>
Expand Down
5 changes: 5 additions & 0 deletions test/test_files/tinysnb/var_length_extend/multi_label.test
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,8 @@
-STATEMENT MATCH (a:person)-[e*2..2]->(b:organisation) WHERE a.fName = 'Alice' RETURN COUNT(*)
---- 1
5


-STATEMENT MATCH (a:person)-[e*2..2 (r, _ | WHERE offset(id(r)) > 0)]->(b:organisation) WHERE a.fName = 'Alice' RETURN rels(e)
---- 1
[{_ID: 3:2, DATE: 2021-06-30, MEETTIME: 2012-12-11 20:07:22, VALIDINTERVAL: 10 days, COMMENTS: [ioji232,jifhe8w99u43434], YEAR: , PLACES: , LENGTH: , GRADING: , RATING: , LOCATION: , TIMES: , DATA: , USEDADDRESS: , ADDRESS: , NOTE: },{_ID: 5:1, DATE: , MEETTIME: , VALIDINTERVAL: , COMMENTS: , YEAR: 2010, PLACES: , LENGTH: , GRADING: [2.100000,4.400000], RATING: 7.600000, LOCATION: , TIMES: , DATA: , USEDADDRESS: , ADDRESS: , NOTE: }]
5 changes: 5 additions & 0 deletions test/test_files/tinysnb/var_length_extend/n_n.test
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@ Greg
3:1|0:2|2
3:2|0:3|1
3:2|0:3|2

-LOG KnowsOneToTwoHopFilterTest2
-STATEMENT MATCH (a:person)-[e:knows*1..2 (r,_ | WHERE list_contains(r.comments, 'rnme'))]->(b:person) WHERE a.fName='Alice' RETURN COUNT(*)
---- 1
1
Loading

0 comments on commit 687cfbd

Please sign in to comment.