Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi label asp #1370

Merged
merged 1 commit into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def serialize(dataset_name, dataset_path, serialized_graph_path):
try:
# Run kuzu shell one query at a time. This ensures a new process is
# created for each query to avoid memory leaks.
subprocess.run([kuzu_exec_path, '-i', serialized_graph_path],
subprocess.run([kuzu_exec_path, serialized_graph_path],
input=(s + ";" + "\n").encode("ascii"), check=True)
except subprocess.CalledProcessError as e:
logging.error('Error executing query: %s', s)
Expand Down
32 changes: 15 additions & 17 deletions src/include/processor/operator/scan_node_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ struct Mask {
};

// Note: This class is not thread-safe.
struct ScanNodeIDSemiMask {
struct NodeTableSemiMask {
public:
explicit ScanNodeIDSemiMask() : numMaskers{0} {}
NodeTableSemiMask() : numMaskers{0} {}

inline void initializeMaskData(common::offset_t maxNodeOffset, common::offset_t maxMorselIdx) {
if (nodeMask != nullptr) {
Expand Down Expand Up @@ -59,11 +59,12 @@ struct ScanNodeIDSemiMask {
};

// Note: This class is not thread-safe. It relies on its caller to correctly synchronize its state.
class ScanTableNodeIDSharedState {
class NodeTableState {
public:
explicit ScanTableNodeIDSharedState(storage::NodeTable* table)
: table{table}, maxNodeOffset{UINT64_MAX}, maxMorselIdx{UINT64_MAX}, currentNodeOffset{0} {
semiMask = std::make_unique<ScanNodeIDSemiMask>();
explicit NodeTableState(storage::NodeTable* table)
: table{table}, maxNodeOffset{common::INVALID_NODE_OFFSET}, maxMorselIdx{UINT64_MAX},
currentNodeOffset{0} {
semiMask = std::make_unique<NodeTableSemiMask>();
}

inline storage::NodeTable* getTable() { return table; }
Expand All @@ -83,7 +84,7 @@ class ScanTableNodeIDSharedState {
semiMask->initializeMaskData(maxNodeOffset, maxMorselIdx);
}
inline bool isSemiMaskEnabled() { return semiMask->getNumMaskers() > 0; }
inline ScanNodeIDSemiMask* getSemiMask() { return semiMask.get(); }
inline NodeTableSemiMask* getSemiMask() { return semiMask.get(); }
inline uint8_t getNumMaskers() const { return semiMask->getNumMaskers(); }
inline void incrementNumMaskers() { semiMask->incrementNumMaskers(); }

Expand All @@ -94,33 +95,30 @@ class ScanTableNodeIDSharedState {
uint64_t maxNodeOffset;
uint64_t maxMorselIdx;
uint64_t currentNodeOffset;
std::unique_ptr<ScanNodeIDSemiMask> semiMask;
std::unique_ptr<NodeTableSemiMask> semiMask;
};

class ScanNodeIDSharedState {
public:
ScanNodeIDSharedState() : currentStateIdx{0} {};

inline void addTableState(storage::NodeTable* table) {
tableStates.push_back(std::make_unique<ScanTableNodeIDSharedState>(table));
tableStates.push_back(std::make_unique<NodeTableState>(table));
}
inline uint32_t getNumTableStates() const { return tableStates.size(); }
inline ScanTableNodeIDSharedState* getTableState(uint32_t idx) const {
return tableStates[idx].get();
}
inline NodeTableState* getTableState(uint32_t idx) const { return tableStates[idx].get(); }

inline void initialize(transaction::Transaction* transaction) {
for (auto& tableState : tableStates) {
tableState->initializeMaxOffset(transaction);
}
}

std::tuple<ScanTableNodeIDSharedState*, common::offset_t, common::offset_t>
getNextRangeToRead();
std::tuple<NodeTableState*, common::offset_t, common::offset_t> getNextRangeToRead();

private:
std::mutex mtx;
std::vector<std::unique_ptr<ScanTableNodeIDSharedState>> tableStates;
std::vector<std::unique_ptr<NodeTableState>> tableStates;
uint32_t currentStateIdx;
};

Expand Down Expand Up @@ -148,8 +146,8 @@ class ScanNodeID : public PhysicalOperator {
sharedState->initialize(context->transaction);
}

void setSelVector(ScanTableNodeIDSharedState* tableState, common::offset_t startOffset,
common::offset_t endOffset);
void setSelVector(
NodeTableState* tableState, common::offset_t startOffset, common::offset_t endOffset);

private:
DataPos outDataPos;
Expand Down
72 changes: 48 additions & 24 deletions src/include/processor/operator/semi_masker.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,67 @@
namespace kuzu {
namespace processor {

class SemiMasker : public PhysicalOperator {
public:
SemiMasker(const DataPos& keyDataPos, std::unique_ptr<PhysicalOperator> child, uint32_t id,
const std::string& paramsString)
class BaseSemiMasker : public PhysicalOperator {
protected:
BaseSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::SEMI_MASKER, std::move(child), id, paramsString},
keyDataPos{keyDataPos}, maskerIdx{0}, scanTableNodeIDSharedState{nullptr} {}
keyDataPos{keyDataPos}, scanNodeIDSharedState{scanNodeIDSharedState} {}

SemiMasker(const SemiMasker& other)
: PhysicalOperator{PhysicalOperatorType::SEMI_MASKER, other.children[0]->clone(), other.id,
other.paramsString},
keyDataPos{other.keyDataPos}, maskerIdx{other.maskerIdx},
scanTableNodeIDSharedState{other.scanTableNodeIDSharedState} {}
void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

inline void setSharedState(ScanTableNodeIDSharedState* sharedState) {
scanTableNodeIDSharedState = sharedState;
}
protected:
DataPos keyDataPos;
ScanNodeIDSharedState* scanNodeIDSharedState;
std::shared_ptr<common::ValueVector> keyValueVector;
};

class SingleTableSemiMasker : public BaseSemiMasker {
public:
SingleTableSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: BaseSemiMasker{keyDataPos, scanNodeIDSharedState, std::move(child), id, paramsString} {}

void initGlobalStateInternal(kuzu::processor::ExecutionContext* context) override;

bool getNextTuplesInternal() override;

inline std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<SemiMasker>(*this);
auto result = std::make_unique<SingleTableSemiMasker>(
keyDataPos, scanNodeIDSharedState, children[0]->clone(), id, paramsString);
result->maskerIdxAndMask = maskerIdxAndMask;
return result;
}

private:
void initGlobalStateInternal(ExecutionContext* context) override;
// Multiple maskers can point to the same SemiMask, thus we associate each masker with an idx
// to indicate the execution sequence of its pipeline. Also, the maskerIdx is used as a flag to
// indicate if a value in the mask is masked or not, as each masker will increment the selected
// value in the mask by 1. More details are described in NodeTableSemiMask.
std::pair<uint8_t, NodeTableSemiMask*> maskerIdxAndMask;
};

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;
class MultiTableSemiMasker : public BaseSemiMasker {
public:
MultiTableSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: BaseSemiMasker{keyDataPos, scanNodeIDSharedState, std::move(child), id, paramsString} {}

void initGlobalStateInternal(kuzu::processor::ExecutionContext* context) override;

bool getNextTuplesInternal() override;

inline std::unique_ptr<PhysicalOperator> clone() override {
auto result = std::make_unique<MultiTableSemiMasker>(
keyDataPos, scanNodeIDSharedState, children[0]->clone(), id, paramsString);
result->maskerIdxAndMasks = maskerIdxAndMasks;
return result;
}

private:
DataPos keyDataPos;
// Multiple maskers can point to the same scanNodeID, thus we associate each masker with an idx
// to indicate the execution sequence of its pipeline. Also, the maskerIdx is used as a flag to
// indicate if a value in the mask is masked or not, as each masker will increment the selected
// value in the mask by 1. More details are described in ScanNodeIDSemiMask.
uint8_t maskerIdx;
std::shared_ptr<common::ValueVector> keyValueVector;
ScanTableNodeIDSharedState* scanTableNodeIDSharedState;
std::unordered_map<common::table_id_t, std::pair<uint8_t, NodeTableSemiMask*>>
maskerIdxAndMasks;
};

} // namespace processor
} // namespace kuzu
4 changes: 0 additions & 4 deletions src/optimizer/asp_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ std::vector<planner::LogicalOperator*> ASPOptimizer::resolveScanNodesToApplySemi
scanNodesCollector.collect(buildRoot);
for (auto& op : scanNodesCollector.getOperators()) {
auto scanNode = (LogicalScanNode*)op;
if (scanNode->getNode()->isMultiLabeled()) {
// We don't push semi mask to multi-labeled scan. This can be solved.
continue;
}
auto nodeID = scanNode->getNode()->getInternalIDProperty();
if (!nodeIDToScanOperatorsMap.contains(nodeID)) {
nodeIDToScanOperatorsMap.insert({nodeID, std::vector<LogicalOperator*>{}});
Expand Down
14 changes: 9 additions & 5 deletions src/processor/mapper/map_asp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalSemiMaskerToPhysical(
auto physicalScanNode = (ScanNodeID*)logicalOpToPhysicalOpMap.at(logicalScanNode);
auto keyDataPos =
DataPos(inSchema->getExpressionPos(*logicalScanNode->getNode()->getInternalIDProperty()));
auto semiMasker = make_unique<SemiMasker>(keyDataPos, std::move(prevOperator), getOperatorID(),
logicalSemiMasker->getExpressionsForPrinting());
assert(physicalScanNode->getSharedState()->getNumTableStates() == 1);
semiMasker->setSharedState(physicalScanNode->getSharedState()->getTableState(0));
return semiMasker;
if (physicalScanNode->getSharedState()->getNumTableStates() > 1) {
return std::make_unique<MultiTableSemiMasker>(keyDataPos,
physicalScanNode->getSharedState(), std::move(prevOperator), getOperatorID(),
logicalSemiMasker->getExpressionsForPrinting());
} else {
return std::make_unique<SingleTableSemiMasker>(keyDataPos,
physicalScanNode->getSharedState(), std::move(prevOperator), getOperatorID(),
logicalSemiMasker->getExpressionsForPrinting());
}
}

} // namespace processor
Expand Down
9 changes: 4 additions & 5 deletions src/processor/operator/scan_node_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace processor {

// Note: blindly update mask does not parallelize well, so we minimize write by first checking
// if the mask is set to true (mask value is equal to the expected currentMaskValue) or not.
void ScanNodeIDSemiMask::incrementMaskValue(uint64_t nodeOffset, uint8_t currentMaskValue) {
void NodeTableSemiMask::incrementMaskValue(uint64_t nodeOffset, uint8_t currentMaskValue) {
if (nodeMask->isMasked(nodeOffset, currentMaskValue)) {
nodeMask->setMask(nodeOffset, currentMaskValue + 1);
}
Expand All @@ -17,7 +17,7 @@ void ScanNodeIDSemiMask::incrementMaskValue(uint64_t nodeOffset, uint8_t current
}
}

std::pair<offset_t, offset_t> ScanTableNodeIDSharedState::getNextRangeToRead() {
std::pair<offset_t, offset_t> NodeTableState::getNextRangeToRead() {
// Note: we use maxNodeOffset=UINT64_MAX to represent an empty table.
if (currentNodeOffset > maxNodeOffset || maxNodeOffset == INVALID_NODE_OFFSET) {
return std::make_pair(currentNodeOffset, currentNodeOffset);
Expand All @@ -36,8 +36,7 @@ std::pair<offset_t, offset_t> ScanTableNodeIDSharedState::getNextRangeToRead() {
return std::make_pair(startOffset, startOffset + range);
}

std::tuple<ScanTableNodeIDSharedState*, offset_t, offset_t>
ScanNodeIDSharedState::getNextRangeToRead() {
std::tuple<NodeTableState*, offset_t, offset_t> ScanNodeIDSharedState::getNextRangeToRead() {
std::unique_lock lck{mtx};
if (currentStateIdx == tableStates.size()) {
return std::make_tuple(nullptr, INVALID_NODE_OFFSET, INVALID_NODE_OFFSET);
Expand Down Expand Up @@ -81,7 +80,7 @@ bool ScanNodeID::getNextTuplesInternal() {
}

void ScanNodeID::setSelVector(
ScanTableNodeIDSharedState* tableState, offset_t startOffset, offset_t endOffset) {
NodeTableState* tableState, offset_t startOffset, offset_t endOffset) {
if (tableState->isSemiMaskEnabled()) {
outValueVector->state->selVector->resetSelectorToValuePosBuffer();
// Fill selected positions based on node mask for nodes between the given startOffset and
Expand Down
58 changes: 46 additions & 12 deletions src/processor/operator/semi_masker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,65 @@ using namespace kuzu::common;
namespace kuzu {
namespace processor {

void SemiMasker::initGlobalStateInternal(kuzu::processor::ExecutionContext* context) {
scanTableNodeIDSharedState->initSemiMask(context->transaction);
maskerIdx = scanTableNodeIDSharedState->getNumMaskers();
void BaseSemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
keyValueVector = resultSet->getValueVector(keyDataPos);
assert(keyValueVector->dataType.typeID == INTERNAL_ID);
}

static std::pair<uint8_t, NodeTableSemiMask*> initSemiMaskForTableState(
NodeTableState* tableState, transaction::Transaction* trx) {
tableState->initSemiMask(trx);
auto maskerIdx = tableState->getNumMaskers();
assert(maskerIdx < UINT8_MAX);
scanTableNodeIDSharedState->incrementNumMaskers();
tableState->incrementNumMaskers();
return std::make_pair(maskerIdx, tableState->getSemiMask());
}

void SemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
keyValueVector = resultSet->getValueVector(keyDataPos);
assert(keyValueVector->dataType.typeID == INTERNAL_ID);
void SingleTableSemiMasker::initGlobalStateInternal(kuzu::processor::ExecutionContext* context) {
assert(scanNodeIDSharedState->getNumTableStates() == 1);
auto tableState = scanNodeIDSharedState->getTableState(0);
maskerIdxAndMask = initSemiMaskForTableState(tableState, context->transaction);
}

bool SingleTableSemiMasker::getNextTuplesInternal() {
if (!children[0]->getNextTuple()) {
return false;
}
auto [maskerIdx, mask] = maskerIdxAndMask;
auto numValues =
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize;
for (auto i = 0u; i < numValues; i++) {
auto pos = keyValueVector->state->selVector->selectedPositions[i];
auto nodeID = keyValueVector->getValue<nodeID_t>(pos);
mask->incrementMaskValue(nodeID.offset, maskerIdx);
}
metrics->numOutputTuple.increase(numValues);
return true;
}

void MultiTableSemiMasker::initGlobalStateInternal(kuzu::processor::ExecutionContext* context) {
assert(scanNodeIDSharedState->getNumTableStates() > 1);
for (auto i = 0u; i < scanNodeIDSharedState->getNumTableStates(); ++i) {
auto tableState = scanNodeIDSharedState->getTableState(i);
auto maskerIdxAndMask = initSemiMaskForTableState(tableState, context->transaction);
maskerIdxAndMasks.insert(
{tableState->getTable()->getTableID(), std::move(maskerIdxAndMask)});
}
}

bool SemiMasker::getNextTuplesInternal() {
bool MultiTableSemiMasker::getNextTuplesInternal() {
if (!children[0]->getNextTuple()) {
return false;
}
auto numValues =
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize;
for (auto i = 0u; i < numValues; i++) {
auto pos = keyValueVector->state->selVector->selectedPositions[i];
scanTableNodeIDSharedState->getSemiMask()->incrementMaskValue(
keyValueVector->getValue<nodeID_t>(pos).offset, maskerIdx);
auto nodeID = keyValueVector->getValue<nodeID_t>(pos);
auto [maskerIdx, mask] = maskerIdxAndMasks.at(nodeID.tableID);
mask->incrementMaskValue(nodeID.offset, maskerIdx);
}
metrics->numOutputTuple.increase(
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize);
metrics->numOutputTuple.increase(numValues);
return true;
}

Expand Down
9 changes: 9 additions & 0 deletions test/test_files/tinysnb/asp/asp.test
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ Alice
Bob
Dan

-NAME AspMultiLabel
-QUERY MATCH (a:person)-[e1:knows|:studyAt|:workAt]->(b:person:organisation) WHERE a.age > 35 RETURN b.fName, b.name
-ENCODED_JOIN HJ(b._id){E(b)S(a)}{S(b)}
---- 4
Alice|
Bob|
Dan|
|CsWork

-NAME AspMultiKey
-QUERY MATCH (a:person)-[e1:knows]->(b:person)-[e2:knows]->(c:person), (a)-[e3:knows]->(c) WHERE a.fName='Alice' RETURN b.fName, c.fName
#-ENCODED_JOIN HJ(c._id,b._id){E(b)E(c)S(a)}{HJ(b._id){S(b)}{E(b)S(c)}}
Expand Down