Skip to content

Commit

Permalink
Add property scan for path
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jun 20, 2023
1 parent d7c2cab commit 1dd0201
Show file tree
Hide file tree
Showing 42 changed files with 766 additions and 281 deletions.
172 changes: 110 additions & 62 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <set>

#include "binder/binder.h"
#include "binder/expression/property_expression.h"

using namespace kuzu::common;
using namespace kuzu::parser;
Expand Down Expand Up @@ -129,6 +130,39 @@ getNodePropertyNameAndPropertiesPairs(const std::vector<NodeTableSchema*>& nodeT
return getPropertyNameAndSchemasPairs(propertyNames, propertyNamesToSchemas);
}

static std::unique_ptr<LogicalType> getRecursiveRelLogicalType(
const NodeExpression& node, const RelExpression& rel) {
std::vector<std::unique_ptr<StructField>> nodeFields;
nodeFields.push_back(std::make_unique<StructField>(
InternalKeyword::ID, std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID)));
for (auto& expression : node.getPropertyExpressions()) {
auto propertyExpression = (PropertyExpression*)expression.get();
nodeFields.push_back(std::make_unique<StructField>(
propertyExpression->getPropertyName(), propertyExpression->getDataType().copy()));
}
auto nodeType = std::make_unique<LogicalType>(
LogicalTypeID::STRUCT, std::make_unique<StructTypeInfo>(std::move(nodeFields)));
auto nodesType = std::make_unique<LogicalType>(
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(std::move(nodeType)));
std::vector<std::unique_ptr<StructField>> relFields;
for (auto& expression : rel.getPropertyExpressions()) {
auto propertyExpression = (PropertyExpression*)expression.get();
relFields.push_back(std::make_unique<StructField>(
propertyExpression->getPropertyName(), propertyExpression->getDataType().copy()));
}
auto relType = std::make_unique<LogicalType>(
LogicalTypeID::STRUCT, std::make_unique<StructTypeInfo>(std::move(relFields)));
auto relsType = std::make_unique<LogicalType>(
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(std::move(relType)));
std::vector<std::unique_ptr<StructField>> recursiveRelFields;
recursiveRelFields.push_back(
std::make_unique<StructField>(InternalKeyword::NODES, std::move(nodesType)));
recursiveRelFields.push_back(
std::make_unique<StructField>(InternalKeyword::RELS, std::move(relsType)));
return std::make_unique<LogicalType>(LogicalTypeID::RECURSIVE_REL,
std::make_unique<StructTypeInfo>(std::move(recursiveRelFields)));
}

