Skip to content

Commit

Permalink
Add equality predicate to recursive join
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jun 26, 2023
1 parent 67eb0c2 commit 4d64fec
Show file tree
Hide file tree
Showing 19 changed files with 140 additions and 102 deletions.
41 changes: 23 additions & 18 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,7 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
// bind variable length
std::shared_ptr<RelExpression> queryRel;
if (QueryRelTypeUtils::isRecursive(relPattern.getRelType())) {
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
queryRel = createRecursiveQueryRel(relPattern.getVariableName(), relPattern.getRelType(),
lowerBound, upperBound, tableIDs, srcNode, dstNode, directionType);
queryRel = createRecursiveQueryRel(relPattern, tableIDs, srcNode, dstNode, directionType);
} else {
tableIDs = pruneRelTableIDs(catalog, tableIDs, *srcNode, *dstNode);
if (tableIDs.empty()) {
Expand All @@ -217,18 +215,17 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
}
queryRel = createNonRecursiveQueryRel(
relPattern.getVariableName(), tableIDs, srcNode, dstNode, directionType);
for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindRelPropertyExpression(*queryRel, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
collection.addKeyVal(queryRel, propertyName, std::make_pair(boundLhs, boundRhs));
}
}
queryRel->setAlias(parsedName);
if (!parsedName.empty()) {
variableScope->addExpression(parsedName, queryRel);
}
for (auto i = 0u; i < relPattern.getNumPropertyKeyValPairs(); ++i) {
auto [propertyName, rhs] = relPattern.getProperty(i);
auto boundLhs = expressionBinder.bindRelPropertyExpression(*queryRel, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
collection.addPropertyKeyValPair(*queryRel, std::make_pair(boundLhs, boundRhs));
}
queryGraph.addQueryRel(queryRel);
}

Expand All @@ -242,8 +239,7 @@ std::shared_ptr<RelExpression> Binder::createNonRecursiveQueryRel(const std::str
return queryRel;
}

