Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify mask struct #1320

Merged
merged 1 commit into from
Feb 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 39 additions & 38 deletions src/include/processor/operator/scan_node_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,75 +8,77 @@
namespace kuzu {
namespace processor {

// Note: This class is not thread-safe.
struct Mask {
public:
Mask(uint64_t size, uint8_t maskedFlag) : maskedFlag{maskedFlag} {
explicit Mask(uint64_t size) {
data = std::make_unique<uint8_t[]>(size);
std::fill(data.get(), data.get() + size, 0);
}

// Notice: This function is not protected with a lock for concurrent writes because of the
// special use case that there is no mixed reads and writes to the mask, and all writes to the
// mask try to set a position to the same value, thus it doesn't matter which thread succeeds.
inline void setMask(uint64_t pos, uint8_t maskerIdx, uint8_t maskValue) {
// Note: blindly update mask does not parallel well, so we minimize write by first checking
// if the mask is true or not.
if (data[pos] == maskerIdx) {
data[pos] = maskValue;
}
}
inline bool isMasked(uint64_t pos) { return data[pos] == maskedFlag; }
inline void setMask(uint64_t pos, uint8_t maskValue) { data[pos] = maskValue; }
inline bool isMasked(uint64_t pos, uint8_t trueMaskVal) { return data[pos] == trueMaskVal; }

private:
// The value of maskedFlag is equivalent to the num of maskers passed. It is used to check if a
// value is selected by all maskers or not. Each masker will increment its selected value by 1.
uint8_t maskedFlag;
std::unique_ptr<uint8_t[]> data;
};

// Note: This class is not thread-safe.
struct ScanNodeIDSemiMask {
public:
ScanNodeIDSemiMask(common::offset_t maxNodeOffset, uint8_t maskedFlag) {
nodeMask = std::make_unique<Mask>(maxNodeOffset + 1, maskedFlag);
morselMask = std::make_unique<Mask>(
(maxNodeOffset >> common::DEFAULT_VECTOR_CAPACITY_LOG_2) + 1, maskedFlag);
explicit ScanNodeIDSemiMask() : numMaskers{0} {}

inline void initializeMaskData(common::offset_t maxNodeOffset, common::offset_t maxMorselIdx) {
if (nodeMask == nullptr) {
assert(morselMask == nullptr);
nodeMask = std::make_unique<Mask>(maxNodeOffset + 1);
morselMask = std::make_unique<Mask>(maxMorselIdx + 1);
}
}

inline bool isMorselMasked(uint64_t morselIdx) {
return morselMask->isMasked(morselIdx, numMaskers);
}
inline bool isNodeMasked(uint64_t nodeOffset) {
return nodeMask->isMasked(nodeOffset, numMaskers);
}

inline bool isNodeMaskEnabled() { return nodeMask != nullptr; }
inline bool isMorselMasked(uint64_t morselIdx) { return morselMask->isMasked(morselIdx); }
inline bool isNodeMasked(uint64_t nodeOffset) { return nodeMask->isMasked(nodeOffset); }
// Increment mask value for the given nodeOffset if its current mask value is equal to
// the specified `currentMaskValue`.
void incrementMaskValue(uint64_t nodeOffset, uint8_t currentMaskValue);

void setMask(uint64_t nodeOffset, uint8_t maskerIdx);
inline uint8_t getNumMaskers() const { return numMaskers; }
inline void incrementNumMaskers() { numMaskers++; }

private:
std::unique_ptr<Mask> nodeMask;
std::unique_ptr<Mask> morselMask;
uint8_t numMaskers;
};

// Note: This class is not thread-safe. It relies on its caller to correctly synchronize its state.
class ScanTableNodeIDSharedState {
public:
explicit ScanTableNodeIDSharedState(storage::NodeTable* table)
: table{table}, maxNodeOffset{UINT64_MAX}, maxMorselIdx{UINT64_MAX}, currentNodeOffset{0},
numMaskers{0}, semiMask{nullptr} {}
: table{table}, maxNodeOffset{UINT64_MAX}, maxMorselIdx{UINT64_MAX}, currentNodeOffset{0} {
semiMask = std::make_unique<ScanNodeIDSemiMask>();
}

inline storage::NodeTable* getTable() { return table; }

inline void initialize(transaction::Transaction* transaction) {
inline void initializeMaxOffset(transaction::Transaction* transaction) {
assert(maxNodeOffset == UINT64_MAX && maxMorselIdx == UINT64_MAX);
maxNodeOffset = table->getMaxNodeOffset(transaction);
maxMorselIdx = maxNodeOffset >> common::DEFAULT_VECTOR_CAPACITY_LOG_2;
}

inline void initSemiMask(transaction::Transaction* transaction) {
if (semiMask == nullptr) {
semiMask = std::make_unique<ScanNodeIDSemiMask>(
table->getMaxNodeOffset(transaction), numMaskers);
}
semiMask->initializeMaskData(maxNodeOffset, maxMorselIdx);
}
inline bool isSemiMaskEnabled() { return semiMask != nullptr && semiMask->isNodeMaskEnabled(); }
inline bool isSemiMaskEnabled() { return semiMask->getNumMaskers() > 0; }
inline ScanNodeIDSemiMask* getSemiMask() { return semiMask.get(); }
inline uint8_t getNumMaskers() const { return numMaskers; }
inline void incrementNumMaskers() { numMaskers++; }
inline uint8_t getNumMaskers() const { return semiMask->getNumMaskers(); }
inline void incrementNumMaskers() { semiMask->incrementNumMaskers(); }

std::pair<common::offset_t, common::offset_t> getNextRangeToRead();

Expand All @@ -85,7 +87,6 @@ class ScanTableNodeIDSharedState {
uint64_t maxNodeOffset;
uint64_t maxMorselIdx;
uint64_t currentNodeOffset;
uint8_t numMaskers;
std::unique_ptr<ScanNodeIDSemiMask> semiMask;
};

Expand All @@ -103,7 +104,7 @@ class ScanNodeIDSharedState {

inline void initialize(transaction::Transaction* transaction) {
for (auto& tableState : tableStates) {
tableState->initialize(transaction);
tableState->initializeMaxOffset(transaction);
}
}

Expand All @@ -112,7 +113,6 @@ class ScanNodeIDSharedState {

private:
std::mutex mtx;

std::vector<std::unique_ptr<ScanTableNodeIDSharedState>> tableStates;
uint32_t currentStateIdx;
};
Expand All @@ -139,7 +139,9 @@ class ScanNodeID : public PhysicalOperator {
}

private:
void initGlobalStateInternal(ExecutionContext* context) override;
inline void initGlobalStateInternal(ExecutionContext* context) override {
sharedState->initialize(context->transaction);
}

void setSelVector(ScanTableNodeIDSharedState* tableState, common::offset_t startOffset,
common::offset_t endOffset);
Expand All @@ -148,7 +150,6 @@ class ScanNodeID : public PhysicalOperator {
std::string nodeID;
DataPos outDataPos;
std::shared_ptr<ScanNodeIDSharedState> sharedState;

std::shared_ptr<common::ValueVector> outValueVector;
};

Expand Down
13 changes: 6 additions & 7 deletions src/include/processor/operator/semi_masker.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@ class SemiMasker : public PhysicalOperator {
keyDataPos{other.keyDataPos}, maskerIdx{other.maskerIdx},
scanTableNodeIDSharedState{other.scanTableNodeIDSharedState} {}

inline void setSharedState(ScanTableNodeIDSharedState* sharedState) {
scanTableNodeIDSharedState = sharedState;
maskerIdx = scanTableNodeIDSharedState->getNumMaskers();
assert(maskerIdx < UINT8_MAX);
scanTableNodeIDSharedState->incrementNumMaskers();
}
// This function is used in the plan mapper to configure the shared state between the SemiMasker
// and ScanNodeID.
void setSharedState(ScanTableNodeIDSharedState* sharedState);

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

Expand All @@ -36,7 +33,9 @@ class SemiMasker : public PhysicalOperator {
}

private:
void initGlobalStateInternal(ExecutionContext* context) override;
inline void initGlobalStateInternal(ExecutionContext* context) override {
scanTableNodeIDSharedState->initSemiMask(context->transaction);
}

private:
DataPos keyDataPos;
Expand Down
2 changes: 2 additions & 0 deletions src/processor/mapper/map_hash_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ static void mapASPJoin(Expression* joinNodeID, HashJoinProbe* hashJoinProbe) {
assert(scanNodeIDCandidates.size() == 1);
// set semi masker
auto tableScan = getTableScanForAccHashJoin(hashJoinProbe);
// TODO(Xiyang): `tableScan->getChild(0)->getChild(0)`. This is not a good practice, can we
// change this to a more meaningful way?
assert(tableScan->getChild(0)->getChild(0)->getOperatorType() ==
PhysicalOperatorType::SEMI_MASKER);
auto semiMasker = (SemiMasker*)tableScan->getChild(0)->getChild(0);
Expand Down
21 changes: 12 additions & 9 deletions src/processor/operator/scan_node_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@ using namespace kuzu::common;
namespace kuzu {
namespace processor {

void ScanNodeIDSemiMask::setMask(uint64_t nodeOffset, uint8_t maskerIdx) {
nodeMask->setMask(nodeOffset, maskerIdx, maskerIdx + 1);
morselMask->setMask(nodeOffset >> DEFAULT_VECTOR_CAPACITY_LOG_2, maskerIdx, maskerIdx + 1);
// Note: blindly update mask does not parallelize well, so we minimize write by first checking
// if the mask is set to true (mask value is equal to the expected currentMaskValue) or not.
void ScanNodeIDSemiMask::incrementMaskValue(uint64_t nodeOffset, uint8_t currentMaskValue) {
if (nodeMask->isMasked(nodeOffset, currentMaskValue)) {
nodeMask->setMask(nodeOffset, currentMaskValue + 1);
}
auto morselIdx = nodeOffset >> DEFAULT_VECTOR_CAPACITY_LOG_2;
if (morselMask->isMasked(morselIdx, currentMaskValue)) {
morselMask->setMask(morselIdx, currentMaskValue + 1);
}
}

std::pair<offset_t, offset_t> ScanTableNodeIDSharedState::getNextRangeToRead() {
// Note: we use maxNodeOffset=UINT64_MAX to represent an empty table.
if (currentNodeOffset > maxNodeOffset || maxNodeOffset == UINT64_MAX) {
if (currentNodeOffset > maxNodeOffset || maxNodeOffset == INVALID_NODE_OFFSET) {
return std::make_pair(currentNodeOffset, currentNodeOffset);
}
if (semiMask) {
if (isSemiMaskEnabled()) {
auto currentMorselIdx = currentNodeOffset >> DEFAULT_VECTOR_CAPACITY_LOG_2;
assert(currentNodeOffset % DEFAULT_VECTOR_CAPACITY == 0);
while (currentMorselIdx <= maxMorselIdx && !semiMask->isMorselMasked(currentMorselIdx)) {
Expand Down Expand Up @@ -73,10 +80,6 @@ bool ScanNodeID::getNextTuplesInternal() {
return true;
}

void ScanNodeID::initGlobalStateInternal(ExecutionContext* context) {
sharedState->initialize(context->transaction);
}

void ScanNodeID::setSelVector(
ScanTableNodeIDSharedState* tableState, offset_t startOffset, offset_t endOffset) {
if (tableState->isSemiMaskEnabled()) {
Expand Down
13 changes: 8 additions & 5 deletions src/processor/operator/semi_masker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ using namespace kuzu::common;
namespace kuzu {
namespace processor {

void SemiMasker::setSharedState(ScanTableNodeIDSharedState* sharedState) {
scanTableNodeIDSharedState = sharedState;
maskerIdx = scanTableNodeIDSharedState->getNumMaskers();
assert(maskerIdx < UINT8_MAX);
scanTableNodeIDSharedState->incrementNumMaskers();
}

void SemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
keyValueVector = resultSet->getValueVector(keyDataPos);
assert(keyValueVector->dataType.typeID == INTERNAL_ID);
Expand All @@ -18,17 +25,13 @@ bool SemiMasker::getNextTuplesInternal() {
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize;
for (auto i = 0u; i < numValues; i++) {
auto pos = keyValueVector->state->selVector->selectedPositions[i];
scanTableNodeIDSharedState->getSemiMask()->setMask(
scanTableNodeIDSharedState->getSemiMask()->incrementMaskValue(
keyValueVector->getValue<nodeID_t>(pos).offset, maskerIdx);
}
metrics->numOutputTuple.increase(
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize);
return true;
}

void SemiMasker::initGlobalStateInternal(ExecutionContext* context) {
scanTableNodeIDSharedState->initSemiMask(context->transaction);
}

} // namespace processor
} // namespace kuzu