void Binder::bindQueryRel(const RelPattern& relPattern,
const std::shared_ptr<NodeExpression>& leftNode,
const std::shared_ptr<NodeExpression>& rightNode, QueryGraph& queryGraph,
Expand Down Expand Up @@ -170,61 +204,21 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
throw BinderException("Self-loop rel " + parsedName + " is not supported.");
}
// bind variable length
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
auto isVariableLength = !(lowerBound == 1 && upperBound == 1);
if (!isVariableLength) {
std::shared_ptr<RelExpression> queryRel;
if (QueryRelTypeUtils::isRecursive(relPattern.getRelType())) {
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
queryRel = createRecursiveQueryRel(relPattern.getVariableName(), relPattern.getRelType(),
lowerBound, upperBound, tableIDs, srcNode, dstNode, directionType);
} else {
tableIDs = pruneRelTableIDs(catalog, tableIDs, *srcNode, *dstNode);
if (tableIDs.empty()) {
throw BinderException("Nodes " + srcNode->toString() + " and " + dstNode->toString() +
" are not connected through rel " + parsedName + ".");
}
}
common::LogicalType dataType;
if (isVariableLength) {
std::vector<std::unique_ptr<StructField>> structFields;
auto varListTypeInfo = std::make_unique<common::VarListTypeInfo>(
std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID));
auto nodeStructField = std::make_unique<StructField>(InternalKeyword::NODES,
std::make_unique<LogicalType>(LogicalTypeID::VAR_LIST, varListTypeInfo->copy()));
auto relStructField = std::make_unique<StructField>(InternalKeyword::RELS,
std::make_unique<LogicalType>(LogicalTypeID::VAR_LIST, varListTypeInfo->copy()));
structFields.push_back(std::move(nodeStructField));
structFields.push_back(std::move(relStructField));
auto structTypeInfo = std::make_unique<StructTypeInfo>(std::move(structFields));
dataType =
common::LogicalType(common::LogicalTypeID::RECURSIVE_REL, std::move(structTypeInfo));
} else {
dataType = common::LogicalType(common::LogicalTypeID::REL);
}
auto queryRel = make_shared<RelExpression>(dataType, getUniqueExpressionName(parsedName),
parsedName, tableIDs, srcNode, dstNode, directionType, relPattern.getRelType());
if (isVariableLength) {
std::unordered_set<common::table_id_t> recursiveNodeTableIDs;
for (auto relTableID : tableIDs) {
auto relTableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(relTableID);
recursiveNodeTableIDs.insert(relTableSchema->srcTableID);
recursiveNodeTableIDs.insert(relTableSchema->dstTableID);
}
auto recursiveNode =
createQueryNode("", std::vector<common::table_id_t>{
recursiveNodeTableIDs.begin(), recursiveNodeTableIDs.end()});
auto lengthExpression = expressionBinder.createInternalLengthExpression(*queryRel);
auto recursiveInfo = std::make_unique<RecursiveInfo>(
lowerBound, upperBound, std::move(recursiveNode), std::move(lengthExpression));
queryRel->setRecursiveInfo(std::move(recursiveInfo));
queryRel = createNonRecursiveQueryRel(
relPattern.getVariableName(), tableIDs, srcNode, dstNode, directionType);
}
queryRel->setAlias(parsedName);
// resolve properties associate with rel table
std::vector<RelTableSchema*> relTableSchemas;
for (auto tableID : tableIDs) {
relTableSchemas.push_back(catalog.getReadOnlyVersion()->getRelTableSchema(tableID));
}
for (auto& [propertyName, propertySchemas] :
getRelPropertyNameAndPropertiesPairs(relTableSchemas)) {
auto propertyExpression = expressionBinder.createPropertyExpression(
*queryRel, propertySchemas, false /* isPrimaryKey */);
queryRel->addPropertyExpression(propertyName, std::move(propertyExpression));
}
if (!parsedName.empty()) {
variableScope->addExpression(parsedName, queryRel);
}
Expand All @@ -238,6 +232,42 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
queryGraph.addQueryRel(queryRel);
}

std::shared_ptr<RelExpression> Binder::createNonRecursiveQueryRel(const std::string& parsedName,
const std::vector<common::table_id_t>& tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType) {
auto queryRel = make_shared<RelExpression>(LogicalType(LogicalTypeID::REL),
getUniqueExpressionName(parsedName), parsedName, tableIDs, std::move(srcNode),
std::move(dstNode), directionType, QueryRelType::NON_RECURSIVE);
bindQueryRelProperties(*queryRel);
return queryRel;
}

std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const std::string& parsedName,
common::QueryRelType relType, uint32_t lowerBound, uint32_t upperBound,
const std::vector<common::table_id_t>& tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType) {
std::unordered_set<common::table_id_t> recursiveNodeTableIDs;
for (auto relTableID : tableIDs) {
auto relTableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(relTableID);
recursiveNodeTableIDs.insert(relTableSchema->srcTableID);
recursiveNodeTableIDs.insert(relTableSchema->dstTableID);
}
auto tmpNode = createQueryNode(
InternalKeyword::ANONYMOUS, std::vector<common::table_id_t>{recursiveNodeTableIDs.begin(),
recursiveNodeTableIDs.end()});
auto tmpRel = createNonRecursiveQueryRel(
InternalKeyword::ANONYMOUS, tableIDs, nullptr, nullptr, directionType);
auto queryRel = make_shared<RelExpression>(*getRecursiveRelLogicalType(*tmpNode, *tmpRel),
getUniqueExpressionName(parsedName), parsedName, tableIDs, std::move(srcNode),
std::move(dstNode), directionType, relType);
auto lengthExpression = expressionBinder.createInternalLengthExpression(*queryRel);
auto recursiveInfo = std::make_unique<RecursiveInfo>(
lowerBound, upperBound, std::move(tmpNode), std::move(tmpRel), std::move(lengthExpression));
queryRel->setRecursiveInfo(std::move(recursiveInfo));
bindQueryRelProperties(*queryRel);
return queryRel;
}

