Skip to content

Commit

Permalink
Merge pull request #2429 from kuzudb/count-subquery
Browse files Browse the repository at this point in the history
Count subquery
  • Loading branch information
andyfengHKU committed Nov 17, 2023
2 parents cd61648 + 6ee73d2 commit f7779ac
Show file tree
Hide file tree
Showing 28 changed files with 3,463 additions and 2,941 deletions.
21 changes: 16 additions & 5 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,9 @@ oC_Atom
| oC_CaseExpression
| oC_ParenthesizedExpression
| oC_FunctionInvocation
| oC_ExistentialSubquery
| oC_PathPatterns
| oC_ExistSubquery
| kU_CountSubquery
| oC_Variable
;

Expand Down Expand Up @@ -588,17 +590,25 @@ oC_ParenthesizedExpression
: '(' SP? oC_Expression SP? ')' ;

oC_FunctionInvocation
: oC_FunctionName SP? '(' SP? '*' SP? ')'
: COUNT SP? '(' SP? '*' SP? ')'
| oC_FunctionName SP? '(' SP? ( DISTINCT SP? )? ( kU_FunctionParameter SP? ( ',' SP? kU_FunctionParameter SP? )* )? ')' ;

COUNT : ( 'C' | 'c' ) ( 'O' | 'o' ) ( 'U' | 'u' ) ( 'N' | 'n' ) ( 'T' | 't' ) ;

oC_FunctionName
: oC_SymbolicName ;

kU_FunctionParameter
: ( oC_SymbolicName SP? ':' '=' SP? )? oC_Expression ;

oC_ExistentialSubquery
: EXISTS SP? '{' SP? MATCH SP? oC_Pattern ( SP? oC_Where )? SP? '}' ;
oC_PathPatterns
: oC_NodePattern ( SP? oC_PatternElementChain )+;

oC_ExistSubquery
: EXISTS SP? '{' SP? MATCH SP? oC_Pattern ( SP? oC_Where )? SP? '}' ;

kU_CountSubquery
: COUNT SP? '{' SP? MATCH SP? oC_Pattern ( SP? oC_Where )? SP? '}' ;

EXISTS : ( 'E' | 'e' ) ( 'X' | 'x' ) ( 'I' | 'i' ) ( 'S' | 's' ) ( 'T' | 't' ) ( 'S' | 's' ) ;

Expand Down Expand Up @@ -709,7 +719,8 @@ oC_SymbolicName
;

kU_NonReservedKeywords
: COMMENT ;
: COMMENT
| COUNT ;

UnescapedSymbolicName
: IdentifierStart ( IdentifierPart )* ;
Expand Down
58 changes: 46 additions & 12 deletions src/binder/bind_expression/bind_subquery_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,28 +1,62 @@
#include "binder/binder.h"
#include "binder/expression/existential_subquery_expression.h"
#include "binder/expression/function_expression.h"
#include "binder/expression/subquery_expression.h"
#include "binder/expression_binder.h"
#include "parser/expression/parsed_subquery_expression.h"

using namespace kuzu::parser;
using namespace kuzu::common;
using namespace kuzu::function;

