Skip to content

Commit

Permalink
refactoring filtering operator
Browse files Browse the repository at this point in the history
  • Loading branch information
ray6080 committed Dec 14, 2022
1 parent ff546cc commit 60dfe48
Show file tree
Hide file tree
Showing 14 changed files with 71 additions and 93 deletions.
Empty file removed examples/CMakeLists.txt
Empty file.
8 changes: 3 additions & 5 deletions src/include/processor/operator/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,26 @@ using namespace kuzu::evaluator;
namespace kuzu {
namespace processor {

class Filter : public PhysicalOperator, public FilteringOperator {
class Filter : public PhysicalOperator, public SelVectorOverWriter {
public:
Filter(unique_ptr<BaseExpressionEvaluator> expressionEvaluator, uint32_t dataChunkToSelectPos,
unique_ptr<PhysicalOperator> child, uint32_t id, const string& paramsString)
: PhysicalOperator{PhysicalOperatorType::FILTER, std::move(child), id, paramsString},
FilteringOperator{1 /* numStatesToSave */}, expressionEvaluator{std::move(
expressionEvaluator)},
expressionEvaluator{std::move(expressionEvaluator)},
dataChunkToSelectPos(dataChunkToSelectPos) {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool getNextTuplesInternal() override;

unique_ptr<PhysicalOperator> clone() override {
inline unique_ptr<PhysicalOperator> clone() override {
return make_unique<Filter>(expressionEvaluator->clone(), dataChunkToSelectPos,
children[0]->clone(), id, paramsString);
}

private:
unique_ptr<BaseExpressionEvaluator> expressionEvaluator;
uint32_t dataChunkToSelectPos;

shared_ptr<DataChunk> dataChunkToSelect;
};

Expand Down
36 changes: 9 additions & 27 deletions src/include/processor/operator/filtering_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,27 @@
#include "common/data_chunk/data_chunk_state.h"

using namespace kuzu::common;
using namespace std;

namespace kuzu {
namespace processor {

class FilteringOperator {
class SelVectorOverWriter {
public:
explicit FilteringOperator(uint64_t numStatesToSave) {
for (auto i = 0u; i < numStatesToSave; ++i) {
auto prevSelVector = make_unique<SelectionVector>(DEFAULT_VECTOR_CAPACITY);
prevSelVector->selectedPositions = nullptr;
prevSelVectors.push_back(std::move(prevSelVector));
}
explicit SelVectorOverWriter() {
currentSelVector = make_shared<SelectionVector>(DEFAULT_VECTOR_CAPACITY);
}

protected:
inline void restoreSelVectors(vector<SelectionVector*>& selVectors) {
for (auto i = 0u; i < selVectors.size(); ++i) {
restoreSelVector(prevSelVectors[i].get(), selVectors[i]);
}
}
inline void restoreSelVector(SelectionVector* selVector) {
restoreSelVector(prevSelVectors[0].get(), selVector);
}
void restoreSelVector(shared_ptr<SelectionVector>& selVector);

inline void saveSelVectors(vector<SelectionVector*>& selVectors) {
for (auto i = 0u; i < selVectors.size(); ++i) {
saveSelVector(prevSelVectors[i].get(), selVectors[i]);
}
}
inline void saveSelVector(SelectionVector* selVector) {
saveSelVector(prevSelVectors[0].get(), selVector);
}
void saveSelVector(shared_ptr<SelectionVector>& selVector);

private:
static void restoreSelVector(SelectionVector* prevSelVector, SelectionVector* selVector);
static void saveSelVector(SelectionVector* prevSelVector, SelectionVector* selVector);
virtual void resetToCurrentSelVector(shared_ptr<SelectionVector>& selVector);

vector<unique_ptr<SelectionVector>> prevSelVectors;
protected:
shared_ptr<SelectionVector> prevSelVector;
shared_ptr<SelectionVector> currentSelVector;
};
} // namespace processor
} // namespace kuzu
12 changes: 6 additions & 6 deletions src/include/processor/operator/flatten.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#pragma once

#include "processor/operator/filtering_operator.h"
#include "processor/operator/physical_operator.h"

namespace kuzu {
namespace processor {

class Flatten : public PhysicalOperator {
class Flatten : public PhysicalOperator, SelVectorOverWriter {
public:
Flatten(uint32_t dataChunkToFlattenPos, unique_ptr<PhysicalOperator> child, uint32_t id,
const string& paramsString)
Expand All @@ -16,21 +17,20 @@ class Flatten : public PhysicalOperator {

bool getNextTuplesInternal() override;

unique_ptr<PhysicalOperator> clone() override {
inline unique_ptr<PhysicalOperator> clone() override {
return make_unique<Flatten>(dataChunkToFlattenPos, children[0]->clone(), id, paramsString);
}

private:
bool isCurrIdxInitialOrLast() {
inline bool isCurrIdxInitialOrLast() {
return dataChunkToFlatten->state->currIdx == -1 ||
dataChunkToFlatten->state->currIdx == (unFlattenedSelVector->selectedSize - 1);
dataChunkToFlatten->state->currIdx == (prevSelVector->selectedSize - 1);
}
void resetToCurrentSelVector(shared_ptr<SelectionVector>& selVector) override;

private:
uint32_t dataChunkToFlattenPos;
std::shared_ptr<DataChunk> dataChunkToFlatten;
std::shared_ptr<SelectionVector> unFlattenedSelVector;
std::shared_ptr<SelectionVector> flattenedSelVector;
};

} // namespace processor
Expand Down
7 changes: 2 additions & 5 deletions src/include/processor/operator/hash_join/hash_join_probe.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@ struct ProbeDataInfo {
};

// Probe side on left, i.e. children[0] and build side on right, i.e. children[1]
class HashJoinProbe : public PhysicalOperator, FilteringOperator {
class HashJoinProbe : public PhysicalOperator, SelVectorOverWriter {
public:
HashJoinProbe(shared_ptr<HashJoinSharedState> sharedState, JoinType joinType,
const ProbeDataInfo& probeDataInfo, unique_ptr<PhysicalOperator> probeChild,
unique_ptr<PhysicalOperator> buildChild, uint32_t id, const string& paramsString)
: PhysicalOperator{PhysicalOperatorType::HASH_JOIN_PROBE, std::move(probeChild),
std::move(buildChild), id, paramsString},
FilteringOperator{probeDataInfo.keysDataPos.size()},
sharedState{std::move(sharedState)}, joinType{joinType}, probeDataInfo{probeDataInfo} {}

// This constructor is used for cloning only.
Expand All @@ -63,7 +62,6 @@ class HashJoinProbe : public PhysicalOperator, FilteringOperator {
const string& paramsString)
: PhysicalOperator{PhysicalOperatorType::HASH_JOIN_PROBE, std::move(probeChild), id,
paramsString},
FilteringOperator{probeDataInfo.keysDataPos.size()},
sharedState{std::move(sharedState)}, joinType{joinType}, probeDataInfo{probeDataInfo} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;
Expand All @@ -81,7 +79,7 @@ class HashJoinProbe : public PhysicalOperator, FilteringOperator {
uint64_t getNextInnerJoinResult();
uint64_t getNextLeftJoinResult();
uint64_t getNextMarkJoinResult();
void setVectorsToNull(vector<shared_ptr<ValueVector>>& vectors);
void setVectorsToNull();

uint64_t getNextJoinResult();

Expand All @@ -93,7 +91,6 @@ class HashJoinProbe : public PhysicalOperator, FilteringOperator {
vector<shared_ptr<ValueVector>> vectorsToReadInto;
vector<uint32_t> columnIdxsToReadFrom;
vector<shared_ptr<ValueVector>> keyVectors;
vector<SelectionVector*> keySelVectors;
shared_ptr<ValueVector> markVector;
unique_ptr<ProbeState> probeState;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace kuzu {
namespace processor {

class ColumnExtendAndScanRelProperties : public BaseExtendAndScanRelProperties,
public FilteringOperator {
public SelVectorOverWriter {
public:
ColumnExtendAndScanRelProperties(const DataPos& inNodeIDVectorPos,
const DataPos& outNodeIDVectorPos, vector<DataPos> outPropertyVectorsPos, Column* adjColumn,
Expand All @@ -17,8 +17,7 @@ class ColumnExtendAndScanRelProperties : public BaseExtendAndScanRelProperties,
: BaseExtendAndScanRelProperties{PhysicalOperatorType::COLUMN_EXTEND, inNodeIDVectorPos,
outNodeIDVectorPos, std::move(outPropertyVectorsPos), std::move(child), id,
paramsString},
FilteringOperator{1 /* numStatesToSave */}, adjColumn{adjColumn},
propertyColumns{std::move(propertyColumns)} {}
adjColumn{adjColumn}, propertyColumns{std::move(propertyColumns)} {}
~ColumnExtendAndScanRelProperties() override = default;

bool getNextTuplesInternal() override;
Expand Down
10 changes: 6 additions & 4 deletions src/include/processor/operator/skip.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@
namespace kuzu {
namespace processor {

class Skip : public PhysicalOperator, public FilteringOperator {
class Skip : public PhysicalOperator, public SelVectorOverWriter {
public:
Skip(uint64_t skipNumber, shared_ptr<atomic_uint64_t> counter, uint32_t dataChunkToSelectPos,
unordered_set<uint32_t> dataChunksPosInScope, unique_ptr<PhysicalOperator> child,
uint32_t id, const string& paramsString)
: PhysicalOperator{PhysicalOperatorType::SKIP, std::move(child), id, paramsString},
FilteringOperator{1 /* numStatesToSave */}, skipNumber{skipNumber}, counter{std::move(
counter)},
skipNumber{skipNumber}, counter{std::move(counter)},
dataChunkToSelectPos{dataChunkToSelectPos}, dataChunksPosInScope{
std::move(dataChunksPosInScope)} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool getNextTuplesInternal() override;

unique_ptr<PhysicalOperator> clone() override {
inline unique_ptr<PhysicalOperator> clone() override {
return make_unique<Skip>(skipNumber, counter, dataChunkToSelectPos, dataChunksPosInScope,
children[0]->clone(), id, paramsString);
}
Expand All @@ -28,6 +29,7 @@ class Skip : public PhysicalOperator, public FilteringOperator {
uint64_t skipNumber;
shared_ptr<atomic_uint64_t> counter;
uint32_t dataChunkToSelectPos;
shared_ptr<DataChunk> dataChunkToSelect;
unordered_set<uint32_t> dataChunksPosInScope;
};

Expand Down
6 changes: 3 additions & 3 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,12 +479,12 @@ void JoinOrderEnumerator::planInnerHashJoin(const SubqueryGraph& subgraph,
auto rightPlanProbeCopy = rightPlan->shallowCopy();
planInnerHashJoin(joinNodes, *leftPlanProbeCopy, *rightPlanBuildCopy);
planFiltersForHashJoin(predicates, *leftPlanProbeCopy);
context->addPlan(newSubgraph, move(leftPlanProbeCopy));
context->addPlan(newSubgraph, std::move(leftPlanProbeCopy));
// flip build and probe side to get another HashJoin plan
if (flipPlan) {
planInnerHashJoin(joinNodes, *rightPlanProbeCopy, *leftPlanBuildCopy);
planFiltersForHashJoin(predicates, *rightPlanProbeCopy);
context->addPlan(newSubgraph, move(rightPlanProbeCopy));
context->addPlan(newSubgraph, std::move(rightPlanProbeCopy));
}
}
}
Expand Down Expand Up @@ -678,7 +678,7 @@ void JoinOrderEnumerator::appendHashJoin(const vector<shared_ptr<NodeExpression>
auto hashJoin = make_shared<LogicalHashJoin>(joinNodes, joinType, isProbeAcc,
buildSideSchema->copy(), buildSideSchema->getExpressionsInScope(),
probePlan.getLastOperator(), buildPlan.getLastOperator());
probePlan.setLastOperator(move(hashJoin));
probePlan.setLastOperator(std::move(hashJoin));
}

void JoinOrderEnumerator::appendMarkJoin(const vector<shared_ptr<NodeExpression>>& joinNodes,
Expand Down
4 changes: 2 additions & 2 deletions src/processor/operator/filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ void Filter::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* cont
bool Filter::getNextTuplesInternal() {
bool hasAtLeastOneSelectedValue;
do {
restoreSelVector(dataChunkToSelect->state->selVector.get());
restoreSelVector(dataChunkToSelect->state->selVector);
if (!children[0]->getNextTuple()) {
return false;
}
saveSelVector(dataChunkToSelect->state->selVector.get());
saveSelVector(dataChunkToSelect->state->selVector);
hasAtLeastOneSelectedValue =
expressionEvaluator->select(*dataChunkToSelect->state->selVector);
if (!dataChunkToSelect->state->isFlat() &&
Expand Down
35 changes: 16 additions & 19 deletions src/processor/operator/filtering_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,29 @@
namespace kuzu {
namespace processor {

void FilteringOperator::restoreSelVector(
SelectionVector* prevSelVector, SelectionVector* selVector) {
if (prevSelVector->selectedPositions == nullptr) {
return;
void SelVectorOverWriter::restoreSelVector(shared_ptr<SelectionVector>& selVector) {
if (prevSelVector != nullptr) {
selVector = prevSelVector;
}
selVector->selectedSize = prevSelVector->selectedSize;
if (prevSelVector->isUnfiltered()) {
selVector->resetSelectorToUnselected();
} else {
auto sizeToCopy = prevSelVector->selectedSize * sizeof(sel_t);
selVector->resetSelectorToValuePosBuffer();
memcpy(
selVector->selectedPositions, prevSelVector->getSelectedPositionsBuffer(), sizeToCopy);
}

void SelVectorOverWriter::saveSelVector(shared_ptr<SelectionVector>& selVector) {
if (prevSelVector == nullptr) {
prevSelVector = selVector;
}
resetToCurrentSelVector(selVector);
}

void FilteringOperator::saveSelVector(SelectionVector* prevSelVector, SelectionVector* selVector) {
prevSelVector->selectedSize = selVector->selectedSize;
void SelVectorOverWriter::resetToCurrentSelVector(shared_ptr<SelectionVector>& selVector) {
currentSelVector->selectedSize = selVector->selectedSize;
if (selVector->isUnfiltered()) {
prevSelVector->resetSelectorToUnselected();
currentSelVector->resetSelectorToUnselected();
} else {
auto sizeToCopy = prevSelVector->selectedSize * sizeof(sel_t);
memcpy(prevSelVector->getSelectedPositionsBuffer(), selVector->getSelectedPositionsBuffer(),
sizeToCopy);
prevSelVector->resetSelectorToValuePosBuffer();
memcpy(currentSelVector->getSelectedPositionsBuffer(), selVector->selectedPositions,
selVector->selectedSize * sizeof(sel_t));
currentSelVector->resetSelectorToValuePosBuffer();
}
selVector = currentSelVector;
}

} // namespace processor
Expand Down
17 changes: 9 additions & 8 deletions src/processor/operator/flatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,28 @@ namespace processor {

void Flatten::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
dataChunkToFlatten = resultSet->dataChunks[dataChunkToFlattenPos];
unFlattenedSelVector = dataChunkToFlatten->state->selVector;
flattenedSelVector = make_shared<SelectionVector>(1 /* capacity */);
flattenedSelVector->resetSelectorToValuePosBufferWithSize(1 /* size */);
dataChunkToFlatten->state->selVector = flattenedSelVector;
currentSelVector->resetSelectorToValuePosBufferWithSize(1 /* size */);
}

bool Flatten::getNextTuplesInternal() {
if (isCurrIdxInitialOrLast()) {
dataChunkToFlatten->state->currIdx = -1;
dataChunkToFlatten->state->selVector = unFlattenedSelVector;
restoreSelVector(dataChunkToFlatten->state->selVector);
if (!children[0]->getNextTuple()) {
return false;
}
dataChunkToFlatten->state->selVector = flattenedSelVector;
saveSelVector(dataChunkToFlatten->state->selVector);
}
dataChunkToFlatten->state->currIdx++;
flattenedSelVector->selectedPositions[0] =
unFlattenedSelVector->selectedPositions[dataChunkToFlatten->state->currIdx];
currentSelVector->selectedPositions[0] =
prevSelVector->selectedPositions[dataChunkToFlatten->state->currIdx];
metrics->numOutputTuple.incrementByOne();
return true;
}

void Flatten::resetToCurrentSelVector(shared_ptr<SelectionVector>& selVector) {
selVector = currentSelVector;
}

} // namespace processor
} // namespace kuzu
9 changes: 4 additions & 5 deletions src/processor/operator/hash_join/hash_join_probe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ void HashJoinProbe::initLocalStateInternal(ResultSet* resultSet, ExecutionContex
for (auto& keyDataPos : probeDataInfo.keysDataPos) {
auto keyVector = resultSet->getValueVector(keyDataPos);
keyVectors.push_back(keyVector);
keySelVectors.push_back(keyVector->state->selVector.get());
}
if (joinType == JoinType::MARK) {
markVector = resultSet->getValueVector(probeDataInfo.markDataPos);
Expand Down Expand Up @@ -43,11 +42,11 @@ bool HashJoinProbe::getNextBatchOfMatchedTuples() {
return true;
}
if (!hasMoreLeft()) {
restoreSelVectors(keySelVectors);
restoreSelVector(keyVectors[0]->state->selVector);
if (!children[0]->getNextTuple()) {
return false;
}
saveSelVectors(keySelVectors);
saveSelVector(keyVectors[0]->state->selVector);
sharedState->getHashTable()->probe(keyVectors, probeState->probedTuples.get());
}
auto numMatchedTuples = 0;
Expand Down Expand Up @@ -92,7 +91,7 @@ bool HashJoinProbe::getNextBatchOfMatchedTuples() {
return true;
}

void HashJoinProbe::setVectorsToNull(vector<shared_ptr<ValueVector>>& vectors) {
void HashJoinProbe::setVectorsToNull() {
for (auto& vector : vectorsToReadInto) {
if (vector->state->isFlat()) {
vector->setNull(vector->state->selVector->selectedPositions[0], true);
Expand Down Expand Up @@ -130,7 +129,7 @@ uint64_t HashJoinProbe::getNextInnerJoinResult() {

uint64_t HashJoinProbe::getNextLeftJoinResult() {
if (getNextInnerJoinResult() == 0) {
setVectorsToNull(vectorsToReadInto);
setVectorsToNull();
}
return 1;
}
Expand Down
4 changes: 2 additions & 2 deletions src/processor/operator/scan_column/adj_column_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ bool ColumnExtendAndScanRelProperties::getNextTuplesInternal() {
bool hasAtLeastOneNonNullValue;
// join with adjColumn
do {
restoreSelVector(inNodeIDVector->state->selVector.get());
restoreSelVector(inNodeIDVector->state->selVector);
if (!children[0]->getNextTuple()) {
return false;
}
saveSelVector(inNodeIDVector->state->selVector.get());
saveSelVector(inNodeIDVector->state->selVector);
outNodeIDVector->setAllNull();
adjColumn->read(transaction, inNodeIDVector, outNodeIDVector);
hasAtLeastOneNonNullValue = NodeIDVector::discardNull(*outNodeIDVector);
Expand Down
Loading

0 comments on commit 60dfe48

Please sign in to comment.