Skip to content

Commit

Permalink
rework mask struct
Browse files Browse the repository at this point in the history
  • Loading branch information
ray6080 committed Feb 26, 2023
1 parent f8cbe3e commit 18aba05
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 59 deletions.
77 changes: 39 additions & 38 deletions src/include/processor/operator/scan_node_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,75 +8,77 @@
namespace kuzu {
namespace processor {

// Note: This class is not thread-safe.
struct Mask {
public:
Mask(uint64_t size, uint8_t maskedFlag) : maskedFlag{maskedFlag} {
explicit Mask(uint64_t size) {
data = std::make_unique<uint8_t[]>(size);
std::fill(data.get(), data.get() + size, 0);
}

// Notice: This function is not protected with a lock for concurrent writes because of the
// special use case that there is no mixed reads and writes to the mask, and all writes to the
// mask try to set a position to the same value, thus it doesn't matter which thread succeeds.
inline void setMask(uint64_t pos, uint8_t maskerIdx, uint8_t maskValue) {
// Note: blindly update mask does not parallel well, so we minimize write by first checking
// if the mask is true or not.
if (data[pos] == maskerIdx) {
data[pos] = maskValue;
}
}
inline bool isMasked(uint64_t pos) { return data[pos] == maskedFlag; }
inline void setMask(uint64_t pos, uint8_t maskValue) { data[pos] = maskValue; }
inline bool isMasked(uint64_t pos, uint8_t trueMaskVal) { return data[pos] == trueMaskVal; }

private:
// The value of maskedFlag is equivalent to the num of maskers passed. It is used to check if a
// value is selected by all maskers or not. Each masker will increment its selected value by 1.
uint8_t maskedFlag;
std::unique_ptr<uint8_t[]> data;
};

// Note: This class is not thread-safe.
struct ScanNodeIDSemiMask {
public:
ScanNodeIDSemiMask(common::offset_t maxNodeOffset, uint8_t maskedFlag) {
nodeMask = std::make_unique<Mask>(maxNodeOffset + 1, maskedFlag);
morselMask = std::make_unique<Mask>(
(maxNodeOffset >> common::DEFAULT_VECTOR_CAPACITY_LOG_2) + 1, maskedFlag);
explicit ScanNodeIDSemiMask() : numMaskers{0} {}

inline void initializeMaskData(common::offset_t maxNodeOffset, common::offset_t maxMorselIdx) {
if (nodeMask == nullptr) {
assert(morselMask == nullptr);
nodeMask = std::make_unique<Mask>(maxNodeOffset + 1);
morselMask = std::make_unique<Mask>(maxMorselIdx + 1);
}
}

inline bool isMorselMasked(uint64_t morselIdx) {
return morselMask->isMasked(morselIdx, numMaskers);
}
inline bool isNodeMasked(uint64_t nodeOffset) {
return nodeMask->isMasked(nodeOffset, numMaskers);
}

inline bool isNodeMaskEnabled() { return nodeMask != nullptr; }
inline bool isMorselMasked(uint64_t morselIdx) { return morselMask->isMasked(morselIdx); }
inline bool isNodeMasked(uint64_t nodeOffset) { return nodeMask->isMasked(nodeOffset); }
// Increment mask value for the given nodeOffset if its current mask value is equal to
// the specified `currentMaskValue`.
void incrementMaskValue(uint64_t nodeOffset, uint8_t currentMaskValue);

void setMask(uint64_t nodeOffset, uint8_t maskerIdx);
inline uint8_t getNumMaskers() const { return numMaskers; }
inline void incrementNumMaskers() { numMaskers++; }

private:
std::unique_ptr<Mask> nodeMask;
std::unique_ptr<Mask> morselMask;
uint8_t numMaskers;
};

// Note: This class is not thread-safe. It relies on its caller to correctly synchronize its state.
class ScanTableNodeIDSharedState {
public:
explicit ScanTableNodeIDSharedState(storage::NodeTable* table)
: table{table}, maxNodeOffset{UINT64_MAX}, maxMorselIdx{UINT64_MAX}, currentNodeOffset{0},
numMaskers{0}, semiMask{nullptr} {}
: table{table}, maxNodeOffset{UINT64_MAX}, maxMorselIdx{UINT64_MAX}, currentNodeOffset{0} {
semiMask = std::make_unique<ScanNodeIDSemiMask>();
}

inline storage::NodeTable* getTable() { return table; }

inline void initialize(transaction::Transaction* transaction) {
inline void initializeMaxOffset(transaction::Transaction* transaction) {
assert(maxNodeOffset == UINT64_MAX && maxMorselIdx == UINT64_MAX);
maxNodeOffset = table->getMaxNodeOffset(transaction);
maxMorselIdx = maxNodeOffset >> common::DEFAULT_VECTOR_CAPACITY_LOG_2;
}

inline void initSemiMask(transaction::Transaction* transaction) {
if (semiMask == nullptr) {
semiMask = std::make_unique<ScanNodeIDSemiMask>(
table->getMaxNodeOffset(transaction), numMaskers);
}
semiMask->initializeMaskData(maxNodeOffset, maxMorselIdx);
}
inline bool isSemiMaskEnabled() { return semiMask != nullptr && semiMask->isNodeMaskEnabled(); }
inline bool isSemiMaskEnabled() { return semiMask->getNumMaskers() > 0; }
inline ScanNodeIDSemiMask* getSemiMask() { return semiMask.get(); }
inline uint8_t getNumMaskers() const { return numMaskers; }
inline void incrementNumMaskers() { numMaskers++; }
inline uint8_t getNumMaskers() const { return semiMask->getNumMaskers(); }
inline void incrementNumMaskers() { semiMask->incrementNumMaskers(); }

std::pair<common::offset_t, common::offset_t> getNextRangeToRead();

Expand All @@ -85,7 +87,6 @@ class ScanTableNodeIDSharedState {
uint64_t maxNodeOffset;
uint64_t maxMorselIdx;
uint64_t currentNodeOffset;
uint8_t numMaskers;
std::unique_ptr<ScanNodeIDSemiMask> semiMask;
};

Expand All @@ -103,7 +104,7 @@ class ScanNodeIDSharedState {

inline void initialize(transaction::Transaction* transaction) {
for (auto& tableState : tableStates) {
tableState->initialize(transaction);
tableState->initializeMaxOffset(transaction);
}
}

Expand All @@ -112,7 +113,6 @@ class ScanNodeIDSharedState {

private:
std::mutex mtx;

std::vector<std::unique_ptr<ScanTableNodeIDSharedState>> tableStates;
uint32_t currentStateIdx;
};
Expand All @@ -139,7 +139,9 @@ class ScanNodeID : public PhysicalOperator {
}

private:
void initGlobalStateInternal(ExecutionContext* context) override;
inline void initGlobalStateInternal(ExecutionContext* context) override {
sharedState->initialize(context->transaction);
}

void setSelVector(ScanTableNodeIDSharedState* tableState, common::offset_t startOffset,
common::offset_t endOffset);
Expand All @@ -148,7 +150,6 @@ class ScanNodeID : public PhysicalOperator {
std::string nodeID;
DataPos outDataPos;
std::shared_ptr<ScanNodeIDSharedState> sharedState;

std::shared_ptr<common::ValueVector> outValueVector;
};

Expand Down
13 changes: 6 additions & 7 deletions src/include/processor/operator/semi_masker.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@ class SemiMasker : public PhysicalOperator {
keyDataPos{other.keyDataPos}, maskerIdx{other.maskerIdx},
scanTableNodeIDSharedState{other.scanTableNodeIDSharedState} {}

inline void setSharedState(ScanTableNodeIDSharedState* sharedState) {
scanTableNodeIDSharedState = sharedState;
maskerIdx = scanTableNodeIDSharedState->getNumMaskers();
assert(maskerIdx < UINT8_MAX);
scanTableNodeIDSharedState->incrementNumMaskers();
}
// This function is used in the plan mapper to configure the shared state between the SemiMasker
// and ScanNodeID.
void setSharedState(ScanTableNodeIDSharedState* sharedState);

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

Expand All @@ -36,7 +33,9 @@ class SemiMasker : public PhysicalOperator {
}

private:
void initGlobalStateInternal(ExecutionContext* context) override;
inline void initGlobalStateInternal(ExecutionContext* context) override {
scanTableNodeIDSharedState->initSemiMask(context->transaction);
}

private:
DataPos keyDataPos;
Expand Down
2 changes: 2 additions & 0 deletions src/processor/mapper/map_hash_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ static void mapASPJoin(Expression* joinNodeID, HashJoinProbe* hashJoinProbe) {
assert(scanNodeIDCandidates.size() == 1);
// set semi masker
auto tableScan = getTableScanForAccHashJoin(hashJoinProbe);
// TODO(Xiyang): `tableScan->getChild(0)->getChild(0)`. This is not a good practice, can we
// change this to a more meaningful way?
assert(tableScan->getChild(0)->getChild(0)->getOperatorType() ==
PhysicalOperatorType::SEMI_MASKER);
auto semiMasker = (SemiMasker*)tableScan->getChild(0)->getChild(0);
Expand Down
21 changes: 12 additions & 9 deletions src/processor/operator/scan_node_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@ using namespace kuzu::common;
namespace kuzu {
namespace processor {

void ScanNodeIDSemiMask::setMask(uint64_t nodeOffset, uint8_t maskerIdx) {
nodeMask->setMask(nodeOffset, maskerIdx, maskerIdx + 1);
morselMask->setMask(nodeOffset >> DEFAULT_VECTOR_CAPACITY_LOG_2, maskerIdx, maskerIdx + 1);
// Note: blindly update mask does not parallelize well, so we minimize write by first checking
// if the mask is set to true (mask value is equal to the expected currentMaskValue) or not.
void ScanNodeIDSemiMask::incrementMaskValue(uint64_t nodeOffset, uint8_t currentMaskValue) {
if (nodeMask->isMasked(nodeOffset, currentMaskValue)) {
nodeMask->setMask(nodeOffset, currentMaskValue + 1);
}
auto morselIdx = nodeOffset >> DEFAULT_VECTOR_CAPACITY_LOG_2;
if (morselMask->isMasked(morselIdx, currentMaskValue)) {
morselMask->setMask(morselIdx, currentMaskValue + 1);
}
}

std::pair<offset_t, offset_t> ScanTableNodeIDSharedState::getNextRangeToRead() {
// Note: we use maxNodeOffset=UINT64_MAX to represent an empty table.
if (currentNodeOffset > maxNodeOffset || maxNodeOffset == UINT64_MAX) {
if (currentNodeOffset > maxNodeOffset || maxNodeOffset == INVALID_NODE_OFFSET) {
return std::make_pair(currentNodeOffset, currentNodeOffset);
}
if (semiMask) {
if (isSemiMaskEnabled()) {
auto currentMorselIdx = currentNodeOffset >> DEFAULT_VECTOR_CAPACITY_LOG_2;
assert(currentNodeOffset % DEFAULT_VECTOR_CAPACITY == 0);
while (currentMorselIdx <= maxMorselIdx && !semiMask->isMorselMasked(currentMorselIdx)) {
Expand Down Expand Up @@ -73,10 +80,6 @@ bool ScanNodeID::getNextTuplesInternal() {
return true;
}

void ScanNodeID::initGlobalStateInternal(ExecutionContext* context) {
sharedState->initialize(context->transaction);
}

void ScanNodeID::setSelVector(
ScanTableNodeIDSharedState* tableState, offset_t startOffset, offset_t endOffset) {
if (tableState->isSemiMaskEnabled()) {
Expand Down
13 changes: 8 additions & 5 deletions src/processor/operator/semi_masker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ using namespace kuzu::common;
namespace kuzu {
namespace processor {

void SemiMasker::setSharedState(ScanTableNodeIDSharedState* sharedState) {
scanTableNodeIDSharedState = sharedState;
maskerIdx = scanTableNodeIDSharedState->getNumMaskers();
assert(maskerIdx < UINT8_MAX);
scanTableNodeIDSharedState->incrementNumMaskers();
}

void SemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
keyValueVector = resultSet->getValueVector(keyDataPos);
assert(keyValueVector->dataType.typeID == INTERNAL_ID);
Expand All @@ -18,17 +25,13 @@ bool SemiMasker::getNextTuplesInternal() {
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize;
for (auto i = 0u; i < numValues; i++) {
auto pos = keyValueVector->state->selVector->selectedPositions[i];
scanTableNodeIDSharedState->getSemiMask()->setMask(
scanTableNodeIDSharedState->getSemiMask()->incrementMaskValue(
keyValueVector->getValue<nodeID_t>(pos).offset, maskerIdx);
}
metrics->numOutputTuple.increase(
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize);
return true;
}

void SemiMasker::initGlobalStateInternal(ExecutionContext* context) {
scanTableNodeIDSharedState->initSemiMask(context->transaction);
}

} // namespace processor
} // namespace kuzu

0 comments on commit 18aba05

Please sign in to comment.