Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Node rel evaluator #1743

Merged
merged 12 commits into from
Jul 6, 2023
2 changes: 1 addition & 1 deletion benchmark/queries/ldbc-sf100/join/q29.benchmark
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
-NAME q29
-COMPARE_RESULT 1
-QUERY MATCH (a:Person)-[:knows]->(b:Person) RETURN MIN(a.birthday), MIN(b.birthday)
-ENCODED_JOIN HJ(b._id){E(b)S(a)}{S(b)}
-ENCODED_JOIN HJ(b._ID){E(b)S(a)}{S(b)}
---- 1
1980-02-01|1980-02-01
2 changes: 1 addition & 1 deletion benchmark/queries/ldbc-sf100/join/q30.benchmark
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
-NAME q30
-COMPARE_RESULT 1
-QUERY MATCH (a:Person)-[:knows]->(b:Person)-[:knows]->(c:Person) RETURN MIN(a.birthday), MIN(b.birthday), MIN(c.birthday)
-ENCODED_JOIN HJ(c._id){HJ(a._id){E(c)E(a)S(b)}{S(a)}}{S(c)}
-ENCODED_JOIN HJ(c._ID){HJ(a._ID){E(c)E(a)S(b)}{S(a)}}{S(c)}
---- 1
1980-02-01|1980-02-01|1980-02-01
62 changes: 36 additions & 26 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,28 +106,10 @@ getNodePropertyNameAndPropertiesPairs(const std::vector<NodeTableSchema*>& nodeT

static std::unique_ptr<LogicalType> getRecursiveRelLogicalType(
const NodeExpression& node, const RelExpression& rel) {
std::vector<std::unique_ptr<StructField>> nodeFields;
nodeFields.push_back(std::make_unique<StructField>(
InternalKeyword::ID, std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID)));
for (auto& expression : node.getPropertyExpressions()) {
auto propertyExpression = (PropertyExpression*)expression.get();
nodeFields.push_back(std::make_unique<StructField>(
propertyExpression->getPropertyName(), propertyExpression->getDataType().copy()));
}
auto nodeType = std::make_unique<LogicalType>(
LogicalTypeID::STRUCT, std::make_unique<StructTypeInfo>(std::move(nodeFields)));
auto nodesType = std::make_unique<LogicalType>(
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(std::move(nodeType)));
std::vector<std::unique_ptr<StructField>> relFields;
for (auto& expression : rel.getPropertyExpressions()) {
auto propertyExpression = (PropertyExpression*)expression.get();
relFields.push_back(std::make_unique<StructField>(
propertyExpression->getPropertyName(), propertyExpression->getDataType().copy()));
}
auto relType = std::make_unique<LogicalType>(
LogicalTypeID::STRUCT, std::make_unique<StructTypeInfo>(std::move(relFields)));
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(node.getDataType().copy()));
auto relsType = std::make_unique<LogicalType>(
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(std::move(relType)));
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(rel.getDataType().copy()));
std::vector<std::unique_ptr<StructField>> recursiveRelFields;
recursiveRelFields.push_back(
std::make_unique<StructField>(InternalKeyword::NODES, std::move(nodesType)));
Expand Down Expand Up @@ -213,7 +195,23 @@ std::shared_ptr<RelExpression> Binder::createNonRecursiveQueryRel(const std::str
auto queryRel = make_shared<RelExpression>(LogicalType(LogicalTypeID::REL),
getUniqueExpressionName(parsedName), parsedName, tableIDs, std::move(srcNode),
std::move(dstNode), directionType, QueryRelType::NON_RECURSIVE);
queryRel->setAlias(parsedName);
bindQueryRelProperties(*queryRel);
queryRel->setLabelExpression(expressionBinder.bindLabelFunction(*queryRel));
std::vector<std::unique_ptr<StructField>> relFields;
relFields.push_back(std::make_unique<StructField>(
InternalKeyword::SRC, std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID)));
relFields.push_back(std::make_unique<StructField>(
InternalKeyword::DST, std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID)));
relFields.push_back(std::make_unique<StructField>(
InternalKeyword::LABEL, std::make_unique<LogicalType>(LogicalTypeID::STRING)));
for (auto& expression : queryRel->getPropertyExpressions()) {
auto propertyExpression = (PropertyExpression*)expression.get();
relFields.push_back(std::make_unique<StructField>(
propertyExpression->getPropertyName(), propertyExpression->getDataType().copy()));
}
common::RelType::setExtraTypeInfo(
queryRel->getDataTypeReference(), std::make_unique<StructTypeInfo>(std::move(relFields)));
return queryRel;
}

