Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Optional Match as the First Clause #1849

Merged
merged 1 commit into from
Jul 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/binder/bind/bind_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ std::unique_ptr<BoundRegularQuery> Binder::bindQuery(const RegularQuery& regular
}

std::unique_ptr<BoundSingleQuery> Binder::bindSingleQuery(const SingleQuery& singleQuery) {
validateFirstMatchIsNotOptional(singleQuery);
auto boundSingleQuery = std::make_unique<BoundSingleQuery>();
for (auto i = 0u; i < singleQuery.getNumQueryParts(); ++i) {
boundSingleQuery->addQueryPart(bindQueryPart(*singleQuery.getQueryPart(i)));
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause&
auto& matchClause = reinterpret_cast<const MatchClause&>(readingClause);
auto [queryGraphCollection, propertyCollection] =
bindGraphPattern(matchClause.getPatternElements());
auto boundMatchClause =
make_unique<BoundMatchClause>(std::move(queryGraphCollection), matchClause.getIsOptional());
auto boundMatchClause = make_unique<BoundMatchClause>(
std::move(queryGraphCollection), matchClause.getMatchClauseType());
std::shared_ptr<Expression> whereExpression;
if (matchClause.hasWhereClause()) {
whereExpression = bindWhereExpression(*matchClause.getWhereClause());
Expand Down
6 changes: 0 additions & 6 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,6 @@ std::shared_ptr<Expression> Binder::createVariable(
return expression;
}

void Binder::validateFirstMatchIsNotOptional(const SingleQuery& singleQuery) {
if (singleQuery.isFirstReadingClauseOptionalMatch()) {
throw BinderException("First match clause cannot be optional match.");
}
}

void Binder::validateProjectionColumnNamesAreUnique(const expression_vector& expressions) {
auto existColumnNames = std::unordered_set<std::string>();
for (auto& expression : expressions) {
Expand Down
5 changes: 0 additions & 5 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,6 @@ class Binder {
std::vector<common::table_id_t> bindRelTableIDs(const std::vector<std::string>& tableNames);

/*** validations ***/
// E.g. Optional MATCH (a) RETURN a.age
// Although this is doable in Neo4j, I don't think the semantic make a lot of sense because
// there is nothing to left join on.
static void validateFirstMatchIsNotOptional(const parser::SingleQuery& singleQuery);

// E.g. ... RETURN a, b AS a
static void validateProjectionColumnNamesAreUnique(const expression_vector& expressions);

Expand Down
16 changes: 6 additions & 10 deletions src/include/binder/query/reading_clause/bound_match_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@
namespace kuzu {
namespace binder {

/**
* BoundMatchClause may not have whereExpression
*/
class BoundMatchClause : public BoundReadingClause {

public:
explicit BoundMatchClause(
std::unique_ptr<QueryGraphCollection> queryGraphCollection, bool isOptional)
explicit BoundMatchClause(std::unique_ptr<QueryGraphCollection> queryGraphCollection,
common::MatchClauseType matchClauseType)
: BoundReadingClause{common::ClauseType::MATCH},
queryGraphCollection{std::move(queryGraphCollection)}, isOptional{isOptional} {}
queryGraphCollection{std::move(queryGraphCollection)}, matchClauseType{matchClauseType} {}

BoundMatchClause(const BoundMatchClause& other)
: BoundReadingClause(common::ClauseType::MATCH),
queryGraphCollection{other.queryGraphCollection->copy()},
whereExpression(other.whereExpression), isOptional{other.isOptional} {}
whereExpression(other.whereExpression), matchClauseType{other.matchClauseType} {}

~BoundMatchClause() override = default;

Expand All @@ -38,7 +34,7 @@ class BoundMatchClause : public BoundReadingClause {
return hasWhereExpression() ? whereExpression->splitOnAND() : expression_vector{};
}

inline bool getIsOptional() const { return isOptional; }
inline common::MatchClauseType getMatchClauseType() const { return matchClauseType; }

inline std::unique_ptr<BoundReadingClause> copy() override {
return std::make_unique<BoundMatchClause>(*this);
Expand All @@ -47,7 +43,7 @@ class BoundMatchClause : public BoundReadingClause {
private:
std::unique_ptr<QueryGraphCollection> queryGraphCollection;
std::shared_ptr<Expression> whereExpression;
bool isOptional;
common::MatchClauseType matchClauseType;
};

} // namespace binder
Expand Down
5 changes: 5 additions & 0 deletions src/include/common/clause_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,10 @@ enum class ClauseType : uint8_t {
InQueryCall = 5,
};

enum class MatchClauseType : uint8_t {
MATCH = 0,
OPTIONAL_MATCH = 1,
};

} // namespace common
} // namespace kuzu
6 changes: 6 additions & 0 deletions src/include/common/join_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,11 @@ enum class JoinType : uint8_t {
MARK = 2,
};

enum class AccumulateType : uint8_t {
REGULAR = 0,
OPTIONAL_ = 1,
EXISTS = 2,
};

} // namespace common
} // namespace kuzu
8 changes: 6 additions & 2 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ValueVector {
friend class StructVector;
friend class StringVector;
friend class ArrowColumnVector;
friend class ValueVectorUtils;

public:
explicit ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager = nullptr);
Expand All @@ -34,8 +35,7 @@ class ValueVector {
inline void setAllNull() { nullMask->setAllNull(); }
inline void setAllNonNull() { nullMask->setAllNonNull(); }
inline void setMayContainNulls() { nullMask->setMayContainNulls(); }
// Note that if this function returns true, there are no null. However, if it returns false, it
// doesn't mean there are nulls, i.e., there may or may not be nulls.
// On return true, there are no null. On return false, there may or may not be nulls.
inline bool hasNoNullsGuarantee() const { return nullMask->hasNoNullsGuarantee(); }
inline void setRangeNonNull(uint32_t startPos, uint32_t len) {
for (auto i = 0u; i < len; ++i) {
Expand All @@ -45,6 +45,10 @@ class ValueVector {
inline uint64_t* getNullMaskData() { return nullMask->getData(); }
inline void setNull(uint32_t pos, bool isNull) { nullMask->setNull(pos, isNull); }
inline uint8_t isNull(uint32_t pos) const { return nullMask->isNull(pos); }
inline void setAsSingleNullEntry() {
state->selVector->selectedSize = 1;
setNull(state->selVector->selectedPositions[0], true);
}

inline uint32_t getNumBytesPerValue() const { return numBytesPerValue; }

Expand Down
10 changes: 5 additions & 5 deletions src/include/parser/query/reading_clause/match_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ namespace parser {

class MatchClause : public ReadingClause {
public:
explicit MatchClause(
std::vector<std::unique_ptr<PatternElement>> patternElements, bool isOptional = false)
explicit MatchClause(std::vector<std::unique_ptr<PatternElement>> patternElements,
common::MatchClauseType matchClauseType)
: ReadingClause{common::ClauseType::MATCH}, patternElements{std::move(patternElements)},
isOptional{isOptional} {}
matchClauseType{matchClauseType} {}
~MatchClause() override = default;

inline const std::vector<std::unique_ptr<PatternElement>>& getPatternElements() const {
Expand All @@ -26,12 +26,12 @@ class MatchClause : public ReadingClause {
inline bool hasWhereClause() const { return whereClause != nullptr; }
inline ParsedExpression* getWhereClause() const { return whereClause.get(); }

inline bool getIsOptional() const { return isOptional; }
inline common::MatchClauseType getMatchClauseType() const { return matchClauseType; }

private:
std::vector<std::unique_ptr<PatternElement>> patternElements;
std::unique_ptr<ParsedExpression> whereClause;
bool isOptional;
common::MatchClauseType matchClauseType;
};

} // namespace parser
Expand Down
2 changes: 0 additions & 2 deletions src/include/parser/query/single_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class SingleQuery {
inline bool hasReturnClause() const { return returnClause != nullptr; }
inline ReturnClause* getReturnClause() const { return returnClause.get(); }

bool isFirstReadingClauseOptionalMatch() const;

private:
std::vector<std::unique_ptr<QueryPart>> queryParts;
std::vector<std::unique_ptr<ReadingClause>> readingClauses;
Expand Down
4 changes: 2 additions & 2 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class JoinOrderEnumerator {

inline void resetState() { context->resetState(); }

std::unique_ptr<JoinOrderEnumeratorContext> enterSubquery(
std::unique_ptr<JoinOrderEnumeratorContext> enterContext(
binder::expression_vector nodeIDsToScanFromInnerAndOuter);
void exitSubquery(std::unique_ptr<JoinOrderEnumeratorContext> prevContext);
void exitContext(std::unique_ptr<JoinOrderEnumeratorContext> prevContext);

inline void planMarkJoin(const binder::expression_vector& joinNodeIDs,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
#pragma once

#include "base_logical_operator.h"
#include "common/join_type.h"

namespace kuzu {
namespace planner {

class LogicalAccumulate : public LogicalOperator {
public:
LogicalAccumulate(std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)} {}
LogicalAccumulate(common::AccumulateType accumulateType, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)}, accumulateType{
accumulateType} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;
void computeFactorizedSchema() final;
void computeFlatSchema() final;

inline std::string getExpressionsForPrinting() const override { return std::string{}; }
inline std::string getExpressionsForPrinting() const final { return std::string{}; }

inline common::AccumulateType getAccumulateType() const { return accumulateType; }
inline binder::expression_vector getExpressions() const {
return children[0]->getSchema()->getExpressionsInScope();
}

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAccumulate>(children[0]->copy());
inline std::unique_ptr<LogicalOperator> copy() final {
return make_unique<LogicalAccumulate>(accumulateType, children[0]->copy());
}

private:
common::AccumulateType accumulateType;
};

} // namespace planner
Expand Down
15 changes: 10 additions & 5 deletions src/include/planner/query_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,20 @@ class QueryPlanner {
void planInQueryCall(
BoundReadingClause* boundReadingClause, std::vector<std::unique_ptr<LogicalPlan>>& plans);

// CTE & subquery planning
void planOptionalMatch(const QueryGraphCollection& queryGraphCollection,
expression_vector& predicates, LogicalPlan& outerPlan);
const expression_vector& predicates, LogicalPlan& leftPlan);
void planRegularMatch(const QueryGraphCollection& queryGraphCollection,
expression_vector& predicates, LogicalPlan& prevPlan);
void planExistsSubquery(std::shared_ptr<Expression>& subquery, LogicalPlan& outerPlan);
const expression_vector& predicates, LogicalPlan& leftPlan);
void planExistsSubquery(std::shared_ptr<Expression> subquery, LogicalPlan& outerPlan);
void planSubqueryIfNecessary(const std::shared_ptr<Expression>& expression, LogicalPlan& plan);

void appendAccumulate(LogicalPlan& plan);
std::unique_ptr<LogicalPlan> planJoins(
const QueryGraphCollection& queryGraphCollection, const expression_vector& predicates);
std::unique_ptr<LogicalPlan> planJoinsInNewContext(
const expression_vector& expressionsToExcludeScan,
const QueryGraphCollection& queryGraphCollection, const expression_vector& predicates);

void appendAccumulate(common::AccumulateType accumulateType, LogicalPlan& plan);

void appendExpressionsScan(const expression_vector& expressions, LogicalPlan& plan);

Expand Down
2 changes: 1 addition & 1 deletion src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class PlanMapper {
std::unique_ptr<PhysicalOperator> mapExplain(planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapCreateMacro(planner::LogicalOperator* logicalOperator);

std::unique_ptr<ResultCollector> createResultCollector(
std::unique_ptr<ResultCollector> createResultCollector(common::AccumulateType accumulateType,
const binder::expression_vector& expressions, planner::Schema* schema,
std::unique_ptr<PhysicalOperator> prevOperator);
std::unique_ptr<PhysicalOperator> createFactorizedTableScan(
Expand Down
2 changes: 0 additions & 2 deletions src/include/processor/operator/hash_join/hash_join_probe.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ class HashJoinProbe : public PhysicalOperator, public SelVectorOverWriter {
uint64_t getMarkJoinResult();
uint64_t getJoinResult();

void setVectorsToNull();

private:
std::shared_ptr<HashJoinSharedState> sharedState;
common::JoinType joinType;
Expand Down
19 changes: 13 additions & 6 deletions src/include/processor/operator/result_collector.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "common/join_type.h"
#include "processor/operator/sink.h"
#include "processor/result/factorized_table.h"

Expand All @@ -24,14 +25,17 @@ class ResultCollectorSharedState {
};

struct ResultCollectorInfo {
common::AccumulateType accumulateType;
std::unique_ptr<FactorizedTableSchema> tableSchema;
std::vector<DataPos> payloadPositions;

ResultCollectorInfo(
ResultCollectorInfo(common::AccumulateType accumulateType,
std::unique_ptr<FactorizedTableSchema> tableSchema, std::vector<DataPos> payloadPositions)
: tableSchema{std::move(tableSchema)}, payloadPositions{std::move(payloadPositions)} {}
: accumulateType{accumulateType}, tableSchema{std::move(tableSchema)},
payloadPositions{std::move(payloadPositions)} {}
ResultCollectorInfo(const ResultCollectorInfo& other)
: tableSchema{other.tableSchema->copy()}, payloadPositions{other.payloadPositions} {}
: accumulateType{other.accumulateType}, tableSchema{other.tableSchema->copy()},
payloadPositions{other.payloadPositions} {}

inline std::unique_ptr<ResultCollectorInfo> copy() const {
return std::make_unique<ResultCollectorInfo>(*this);
Expand All @@ -48,19 +52,22 @@ class ResultCollector : public Sink {
std::move(child), id, paramsString},
info{std::move(info)}, sharedState{std::move(sharedState)} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;
void executeInternal(ExecutionContext* context) final;

void executeInternal(ExecutionContext* context) override;
void finalize(ExecutionContext* context) final;

inline std::shared_ptr<FactorizedTable> getResultFactorizedTable() {
return sharedState->getTable();
}

std::unique_ptr<PhysicalOperator> clone() override {
std::unique_ptr<PhysicalOperator> clone() final {
return make_unique<ResultCollector>(resultSetDescriptor->copy(), info->copy(), sharedState,
children[0]->clone(), id, paramsString);
}

private:
void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) final;

private:
std::unique_ptr<ResultCollectorInfo> info;
std::shared_ptr<ResultCollectorSharedState> sharedState;
Expand Down
3 changes: 2 additions & 1 deletion src/optimizer/acc_hash_join_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ std::shared_ptr<planner::LogicalOperator> HashJoinSIPOptimizer::appendPathSemiMa

std::shared_ptr<planner::LogicalOperator> HashJoinSIPOptimizer::appendAccumulate(
std::shared_ptr<planner::LogicalOperator> child) {
auto accumulate = std::make_shared<LogicalAccumulate>(std::move(child));
auto accumulate =
std::make_shared<LogicalAccumulate>(common::AccumulateType::REGULAR, std::move(child));
accumulate->computeFlatSchema();
return accumulate;
}
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace kuzu {
namespace optimizer {

void Optimizer::optimize(planner::LogicalPlan* plan) {

// Factorization structure should be removed before further optimization can be applied.
auto removeFactorizationRewriter = RemoveFactorizationRewriter();
removeFactorizationRewriter.rewrite(plan);
Expand Down
3 changes: 3 additions & 0 deletions src/optimizer/projection_push_down_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ void ProjectionPushDownOptimizer::visitExtend(planner::LogicalOperator* op) {

void ProjectionPushDownOptimizer::visitAccumulate(planner::LogicalOperator* op) {
auto accumulate = (LogicalAccumulate*)op;
if (accumulate->getAccumulateType() != common::AccumulateType::REGULAR) {
return;
}
auto expressionsBeforePruning = accumulate->getExpressions();
auto expressionsAfterPruning = pruneExpressions(expressionsBeforePruning);
if (expressionsBeforePruning.size() == expressionsAfterPruning.size()) {
Expand Down
23 changes: 0 additions & 23 deletions src/parser/query/single_query.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +0,0 @@
#include "parser/query/single_query.h"

namespace kuzu {
namespace parser {

bool SingleQuery::isFirstReadingClauseOptionalMatch() const {
for (auto& queryPart : queryParts) {
if (queryPart->getNumReadingClauses() != 0 &&
queryPart->getReadingClause(0)->getClauseType() == common::ClauseType::MATCH) {
return ((MatchClause*)queryPart->getReadingClause(0))->getIsOptional();
} else {
return false;
}
}
if (getNumReadingClauses() != 0 &&
getReadingClause(0)->getClauseType() == common::ClauseType::MATCH) {
return ((MatchClause*)getReadingClause(0))->getIsOptional();
}
return false;
}

} // namespace parser
} // namespace kuzu
4 changes: 3 additions & 1 deletion src/parser/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,10 @@ std::unique_ptr<ReadingClause> Transformer::transformReadingClause(
}

std::unique_ptr<ReadingClause> Transformer::transformMatch(CypherParser::OC_MatchContext& ctx) {
auto matchClauseType =
ctx.OPTIONAL() ? common::MatchClauseType::OPTIONAL_MATCH : common::MatchClauseType::MATCH;
auto matchClause =
std::make_unique<MatchClause>(transformPattern(*ctx.oC_Pattern()), ctx.OPTIONAL());
std::make_unique<MatchClause>(transformPattern(*ctx.oC_Pattern()), matchClauseType);
if (ctx.oC_Where()) {
matchClause->setWhereClause(transformWhere(*ctx.oC_Where()));
}
Expand Down
1 change: 1 addition & 0 deletions src/planner/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_subdirectory(join_order)
add_subdirectory(operator)
add_subdirectory(plan_operator)

add_library(kuzu_planner
OBJECT
Expand Down
Loading
Loading