std::pair<uint64_t, uint64_t> Binder::bindVariableLengthRelBound(
const kuzu::parser::RelPattern& relPattern) {
auto lowerBound = std::min(TypeUtils::convertToUint32(relPattern.getLowerBound().c_str()),
Expand All @@ -254,6 +284,19 @@ std::pair<uint64_t, uint64_t> Binder::bindVariableLengthRelBound(
return std::make_pair(lowerBound, upperBound);
}

void Binder::bindQueryRelProperties(RelExpression& rel) {
std::vector<RelTableSchema*> tableSchemas;
for (auto tableID : rel.getTableIDs()) {
tableSchemas.push_back(catalog.getReadOnlyVersion()->getRelTableSchema(tableID));
}
for (auto& [propertyName, propertySchemas] :
getRelPropertyNameAndPropertiesPairs(tableSchemas)) {
auto propertyExpression = expressionBinder.createPropertyExpression(
rel, propertySchemas, false /* isPrimaryKey */);
rel.addPropertyExpression(propertyName, std::move(propertyExpression));
}
}

std::shared_ptr<NodeExpression> Binder::bindQueryNode(
const NodePattern& nodePattern, QueryGraph& queryGraph, PropertyKeyValCollection& collection) {
auto parsedName = nodePattern.getVariableName();
Expand Down Expand Up @@ -294,26 +337,31 @@ std::shared_ptr<NodeExpression> Binder::createQueryNode(
make_shared<NodeExpression>(getUniqueExpressionName(parsedName), parsedName, tableIDs);
queryNode->setAlias(parsedName);
queryNode->setInternalIDProperty(expressionBinder.createInternalNodeIDExpression(*queryNode));
// resolve properties associate with node table
std::vector<NodeTableSchema*> nodeTableSchemas;
for (auto tableID : tableIDs) {
nodeTableSchemas.push_back(catalog.getReadOnlyVersion()->getNodeTableSchema(tableID));
}
auto isSingleTable = nodeTableSchemas.size() == 1;
for (auto& [propertyName, propertySchemas] :
getNodePropertyNameAndPropertiesPairs(nodeTableSchemas)) {
auto isPrimaryKey = isSingleTable && nodeTableSchemas[0]->getPrimaryKey().propertyID ==
propertySchemas[0].propertyID;
auto propertyExpression =
expressionBinder.createPropertyExpression(*queryNode, propertySchemas, isPrimaryKey);
queryNode->addPropertyExpression(propertyName, std::move(propertyExpression));
}
bindQueryNodeProperties(*queryNode);
if (!parsedName.empty()) {
variableScope->addExpression(parsedName, queryNode);
}
return queryNode;
}

void Binder::bindQueryNodeProperties(NodeExpression& node) {
std::vector<NodeTableSchema*> tableSchemas;
for (auto tableID : node.getTableIDs()) {
tableSchemas.push_back(catalog.getReadOnlyVersion()->getNodeTableSchema(tableID));
}
for (auto& [propertyName, propertySchemas] :
getNodePropertyNameAndPropertiesPairs(tableSchemas)) {
bool isPrimaryKey = false;
if (!node.isMultiLabeled()) {
isPrimaryKey =
tableSchemas[0]->getPrimaryKey().propertyID == propertySchemas[0].propertyID;
}
auto propertyExpression =
expressionBinder.createPropertyExpression(node, propertySchemas, isPrimaryKey);
node.addPropertyExpression(propertyName, std::move(propertyExpression));
}
}

std::vector<table_id_t> Binder::bindTableIDs(
const std::vector<std::string>& tableNames, LogicalTypeID nodeOrRelType) {
std::unordered_set<table_id_t> tableIDs;
Expand Down
23 changes: 15 additions & 8 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
// e.g. COUNT(a) -> COUNT(a._id)
if (child->dataType.getLogicalTypeID() == LogicalTypeID::NODE ||
child->dataType.getLogicalTypeID() == LogicalTypeID::REL) {
child = bindInternalIDExpression(*child);
child = bindInternalIDExpression(child);
}
childrenTypes.push_back(child->dataType);
children.push_back(std::move(child));
Expand Down Expand Up @@ -131,9 +131,9 @@ std::shared_ptr<Expression> ExpressionBinder::rewriteFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName) {
if (functionName == ID_FUNC_NAME) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(
*child, std::vector<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL});
return bindInternalIDExpression(*child);
validateExpectedDataType(*child, std::vector<LogicalTypeID>{LogicalTypeID::NODE,
LogicalTypeID::REL, LogicalTypeID::STRUCT});
return bindInternalIDExpression(child);
} else if (functionName == LABEL_FUNC_NAME) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(
Expand All @@ -158,14 +158,21 @@ std::unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
}

std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
const Expression& expression) {
switch (expression.getDataType().getLogicalTypeID()) {
std::shared_ptr<Expression> expression) {
switch (expression->getDataType().getLogicalTypeID()) {
case common::LogicalTypeID::NODE: {
auto& node = (NodeExpression&)expression;
auto& node = (NodeExpression&)*expression;
return node.getInternalIDProperty();
}
case common::LogicalTypeID::REL: {
return bindRelPropertyExpression(expression, InternalKeyword::ID);
return bindRelPropertyExpression(*expression, InternalKeyword::ID);
}
case common::LogicalTypeID::STRUCT: {
auto stringValue =
std::make_unique<Value>(LogicalType{LogicalTypeID::STRING}, InternalKeyword::ID);
return bindScalarFunctionExpression(
expression_vector{expression, createLiteralExpression(std::move(stringValue))},
STRUCT_EXTRACT_FUNC_NAME);
}
default:
throw NotImplementedException("ExpressionBinder::bindInternalIDExpression");
Expand Down
7 changes: 7 additions & 0 deletions src/binder/bind_expression/bind_variable_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/variable_expression.h"
#include "binder/expression_binder.h"
#include "parser/expression/parsed_variable_expression.h"

Expand All @@ -18,5 +19,11 @@ std::shared_ptr<Expression> ExpressionBinder::bindVariableExpression(
"Variable " + parsedExpression.getRawName() + " is not in scope.");
}

std::shared_ptr<Expression> ExpressionBinder::createVariableExpression(
common::LogicalType logicalType, std::string uniqueName, std::string name) {
return std::make_shared<VariableExpression>(
std::move(logicalType), std::move(uniqueName), std::move(name));
}

} // namespace binder
} // namespace kuzu
8 changes: 4 additions & 4 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ std::shared_ptr<Expression> Binder::createVariable(
throw BinderException("Variable " + name + " already exists.");
}
auto uniqueName = getUniqueExpressionName(name);
auto variable = make_shared<VariableExpression>(dataType, uniqueName, name);
variable->setAlias(name);
variableScope->addExpression(name, variable);
return variable;
auto expression = expressionBinder.createVariableExpression(dataType, uniqueName, name);
expression->setAlias(name);
variableScope->addExpression(name, expression);
return expression;
}

