Skip to content

Commit

Permalink
allow optional match as first match
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jul 23, 2023
1 parent 77c7829 commit 0643b92
Show file tree
Hide file tree
Showing 39 changed files with 211 additions and 162 deletions.
1 change: 0 additions & 1 deletion src/binder/bind/bind_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ std::unique_ptr<BoundRegularQuery> Binder::bindQuery(const RegularQuery& regular
}

std::unique_ptr<BoundSingleQuery> Binder::bindSingleQuery(const SingleQuery& singleQuery) {
validateFirstMatchIsNotOptional(singleQuery);
auto boundSingleQuery = std::make_unique<BoundSingleQuery>();
for (auto i = 0u; i < singleQuery.getNumQueryParts(); ++i) {
boundSingleQuery->addQueryPart(bindQueryPart(*singleQuery.getQueryPart(i)));
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause&
auto& matchClause = reinterpret_cast<const MatchClause&>(readingClause);
auto [queryGraphCollection, propertyCollection] =
bindGraphPattern(matchClause.getPatternElements());
auto boundMatchClause =
make_unique<BoundMatchClause>(std::move(queryGraphCollection), matchClause.getIsOptional());
auto boundMatchClause = make_unique<BoundMatchClause>(
std::move(queryGraphCollection), matchClause.getMatchClauseType());
std::shared_ptr<Expression> whereExpression;
if (matchClause.hasWhereClause()) {
whereExpression = bindWhereExpression(*matchClause.getWhereClause());
Expand Down
6 changes: 0 additions & 6 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,6 @@ std::shared_ptr<Expression> Binder::createVariable(
return expression;
}

void Binder::validateFirstMatchIsNotOptional(const SingleQuery& singleQuery) {
if (singleQuery.isFirstReadingClauseOptionalMatch()) {
throw BinderException("First match clause cannot be optional match.");
}
}

void Binder::validateProjectionColumnNamesAreUnique(const expression_vector& expressions) {
auto existColumnNames = std::unordered_set<std::string>();
for (auto& expression : expressions) {
Expand Down
5 changes: 0 additions & 5 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,6 @@ class Binder {
std::vector<common::table_id_t> bindRelTableIDs(const std::vector<std::string>& tableNames);

/*** validations ***/
// E.g. Optional MATCH (a) RETURN a.age
// Although this is doable in Neo4j, I don't think the semantic make a lot of sense because
// there is nothing to left join on.
static void validateFirstMatchIsNotOptional(const parser::SingleQuery& singleQuery);

// E.g. ... RETURN a, b AS a
static void validateProjectionColumnNamesAreUnique(const expression_vector& expressions);

Expand Down
16 changes: 6 additions & 10 deletions src/include/binder/query/reading_clause/bound_match_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@
namespace kuzu {
namespace binder {

/**
* BoundMatchClause may not have whereExpression
*/
class BoundMatchClause : public BoundReadingClause {

public:
explicit BoundMatchClause(
std::unique_ptr<QueryGraphCollection> queryGraphCollection, bool isOptional)
explicit BoundMatchClause(std::unique_ptr<QueryGraphCollection> queryGraphCollection,
common::MatchClauseType matchClauseType)
: BoundReadingClause{common::ClauseType::MATCH},
queryGraphCollection{std::move(queryGraphCollection)}, isOptional{isOptional} {}
queryGraphCollection{std::move(queryGraphCollection)}, matchClauseType{matchClauseType} {}

BoundMatchClause(const BoundMatchClause& other)
: BoundReadingClause(common::ClauseType::MATCH),
queryGraphCollection{other.queryGraphCollection->copy()},
whereExpression(other.whereExpression), isOptional{other.isOptional} {}
whereExpression(other.whereExpression), matchClauseType{other.matchClauseType} {}

~BoundMatchClause() override = default;

Expand All @@ -38,7 +34,7 @@ class BoundMatchClause : public BoundReadingClause {
return hasWhereExpression() ? whereExpression->splitOnAND() : expression_vector{};
}

inline bool getIsOptional() const { return isOptional; }
inline common::MatchClauseType getMatchClauseType() const { return matchClauseType; }

inline std::unique_ptr<BoundReadingClause> copy() override {
return std::make_unique<BoundMatchClause>(*this);
Expand All @@ -47,7 +43,7 @@ class BoundMatchClause : public BoundReadingClause {
private:
std::unique_ptr<QueryGraphCollection> queryGraphCollection;
std::shared_ptr<Expression> whereExpression;
bool isOptional;
common::MatchClauseType matchClauseType;
};

} // namespace binder
Expand Down
5 changes: 5 additions & 0 deletions src/include/common/clause_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,10 @@ enum class ClauseType : uint8_t {
InQueryCall = 5,
};

enum class MatchClauseType : uint8_t {
MATCH = 0,
OPTIONAL_MATCH = 1,
};

} // namespace common
} // namespace kuzu
6 changes: 6 additions & 0 deletions src/include/common/join_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,11 @@ enum class JoinType : uint8_t {
MARK = 2,
};

enum class AccumulateType : uint8_t {
REGULAR = 0,
OPTIONAL = 1,
EXISTS = 2,
};

} // namespace common
} // namespace kuzu
8 changes: 6 additions & 2 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ValueVector {
friend class StructVector;
friend class StringVector;
friend class ArrowColumnVector;
friend class ValueVectorUtils;

public:
explicit ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager = nullptr);
Expand All @@ -34,8 +35,7 @@ class ValueVector {
inline void setAllNull() { nullMask->setAllNull(); }
inline void setAllNonNull() { nullMask->setAllNonNull(); }
inline void setMayContainNulls() { nullMask->setMayContainNulls(); }
// Note that if this function returns true, there are no null. However, if it returns false, it
// doesn't mean there are nulls, i.e., there may or may not be nulls.
// On return true, there are no null. On return false, there may or may not be nulls.
inline bool hasNoNullsGuarantee() const { return nullMask->hasNoNullsGuarantee(); }
inline void setRangeNonNull(uint32_t startPos, uint32_t len) {
for (auto i = 0u; i < len; ++i) {
Expand All @@ -45,6 +45,10 @@ class ValueVector {
inline uint64_t* getNullMaskData() { return nullMask->getData(); }
inline void setNull(uint32_t pos, bool isNull) { nullMask->setNull(pos, isNull); }
inline uint8_t isNull(uint32_t pos) const { return nullMask->isNull(pos); }
inline void setAsSingleNullEntry() {
state->selVector->selectedSize = 1;
setNull(state->selVector->selectedPositions[0], true);
}

inline uint32_t getNumBytesPerValue() const { return numBytesPerValue; }

Expand Down
10 changes: 5 additions & 5 deletions src/include/parser/query/reading_clause/match_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ namespace parser {

class MatchClause : public ReadingClause {
public:
explicit MatchClause(
std::vector<std::unique_ptr<PatternElement>> patternElements, bool isOptional = false)
explicit MatchClause(std::vector<std::unique_ptr<PatternElement>> patternElements,
common::MatchClauseType matchClauseType)
: ReadingClause{common::ClauseType::MATCH}, patternElements{std::move(patternElements)},
isOptional{isOptional} {}
matchClauseType{matchClauseType} {}
~MatchClause() override = default;

inline const std::vector<std::unique_ptr<PatternElement>>& getPatternElements() const {
Expand All @@ -26,12 +26,12 @@ class MatchClause : public ReadingClause {
inline bool hasWhereClause() const { return whereClause != nullptr; }
inline ParsedExpression* getWhereClause() const { return whereClause.get(); }

inline bool getIsOptional() const { return isOptional; }
inline common::MatchClauseType getMatchClauseType() const { return matchClauseType; }

private:
std::vector<std::unique_ptr<PatternElement>> patternElements;
std::unique_ptr<ParsedExpression> whereClause;
bool isOptional;
common::MatchClauseType matchClauseType;
};

} // namespace parser
Expand Down
2 changes: 0 additions & 2 deletions src/include/parser/query/single_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class SingleQuery {
inline bool hasReturnClause() const { return returnClause != nullptr; }
inline ReturnClause* getReturnClause() const { return returnClause.get(); }

bool isFirstReadingClauseOptionalMatch() const;

private:
std::vector<std::unique_ptr<QueryPart>> queryParts;
std::vector<std::unique_ptr<ReadingClause>> readingClauses;
Expand Down
4 changes: 2 additions & 2 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class JoinOrderEnumerator {

inline void resetState() { context->resetState(); }

std::unique_ptr<JoinOrderEnumeratorContext> enterSubquery(
std::unique_ptr<JoinOrderEnumeratorContext> enterContext(
binder::expression_vector nodeIDsToScanFromInnerAndOuter);
void exitSubquery(std::unique_ptr<JoinOrderEnumeratorContext> prevContext);
void exitContext(std::unique_ptr<JoinOrderEnumeratorContext> prevContext);

inline void planMarkJoin(const binder::expression_vector& joinNodeIDs,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
#pragma once

#include "base_logical_operator.h"
#include "common/join_type.h"

namespace kuzu {
namespace planner {

class LogicalAccumulate : public LogicalOperator {
public:
LogicalAccumulate(std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)} {}
LogicalAccumulate(common::AccumulateType accumulateType, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)}, accumulateType{
accumulateType} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;
void computeFactorizedSchema() final;
void computeFlatSchema() final;

inline std::string getExpressionsForPrinting() const override { return std::string{}; }
inline std::string getExpressionsForPrinting() const final { return std::string{}; }

inline common::AccumulateType getAccumulateType() const { return accumulateType; }
inline binder::expression_vector getExpressions() const {
return children[0]->getSchema()->getExpressionsInScope();
}

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAccumulate>(children[0]->copy());
inline std::unique_ptr<LogicalOperator> copy() final {
return make_unique<LogicalAccumulate>(accumulateType, children[0]->copy());
}

private:
common::AccumulateType accumulateType;
};

} // namespace planner
Expand Down
15 changes: 10 additions & 5 deletions src/include/planner/query_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,20 @@ class QueryPlanner {
void planInQueryCall(
BoundReadingClause* boundReadingClause, std::vector<std::unique_ptr<LogicalPlan>>& plans);

// CTE & subquery planning
void planOptionalMatch(const QueryGraphCollection& queryGraphCollection,
expression_vector& predicates, LogicalPlan& outerPlan);
const expression_vector& predicates, LogicalPlan& leftPlan);
void planRegularMatch(const QueryGraphCollection& queryGraphCollection,
expression_vector& predicates, LogicalPlan& prevPlan);
void planExistsSubquery(std::shared_ptr<Expression>& subquery, LogicalPlan& outerPlan);
const expression_vector& predicates, LogicalPlan& leftPlan);
void planExistsSubquery(std::shared_ptr<Expression> subquery, LogicalPlan& outerPlan);
void planSubqueryIfNecessary(const std::shared_ptr<Expression>& expression, LogicalPlan& plan);

void appendAccumulate(LogicalPlan& plan);
std::unique_ptr<LogicalPlan> planJoins(
const QueryGraphCollection& queryGraphCollection, const expression_vector& predicates);
std::unique_ptr<LogicalPlan> planJoinsInNewContext(
const expression_vector& expressionsToExcludeScan,
const QueryGraphCollection& queryGraphCollection, const expression_vector& predicates);

void appendAccumulate(common::AccumulateType accumulateType, LogicalPlan& plan);

void appendExpressionsScan(const expression_vector& expressions, LogicalPlan& plan);

Expand Down
2 changes: 1 addition & 1 deletion src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class PlanMapper {
std::unique_ptr<PhysicalOperator> mapExplain(planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapCreateMacro(planner::LogicalOperator* logicalOperator);

std::unique_ptr<ResultCollector> createResultCollector(
std::unique_ptr<ResultCollector> createResultCollector(common::AccumulateType accumulateType,
const binder::expression_vector& expressions, planner::Schema* schema,
std::unique_ptr<PhysicalOperator> prevOperator);
std::unique_ptr<PhysicalOperator> createFactorizedTableScan(
Expand Down
2 changes: 0 additions & 2 deletions src/include/processor/operator/hash_join/hash_join_probe.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ class HashJoinProbe : public PhysicalOperator, public SelVectorOverWriter {
uint64_t getMarkJoinResult();
uint64_t getJoinResult();

void setVectorsToNull();

private:
std::shared_ptr<HashJoinSharedState> sharedState;
common::JoinType joinType;
Expand Down
19 changes: 13 additions & 6 deletions src/include/processor/operator/result_collector.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "common/join_type.h"
#include "processor/operator/sink.h"
#include "processor/result/factorized_table.h"

Expand All @@ -24,14 +25,17 @@ class ResultCollectorSharedState {
};

struct ResultCollectorInfo {
common::AccumulateType accumulateType;
std::unique_ptr<FactorizedTableSchema> tableSchema;
std::vector<DataPos> payloadPositions;

ResultCollectorInfo(
ResultCollectorInfo(common::AccumulateType accumulateType,
std::unique_ptr<FactorizedTableSchema> tableSchema, std::vector<DataPos> payloadPositions)
: tableSchema{std::move(tableSchema)}, payloadPositions{std::move(payloadPositions)} {}
: accumulateType{accumulateType}, tableSchema{std::move(tableSchema)},
payloadPositions{std::move(payloadPositions)} {}
ResultCollectorInfo(const ResultCollectorInfo& other)
: tableSchema{other.tableSchema->copy()}, payloadPositions{other.payloadPositions} {}
: accumulateType{other.accumulateType}, tableSchema{other.tableSchema->copy()},
payloadPositions{other.payloadPositions} {}

inline std::unique_ptr<ResultCollectorInfo> copy() const {
return std::make_unique<ResultCollectorInfo>(*this);
Expand All @@ -48,19 +52,22 @@ class ResultCollector : public Sink {
std::move(child), id, paramsString},
info{std::move(info)}, sharedState{std::move(sharedState)} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;
void executeInternal(ExecutionContext* context) final;

void executeInternal(ExecutionContext* context) override;
void finalize(ExecutionContext* context) final;

inline std::shared_ptr<FactorizedTable> getResultFactorizedTable() {
return sharedState->getTable();
}

std::unique_ptr<PhysicalOperator> clone() override {
std::unique_ptr<PhysicalOperator> clone() final {
return make_unique<ResultCollector>(resultSetDescriptor->copy(), info->copy(), sharedState,
children[0]->clone(), id, paramsString);
}

private:
void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) final;

private:
std::unique_ptr<ResultCollectorInfo> info;
std::shared_ptr<ResultCollectorSharedState> sharedState;
Expand Down
3 changes: 2 additions & 1 deletion src/optimizer/acc_hash_join_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ std::shared_ptr<planner::LogicalOperator> HashJoinSIPOptimizer::appendPathSemiMa

std::shared_ptr<planner::LogicalOperator> HashJoinSIPOptimizer::appendAccumulate(
std::shared_ptr<planner::LogicalOperator> child) {
auto accumulate = std::make_shared<LogicalAccumulate>(std::move(child));
auto accumulate =
std::make_shared<LogicalAccumulate>(common::AccumulateType::REGULAR, std::move(child));
accumulate->computeFlatSchema();
return accumulate;
}
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace kuzu {
namespace optimizer {

void Optimizer::optimize(planner::LogicalPlan* plan) {

// Factorization structure should be removed before further optimization can be applied.
auto removeFactorizationRewriter = RemoveFactorizationRewriter();
removeFactorizationRewriter.rewrite(plan);
Expand Down
3 changes: 3 additions & 0 deletions src/optimizer/projection_push_down_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ void ProjectionPushDownOptimizer::visitExtend(planner::LogicalOperator* op) {

void ProjectionPushDownOptimizer::visitAccumulate(planner::LogicalOperator* op) {
auto accumulate = (LogicalAccumulate*)op;
if (accumulate->getAccumulateType() != common::AccumulateType::REGULAR) {
return;
}
auto expressionsBeforePruning = accumulate->getExpressions();
auto expressionsAfterPruning = pruneExpressions(expressionsBeforePruning);
if (expressionsBeforePruning.size() == expressionsAfterPruning.size()) {
Expand Down
23 changes: 0 additions & 23 deletions src/parser/query/single_query.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +0,0 @@
#include "parser/query/single_query.h"

namespace kuzu {
namespace parser {

bool SingleQuery::isFirstReadingClauseOptionalMatch() const {
for (auto& queryPart : queryParts) {
if (queryPart->getNumReadingClauses() != 0 &&
queryPart->getReadingClause(0)->getClauseType() == common::ClauseType::MATCH) {
return ((MatchClause*)queryPart->getReadingClause(0))->getIsOptional();
} else {
return false;
}
}
if (getNumReadingClauses() != 0 &&
getReadingClause(0)->getClauseType() == common::ClauseType::MATCH) {
return ((MatchClause*)getReadingClause(0))->getIsOptional();
}
return false;
}

} // namespace parser
} // namespace kuzu
4 changes: 3 additions & 1 deletion src/parser/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,10 @@ std::unique_ptr<ReadingClause> Transformer::transformReadingClause(
}

std::unique_ptr<ReadingClause> Transformer::transformMatch(CypherParser::OC_MatchContext& ctx) {
auto matchClauseType =
ctx.OPTIONAL() ? common::MatchClauseType::OPTIONAL_MATCH : common::MatchClauseType::MATCH;
auto matchClause =
std::make_unique<MatchClause>(transformPattern(*ctx.oC_Pattern()), ctx.OPTIONAL());
std::make_unique<MatchClause>(transformPattern(*ctx.oC_Pattern()), matchClauseType);
if (ctx.oC_Where()) {
matchClause->setWhereClause(transformWhere(*ctx.oC_Where()));
}
Expand Down
1 change: 1 addition & 0 deletions src/planner/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_subdirectory(join_order)
add_subdirectory(operator)
add_subdirectory(plan_operator)

add_library(kuzu_planner
OBJECT
Expand Down
Loading

0 comments on commit 0643b92

Please sign in to comment.