Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive join key equality comparison #1721

Merged
merged 1 commit into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,7 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
// bind variable length
std::shared_ptr<RelExpression> queryRel;
if (QueryRelTypeUtils::isRecursive(relPattern.getRelType())) {
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
queryRel = createRecursiveQueryRel(relPattern.getVariableName(), relPattern.getRelType(),
lowerBound, upperBound, tableIDs, srcNode, dstNode, directionType);
queryRel = createRecursiveQueryRel(relPattern, tableIDs, srcNode, dstNode, directionType);
} else {
tableIDs = pruneRelTableIDs(catalog, tableIDs, *srcNode, *dstNode);
if (tableIDs.empty()) {
Expand All @@ -217,18 +215,17 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
}
queryRel = createNonRecursiveQueryRel(
relPattern.getVariableName(), tableIDs, srcNode, dstNode, directionType);
for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindRelPropertyExpression(*queryRel, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
collection.addKeyVal(queryRel, propertyName, std::make_pair(boundLhs, boundRhs));
}
}
queryRel->setAlias(parsedName);
if (!parsedName.empty()) {
variableScope->addExpression(parsedName, queryRel);
}
for (auto i = 0u; i < relPattern.getNumPropertyKeyValPairs(); ++i) {
auto [propertyName, rhs] = relPattern.getProperty(i);
auto boundLhs = expressionBinder.bindRelPropertyExpression(*queryRel, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
collection.addPropertyKeyValPair(*queryRel, std::make_pair(boundLhs, boundRhs));
}
queryGraph.addQueryRel(queryRel);
}

Expand All @@ -242,8 +239,7 @@ std::shared_ptr<RelExpression> Binder::createNonRecursiveQueryRel(const std::str
return queryRel;
}

