Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Dec 1, 2022
1 parent b5e8b6a commit 53d0518
Show file tree
Hide file tree
Showing 17 changed files with 152 additions and 121 deletions.
24 changes: 24 additions & 0 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <set>

#include "binder/binder.h"

namespace kuzu {
Expand Down Expand Up @@ -37,6 +39,27 @@ unique_ptr<QueryGraph> Binder::bindPatternElement(
return queryGraph;
}

// E.g. MATCH (:person)-[:studyAt]->(:person) ...
static void validateNodeRelConnectivity(const Catalog& catalog_, const RelExpression& rel,
const NodeExpression& srcNode, const NodeExpression& dstNode) {
set<pair<table_id_t, table_id_t>> srcDstTableIDs;
for (auto relTableID : rel.getTableIDs()) {
for (auto [srcTableID, dstTableID] :
catalog_.getReadOnlyVersion()->getRelTableSchema(relTableID)->srcDstTableIDs) {
srcDstTableIDs.insert({srcTableID, dstTableID});
}
}
for (auto srcTableID : srcNode.getTableIDs()) {
for (auto dstTableID : dstNode.getTableIDs()) {
if (srcDstTableIDs.contains(make_pair(srcTableID, dstTableID))) {
return;
}
}
}
throw BinderException("Nodes " + srcNode.getRawName() + " and " + dstNode.getRawName() +
" are not connected through rel " + rel.getRawName() + ".");
}

void Binder::bindQueryRel(const RelPattern& relPattern, const shared_ptr<NodeExpression>& leftNode,
const shared_ptr<NodeExpression>& rightNode, QueryGraph& queryGraph,
PropertyKeyValCollection& collection) {
Expand Down Expand Up @@ -73,6 +96,7 @@ void Binder::bindQueryRel(const RelPattern& relPattern, const shared_ptr<NodeExp
}
queryRel->setAlias(parsedName);
queryRel->setRawName(parsedName);
validateNodeRelConnectivity(catalog, *queryRel, *srcNode, *dstNode);
if (!parsedName.empty()) {
variablesInScope.insert({parsedName, queryRel});
}
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@ expression_vector Binder::rewriteAsAllProperties(
properties.push_back(property);
}
}
} break ;
} break;
case REL: {
auto& rel = (RelExpression&)*expression;
for (auto tableID : rel.getTableIDs()) {
for (auto& property : catalog.getReadOnlyVersion()->getRelProperties(tableID)) {
properties.push_back(property);
}
}
} break ;
} break;
default:
throw NotImplementedException(
"Cannot rewrite type " + Types::dataTypeToString(nodeOrRelType));
Expand Down
3 changes: 2 additions & 1 deletion src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ unique_ptr<BoundCreateRel> Binder::bindCreateRel(
if (collection.hasPropertyKeyValPair(*rel, property.name)) {
setItems.push_back(collection.getPropertyKeyValPair(*rel, property.name));
} else {
auto propertyExpression = expressionBinder.bindRelPropertyExpression(rel, property.name);
auto propertyExpression =
expressionBinder.bindRelPropertyExpression(rel, property.name);
shared_ptr<Expression> nullExpression =
LiteralExpression::createNullLiteralExpression(getUniqueExpressionName("NULL"));
nullExpression = ExpressionBinder::implicitCastIfNecessary(
Expand Down
3 changes: 2 additions & 1 deletion src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class Binder {
const vector<unique_ptr<ParsedExpression>>& projectionExpressions, bool containsStar);
// For RETURN clause, we write variable "v" as all properties of "v"
expression_vector rewriteProjectionExpressions(const expression_vector& expressions);
expression_vector rewriteAsAllProperties(const shared_ptr<Expression>& expression, DataTypeID nodeOrRelType);
expression_vector rewriteAsAllProperties(
const shared_ptr<Expression>& expression, DataTypeID nodeOrRelType);

void bindOrderBySkipLimitIfNecessary(
BoundProjectionBody& boundProjectionBody, const ProjectionBody& projectionBody);
Expand Down
1 change: 0 additions & 1 deletion src/include/storage/store/rels_statistics.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class RelStatistics : public TableStatistics {
move(numRelsPerDirectionBoundTable)} {}
RelStatistics(vector<pair<table_id_t, table_id_t>> srcDstTableIDs);


inline uint64_t getNumRelsForDirectionBoundTable(
RelDirection relDirection, table_id_t boundNodeTableID) const {
if (!numRelsPerDirectionBoundTable[relDirection].contains(boundNodeTableID)) {
Expand Down
75 changes: 38 additions & 37 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ vector<unique_ptr<LogicalPlan>> JoinOrderEnumerator::planCrossProduct(
return result;
}

vector<unique_ptr<LogicalPlan>> JoinOrderEnumerator::enumerate(QueryGraph* queryGraph,
expression_vector& predicates) {
vector<unique_ptr<LogicalPlan>> JoinOrderEnumerator::enumerate(
QueryGraph* queryGraph, expression_vector& predicates) {
context->init(queryGraph, predicates);
if (!context->expressionsToScanFromOuter.empty()) {
planOuterExpressionsScan(context->expressionsToScanFromOuter);
Expand Down Expand Up @@ -122,8 +122,8 @@ void JoinOrderEnumerator::planOuterExpressionsScan(expression_vector& expression
}
auto plan = make_unique<LogicalPlan>();
appendFTableScan(context->outerPlan, expressions, *plan);
auto predicates = getNewlyMatchedExpressions(context->getEmptySubqueryGraph(), newSubgraph,
context->getWhereExpressions());
auto predicates = getNewlyMatchedExpressions(
context->getEmptySubqueryGraph(), newSubgraph, context->getWhereExpressions());
for (auto& predicate : predicates) {
queryPlanner->appendFilter(predicate, *plan);
}
Expand All @@ -150,8 +150,8 @@ static bool isPrimaryPropertyAndLiteralPair(const Expression& left, const Expres
return propertyExpression.getPropertyID(node.getTableID()) == primaryKeyID;
}

static bool isIndexScanExpression(Expression& expression, const NodeExpression& node,
uint32_t primaryKeyID) {
static bool isIndexScanExpression(
Expression& expression, const NodeExpression& node, uint32_t primaryKeyID) {
if (expression.expressionType != EQUALS) { // check equality comparison
return false;
}
Expand Down Expand Up @@ -193,8 +193,8 @@ void JoinOrderEnumerator::planNodeScan(uint32_t nodePos) {
auto newSubgraph = context->getEmptySubqueryGraph();
newSubgraph.addQueryNode(nodePos);
auto plan = make_unique<LogicalPlan>();
auto predicates = getNewlyMatchedExpressions(context->getEmptySubqueryGraph(), newSubgraph,
context->getWhereExpressions());
auto predicates = getNewlyMatchedExpressions(
context->getEmptySubqueryGraph(), newSubgraph, context->getWhereExpressions());
// In un-nested subquery, e.g. MATCH (a) OPTIONAL MATCH (a)-[e1]->(b), the inner query
// ("(a)-[e1]->(b)") needs to scan a, which is already scanned in the outer query (a). To avoid
// scanning storage twice, we keep track of node table "a" and make sure when planning inner
Expand All @@ -221,17 +221,17 @@ void JoinOrderEnumerator::planNodeScan(uint32_t nodePos) {
context->addPlan(newSubgraph, std::move(plan));
}

void JoinOrderEnumerator::planFiltersForNode(expression_vector& predicates,
shared_ptr<NodeExpression> node, LogicalPlan& plan) {
void JoinOrderEnumerator::planFiltersForNode(
expression_vector& predicates, shared_ptr<NodeExpression> node, LogicalPlan& plan) {
for (auto& predicate : predicates) {
auto propertiesToScan = getPropertiesForVariable(*predicate, *node);
queryPlanner->appendScanNodePropIfNecessary(propertiesToScan, node, plan);
queryPlanner->appendFilter(predicate, plan);
}
}

void JoinOrderEnumerator::planPropertyScansForNode(shared_ptr<NodeExpression> node,
LogicalPlan& plan) {
void JoinOrderEnumerator::planPropertyScansForNode(
shared_ptr<NodeExpression> node, LogicalPlan& plan) {
auto properties = queryPlanner->getPropertiesForNode(*node);
queryPlanner->appendScanNodePropIfNecessary(properties, node, plan);
}
Expand All @@ -240,8 +240,8 @@ void JoinOrderEnumerator::planRelScan(uint32_t relPos) {
auto rel = context->queryGraph->getQueryRel(relPos);
auto newSubgraph = context->getEmptySubqueryGraph();
newSubgraph.addQueryRel(relPos);
auto predicates = getNewlyMatchedExpressions(context->getEmptySubqueryGraph(), newSubgraph,
context->getWhereExpressions());
auto predicates = getNewlyMatchedExpressions(
context->getEmptySubqueryGraph(), newSubgraph, context->getWhereExpressions());
for (auto direction : REL_DIRECTIONS) {
auto plan = make_unique<LogicalPlan>();
auto boundNode = direction == FWD ? rel->getSrcNode() : rel->getDstNode();
Expand All @@ -251,17 +251,17 @@ void JoinOrderEnumerator::planRelScan(uint32_t relPos) {
}
}

void JoinOrderEnumerator::planFiltersForRel(expression_vector& predicates, RelExpression& rel,
RelDirection direction, LogicalPlan& plan) {
void JoinOrderEnumerator::planFiltersForRel(
expression_vector& predicates, RelExpression& rel, RelDirection direction, LogicalPlan& plan) {
for (auto& predicate : predicates) {
auto relPropertiesToScan = getPropertiesForVariable(*predicate, rel);
queryPlanner->appendScanRelPropsIfNecessary(relPropertiesToScan, rel, direction, plan);
queryPlanner->appendFilter(predicate, plan);
}
}

void JoinOrderEnumerator::planPropertyScansForRel(RelExpression& rel, RelDirection direction,
LogicalPlan& plan) {
void JoinOrderEnumerator::planPropertyScansForRel(
RelExpression& rel, RelDirection direction, LogicalPlan& plan) {
auto relProperties = queryPlanner->getPropertiesForRel(rel);
queryPlanner->appendScanRelPropsIfNecessary(relProperties, rel, direction, plan);
}
Expand Down Expand Up @@ -368,8 +368,8 @@ void JoinOrderEnumerator::planWCOJoin(const SubqueryGraph& subgraph,
// We prune such join.
// Note that this does not mean we may lose good plan. An equivalent join can be found between [e2]
// and (a)-[e1]->(b).
static bool needPruneImplicitJoins(const SubqueryGraph& leftSubgraph,
const SubqueryGraph& rightSubgraph, uint32_t numJoinNodes) {
static bool needPruneImplicitJoins(
const SubqueryGraph& leftSubgraph, const SubqueryGraph& rightSubgraph, uint32_t numJoinNodes) {
auto leftNodePositions = leftSubgraph.getNodePositionsIgnoringNodeSelector();
auto rightNodePositions = rightSubgraph.getNodePositionsIgnoringNodeSelector();
auto intersectionSize = 0;
Expand Down Expand Up @@ -493,8 +493,8 @@ void JoinOrderEnumerator::planFiltersForHashJoin(expression_vector& predicates,
}
}

void JoinOrderEnumerator::appendFTableScan(LogicalPlan* outerPlan,
expression_vector& expressionsToScan, LogicalPlan& plan) {
void JoinOrderEnumerator::appendFTableScan(
LogicalPlan* outerPlan, expression_vector& expressionsToScan, LogicalPlan& plan) {
unordered_map<uint32_t, expression_vector> groupPosToExpressionsMap;
for (auto& expression : expressionsToScan) {
auto outerPos = outerPlan->getSchema()->getGroupPos(expression->getUniqueName());
Expand All @@ -516,8 +516,8 @@ void JoinOrderEnumerator::appendFTableScan(LogicalPlan* outerPlan,
assert(outerPlan->getLastOperator()->getLogicalOperatorType() ==
LogicalOperatorType::LOGICAL_ACCUMULATE);
auto logicalAcc = (LogicalAccumulate*)outerPlan->getLastOperator().get();
auto fTableScan = make_shared<LogicalFTableScan>(expressionsToScan,
logicalAcc->getExpressions(), flatOutputGroupPositions);
auto fTableScan = make_shared<LogicalFTableScan>(
expressionsToScan, logicalAcc->getExpressions(), flatOutputGroupPositions);
plan.setLastOperator(std::move(fTableScan));
}

Expand All @@ -536,8 +536,8 @@ void JoinOrderEnumerator::appendScanNode(shared_ptr<NodeExpression>& node, Logic
plan.setLastOperator(std::move(scan));
}

void JoinOrderEnumerator::appendIndexScanNode(shared_ptr<NodeExpression>& node,
shared_ptr<Expression> indexExpression, LogicalPlan& plan) {
void JoinOrderEnumerator::appendIndexScanNode(
shared_ptr<NodeExpression>& node, shared_ptr<Expression> indexExpression, LogicalPlan& plan) {
assert(plan.isEmpty());
auto schema = plan.getSchema();
auto scan = make_shared<LogicalIndexScanNode>(node, std::move(indexExpression));
Expand All @@ -548,8 +548,8 @@ void JoinOrderEnumerator::appendIndexScanNode(shared_ptr<NodeExpression>& node,
plan.setLastOperator(std::move(scan));
}

bool JoinOrderEnumerator::needExtendToNewGroup(RelExpression& rel, NodeExpression& boundNode,
RelDirection direction) {
bool JoinOrderEnumerator::needExtendToNewGroup(
RelExpression& rel, NodeExpression& boundNode, RelDirection direction) {
auto extendToNewGroup = false;
extendToNewGroup |= rel.getNumTableIDs() > 1;
if (rel.getNumTableIDs() == 1) {
Expand All @@ -560,14 +560,15 @@ bool JoinOrderEnumerator::needExtendToNewGroup(RelExpression& rel, NodeExpressio
return extendToNewGroup;
}

bool JoinOrderEnumerator::needFlatInput(RelExpression& rel, NodeExpression& boundNode, RelDirection direction) {
bool JoinOrderEnumerator::needFlatInput(
RelExpression& rel, NodeExpression& boundNode, RelDirection direction) {
auto needFlatInput = needExtendToNewGroup(rel, boundNode, direction);
needFlatInput |= rel.isVariableLength();
return needFlatInput;
}

void JoinOrderEnumerator::appendExtend(shared_ptr<RelExpression>& rel, RelDirection direction,
LogicalPlan& plan) {
void JoinOrderEnumerator::appendExtend(
shared_ptr<RelExpression>& rel, RelDirection direction, LogicalPlan& plan) {
auto schema = plan.getSchema();
auto boundNode = FWD == direction ? rel->getSrcNode() : rel->getDstNode();
if (boundNode->getNumTableIDs() > 1) {
Expand All @@ -578,8 +579,8 @@ void JoinOrderEnumerator::appendExtend(shared_ptr<RelExpression>& rel, RelDirect
if (needFlatInput(*rel, *boundNode, direction)) {
QueryPlanner::appendFlattenIfNecessary(boundNode->getInternalIDProperty(), plan);
}
auto extend = make_shared<LogicalExtend>(boundNode, nbrNode, rel, direction, extendToNewGroup,
plan.getLastOperator());
auto extend = make_shared<LogicalExtend>(
boundNode, nbrNode, rel, direction, extendToNewGroup, plan.getLastOperator());
extend->computeSchema(*schema);
plan.setLastOperator(std::move(extend));
// update cardinality estimation info
Expand Down Expand Up @@ -769,8 +770,8 @@ void JoinOrderEnumerator::appendCrossProduct(LogicalPlan& probePlan, LogicalPlan
probePlan.setLastOperator(std::move(crossProduct));
}

expression_vector JoinOrderEnumerator::getPropertiesForVariable(Expression& expression,
Expression& variable) {
expression_vector JoinOrderEnumerator::getPropertiesForVariable(
Expression& expression, Expression& variable) {
expression_vector result;
for (auto& propertyExpression : expression.getSubPropertyExpressions()) {
if (propertyExpression->getChild(0)->getUniqueName() != variable.getUniqueName()) {
Expand All @@ -781,8 +782,8 @@ expression_vector JoinOrderEnumerator::getPropertiesForVariable(Expression& expr
return result;
}

uint64_t JoinOrderEnumerator::getExtensionRate(const RelExpression& rel,
const NodeExpression& boundNode, RelDirection direction) {
uint64_t JoinOrderEnumerator::getExtensionRate(
const RelExpression& rel, const NodeExpression& boundNode, RelDirection direction) {
auto boundNodeTableID = boundNode.getTableID();
double numBoundNodes =
nodesStatistics.getNodeStatisticsAndDeletedIDs(boundNodeTableID)->getNumTuples();
Expand Down
23 changes: 11 additions & 12 deletions test/runner/e2e_ddl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,26 +172,25 @@ class TinySnbDDLTest : public DBTest {
auto result = conn->query("MATCH (:person)-[:belongs]->(:organisation) RETURN count(*)");
ASSERT_TRUE(result->isSuccess());
ASSERT_EQ(TestHelper::convertResultToString(*result), vector<string>{"2"});
result = conn->query("MATCH (:person)-[:belongs]->(:country) RETURN count(*)");
result = conn->query("MATCH (a:person)-[e:belongs]->(b:country) RETURN count(*)");
ASSERT_FALSE(result->isSuccess());
ASSERT_EQ(result->getErrorMessage(), "Binder exception: Node table person doesn't connect "
"to country through rel table belongs.");
ASSERT_EQ(result->getErrorMessage(),
"Binder exception: Nodes a and b are not connected through rel e.");
result = conn->query("MATCH (:organisation)-[:belongs]->(:country) RETURN count(*)");
ASSERT_TRUE(result->isSuccess());
ASSERT_EQ(TestHelper::convertResultToString(*result), vector<string>{"1"});
result = conn->query("MATCH (:organisation)-[:belongs]->(:person) RETURN count(*)");
result = conn->query("MATCH (a:organisation)-[e:belongs]->(b:person) RETURN count(*)");
ASSERT_FALSE(result->isSuccess());
ASSERT_EQ(result->getErrorMessage(),
"Binder exception: Node table organisation doesn't connect "
"to person through rel table belongs.");
result = conn->query("MATCH (:country)-[:belongs]->(:person) RETURN count(*)");
"Binder exception: Nodes a and b are not connected through rel e.");
result = conn->query("MATCH (a:country)-[e:belongs]->(b:person) RETURN count(*)");
ASSERT_FALSE(result->isSuccess());
ASSERT_EQ(result->getErrorMessage(), "Binder exception: Node table country doesn't connect "
"to person through rel table belongs.");
result = conn->query("MATCH (:country)-[:belongs]->(:organisation) RETURN count(*)");
ASSERT_EQ(result->getErrorMessage(),
"Binder exception: Nodes a and b are not connected through rel e.");
result = conn->query("MATCH (a:country)-[e:belongs]->(b:organisation) RETURN count(*)");
ASSERT_FALSE(result->isSuccess());
ASSERT_EQ(result->getErrorMessage(), "Binder exception: Node table country doesn't connect "
"to organisation through rel table belongs.");
ASSERT_EQ(result->getErrorMessage(),
"Binder exception: Nodes a and b are not connected through rel e.");
}

void createRelMixedRelationCommitAndRecoveryTest(TransactionTestType transactionTestType) {
Expand Down
Loading

0 comments on commit 53d0518

Please sign in to comment.