Skip to content

Commit

Permalink
Merge pull request #1771 from kuzudb/named-path
Browse files Browse the repository at this point in the history
Named path
  • Loading branch information
andyfengHKU committed Jul 7, 2023
2 parents 4dd9770 + 27aaae8 commit 52bc493
Show file tree
Hide file tree
Showing 26 changed files with 1,858 additions and 1,337 deletions.
3 changes: 2 additions & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ oC_Pattern
: oC_PatternPart ( SP? ',' SP? oC_PatternPart )* ;

oC_PatternPart
: oC_AnonymousPatternPart ;
: ( oC_Variable SP? '=' SP? oC_AnonymousPatternPart )
| oC_AnonymousPatternPart ;

oC_AnonymousPatternPart
: oC_PatternElement ;
Expand Down
91 changes: 74 additions & 17 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <set>

#include "binder/binder.h"
#include "binder/expression/path_expression.h"
#include "binder/expression/property_expression.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -32,18 +33,88 @@ Binder::bindGraphPattern(const std::vector<std::unique_ptr<PatternElement>>& gra
std::unique_ptr<QueryGraph> Binder::bindPatternElement(
const PatternElement& patternElement, PropertyKeyValCollection& collection) {
auto queryGraph = std::make_unique<QueryGraph>();
expression_vector nodeAndRels;
auto leftNode = bindQueryNode(*patternElement.getFirstNodePattern(), *queryGraph, collection);
nodeAndRels.push_back(leftNode);
for (auto i = 0u; i < patternElement.getNumPatternElementChains(); ++i) {
auto patternElementChain = patternElement.getPatternElementChain(i);
auto rightNode =
bindQueryNode(*patternElementChain->getNodePattern(), *queryGraph, collection);
bindQueryRel(
auto rel = bindQueryRel(
*patternElementChain->getRelPattern(), leftNode, rightNode, *queryGraph, collection);
nodeAndRels.push_back(rel);
nodeAndRels.push_back(rightNode);
leftNode = rightNode;
}
if (patternElement.hasPathName()) {
auto pathName = patternElement.getPathName();
auto pathExpression = createPathExpression(pathName, nodeAndRels);
variableScope->addExpression(pathName, pathExpression);
}
return queryGraph;
}

static std::unique_ptr<LogicalType> getRecursiveRelLogicalType(
const NodeExpression& node, const RelExpression& rel) {
auto nodesType = std::make_unique<LogicalType>(
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(node.getDataType().copy()));
auto relsType = std::make_unique<LogicalType>(
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(rel.getDataType().copy()));
std::vector<std::unique_ptr<StructField>> recursiveRelFields;
recursiveRelFields.push_back(
std::make_unique<StructField>(InternalKeyword::NODES, std::move(nodesType)));
recursiveRelFields.push_back(
std::make_unique<StructField>(InternalKeyword::RELS, std::move(relsType)));
return std::make_unique<LogicalType>(LogicalTypeID::RECURSIVE_REL,
std::make_unique<StructTypeInfo>(std::move(recursiveRelFields)));
}

std::shared_ptr<Expression> Binder::createPathExpression(
const std::string& pathName, const expression_vector& children) {
std::unordered_set<common::table_id_t> nodeTableIDSet;
std::unordered_set<common::table_id_t> relTableIDSet;
for (auto& child : children) {
switch (child->getDataType().getLogicalTypeID()) {
case common::LogicalTypeID::NODE: {
auto node = (NodeExpression*)child.get();
for (auto tableID : node->getTableIDs()) {
nodeTableIDSet.insert(tableID);
}
} break;
case common::LogicalTypeID::REL: {
auto rel = (RelExpression*)child.get();
for (auto tableID : rel->getTableIDs()) {
relTableIDSet.insert(tableID);
}
} break;
case common::LogicalTypeID::RECURSIVE_REL: {
auto recursiveRel = (RelExpression*)child.get();
auto recursiveInfo = recursiveRel->getRecursiveInfo();
for (auto tableID : recursiveInfo->node->getTableIDs()) {
nodeTableIDSet.insert(tableID);
}
for (auto tableID : recursiveInfo->rel->getTableIDs()) {
relTableIDSet.insert(tableID);
}
} break;
default:
throw NotImplementedException("Binder::createPathExpression");
}
}
auto nodeTableIDs =
std::vector<common::table_id_t>{nodeTableIDSet.begin(), nodeTableIDSet.end()};
std::sort(nodeTableIDs.begin(), nodeTableIDs.end());
auto relTableIDs = std::vector<common::table_id_t>{relTableIDSet.begin(), relTableIDSet.end()};
std::sort(relTableIDs.begin(), relTableIDs.end());
auto node = createQueryNode(InternalKeyword::ANONYMOUS, nodeTableIDs);
auto rel = createNonRecursiveQueryRel(
InternalKeyword::ANONYMOUS, relTableIDs, nullptr, nullptr, RelDirectionType::UNKNOWN);
auto dataType = getRecursiveRelLogicalType(*node, *rel);
auto uniqueName = getUniqueExpressionName(pathName);
return std::make_shared<PathExpression>(
*dataType, uniqueName, pathName, std::move(node), std::move(rel), children);
}

static std::vector<table_id_t> pruneRelTableIDs(const Catalog& catalog_,
const std::vector<table_id_t>& relTableIDs, const NodeExpression& srcNode,
const NodeExpression& dstNode) {
Expand Down Expand Up @@ -104,22 +175,7 @@ getNodePropertyNameAndPropertiesPairs(const std::vector<NodeTableSchema*>& nodeT
return getPropertyNameAndSchemasPairs(propertyNames, propertyNamesToSchemas);
}

static std::unique_ptr<LogicalType> getRecursiveRelLogicalType(
const NodeExpression& node, const RelExpression& rel) {
auto nodesType = std::make_unique<LogicalType>(
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(node.getDataType().copy()));
auto relsType = std::make_unique<LogicalType>(
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(rel.getDataType().copy()));
std::vector<std::unique_ptr<StructField>> recursiveRelFields;
recursiveRelFields.push_back(
std::make_unique<StructField>(InternalKeyword::NODES, std::move(nodesType)));
recursiveRelFields.push_back(
std::make_unique<StructField>(InternalKeyword::RELS, std::move(relsType)));
return std::make_unique<LogicalType>(LogicalTypeID::RECURSIVE_REL,
std::make_unique<StructTypeInfo>(std::move(recursiveRelFields)));
}

void Binder::bindQueryRel(const RelPattern& relPattern,
std::shared_ptr<RelExpression> Binder::bindQueryRel(const RelPattern& relPattern,
const std::shared_ptr<NodeExpression>& leftNode,
const std::shared_ptr<NodeExpression>& rightNode, QueryGraph& queryGraph,
PropertyKeyValCollection& collection) {
Expand Down Expand Up @@ -187,6 +243,7 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
variableScope->addExpression(parsedName, queryRel);
}
queryGraph.addQueryRel(queryRel);
return queryRel;
}

std::shared_ptr<RelExpression> Binder::createNonRecursiveQueryRel(const std::string& parsedName,
Expand Down
4 changes: 4 additions & 0 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ std::unique_ptr<BoundWithClause> Binder::bindWithClause(const WithClause& withCl
for (auto& property : rel->getPropertyExpressions()) {
newProjectionExpressions.push_back(property->copy());
}
} else if (ExpressionUtil::isRecursiveRelVariable(*expression)) {
auto rel = (RelExpression*)expression.get();
newProjectionExpressions.push_back(expression);
newProjectionExpressions.push_back(rel->getLengthExpression());
} else {
newProjectionExpressions.push_back(expression);
}
Expand Down
2 changes: 2 additions & 0 deletions src/common/expression_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ std::string expressionTypeToString(ExpressionType type) {
return "LITERAL";
case VARIABLE:
return "VARIABLE";
case PATH:
return "PATH";
case PARAMETER:
return "PARAMETER";
case FUNCTION:
Expand Down
4 changes: 4 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ StructTypeInfo::StructTypeInfo(std::vector<std::unique_ptr<StructField>> fields)
}
}

bool StructTypeInfo::hasField(const std::string& fieldName) const {
return fieldNameToIdxMap.contains(fieldName);
}

struct_field_idx_t StructTypeInfo::getStructFieldIdx(std::string fieldName) const {
StringUtils::toUpper(fieldName);
if (fieldNameToIdxMap.contains(fieldName)) {
Expand Down
8 changes: 4 additions & 4 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ void ValueVector::copyFromVectorData(
}

void ValueVector::copyFromVectorData(
uint64_t posToCopy, const ValueVector* srcVector, uint64_t srcPos) {
setNull(posToCopy, srcVector->isNull(srcPos));
if (!isNull(posToCopy)) {
copyFromVectorData(getData() + posToCopy * getNumBytesPerValue(), srcVector,
uint64_t dstPos, const ValueVector* srcVector, uint64_t srcPos) {
setNull(dstPos, srcVector->isNull(srcPos));
if (!isNull(dstPos)) {
copyFromVectorData(getData() + dstPos * getNumBytesPerValue(), srcVector,
srcVector->getData() + srcPos * srcVector->getNumBytesPerValue());
}
}
Expand Down
1 change: 1 addition & 0 deletions src/expression_evaluator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_library(kuzu_expression_evaluator
function_evaluator.cpp
literal_evaluator.cpp
node_rel_evaluator.cpp
path_evaluator.cpp
reference_evaluator.cpp)

set(ALL_OBJECT_FILES
Expand Down
Loading

0 comments on commit 52bc493

Please sign in to comment.