namespace kuzu {
namespace binder {

std::shared_ptr<Expression> ExpressionBinder::bindExistentialSubqueryExpression(
const ParsedExpression& parsedExpression) {
auto& subqueryExpression = (ParsedSubqueryExpression&)parsedExpression;
std::shared_ptr<Expression> ExpressionBinder::bindSubqueryExpression(
const ParsedExpression& parsedExpr) {
auto& subqueryExpr = reinterpret_cast<const ParsedSubqueryExpression&>(parsedExpr);
auto prevScope = binder->saveScope();
auto [queryGraph, _] = binder->bindGraphPattern(subqueryExpression.getPatternElements());
auto rawName = parsedExpression.getRawName();
auto [queryGraph, _] = binder->bindGraphPattern(subqueryExpr.getPatternElements());
auto subqueryType = subqueryExpr.getSubqueryType();
auto dataType =
subqueryType == SubqueryType::COUNT ? LogicalType::INT64() : LogicalType::BOOL();
auto rawName = subqueryExpr.getRawName();
auto uniqueName = binder->getUniqueExpressionName(rawName);
auto boundSubqueryExpression = make_shared<ExistentialSubqueryExpression>(
std::move(queryGraph), std::move(uniqueName), std::move(rawName));
if (subqueryExpression.hasWhereClause()) {
boundSubqueryExpression->setWhereExpression(
binder->bindWhereExpression(*subqueryExpression.getWhereClause()));
auto boundSubqueryExpr = make_shared<SubqueryExpression>(
subqueryType, *dataType, std::move(queryGraph), uniqueName, std::move(rawName));
// Bind predicate
if (subqueryExpr.hasWhereClause()) {
auto where = binder->bindWhereExpression(*subqueryExpr.getWhereClause());
boundSubqueryExpr->setWhereExpression(std::move(where));
}
// Bind projection
auto function = binder->catalog.getBuiltInFunctions()->matchAggregateFunction(
COUNT_STAR_FUNC_NAME, std::vector<LogicalType*>{}, false);
auto bindData = std::make_unique<FunctionBindData>(LogicalType(function->returnTypeID));
auto countStarExpr = std::make_shared<AggregateFunctionExpression>(COUNT_STAR_FUNC_NAME,
std::move(bindData), expression_vector{}, function->copy(),
binder->getUniqueExpressionName(COUNT_STAR_FUNC_NAME));
boundSubqueryExpr->setCountStarExpr(countStarExpr);
std::shared_ptr<Expression> projectionExpr;
switch (subqueryType) {
case SubqueryType::COUNT: {
// Rewrite COUNT subquery as COUNT(*)
projectionExpr = countStarExpr;
} break;
case SubqueryType::EXISTS: {
// Rewrite EXISTS subquery as COUNT(*) > 0
auto literalExpr = createLiteralExpression(std::make_unique<Value>((int64_t)0));
projectionExpr = bindComparisonExpression(
ExpressionType::GREATER_THAN, expression_vector{countStarExpr, literalExpr});
} break;
default:
KU_UNREACHABLE;
}
// Use the same unique identifier for projection & subquery expression. We will replace subquery
// expression with projection expression during processing.
projectionExpr->setUniqueName(uniqueName);
boundSubqueryExpr->setProjectionExpr(projectionExpr);
binder->restoreScope(std::move(prevScope));
return boundSubqueryExpression;
return boundSubqueryExpr;
}

} // namespace binder
Expand Down
4 changes: 2 additions & 2 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindExpression(
expression = bindLiteralExpression(parsedExpression);
} else if (ExpressionType::VARIABLE == expressionType) {
expression = bindVariableExpression(parsedExpression);
} else if (ExpressionType::EXISTENTIAL_SUBQUERY == expressionType) {
expression = bindExistentialSubqueryExpression(parsedExpression);
} else if (ExpressionType::SUBQUERY == expressionType) {
expression = bindSubqueryExpression(parsedExpression);
} else if (ExpressionType::CASE_ELSE == expressionType) {
expression = bindCaseExpression(parsedExpression);
} else {
Expand Down
10 changes: 5 additions & 5 deletions src/binder/expression_visitor.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "binder/expression_visitor.h"

#include "binder/expression/case_expression.h"
#include "binder/expression/existential_subquery_expression.h"
#include "binder/expression/node_expression.h"
#include "binder/expression/property_expression.h"
#include "binder/expression/rel_expression.h"
#include "binder/expression/subquery_expression.h"

using namespace kuzu::common;

Expand All @@ -16,8 +16,8 @@ expression_vector ExpressionChildrenCollector::collectChildren(const Expression&
case ExpressionType::CASE_ELSE: {
return collectCaseChildren(expression);
}
case ExpressionType::EXISTENTIAL_SUBQUERY: {
return collectExistentialSubqueryChildren(expression);
case ExpressionType::SUBQUERY: {
return collectSubqueryChildren(expression);
}
case ExpressionType::PATTERN: {
switch (expression.dataType.getLogicalTypeID()) {
Expand Down Expand Up @@ -50,10 +50,10 @@ expression_vector ExpressionChildrenCollector::collectCaseChildren(const Express
return result;
}

expression_vector ExpressionChildrenCollector::collectExistentialSubqueryChildren(
expression_vector ExpressionChildrenCollector::collectSubqueryChildren(
const Expression& expression) {
expression_vector result;
auto& subqueryExpression = (ExistentialSubqueryExpression&)expression;
auto& subqueryExpression = (SubqueryExpression&)expression;
for (auto& node : subqueryExpression.getQueryGraphCollection()->getQueryNodes()) {
result.push_back(node->getInternalID());
}
Expand Down
6 changes: 3 additions & 3 deletions src/common/expression_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ bool isExpressionAggregate(ExpressionType type) {
}

bool isExpressionSubquery(ExpressionType type) {
return ExpressionType::EXISTENTIAL_SUBQUERY == type;
return ExpressionType::SUBQUERY == type;
}

// LCOV_EXCL_START
Expand Down Expand Up @@ -87,8 +87,8 @@ std::string expressionTypeToString(ExpressionType type) {
return "SCALAR_FUNCTION";
case ExpressionType::AGGREGATE_FUNCTION:
return "AGGREGATE_FUNCTION";
case ExpressionType::EXISTENTIAL_SUBQUERY:
return "EXISTENTIAL_SUBQUERY";
case ExpressionType::SUBQUERY:
return "SUBQUERY";
default:
KU_UNREACHABLE;
}
Expand Down
1 change: 1 addition & 0 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Expression : public std::enable_shared_from_this<Expression> {

inline void setAlias(const std::string& name) { alias = name; }

inline void setUniqueName(const std::string& name) { uniqueName = name; }
inline std::string getUniqueName() const {
KU_ASSERT(!uniqueName.empty());
return uniqueName;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
#pragma once

#include "binder/query/query_graph.h"
#include "common/enums/subquery_type.h"
#include "expression.h"

namespace kuzu {
namespace binder {

class ExistentialSubqueryExpression : public Expression {
class SubqueryExpression : public Expression {
public:
ExistentialSubqueryExpression(std::unique_ptr<QueryGraphCollection> queryGraphCollection,
std::string uniqueName, std::string rawName)
: Expression{common::ExpressionType::EXISTENTIAL_SUBQUERY,
common::LogicalType(common::LogicalTypeID::BOOL), std::move(uniqueName)},
SubqueryExpression(common::SubqueryType subqueryType, common::LogicalType dataType,
std::unique_ptr<QueryGraphCollection> queryGraphCollection, std::string uniqueName,
std::string rawName)
: Expression{common::ExpressionType::SUBQUERY, dataType, std::move(uniqueName)},
subqueryType{subqueryType},
queryGraphCollection{std::move(queryGraphCollection)}, rawName{std::move(rawName)} {}

inline common::SubqueryType getSubqueryType() const { return subqueryType; }

inline QueryGraphCollection* getQueryGraphCollection() const {
return queryGraphCollection.get();
}
Expand All @@ -27,11 +31,23 @@ class ExistentialSubqueryExpression : public Expression {
return hasWhereExpression() ? whereExpression->splitOnAND() : expression_vector{};
}

inline void setCountStarExpr(std::shared_ptr<Expression> expr) {
countStarExpr = std::move(expr);
}
inline std::shared_ptr<Expression> getCountStarExpr() const { return countStarExpr; }
inline void setProjectionExpr(std::shared_ptr<Expression> expr) {
projectionExpr = std::move(expr);
}
inline std::shared_ptr<Expression> getProjectionExpr() const { return projectionExpr; }

std::string toStringInternal() const final { return rawName; }

private:
common::SubqueryType subqueryType;
std::unique_ptr<QueryGraphCollection> queryGraphCollection;
std::shared_ptr<Expression> whereExpression;
std::shared_ptr<Expression> countStarExpr;
std::shared_ptr<Expression> projectionExpr;
std::string rawName;
};

Expand Down
2 changes: 1 addition & 1 deletion src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ExpressionBinder {
std::shared_ptr<Expression> createVariableExpression(
common::LogicalType logicalType, std::string name);
// Subquery expressions.
std::shared_ptr<Expression> bindExistentialSubqueryExpression(
std::shared_ptr<Expression> bindSubqueryExpression(
const parser::ParsedExpression& parsedExpression);
// Case expressions.
std::shared_ptr<Expression> bindCaseExpression(
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/expression_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ExpressionChildrenCollector {
private:
static expression_vector collectCaseChildren(const Expression& expression);

static expression_vector collectExistentialSubqueryChildren(const Expression& expression);
static expression_vector collectSubqueryChildren(const Expression& expression);

static expression_vector collectNodeChildren(const Expression& expression);

Expand Down Expand Up @@ -61,7 +61,7 @@ class ExpressionCollector {
const std::shared_ptr<Expression>& expression) {
KU_ASSERT(expressions.empty());
collectExpressionsInternal(expression, [&](const Expression& expression) {
return expression.expressionType == common::ExpressionType::EXISTENTIAL_SUBQUERY;
return expression.expressionType == common::ExpressionType::SUBQUERY;
});
return expressions;
}
Expand Down
2 changes: 1 addition & 1 deletion src/include/common/enums/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ enum class ExpressionType : uint8_t {

AGGREGATE_FUNCTION = 130,

EXISTENTIAL_SUBQUERY = 190,
SUBQUERY = 190,

CASE_ELSE = 200,

Expand Down
2 changes: 1 addition & 1 deletion src/include/common/enums/join_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ enum class JoinType : uint8_t {
INNER = 0,
LEFT = 1,
MARK = 2,
COUNT = 3,
};

enum class AccumulateType : uint8_t {
REGULAR = 0,
OPTIONAL_ = 1,
EXISTS = 2,
};

} // namespace common
Expand Down
14 changes: 14 additions & 0 deletions src/include/common/enums/subquery_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once

#include <cstdint>

namespace kuzu {
namespace common {

enum class SubqueryType : uint8_t {
COUNT = 1,
EXISTS = 2,
};

}
} // namespace kuzu
24 changes: 13 additions & 11 deletions src/include/parser/expression/parsed_subquery_expression.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "common/assert.h"
#include "common/enums/subquery_type.h"
#include "parsed_expression.h"
#include "parser/query/graph_pattern/pattern_element.h"

Expand All @@ -9,18 +10,18 @@ namespace parser {

class ParsedSubqueryExpression : public ParsedExpression {
public:
ParsedSubqueryExpression(
std::vector<std::unique_ptr<PatternElement>> patternElements, std::string rawName)
: ParsedExpression{common::ExpressionType::EXISTENTIAL_SUBQUERY, std::move(rawName)},
patternElements{std::move(patternElements)} {}

ParsedSubqueryExpression(common::ExpressionType type, std::string alias, std::string rawName,
parsed_expression_vector children,
std::vector<std::unique_ptr<PatternElement>> patternElements,
std::unique_ptr<ParsedExpression> whereClause)
: ParsedExpression{type, std::move(alias), std::move(rawName), std::move(children)},
patternElements{std::move(patternElements)}, whereClause{std::move(whereClause)} {}
ParsedSubqueryExpression(common::SubqueryType subqueryType, std::string rawName)
: ParsedExpression{common::ExpressionType::SUBQUERY, std::move(rawName)},
subqueryType{subqueryType} {}

inline common::SubqueryType getSubqueryType() const { return subqueryType; }

inline void addPatternElement(std::unique_ptr<PatternElement> element) {
patternElements.push_back(std::move(element));
}
inline void setPatternElements(std::vector<std::unique_ptr<PatternElement>> elements) {
patternElements = std::move(elements);
}
inline const std::vector<std::unique_ptr<PatternElement>>& getPatternElements() const {
return patternElements;
}
Expand All @@ -42,6 +43,7 @@ class ParsedSubqueryExpression : public ParsedExpression {
void serializeInternal(common::Serializer& /*serializer*/) const override { KU_UNREACHABLE; }

private:
common::SubqueryType subqueryType;
std::vector<std::unique_ptr<PatternElement>> patternElements;
std::unique_ptr<ParsedExpression> whereClause;
};
Expand Down
8 changes: 6 additions & 2 deletions src/include/parser/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,12 @@ class Transformer {
std::unique_ptr<ParsedExpression> transformFunctionParameterExpression(
CypherParser::KU_FunctionParameterContext& ctx);

std::unique_ptr<ParsedExpression> transformExistentialSubquery(
CypherParser::OC_ExistentialSubqueryContext& ctx);
std::unique_ptr<ParsedExpression> transformPathPattern(
CypherParser::OC_PathPatternsContext& ctx);
std::unique_ptr<ParsedExpression> transformExistSubquery(
CypherParser::OC_ExistSubqueryContext& ctx);
std::unique_ptr<ParsedExpression> transformCountSubquery(
CypherParser::KU_CountSubqueryContext& ctx);

std::unique_ptr<ParsedExpression> createPropertyExpression(
CypherParser::OC_PropertyLookupContext& ctx, std::unique_ptr<ParsedExpression> child);
Expand Down
2 changes: 1 addition & 1 deletion src/include/planner/query_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class QueryPlanner {
const binder::expression_vector& predicates, LogicalPlan& leftPlan);
void planRegularMatch(const binder::QueryGraphCollection& queryGraphCollection,
const binder::expression_vector& predicates, LogicalPlan& leftPlan);
void planExistsSubquery(std::shared_ptr<binder::Expression> subquery, LogicalPlan& outerPlan);
void planSubquery(std::shared_ptr<binder::Expression> subquery, LogicalPlan& outerPlan);
void planSubqueryIfNecessary(
const std::shared_ptr<binder::Expression>& expression, LogicalPlan& plan);

Expand Down
1 change: 1 addition & 0 deletions src/include/processor/operator/hash_join/hash_join_probe.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class HashJoinProbe : public PhysicalOperator, public SelVectorOverWriter {
uint64_t getInnerJoinResultForUnFlatKey();
uint64_t getLeftJoinResult();
uint64_t getMarkJoinResult();
uint64_t getCountJoinResult();
uint64_t getJoinResult();

private:
Expand Down
Loading

0 comments on commit f7779ac

Please sign in to comment.