void Binder::validateFirstMatchIsNotOptional(const SingleQuery& singleQuery) {
Expand Down
8 changes: 8 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ struct_field_idx_t StructTypeInfo::getStructFieldIdx(std::string fieldName) cons
return INVALID_STRUCT_FIELD_IDX;
}

StructField* StructTypeInfo::getStructField(const std::string& fieldName) const {
auto idx = getStructFieldIdx(fieldName);
if (idx == INVALID_STRUCT_FIELD_IDX) {
throw BinderException("Cannot find field " + fieldName + " in STRUCT.");
}
return fields[idx].get();
}

std::vector<LogicalType*> StructTypeInfo::getChildrenTypes() const {
std::vector<LogicalType*> childrenTypesToReturn{fields.size()};
for (auto i = 0u; i < fields.size(); i++) {
Expand Down
10 changes: 10 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,22 @@ class Binder {
const std::shared_ptr<NodeExpression>& leftNode,
const std::shared_ptr<NodeExpression>& rightNode, QueryGraph& queryGraph,
PropertyKeyValCollection& collection);
std::shared_ptr<RelExpression> createNonRecursiveQueryRel(const std::string& parsedName,
const std::vector<common::table_id_t>& tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType);
std::shared_ptr<RelExpression> createRecursiveQueryRel(const std::string& parsedName,
common::QueryRelType relType, uint32_t lowerBound, uint32_t upperBound,
const std::vector<common::table_id_t>& tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType);
std::pair<uint64_t, uint64_t> bindVariableLengthRelBound(const parser::RelPattern& relPattern);
void bindQueryRelProperties(RelExpression& rel);

std::shared_ptr<NodeExpression> bindQueryNode(const parser::NodePattern& nodePattern,
QueryGraph& queryGraph, PropertyKeyValCollection& collection);
std::shared_ptr<NodeExpression> createQueryNode(const parser::NodePattern& nodePattern);
std::shared_ptr<NodeExpression> createQueryNode(
const std::string& parsedName, const std::vector<common::table_id_t>& tableIDs);
void bindQueryNodeProperties(NodeExpression& node);
inline std::vector<common::table_id_t> bindNodeTableIDs(
const std::vector<std::string>& tableNames) {
return bindTableIDs(tableNames, common::LogicalTypeID::NODE);
Expand Down
16 changes: 8 additions & 8 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ enum class RelDirectionType : uint8_t {
BOTH = 1,
};

class RelExpression;

struct RecursiveInfo {
uint64_t lowerBound;
uint64_t upperBound;
std::shared_ptr<NodeExpression> recursiveNode;
std::shared_ptr<NodeExpression> node;
std::shared_ptr<RelExpression> rel;
std::shared_ptr<Expression> lengthExpression;

RecursiveInfo(size_t lowerBound, size_t upperBound,
std::shared_ptr<NodeExpression> recursiveNode, std::shared_ptr<Expression> lengthExpression)
: lowerBound{lowerBound}, upperBound{upperBound}, recursiveNode{std::move(recursiveNode)},
lengthExpression{std::move(lengthExpression)} {}
RecursiveInfo(size_t lowerBound, size_t upperBound, std::shared_ptr<NodeExpression> node,
std::shared_ptr<RelExpression> rel, std::shared_ptr<Expression> lengthExpression)
: lowerBound{lowerBound}, upperBound{upperBound}, node{std::move(node)},
rel{std::move(rel)}, lengthExpression{std::move(lengthExpression)} {}
};

class RelExpression : public NodeOrRelExpression {
Expand Down Expand Up @@ -58,9 +61,6 @@ class RelExpression : public NodeOrRelExpression {
inline RecursiveInfo* getRecursiveInfo() const { return recursiveInfo.get(); }
inline size_t getLowerBound() const { return recursiveInfo->lowerBound; }
inline size_t getUpperBound() const { return recursiveInfo->upperBound; }
inline std::shared_ptr<NodeExpression> getRecursiveNode() const {
return recursiveInfo->recursiveNode;
}
inline std::shared_ptr<Expression> getLengthExpression() const {
return recursiveInfo->lengthExpression;
}
Expand Down
Loading

0 comments on commit 1dd0201

Please sign in to comment.