std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const std::string& parsedName,
common::QueryRelType relType, uint32_t lowerBound, uint32_t upperBound,
std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::RelPattern& relPattern,
const std::vector<common::table_id_t>& tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType) {
std::unordered_set<common::table_id_t> recursiveNodeTableIDs;
Expand All @@ -257,12 +253,22 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const std::string
recursiveNodeTableIDs.end()});
auto tmpRel = createNonRecursiveQueryRel(
InternalKeyword::ANONYMOUS, tableIDs, nullptr, nullptr, directionType);
expression_vector predicates;
for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindRelPropertyExpression(*tmpRel, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
predicates.push_back(
expressionBinder.createEqualityComparisonExpression(boundLhs, boundRhs));
}
auto parsedName = relPattern.getVariableName();
auto queryRel = make_shared<RelExpression>(*getRecursiveRelLogicalType(*tmpNode, *tmpRel),
getUniqueExpressionName(parsedName), parsedName, tableIDs, std::move(srcNode),
std::move(dstNode), directionType, relType);
std::move(dstNode), directionType, relPattern.getRelType());
auto lengthExpression = expressionBinder.createInternalLengthExpression(*queryRel);
auto recursiveInfo = std::make_unique<RecursiveInfo>(
lowerBound, upperBound, std::move(tmpNode), std::move(tmpRel), std::move(lengthExpression));
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
auto recursiveInfo = std::make_unique<RecursiveInfo>(lowerBound, upperBound, std::move(tmpNode),
std::move(tmpRel), std::move(lengthExpression), std::move(predicates));
queryRel->setRecursiveInfo(std::move(recursiveInfo));
bindQueryRelProperties(*queryRel);
return queryRel;
Expand Down Expand Up @@ -314,12 +320,11 @@ std::shared_ptr<NodeExpression> Binder::bindQueryNode(
} else {
queryNode = createQueryNode(nodePattern);
}
for (auto i = 0u; i < nodePattern.getNumPropertyKeyValPairs(); ++i) {
auto [propertyName, rhs] = nodePattern.getProperty(i);
for (auto& [propertyName, rhs] : nodePattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindNodePropertyExpression(*queryNode, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
collection.addPropertyKeyValPair(*queryNode, std::make_pair(boundLhs, boundRhs));
collection.addKeyVal(queryNode, propertyName, std::make_pair(boundLhs, boundRhs));
}
queryGraph.addQueryNode(queryNode);
return queryNode;
Expand Down
13 changes: 4 additions & 9 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,10 @@ std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause&
whereExpression = bindWhereExpression(*matchClause.getWhereClause());
}
// Rewrite key value pairs in MATCH clause as predicate
for (auto& keyValPairs : propertyCollection->getAllPropertyKeyValPairs()) {
auto predicate = expressionBinder.bindComparisonExpression(
EQUALS, expression_vector{keyValPairs.first, keyValPairs.second});
if (whereExpression != nullptr) {
whereExpression = expressionBinder.bindBooleanExpression(
AND, expression_vector{whereExpression, predicate});
} else {
whereExpression = predicate;
}
for (auto& [key, val] : propertyCollection->getKeyVals()) {
auto predicate = expressionBinder.createEqualityComparisonExpression(key, val);
whereExpression =
expressionBinder.combineConjunctiveExpressions(predicate, whereExpression);
}
boundMatchClause->setWhereExpression(std::move(whereExpression));
return boundMatchClause;
Expand Down
6 changes: 3 additions & 3 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ std::unique_ptr<BoundCreateNode> Binder::bindCreateNode(
auto primaryKey = nodeTableSchema->getPrimaryKey();
std::shared_ptr<Expression> primaryKeyExpression;
std::vector<expression_pair> setItems;
for (auto& [key, val] : collection.getPropertyKeyValPairs(*node)) {
for (auto& [key, val] : collection.getKeyVals(node)) {
auto propertyExpression = static_pointer_cast<PropertyExpression>(key);
if (propertyExpression->getPropertyID(nodeTableID) == primaryKey.propertyID) {
primaryKeyExpression = val;
Expand Down Expand Up @@ -94,8 +94,8 @@ std::unique_ptr<BoundCreateRel> Binder::bindCreateRel(
// null if user does not specify a property in the query.
std::vector<expression_pair> setItems;
for (auto& property : catalogContent->getRelProperties(relTableID)) {
if (collection.hasPropertyKeyValPair(*rel, property.name)) {
setItems.push_back(collection.getPropertyKeyValPair(*rel, property.name));
if (collection.hasKeyVal(rel, property.name)) {
setItems.push_back(collection.getKeyVal(rel, property.name));
} else {
auto propertyExpression =
expressionBinder.bindRelPropertyExpression(*rel, property.name);
Expand Down
11 changes: 11 additions & 0 deletions src/binder/bind_expression/bind_boolean_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,16 @@ std::shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(
uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::combineConjunctiveExpressions(
std::shared_ptr<Expression> left, std::shared_ptr<Expression> right) {
if (left == nullptr) {
return right;
} else if (right == nullptr) {
return left;
} else {
return bindBooleanExpression(AND, expression_vector{std::move(left), std::move(right)});
}
}

} // namespace binder
} // namespace kuzu
6 changes: 6 additions & 0 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,11 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::createEqualityComparisonExpression(
std::shared_ptr<Expression> left, std::shared_ptr<Expression> right) {
return bindComparisonExpression(
common::EQUALS, expression_vector{std::move(left), std::move(right)});
}

} // namespace binder
} // namespace kuzu
57 changes: 25 additions & 32 deletions src/binder/query/query_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,58 +227,51 @@ std::unique_ptr<QueryGraphCollection> QueryGraphCollection::copy() const {
return result;
}

void PropertyKeyValCollection::addPropertyKeyValPair(
const Expression& variable, expression_pair propertyKeyValPair) {
auto varName = variable.getUniqueName();
if (!varNameToPropertyKeyValPairs.contains(varName)) {
varNameToPropertyKeyValPairs.insert(
{varName, std::unordered_map<std::string, expression_pair>{}});
void PropertyKeyValCollection::addKeyVal(
std::shared_ptr<Expression> variable, const std::string& propertyName, expression_pair keyVal) {
if (!propertyKeyValMap.contains(variable)) {
propertyKeyValMap.insert({variable, std::unordered_map<std::string, expression_pair>{}});
}
auto property = (PropertyExpression*)propertyKeyValPair.first.get();
assert(!varNameToPropertyKeyValPairs.at(varName).contains(property->getPropertyName()));
varNameToPropertyKeyValPairs.at(varName).insert(
{property->getPropertyName(), std::move(propertyKeyValPair)});
propertyKeyValMap.at(variable).insert({propertyName, std::move(keyVal)});
}

std::vector<expression_pair> PropertyKeyValCollection::getPropertyKeyValPairs(
const kuzu::binder::Expression& variable) const {
auto varName = variable.getUniqueName();
if (!varNameToPropertyKeyValPairs.contains(varName)) {
return std::vector<expression_pair>{};
}
std::vector<expression_pair> PropertyKeyValCollection::getKeyVals() const {
std::vector<expression_pair> result;
for (auto& [_, setItem] : varNameToPropertyKeyValPairs.at(varName)) {
result.push_back(setItem);
for (auto& [_, keyVals] : propertyKeyValMap) {
for (auto& [_, keyVal] : keyVals) {
result.push_back(keyVal);
}
}
return result;
}

std::vector<expression_pair> PropertyKeyValCollection::getAllPropertyKeyValPairs() const {
std::vector<expression_pair> PropertyKeyValCollection::getKeyVals(
std::shared_ptr<Expression> variable) const {
std::vector<expression_pair> result;
for (auto& [varName, keyValPairsMap] : varNameToPropertyKeyValPairs) {
for (auto& [propertyName, keyValPairs] : keyValPairsMap) {
result.push_back(keyValPairs);
}
if (!propertyKeyValMap.contains(variable)) {
return result;
}
for (auto& [_, keyVal] : propertyKeyValMap.at(variable)) {
result.push_back(keyVal);
}
return result;
}

bool PropertyKeyValCollection::hasPropertyKeyValPair(
const Expression& variable, const std::string& propertyName) const {
auto varName = variable.getUniqueName();
if (!varNameToPropertyKeyValPairs.contains(varName)) {
bool PropertyKeyValCollection::hasKeyVal(
std::shared_ptr<Expression> variable, const std::string& propertyName) const {
if (!propertyKeyValMap.contains(variable)) {
return false;
}
if (!varNameToPropertyKeyValPairs.at(varName).contains(propertyName)) {
if (!propertyKeyValMap.at(variable).contains(propertyName)) {
return false;
}
return true;
}

expression_pair PropertyKeyValCollection::getPropertyKeyValPair(
const Expression& variable, const std::string& propertyName) const {
assert(hasPropertyKeyValPair(variable, propertyName));
return varNameToPropertyKeyValPairs.at(variable.getUniqueName()).at(propertyName);
expression_pair PropertyKeyValCollection::getKeyVal(
std::shared_ptr<Expression> variable, const std::string& propertyName) const {
assert(hasKeyVal(variable, propertyName));
return propertyKeyValMap.at(variable).at(propertyName);
}

} // namespace binder
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ class Binder {
std::shared_ptr<RelExpression> createNonRecursiveQueryRel(const std::string& parsedName,
const std::vector<common::table_id_t>& tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType);
std::shared_ptr<RelExpression> createRecursiveQueryRel(const std::string& parsedName,
common::QueryRelType relType, uint32_t lowerBound, uint32_t upperBound,
std::shared_ptr<RelExpression> createRecursiveQueryRel(const parser::RelPattern& relPattern,
const std::vector<common::table_id_t>& tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType);
std::pair<uint64_t, uint64_t> bindVariableLengthRelBound(const parser::RelPattern& relPattern);
Expand All @@ -189,6 +188,7 @@ class Binder {
std::shared_ptr<NodeExpression> createQueryNode(
const std::string& parsedName, const std::vector<common::table_id_t>& tableIDs);
void bindQueryNodeProperties(NodeExpression& node);

inline std::vector<common::table_id_t> bindNodeTableIDs(
const std::vector<std::string>& tableNames) {
return bindTableIDs(tableNames, common::LogicalTypeID::NODE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class ExistentialSubqueryExpression : public Expression {
}
inline bool hasWhereExpression() const { return whereExpression != nullptr; }
inline std::shared_ptr<Expression> getWhereExpression() const { return whereExpression; }
inline expression_vector getPredicatesSplitOnAnd() const {
return hasWhereExpression() ? whereExpression->splitOnAND() : expression_vector{};
}

expression_vector getChildren() const override;

Expand Down
2 changes: 1 addition & 1 deletion src/include/binder/expression/node_rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class NodeOrRelExpression : public Expression {
protected:
std::string variableName;
std::vector<common::table_id_t> tableIDs;
std::unordered_map<std::string, size_t> propertyNameToIdx;
std::unordered_map<std::string, common::vector_idx_t> propertyNameToIdx;
std::vector<std::unique_ptr<Expression>> properties;
};

Expand Down
11 changes: 7 additions & 4 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ struct RecursiveInfo {
std::shared_ptr<NodeExpression> node;
std::shared_ptr<RelExpression> rel;
std::shared_ptr<Expression> lengthExpression;
expression_vector predicates;

RecursiveInfo(size_t lowerBound, size_t upperBound, std::shared_ptr<NodeExpression> node,
std::shared_ptr<RelExpression> rel, std::shared_ptr<Expression> lengthExpression)
: lowerBound{lowerBound}, upperBound{upperBound}, node{std::move(node)},
rel{std::move(rel)}, lengthExpression{std::move(lengthExpression)} {}
RecursiveInfo(uint64_t lowerBound, uint64_t upperBound, std::shared_ptr<NodeExpression> node,
std::shared_ptr<RelExpression> rel, std::shared_ptr<Expression> lengthExpression,
expression_vector predicates)
: lowerBound{lowerBound}, upperBound{upperBound}, node{std::move(node)}, rel{std::move(
rel)},
lengthExpression{std::move(lengthExpression)}, predicates{std::move(predicates)} {}
};

class RelExpression : public NodeOrRelExpression {
Expand Down
4 changes: 4 additions & 0 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ class ExpressionBinder {
const parser::ParsedExpression& parsedExpression);
std::shared_ptr<Expression> bindBooleanExpression(
common::ExpressionType expressionType, const expression_vector& children);
std::shared_ptr<Expression> combineConjunctiveExpressions(
std::shared_ptr<Expression> left, std::shared_ptr<Expression> right);

std::shared_ptr<Expression> bindComparisonExpression(
const parser::ParsedExpression& parsedExpression);
std::shared_ptr<Expression> bindComparisonExpression(
common::ExpressionType expressionType, const expression_vector& children);
std::shared_ptr<Expression> createEqualityComparisonExpression(
std::shared_ptr<Expression> left, std::shared_ptr<Expression> right);

std::shared_ptr<Expression> bindNullOperatorExpression(
const parser::ParsedExpression& parsedExpression);
Expand Down
5 changes: 3 additions & 2 deletions src/include/binder/query/reading_clause/bound_match_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ class BoundMatchClause : public BoundReadingClause {
inline void setWhereExpression(std::shared_ptr<Expression> expression) {
whereExpression = std::move(expression);
}

inline bool hasWhereExpression() const { return whereExpression != nullptr; }

inline std::shared_ptr<Expression> getWhereExpression() const { return whereExpression; }
inline expression_vector getPredicatesSplitOnAnd() const {
return hasWhereExpression() ? whereExpression->splitOnAND() : expression_vector{};
}

inline bool getIsOptional() const { return isOptional; }

Expand Down
21 changes: 10 additions & 11 deletions src/include/binder/query/reading_clause/query_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,25 +166,24 @@ class PropertyKeyValCollection {
public:
PropertyKeyValCollection() = default;
PropertyKeyValCollection(const PropertyKeyValCollection& other)
: varNameToPropertyKeyValPairs{other.varNameToPropertyKeyValPairs} {}
: propertyKeyValMap{other.propertyKeyValMap} {}

void addPropertyKeyValPair(const Expression& variable, expression_pair propertyKeyValPair);
std::vector<expression_pair> getPropertyKeyValPairs(const Expression& variable) const;
std::vector<expression_pair> getAllPropertyKeyValPairs() const;

bool hasPropertyKeyValPair(const Expression& variable, const std::string& propertyName) const;
expression_pair getPropertyKeyValPair(
const Expression& variable, const std::string& propertyName) const;
void addKeyVal(std::shared_ptr<Expression> variable, const std::string& propertyName,
expression_pair keyVal);
std::vector<expression_pair> getKeyVals() const;
std::vector<expression_pair> getKeyVals(std::shared_ptr<Expression> variable) const;
bool hasKeyVal(std::shared_ptr<Expression> variable, const std::string& propertyName) const;
expression_pair getKeyVal(
std::shared_ptr<Expression> variable, const std::string& propertyName) const;

inline std::unique_ptr<PropertyKeyValCollection> copy() const {
return std::make_unique<PropertyKeyValCollection>(*this);
}

private:
// First indexed on variable name, then indexed on property name.
// First indexed on variable, then indexed on property name.
// a -> { age -> pair<a.age,12>, name -> pair<name,'Alice'>}
std::unordered_map<std::string, std::unordered_map<std::string, expression_pair>>
varNameToPropertyKeyValPairs;
expression_map<std::unordered_map<std::string, expression_pair>> propertyKeyValMap;
};

} // namespace binder
Expand Down
Loading
Loading