Skip to content

Commit

Permalink
Add multi-label asp
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 13, 2023
1 parent b13e5b8 commit 9fe94ad
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 68 deletions.
2 changes: 1 addition & 1 deletion benchmark/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def serialize(dataset_name, dataset_path, serialized_graph_path):
try:
# Run kuzu shell one query at a time. This ensures a new process is
# created for each query to avoid memory leaks.
subprocess.run([kuzu_exec_path, '-i', serialized_graph_path],
subprocess.run([kuzu_exec_path, serialized_graph_path],
input=(s + ";" + "\n").encode("ascii"), check=True)
except subprocess.CalledProcessError as e:
logging.error('Error executing query: %s', s)
Expand Down
32 changes: 15 additions & 17 deletions src/include/processor/operator/scan_node_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ struct Mask {
};

// Note: This class is not thread-safe.
struct ScanNodeIDSemiMask {
struct NodeTableSemiMask {
public:
explicit ScanNodeIDSemiMask() : numMaskers{0} {}
NodeTableSemiMask() : numMaskers{0} {}

inline void initializeMaskData(common::offset_t maxNodeOffset, common::offset_t maxMorselIdx) {
if (nodeMask != nullptr) {
Expand Down Expand Up @@ -59,11 +59,12 @@ struct ScanNodeIDSemiMask {
};

// Note: This class is not thread-safe. It relies on its caller to correctly synchronize its state.
class ScanTableNodeIDSharedState {
class NodeTableState {
public:
explicit ScanTableNodeIDSharedState(storage::NodeTable* table)
: table{table}, maxNodeOffset{UINT64_MAX}, maxMorselIdx{UINT64_MAX}, currentNodeOffset{0} {
semiMask = std::make_unique<ScanNodeIDSemiMask>();
explicit NodeTableState(storage::NodeTable* table)
: table{table}, maxNodeOffset{common::INVALID_NODE_OFFSET}, maxMorselIdx{UINT64_MAX},
currentNodeOffset{0} {
semiMask = std::make_unique<NodeTableSemiMask>();
}

inline storage::NodeTable* getTable() { return table; }
Expand All @@ -83,7 +84,7 @@ class ScanTableNodeIDSharedState {
semiMask->initializeMaskData(maxNodeOffset, maxMorselIdx);
}
inline bool isSemiMaskEnabled() { return semiMask->getNumMaskers() > 0; }
inline ScanNodeIDSemiMask* getSemiMask() { return semiMask.get(); }
inline NodeTableSemiMask* getSemiMask() { return semiMask.get(); }
inline uint8_t getNumMaskers() const { return semiMask->getNumMaskers(); }
inline void incrementNumMaskers() { semiMask->incrementNumMaskers(); }

Expand All @@ -94,33 +95,30 @@ class ScanTableNodeIDSharedState {
uint64_t maxNodeOffset;
uint64_t maxMorselIdx;
uint64_t currentNodeOffset;
std::unique_ptr<ScanNodeIDSemiMask> semiMask;
std::unique_ptr<NodeTableSemiMask> semiMask;
};

class ScanNodeIDSharedState {
public:
ScanNodeIDSharedState() : currentStateIdx{0} {};

inline void addTableState(storage::NodeTable* table) {
tableStates.push_back(std::make_unique<ScanTableNodeIDSharedState>(table));
tableStates.push_back(std::make_unique<NodeTableState>(table));
}
inline uint32_t getNumTableStates() const { return tableStates.size(); }
inline ScanTableNodeIDSharedState* getTableState(uint32_t idx) const {
return tableStates[idx].get();
}
inline NodeTableState* getTableState(uint32_t idx) const { return tableStates[idx].get(); }

inline void initialize(transaction::Transaction* transaction) {
for (auto& tableState : tableStates) {
tableState->initializeMaxOffset(transaction);
}
}

std::tuple<ScanTableNodeIDSharedState*, common::offset_t, common::offset_t>
getNextRangeToRead();
std::tuple<NodeTableState*, common::offset_t, common::offset_t> getNextRangeToRead();

private:
std::mutex mtx;
std::vector<std::unique_ptr<ScanTableNodeIDSharedState>> tableStates;
std::vector<std::unique_ptr<NodeTableState>> tableStates;
uint32_t currentStateIdx;
};

Expand Down Expand Up @@ -148,8 +146,8 @@ class ScanNodeID : public PhysicalOperator {
sharedState->initialize(context->transaction);
}

void setSelVector(ScanTableNodeIDSharedState* tableState, common::offset_t startOffset,
common::offset_t endOffset);
void setSelVector(
NodeTableState* tableState, common::offset_t startOffset, common::offset_t endOffset);

private:
DataPos outDataPos;
Expand Down
72 changes: 48 additions & 24 deletions src/include/processor/operator/semi_masker.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,67 @@
namespace kuzu {
namespace processor {

class SemiMasker : public PhysicalOperator {
public:
SemiMasker(const DataPos& keyDataPos, std::unique_ptr<PhysicalOperator> child, uint32_t id,
const std::string& paramsString)
class BaseSemiMasker : public PhysicalOperator {
protected:
BaseSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::SEMI_MASKER, std::move(child), id, paramsString},
keyDataPos{keyDataPos}, maskerIdx{0}, scanTableNodeIDSharedState{nullptr} {}
keyDataPos{keyDataPos}, scanNodeIDSharedState{scanNodeIDSharedState} {}

SemiMasker(const SemiMasker& other)
: PhysicalOperator{PhysicalOperatorType::SEMI_MASKER, other.children[0]->clone(), other.id,
other.paramsString},
keyDataPos{other.keyDataPos}, maskerIdx{other.maskerIdx},
scanTableNodeIDSharedState{other.scanTableNodeIDSharedState} {}
void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

inline void setSharedState(ScanTableNodeIDSharedState* sharedState) {
scanTableNodeIDSharedState = sharedState;
}
protected:
DataPos keyDataPos;
ScanNodeIDSharedState* scanNodeIDSharedState;
std::shared_ptr<common::ValueVector> keyValueVector;
};

class SingleTableSemiMasker : public BaseSemiMasker {
public:
SingleTableSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: BaseSemiMasker{keyDataPos, scanNodeIDSharedState, std::move(child), id, paramsString} {}

void initGlobalStateInternal(kuzu::processor::ExecutionContext* context) override;

bool getNextTuplesInternal() override;

inline std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<SemiMasker>(*this);
auto result = std::make_unique<SingleTableSemiMasker>(
keyDataPos, scanNodeIDSharedState, children[0]->clone(), id, paramsString);
result->maskerIdxAndMask = maskerIdxAndMask;
return result;
}

private:
void initGlobalStateInternal(ExecutionContext* context) override;
// Multiple maskers can point to the same SemiMask, thus we associate each masker with an idx
// to indicate the execution sequence of its pipeline. Also, the maskerIdx is used as a flag to
// indicate if a value in the mask is masked or not, as each masker will increment the selected
// value in the mask by 1. More details are described in NodeTableSemiMask.
std::pair<uint8_t, NodeTableSemiMask*> maskerIdxAndMask;
};

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;
class MultiTableSemiMasker : public BaseSemiMasker {
public:
MultiTableSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: BaseSemiMasker{keyDataPos, scanNodeIDSharedState, std::move(child), id, paramsString} {}

void initGlobalStateInternal(kuzu::processor::ExecutionContext* context) override;

bool getNextTuplesInternal() override;

inline std::unique_ptr<PhysicalOperator> clone() override {
auto result = std::make_unique<MultiTableSemiMasker>(
keyDataPos, scanNodeIDSharedState, children[0]->clone(), id, paramsString);
result->maskerIdxAndMasks = maskerIdxAndMasks;
return result;
}

private:
DataPos keyDataPos;
// Multiple maskers can point to the same scanNodeID, thus we associate each masker with an idx
// to indicate the execution sequence of its pipeline. Also, the maskerIdx is used as a flag to
// indicate if a value in the mask is masked or not, as each masker will increment the selected
// value in the mask by 1. More details are described in ScanNodeIDSemiMask.
uint8_t maskerIdx;
std::shared_ptr<common::ValueVector> keyValueVector;
ScanTableNodeIDSharedState* scanTableNodeIDSharedState;
std::unordered_map<common::table_id_t, std::pair<uint8_t, NodeTableSemiMask*>>
maskerIdxAndMasks;
};

} // namespace processor
} // namespace kuzu
4 changes: 0 additions & 4 deletions src/optimizer/asp_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ std::vector<planner::LogicalOperator*> ASPOptimizer::resolveScanNodesToApplySemi
scanNodesCollector.collect(buildRoot);
for (auto& op : scanNodesCollector.getOperators()) {
auto scanNode = (LogicalScanNode*)op;
if (scanNode->getNode()->isMultiLabeled()) {
// We don't push semi mask to multi-labeled scan. This can be solved.
continue;
}
auto nodeID = scanNode->getNode()->getInternalIDProperty();
if (!nodeIDToScanOperatorsMap.contains(nodeID)) {
nodeIDToScanOperatorsMap.insert({nodeID, std::vector<LogicalOperator*>{}});
Expand Down
14 changes: 9 additions & 5 deletions src/processor/mapper/map_asp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalSemiMaskerToPhysical(
auto physicalScanNode = (ScanNodeID*)logicalOpToPhysicalOpMap.at(logicalScanNode);
auto keyDataPos =
DataPos(inSchema->getExpressionPos(*logicalScanNode->getNode()->getInternalIDProperty()));
auto semiMasker = make_unique<SemiMasker>(keyDataPos, std::move(prevOperator), getOperatorID(),
logicalSemiMasker->getExpressionsForPrinting());
assert(physicalScanNode->getSharedState()->getNumTableStates() == 1);
semiMasker->setSharedState(physicalScanNode->getSharedState()->getTableState(0));
return semiMasker;
if (physicalScanNode->getSharedState()->getNumTableStates() > 1) {
return std::make_unique<MultiTableSemiMasker>(keyDataPos,
physicalScanNode->getSharedState(), std::move(prevOperator), getOperatorID(),
logicalSemiMasker->getExpressionsForPrinting());
} else {
return std::make_unique<SingleTableSemiMasker>(keyDataPos,
physicalScanNode->getSharedState(), std::move(prevOperator), getOperatorID(),
logicalSemiMasker->getExpressionsForPrinting());
}
}

} // namespace processor
Expand Down
9 changes: 4 additions & 5 deletions src/processor/operator/scan_node_id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace processor {

// Note: blindly update mask does not parallelize well, so we minimize write by first checking
// if the mask is set to true (mask value is equal to the expected currentMaskValue) or not.
void ScanNodeIDSemiMask::incrementMaskValue(uint64_t nodeOffset, uint8_t currentMaskValue) {
void NodeTableSemiMask::incrementMaskValue(uint64_t nodeOffset, uint8_t currentMaskValue) {
if (nodeMask->isMasked(nodeOffset, currentMaskValue)) {
nodeMask->setMask(nodeOffset, currentMaskValue + 1);
}
Expand All @@ -17,7 +17,7 @@ void ScanNodeIDSemiMask::incrementMaskValue(uint64_t nodeOffset, uint8_t current
}
}

std::pair<offset_t, offset_t> ScanTableNodeIDSharedState::getNextRangeToRead() {
std::pair<offset_t, offset_t> NodeTableState::getNextRangeToRead() {
// Note: we use maxNodeOffset=UINT64_MAX to represent an empty table.
if (currentNodeOffset > maxNodeOffset || maxNodeOffset == INVALID_NODE_OFFSET) {
return std::make_pair(currentNodeOffset, currentNodeOffset);
Expand All @@ -36,8 +36,7 @@ std::pair<offset_t, offset_t> ScanTableNodeIDSharedState::getNextRangeToRead() {
return std::make_pair(startOffset, startOffset + range);
}

std::tuple<ScanTableNodeIDSharedState*, offset_t, offset_t>
ScanNodeIDSharedState::getNextRangeToRead() {
std::tuple<NodeTableState*, offset_t, offset_t> ScanNodeIDSharedState::getNextRangeToRead() {
std::unique_lock lck{mtx};
if (currentStateIdx == tableStates.size()) {
return std::make_tuple(nullptr, INVALID_NODE_OFFSET, INVALID_NODE_OFFSET);
Expand Down Expand Up @@ -81,7 +80,7 @@ bool ScanNodeID::getNextTuplesInternal() {
}

void ScanNodeID::setSelVector(
ScanTableNodeIDSharedState* tableState, offset_t startOffset, offset_t endOffset) {
NodeTableState* tableState, offset_t startOffset, offset_t endOffset) {
if (tableState->isSemiMaskEnabled()) {
outValueVector->state->selVector->resetSelectorToValuePosBuffer();
// Fill selected positions based on node mask for nodes between the given startOffset and
Expand Down
58 changes: 46 additions & 12 deletions src/processor/operator/semi_masker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,65 @@ using namespace kuzu::common;
namespace kuzu {
namespace processor {

void SemiMasker::initGlobalStateInternal(kuzu::processor::ExecutionContext* context) {
scanTableNodeIDSharedState->initSemiMask(context->transaction);
maskerIdx = scanTableNodeIDSharedState->getNumMaskers();
void BaseSemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
keyValueVector = resultSet->getValueVector(keyDataPos);
assert(keyValueVector->dataType.typeID == INTERNAL_ID);
}

static std::pair<uint8_t, NodeTableSemiMask*> initSemiMaskForTableState(
NodeTableState* tableState, transaction::Transaction* trx) {
tableState->initSemiMask(trx);
auto maskerIdx = tableState->getNumMaskers();
assert(maskerIdx < UINT8_MAX);
scanTableNodeIDSharedState->incrementNumMaskers();
tableState->incrementNumMaskers();
return std::make_pair(maskerIdx, tableState->getSemiMask());
}

void SemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
keyValueVector = resultSet->getValueVector(keyDataPos);
assert(keyValueVector->dataType.typeID == INTERNAL_ID);
void SingleTableSemiMasker::initGlobalStateInternal(kuzu::processor::ExecutionContext* context) {
assert(scanNodeIDSharedState->getNumTableStates() == 1);
auto tableState = scanNodeIDSharedState->getTableState(0);
maskerIdxAndMask = initSemiMaskForTableState(tableState, context->transaction);
}

bool SingleTableSemiMasker::getNextTuplesInternal() {
if (!children[0]->getNextTuple()) {
return false;
}
auto [maskerIdx, mask] = maskerIdxAndMask;
auto numValues =
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize;
for (auto i = 0u; i < numValues; i++) {
auto pos = keyValueVector->state->selVector->selectedPositions[i];
auto nodeID = keyValueVector->getValue<nodeID_t>(pos);
mask->incrementMaskValue(nodeID.offset, maskerIdx);
}
metrics->numOutputTuple.increase(numValues);
return true;
}

void MultiTableSemiMasker::initGlobalStateInternal(kuzu::processor::ExecutionContext* context) {
assert(scanNodeIDSharedState->getNumTableStates() > 1);
for (auto i = 0u; i < scanNodeIDSharedState->getNumTableStates(); ++i) {
auto tableState = scanNodeIDSharedState->getTableState(i);
auto maskerIdxAndMask = initSemiMaskForTableState(tableState, context->transaction);
maskerIdxAndMasks.insert(
{tableState->getTable()->getTableID(), std::move(maskerIdxAndMask)});
}
}

bool SemiMasker::getNextTuplesInternal() {
bool MultiTableSemiMasker::getNextTuplesInternal() {
if (!children[0]->getNextTuple()) {
return false;
}
auto numValues =
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize;
for (auto i = 0u; i < numValues; i++) {
auto pos = keyValueVector->state->selVector->selectedPositions[i];
scanTableNodeIDSharedState->getSemiMask()->incrementMaskValue(
keyValueVector->getValue<nodeID_t>(pos).offset, maskerIdx);
auto nodeID = keyValueVector->getValue<nodeID_t>(pos);
auto [maskerIdx, mask] = maskerIdxAndMasks.at(nodeID.tableID);
mask->incrementMaskValue(nodeID.offset, maskerIdx);
}
metrics->numOutputTuple.increase(
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize);
metrics->numOutputTuple.increase(numValues);
return true;
}

Expand Down
9 changes: 9 additions & 0 deletions test/test_files/tinysnb/asp/asp.test
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ Alice
Bob
Dan

-NAME AspMultiLabel
-QUERY MATCH (a:person)-[e1:knows|:studyAt|:workAt]->(b:person:organisation) WHERE a.age > 35 RETURN b.fName, b.name
-ENCODED_JOIN HJ(b._id){E(b)S(a)}{S(b)}
---- 4
Alice|
Bob|
Dan|
|CsWork

-NAME AspMultiKey
-QUERY MATCH (a:person)-[e1:knows]->(b:person)-[e2:knows]->(c:person), (a)-[e3:knows]->(c) WHERE a.fName='Alice' RETURN b.fName, c.fName
#-ENCODED_JOIN HJ(c._id,b._id){E(b)E(c)S(a)}{HJ(b._id){S(b)}{E(b)S(c)}}
Expand Down

0 comments on commit 9fe94ad

Please sign in to comment.