std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const std::string& parsedName,
common::QueryRelType relType, uint32_t lowerBound, uint32_t upperBound,
std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::RelPattern& relPattern,
const std::vector<common::table_id_t>& tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType) {
std::unordered_set<common::table_id_t> recursiveNodeTableIDs;
Expand All @@ -257,12 +253,22 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const std::string
recursiveNodeTableIDs.end()});
auto tmpRel = createNonRecursiveQueryRel(
InternalKeyword::ANONYMOUS, tableIDs, nullptr, nullptr, directionType);
expression_vector predicates;
for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindRelPropertyExpression(*tmpRel, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
predicates.push_back(
expressionBinder.createEqualityComparisonExpression(boundLhs, boundRhs));
}
auto parsedName = relPattern.getVariableName();
auto queryRel = make_shared<RelExpression>(*getRecursiveRelLogicalType(*tmpNode, *tmpRel),
getUniqueExpressionName(parsedName), parsedName, tableIDs, std::move(srcNode),
std::move(dstNode), directionType, relType);
std::move(dstNode), directionType, relPattern.getRelType());
auto lengthExpression = expressionBinder.createInternalLengthExpression(*queryRel);
auto recursiveInfo = std::make_unique<RecursiveInfo>(
lowerBound, upperBound, std::move(tmpNode), std::move(tmpRel), std::move(lengthExpression));
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
auto recursiveInfo = std::make_unique<RecursiveInfo>(lowerBound, upperBound, std::move(tmpNode),
std::move(tmpRel), std::move(lengthExpression), std::move(predicates));
queryRel->setRecursiveInfo(std::move(recursiveInfo));
bindQueryRelProperties(*queryRel);
return queryRel;
Expand Down Expand Up @@ -314,12 +320,11 @@ std::shared_ptr<NodeExpression> Binder::bindQueryNode(
} else {
queryNode = createQueryNode(nodePattern);
}
for (auto i = 0u; i < nodePattern.getNumPropertyKeyValPairs(); ++i) {
auto [propertyName, rhs] = nodePattern.getProperty(i);
for (auto& [propertyName, rhs] : nodePattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindNodePropertyExpression(*queryNode, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
collection.addPropertyKeyValPair(*queryNode, std::make_pair(boundLhs, boundRhs));
collection.addKeyVal(queryNode, propertyName, std::make_pair(boundLhs, boundRhs));
}
queryGraph.addQueryNode(queryNode);
return queryNode;
Expand Down
13 changes: 4 additions & 9 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,10 @@ std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause&
whereExpression = bindWhereExpression(*matchClause.getWhereClause());
}
// Rewrite key value pairs in MATCH clause as predicate
for (auto& keyValPairs : propertyCollection->getAllPropertyKeyValPairs()) {
auto predicate = expressionBinder.bindComparisonExpression(
EQUALS, expression_vector{keyValPairs.first, keyValPairs.second});
if (whereExpression != nullptr) {
whereExpression = expressionBinder.bindBooleanExpression(
AND, expression_vector{whereExpression, predicate});
} else {
whereExpression = predicate;
}
for (auto& [key, val] : propertyCollection->getKeyVals()) {
auto predicate = expressionBinder.createEqualityComparisonExpression(key, val);
whereExpression =
expressionBinder.combineConjunctiveExpressions(predicate, whereExpression);
}
boundMatchClause->setWhereExpression(std::move(whereExpression));
return boundMatchClause;
Expand Down
6 changes: 3 additions & 3 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ std::unique_ptr<BoundCreateNode> Binder::bindCreateNode(
auto primaryKey = nodeTableSchema->getPrimaryKey();
std::shared_ptr<Expression> primaryKeyExpression;
std::vector<expression_pair> setItems;
for (auto& [key, val] : collection.getPropertyKeyValPairs(*node)) {
for (auto& [key, val] : collection.getKeyVals(node)) {
auto propertyExpression = static_pointer_cast<PropertyExpression>(key);
if (propertyExpression->getPropertyID(nodeTableID) == primaryKey.propertyID) {
primaryKeyExpression = val;
Expand Down Expand Up @@ -94,8 +94,8 @@ std::unique_ptr<BoundCreateRel> Binder::bindCreateRel(
// null if user does not specify a property in the query.
std::vector<expression_pair> setItems;
for (auto& property : catalogContent->getRelProperties(relTableID)) {
if (collection.hasPropertyKeyValPair(*rel, property.name)) {
setItems.push_back(collection.getPropertyKeyValPair(*rel, property.name));
if (collection.hasKeyVal(rel, property.name)) {
setItems.push_back(collection.getKeyVal(rel, property.name));
} else {
auto propertyExpression =
expressionBinder.bindRelPropertyExpression(*rel, property.name);
Expand Down
11 changes: 11 additions & 0 deletions src/binder/bind_expression/bind_boolean_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,16 @@ std::shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(
uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::combineConjunctiveExpressions(
std::shared_ptr<Expression> left, std::shared_ptr<Expression> right) {
if (left == nullptr) {
return right;
} else if (right == nullptr) {
return left;
} else {
return bindBooleanExpression(AND, expression_vector{std::move(left), std::move(right)});
}
}

} // namespace binder
} // namespace kuzu
6 changes: 6 additions & 0 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,11 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::createEqualityComparisonExpression(
std::shared_ptr<Expression> left, std::shared_ptr<Expression> right) {
return bindComparisonExpression(
common::EQUALS, expression_vector{std::move(left), std::move(right)});
}

} // namespace binder
} // namespace kuzu
57 changes: 25 additions & 32 deletions src/binder/query/query_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,58 +227,51 @@ std::unique_ptr<QueryGraphCollection> QueryGraphCollection::copy() const {
return result;
}

void PropertyKeyValCollection::addPropertyKeyValPair(
const Expression& variable, expression_pair propertyKeyValPair) {
auto varName = variable.getUniqueName();
if (!varNameToPropertyKeyValPairs.contains(varName)) {
varNameToPropertyKeyValPairs.insert(
{varName, std::unordered_map<std::string, expression_pair>{}});
void PropertyKeyValCollection::addKeyVal(
std::shared_ptr<Expression> variable, const std::string& propertyName, expression_pair keyVal) {
if (!propertyKeyValMap.contains(variable)) {
propertyKeyValMap.insert({variable, std::unordered_map<std::string, expression_pair>{}});
}
auto property = (PropertyExpression*)propertyKeyValPair.first.get();
assert(!varNameToPropertyKeyValPairs.at(varName).contains(property->getPropertyName()));
varNameToPropertyKeyValPairs.at(varName).insert(
{property->getPropertyName(), std::move(propertyKeyValPair)});
propertyKeyValMap.at(variable).insert({propertyName, std::move(keyVal)});
}

std::vector<expression_pair> PropertyKeyValCollection::getPropertyKeyValPairs(
const kuzu::binder::Expression& variable) const {
auto varName = variable.getUniqueName();
if (!varNameToPropertyKeyValPairs.contains(varName)) {
return std::vector<expression_pair>{};
}
std::vector<expression_pair> PropertyKeyValCollection::getKeyVals() const {
std::vector<expression_pair> result;
for (auto& [_, setItem] : varNameToPropertyKeyValPairs.at(varName)) {
result.push_back(setItem);
for (auto& [_, keyVals] : propertyKeyValMap) {
for (auto& [_, keyVal] : keyVals) {
result.push_back(keyVal);
}
}
return result;
}

std::vector<expression_pair> PropertyKeyValCollection::getAllPropertyKeyValPairs() const {
std::vector<expression_pair> PropertyKeyValCollection::getKeyVals(
std::shared_ptr<Expression> variable) const {
std::vector<expression_pair> result;
for (auto& [varName, keyValPairsMap] : varNameToPropertyKeyValPairs) {
for (auto& [propertyName, keyValPairs] : keyValPairsMap) {
result.push_back(keyValPairs);
}
if (!propertyKeyValMap.contains(variable)) {
return result;
}
for (auto& [_, keyVal] : propertyKeyValMap.at(variable)) {
result.push_back(keyVal);
}
return result;
}

bool PropertyKeyValCollection::hasPropertyKeyValPair(
const Expression& variable, const std::string& propertyName) const {
auto varName = variable.getUniqueName();
if (!varNameToPropertyKeyValPairs.contains(varName)) {
bool PropertyKeyValCollection::hasKeyVal(
std::shared_ptr<Expression> variable, const std::string& propertyName) const {
if (!propertyKeyValMap.contains(variable)) {
return false;
}
if (!varNameToPropertyKeyValPairs.at(varName).contains(propertyName)) {
if (!propertyKeyValMap.at(variable).contains(propertyName)) {
return false;
}
return true;
}

expression_pair PropertyKeyValCollection::getPropertyKeyValPair(
const Expression& variable, const std::string& propertyName) const {
assert(hasPropertyKeyValPair(variable, propertyName));
return varNameToPropertyKeyValPairs.at(variable.getUniqueName()).at(propertyName);
expression_pair PropertyKeyValCollection::getKeyVal(
std::shared_ptr<Expression> variable, const std::string& propertyName) const {
assert(hasKeyVal(variable, propertyName));
return propertyKeyValMap.at(variable).at(propertyName);
}

} // namespace binder
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ class Binder {
std::shared_ptr<RelExpression> createNonRecursiveQueryRel(const std::string& parsedName,
const std::vector<common::table_id_t>& tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType);
std::shared_ptr<RelExpression> createRecursiveQueryRel(const std::string& parsedName,
common::QueryRelType relType, uint32_t lowerBound, uint32_t upperBound,
std::shared_ptr<RelExpression> createRecursiveQueryRel(const parser::RelPattern& relPattern,
const std::vector<common::table_id_t>& tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType);
std::pair<uint64_t, uint64_t> bindVariableLengthRelBound(const parser::RelPattern& relPattern);
Expand All @@ -189,6 +188,7 @@ class Binder {
std::shared_ptr<NodeExpression> createQueryNode(
const std::string& parsedName, const std::vector<common::table_id_t>& tableIDs);
void bindQueryNodeProperties(NodeExpression& node);

inline std::vector<common::table_id_t> bindNodeTableIDs(
const std::vector<std::string>& tableNames) {
return bindTableIDs(tableNames, common::LogicalTypeID::NODE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class ExistentialSubqueryExpression : public Expression {
}
inline bool hasWhereExpression() const { return whereExpression != nullptr; }
inline std::shared_ptr<Expression> getWhereExpression() const { return whereExpression; }
inline expression_vector getPredicatesSplitOnAnd() const {
return hasWhereExpression() ? whereExpression->splitOnAND() : expression_vector{};
}

expression_vector getChildren() const override;

Expand Down
2 changes: 1 addition & 1 deletion src/include/binder/expression/node_rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class NodeOrRelExpression : public Expression {
protected:
std::string variableName;
std::vector<common::table_id_t> tableIDs;
std::unordered_map<std::string, size_t> propertyNameToIdx;
std::unordered_map<std::string, common::vector_idx_t> propertyNameToIdx;
std::vector<std::unique_ptr<Expression>> properties;
};

Expand Down
11 changes: 7 additions & 4 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ struct RecursiveInfo {
std::shared_ptr<NodeExpression> node;
std::shared_ptr<RelExpression> rel;
std::shared_ptr<Expression> lengthExpression;
expression_vector predicates;

RecursiveInfo(size_t lowerBound, size_t upperBound, std::shared_ptr<NodeExpression> node,
std::shared_ptr<RelExpression> rel, std::shared_ptr<Expression> lengthExpression)
: lowerBound{lowerBound}, upperBound{upperBound}, node{std::move(node)},
rel{std::move(rel)}, lengthExpression{std::move(lengthExpression)} {}
RecursiveInfo(uint64_t lowerBound, uint64_t upperBound, std::shared_ptr<NodeExpression> node,
std::shared_ptr<RelExpression> rel, std::shared_ptr<Expression> lengthExpression,
expression_vector predicates)
: lowerBound{lowerBound}, upperBound{upperBound}, node{std::move(node)}, rel{std::move(
rel)},
lengthExpression{std::move(lengthExpression)}, predicates{std::move(predicates)} {}
};

class RelExpression : public NodeOrRelExpression {
Expand Down
4 changes: 4 additions & 0 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ class ExpressionBinder {
const parser::ParsedExpression& parsedExpression);
std::shared_ptr<Expression> bindBooleanExpression(
common::ExpressionType expressionType, const expression_vector& children);
std::shared_ptr<Expression> combineConjunctiveExpressions(
std::shared_ptr<Expression> left, std::shared_ptr<Expression> right);

std::shared_ptr<Expression> bindComparisonExpression(
const parser::ParsedExpression& parsedExpression);
std::shared_ptr<Expression> bindComparisonExpression(
common::ExpressionType expressionType, const expression_vector& children);
std::shared_ptr<Expression> createEqualityComparisonExpression(
std::shared_ptr<Expression> left, std::shared_ptr<Expression> right);

std::shared_ptr<Expression> bindNullOperatorExpression(
const parser::ParsedExpression& parsedExpression);
Expand Down
5 changes: 3 additions & 2 deletions src/include/binder/query/reading_clause/bound_match_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ class BoundMatchClause : public BoundReadingClause {
inline void setWhereExpression(std::shared_ptr<Expression> expression) {
whereExpression = std::move(expression);
}

inline bool hasWhereExpression() const { return whereExpression != nullptr; }

inline std::shared_ptr<Expression> getWhereExpression() const { return whereExpression; }
inline expression_vector getPredicatesSplitOnAnd() const {
return hasWhereExpression() ? whereExpression->splitOnAND() : expression_vector{};
}

inline bool getIsOptional() const { return isOptional; }

Expand Down
21 changes: 10 additions & 11 deletions src/include/binder/query/reading_clause/query_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,25 +166,24 @@ class PropertyKeyValCollection {
public:
PropertyKeyValCollection() = default;
PropertyKeyValCollection(const PropertyKeyValCollection& other)
: varNameToPropertyKeyValPairs{other.varNameToPropertyKeyValPairs} {}
: propertyKeyValMap{other.propertyKeyValMap} {}

void addPropertyKeyValPair(const Expression& variable, expression_pair propertyKeyValPair);
std::vector<expression_pair> getPropertyKeyValPairs(const Expression& variable) const;
std::vector<expression_pair> getAllPropertyKeyValPairs() const;

bool hasPropertyKeyValPair(const Expression& variable, const std::string& propertyName) const;
expression_pair getPropertyKeyValPair(
const Expression& variable, const std::string& propertyName) const;
void addKeyVal(std::shared_ptr<Expression> variable, const std::string& propertyName,
expression_pair keyVal);
std::vector<expression_pair> getKeyVals() const;
std::vector<expression_pair> getKeyVals(std::shared_ptr<Expression> variable) const;
bool hasKeyVal(std::shared_ptr<Expression> variable, const std::string& propertyName) const;
expression_pair getKeyVal(
std::shared_ptr<Expression> variable, const std::string& propertyName) const;

inline std::unique_ptr<PropertyKeyValCollection> copy() const {
return std::make_unique<PropertyKeyValCollection>(*this);
}

private:
// First indexed on variable name, then indexed on property name.
// First indexed on variable, then indexed on property name.
// a -> { age -> pair<a.age,12>, name -> pair<name,'Alice'>}
std::unordered_map<std::string, std::unordered_map<std::string, expression_pair>>
varNameToPropertyKeyValPairs;
expression_map<std::unordered_map<std::string, expression_pair>> propertyKeyValMap;
};

} // namespace binder
Expand Down
Loading

0 comments on commit 4d64fec

Please sign in to comment.