Expand Down Expand Up @@ -311,6 +309,9 @@ std::shared_ptr<NodeExpression> Binder::bindQueryNode(
}
} else {
queryNode = createQueryNode(nodePattern);
if (!parsedName.empty()) {
variableScope->addExpression(parsedName, queryNode);
}
}
for (auto& [propertyName, rhs] : nodePattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindNodePropertyExpression(*queryNode, propertyName);
Expand All @@ -330,14 +331,24 @@ std::shared_ptr<NodeExpression> Binder::createQueryNode(const NodePattern& nodeP

std::shared_ptr<NodeExpression> Binder::createQueryNode(
const std::string& parsedName, const std::vector<common::table_id_t>& tableIDs) {
auto queryNode =
make_shared<NodeExpression>(getUniqueExpressionName(parsedName), parsedName, tableIDs);
auto queryNode = make_shared<NodeExpression>(LogicalType(common::LogicalTypeID::NODE),
getUniqueExpressionName(parsedName), parsedName, tableIDs);
queryNode->setAlias(parsedName);
queryNode->setInternalIDProperty(expressionBinder.createInternalNodeIDExpression(*queryNode));
bindQueryNodeProperties(*queryNode);
if (!parsedName.empty()) {
variableScope->addExpression(parsedName, queryNode);
queryNode->setInternalIDProperty(expressionBinder.createInternalNodeIDExpression(*queryNode));
queryNode->setLabelExpression(expressionBinder.bindLabelFunction(*queryNode));
std::vector<std::unique_ptr<StructField>> nodeFields;
nodeFields.push_back(std::make_unique<StructField>(
InternalKeyword::ID, std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID)));
nodeFields.push_back(std::make_unique<StructField>(
InternalKeyword::LABEL, std::make_unique<LogicalType>(LogicalTypeID::STRING)));
for (auto& expression : queryNode->getPropertyExpressions()) {
auto propertyExpression = (PropertyExpression*)expression.get();
nodeFields.push_back(std::make_unique<StructField>(
propertyExpression->getPropertyName(), propertyExpression->getDataType().copy()));
}
common::NodeType::setExtraTypeInfo(
queryNode->getDataTypeReference(), std::make_unique<StructTypeInfo>(std::move(nodeFields)));
return queryNode;
}

Expand Down Expand Up @@ -373,7 +384,6 @@ std::vector<table_id_t> Binder::bindTableIDs(
tableIDs.insert(bindNodeTableID(tableName));
}
}

} break;
case LogicalTypeID::REL: {
if (tableNames.empty()) {
Expand Down
67 changes: 24 additions & 43 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,27 @@ std::unique_ptr<BoundWithClause> Binder::bindWithClause(const WithClause& withCl
auto projectionExpressions = bindProjectionExpressions(
projectionBody->getProjectionExpressions(), projectionBody->containsStar());
validateProjectionColumnsInWithClauseAreAliased(projectionExpressions);
auto boundProjectionBody = bindProjectionBody(*projectionBody, projectionExpressions);
expression_vector newProjectionExpressions;
for (auto& expression : projectionExpressions) {
if (ExpressionUtil::isNodeVariable(*expression)) {
auto node = (NodeExpression*)expression.get();
newProjectionExpressions.push_back(node->getInternalIDProperty());
for (auto& property : node->getPropertyExpressions()) {
newProjectionExpressions.push_back(property->copy());
}
} else if (ExpressionUtil::isRelVariable(*expression)) {
auto rel = (RelExpression*)expression.get();
for (auto& property : rel->getPropertyExpressions()) {
newProjectionExpressions.push_back(property->copy());
}
} else {
newProjectionExpressions.push_back(expression);
}
}
auto boundProjectionBody = bindProjectionBody(*projectionBody, newProjectionExpressions);
validateOrderByFollowedBySkipOrLimitInWithClause(*boundProjectionBody);
variableScope->clear();
addExpressionsToScope(boundProjectionBody->getProjectionExpressions());
addExpressionsToScope(projectionExpressions);
auto boundWithClause = std::make_unique<BoundWithClause>(std::move(boundProjectionBody));
if (withClause.hasWhereExpression()) {
boundWithClause->setWhereExpression(bindWhereExpression(*withClause.getWhereExpression()));
Expand All @@ -30,16 +47,9 @@ std::unique_ptr<BoundReturnClause> Binder::bindReturnClause(const ReturnClause&
projectionBody->getProjectionExpressions(), projectionBody->containsStar());
auto statementResult = std::make_unique<BoundStatementResult>();
for (auto& expression : boundProjectionExpressions) {
auto dataType = expression->getDataType();
if (dataType.getLogicalTypeID() == common::LogicalTypeID::NODE ||
dataType.getLogicalTypeID() == common::LogicalTypeID::REL) {
statementResult->addColumn(expression, rewriteNodeOrRelExpression(*expression));
} else {
statementResult->addColumn(expression, expression_vector{expression});
}
statementResult->addColumn(expression);
}
auto boundProjectionBody =
bindProjectionBody(*projectionBody, statementResult->getExpressionsToCollect());
auto boundProjectionBody = bindProjectionBody(*projectionBody, statementResult->getColumns());
return std::make_unique<BoundReturnClause>(
std::move(boundProjectionBody), std::move(statementResult));
}
Expand Down Expand Up @@ -100,6 +110,9 @@ std::unique_ptr<BoundProjectionBody> Binder::bindProjectionBody(
if (ExpressionUtil::isNodeVariable(*expression)) {
auto node = (NodeExpression*)expression.get();
augmentedGroupByExpressions.push_back(node->getInternalIDProperty());
} else if (ExpressionUtil::isRelVariable(*expression)) {
auto rel = (RelExpression*)expression.get();
augmentedGroupByExpressions.push_back(rel->getInternalIDProperty());
}
}
boundProjectionBody->setGroupByExpressions(std::move(augmentedGroupByExpressions));
Expand Down Expand Up @@ -163,38 +176,6 @@ expression_vector Binder::bindProjectionExpressions(
return result;
}

expression_vector Binder::rewriteNodeOrRelExpression(const Expression& expression) {
if (expression.dataType.getLogicalTypeID() == common::LogicalTypeID::NODE) {
return rewriteNodeExpression(expression);
} else {
assert(expression.dataType.getLogicalTypeID() == common::LogicalTypeID::REL);
return rewriteRelExpression(expression);
}
}

expression_vector Binder::rewriteNodeExpression(const kuzu::binder::Expression& expression) {
expression_vector result;
auto& node = (NodeExpression&)expression;
result.push_back(node.getInternalIDProperty());
result.push_back(expressionBinder.bindLabelFunction(node));
for (auto& property : node.getPropertyExpressions()) {
result.push_back(property->copy());
}
return result;
}

expression_vector Binder::rewriteRelExpression(const Expression& expression) {
expression_vector result;
auto& rel = (RelExpression&)expression;
result.push_back(rel.getSrcNode()->getInternalIDProperty());
result.push_back(rel.getDstNode()->getInternalIDProperty());
result.push_back(expressionBinder.bindLabelFunction(rel));
for (auto& property : rel.getPropertyExpressions()) {
result.push_back(property->copy());
}
return result;
}

expression_vector Binder::bindOrderByExpressions(
const std::vector<std::unique_ptr<ParsedExpression>>& orderByExpressions) {
expression_vector boundOrderByExpressions;
Expand Down
21 changes: 8 additions & 13 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,24 +159,19 @@ std::unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(

std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
std::shared_ptr<Expression> expression) {
switch (expression->getDataType().getLogicalTypeID()) {
case common::LogicalTypeID::NODE: {
if (ExpressionUtil::isNodeVariable(*expression)) {
auto& node = (NodeExpression&)*expression;
return node.getInternalIDProperty();
}
case common::LogicalTypeID::REL: {
if (ExpressionUtil::isRelVariable(*expression)) {
return bindRelPropertyExpression(*expression, InternalKeyword::ID);
}
case common::LogicalTypeID::STRUCT: {
auto stringValue =
std::make_unique<Value>(LogicalType{LogicalTypeID::STRING}, InternalKeyword::ID);
return bindScalarFunctionExpression(
expression_vector{expression, createLiteralExpression(std::move(stringValue))},
STRUCT_EXTRACT_FUNC_NAME);
}
default:
throw NotImplementedException("ExpressionBinder::bindInternalIDExpression");
}
assert(expression->dataType.getPhysicalType() == common::PhysicalTypeID::STRUCT);
auto stringValue =
std::make_unique<Value>(LogicalType{LogicalTypeID::STRING}, InternalKeyword::ID);
return bindScalarFunctionExpression(
expression_vector{expression, createLiteralExpression(std::move(stringValue))},
STRUCT_EXTRACT_FUNC_NAME);
}

static std::vector<std::unique_ptr<Value>> populateLabelValues(
Expand Down
7 changes: 3 additions & 4 deletions src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ std::shared_ptr<Expression> ExpressionBinder::bindPropertyExpression(
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(*child,
std::vector<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL, LogicalTypeID::STRUCT});
auto childTypeID = child->dataType.getLogicalTypeID();
if (LogicalTypeID::NODE == childTypeID) {
if (ExpressionUtil::isNodeVariable(*child)) {
return bindNodePropertyExpression(*child, propertyName);
} else if (LogicalTypeID::REL == childTypeID) {
} else if (ExpressionUtil::isRelVariable(*child)) {
return bindRelPropertyExpression(*child, propertyName);
} else {
assert(LogicalTypeID::STRUCT == childTypeID);
assert(child->expressionType == common::FUNCTION);
auto stringValue =
std::make_unique<Value>(LogicalType{LogicalTypeID::STRING}, propertyName);
return bindScalarFunctionExpression(
Expand Down
12 changes: 1 addition & 11 deletions src/binder/bound_statement_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,7 @@ std::unique_ptr<BoundStatementResult> BoundStatementResult::createSingleStringCo
auto value = std::make_unique<common::Value>(
common::LogicalType{common::LogicalTypeID::STRING}, columnName);
auto stringColumn = std::make_shared<LiteralExpression>(std::move(value), columnName);
result->addColumn(stringColumn, expression_vector{stringColumn});
return result;
}

expression_vector BoundStatementResult::getExpressionsToCollect() {
expression_vector result;
for (auto& expressionsToCollect : expressionsToCollectPerColumn) {
for (auto& expression : expressionsToCollect) {
result.push_back(expression);
}
}
result->addColumn(stringColumn);
return result;
}

Expand Down
36 changes: 36 additions & 0 deletions src/binder/expression/expression_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

#include "binder/expression/case_expression.h"
#include "binder/expression/existential_subquery_expression.h"
#include "binder/expression/node_expression.h"
#include "binder/expression/property_expression.h"
#include "binder/expression/rel_expression.h"

namespace kuzu {
namespace binder {
Expand All @@ -15,6 +17,19 @@ expression_vector ExpressionChildrenCollector::collectChildren(const Expression&
case common::ExpressionType::EXISTENTIAL_SUBQUERY: {
return collectExistentialSubqueryChildren(expression);
}
case common::ExpressionType::VARIABLE: {
switch (expression.dataType.getLogicalTypeID()) {
case common::LogicalTypeID::NODE: {
return collectNodeChildren(expression);
}
case common::LogicalTypeID::REL: {
return collectRelChildren(expression);
}
default: {
return expression_vector{};
}
}
}
default: {
return expression.children;
}
Expand Down Expand Up @@ -46,6 +61,27 @@ expression_vector ExpressionChildrenCollector::collectExistentialSubqueryChildre
return result;
}

expression_vector ExpressionChildrenCollector::collectNodeChildren(const Expression& expression) {
expression_vector result;
auto& node = (NodeExpression&)expression;
for (auto& property : node.getPropertyExpressions()) {
result.push_back(property->copy());
}
result.push_back(node.getInternalIDProperty());
return result;
}

expression_vector ExpressionChildrenCollector::collectRelChildren(const Expression& expression) {
expression_vector result;
auto& rel = (RelExpression&)expression;
result.push_back(rel.getSrcNode()->getInternalIDProperty());
result.push_back(rel.getDstNode()->getInternalIDProperty());
for (auto& property : rel.getPropertyExpressions()) {
result.push_back(property->copy());
}
return result;
}

bool ExpressionVisitor::hasExpression(
const Expression& expression, const std::function<bool(const Expression&)>& condition) {
if (condition(expression)) {
Expand Down
Loading
Loading