Skip to content

Commit

Permalink
Merge pull request #1315 from kuzudb/intersect
Browse files Browse the repository at this point in the history
Fix #998
  • Loading branch information
ray6080 committed Feb 23, 2023
2 parents 77de059 + 2d94a5a commit 7c9ccf9
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 36 deletions.
18 changes: 15 additions & 3 deletions src/include/processor/operator/intersect/intersect.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ class Intersect : public PhysicalOperator {
const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::INTERSECT, std::move(children), id, paramsString},
outputDataPos{outputDataPos},
intersectDataInfos{std::move(intersectDataInfos)}, sharedHTs{std::move(sharedHTs)} {}
intersectDataInfos{std::move(intersectDataInfos)}, sharedHTs{std::move(sharedHTs)} {
tupleIdxPerBuildSide.resize(this->sharedHTs.size(), 0);
carryBuildSideIdx = -1u;
probedFlatTuples.resize(this->sharedHTs.size());
}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

Expand All @@ -33,14 +37,17 @@ class Intersect : public PhysicalOperator {
}

private:
std::vector<common::nodeID_t> getProbeKeys();
std::vector<uint8_t*> probeHTs(const std::vector<common::nodeID_t>& keys);
// For each build side, probe its HT and return a vector of matched flat tuples.
void probeHTs();
// Left is always the one with less num of values.
static void twoWayIntersect(common::nodeID_t* leftNodeIDs, common::SelectionVector& lSelVector,
common::nodeID_t* rightNodeIDs, common::SelectionVector& rSelVector);
void intersectLists(const std::vector<common::overflow_value_t>& listsToIntersect);
void populatePayloads(
const std::vector<uint8_t*>& tuples, const std::vector<uint32_t>& listIdxes);
bool hasNextTuplesToIntersect();

inline uint32_t getNumBuilds() { return sharedHTs.size(); }

private:
DataPos outputDataPos;
Expand All @@ -53,6 +60,11 @@ class Intersect : public PhysicalOperator {
std::vector<std::unique_ptr<common::SelectionVector>> intersectSelVectors;
std::vector<std::shared_ptr<IntersectSharedState>> sharedHTs;
std::vector<bool> isIntersectListAFlatValue;
std::vector<std::vector<uint8_t*>> probedFlatTuples;
// Keep track of the tuple to intersect for each build side.
std::vector<uint32_t> tupleIdxPerBuildSide;
// This is used to indicate which build side to increment the tuple idx for.
uint32_t carryBuildSideIdx;
};

} // namespace processor
Expand Down
90 changes: 61 additions & 29 deletions src/processor/operator/intersect/intersect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,23 @@ void Intersect::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* c
}
}

std::vector<uint8_t*> Intersect::probeHTs(const std::vector<nodeID_t>& keys) {
std::vector<uint8_t*> tuples(keys.size());
hash_t tmpHash;
for (auto i = 0u; i < keys.size(); i++) {
Hash::operation<nodeID_t>(keys[i], false, tmpHash);
tuples[i] = sharedHTs[i]->getHashTable()->getTupleForHash(tmpHash);
while (tuples[i]) {
if (*(nodeID_t*)tuples[i] == keys[i]) {
break; // The build side should guarantee each key only has one matching tuple.
void Intersect::probeHTs() {
std::vector<std::vector<overflow_value_t>> flatTuples(probeKeyVectors.size());
hash_t hashVal;
for (auto i = 0u; i < probeKeyVectors.size(); i++) {
assert(probeKeyVectors[i]->state->isFlat());
probedFlatTuples[i].clear();
auto key = probeKeyVectors[i]->getValue<nodeID_t>(
probeKeyVectors[i]->state->selVector->selectedPositions[0]);
Hash::operation<nodeID_t>(key, false, hashVal);
auto flatTuple = sharedHTs[i]->getHashTable()->getTupleForHash(hashVal);
while (flatTuple) {
if (*(nodeID_t*)flatTuple == key) {
probedFlatTuples[i].push_back(flatTuple);
}
tuples[i] = *sharedHTs[i]->getHashTable()->getPrevTuple(tuples[i]);
flatTuple = *sharedHTs[i]->getHashTable()->getPrevTuple(flatTuple);
}
}
return tuples;
}

void Intersect::twoWayIntersect(nodeID_t* leftNodeIDs, SelectionVector& lSelVector,
Expand Down Expand Up @@ -71,23 +74,10 @@ void Intersect::twoWayIntersect(nodeID_t* leftNodeIDs, SelectionVector& lSelVect
rSelVector.resetSelectorToValuePosBufferWithSize(outputValuePosition);
}

std::vector<nodeID_t> Intersect::getProbeKeys() {
std::vector<nodeID_t> keys(probeKeyVectors.size());
for (auto i = 0u; i < keys.size(); i++) {
assert(probeKeyVectors[i]->state->isFlat());
keys[i] = probeKeyVectors[i]->getValue<nodeID_t>(
probeKeyVectors[i]->state->selVector->selectedPositions[0]);
}
return keys;
}

static std::vector<overflow_value_t> fetchListsToIntersectFromTuples(
const std::vector<uint8_t*>& tuples, const std::vector<bool>& isFlatValue) {
std::vector<overflow_value_t> listsToIntersect(tuples.size());
for (auto i = 0u; i < tuples.size(); i++) {
if (!tuples[i]) {
continue; // overflow_value will be initialized with size 0 for non-matching tuples.
}
listsToIntersect[i] =
isFlatValue[i] ? overflow_value_t{1 /* numElements */, tuples[i] + sizeof(nodeID_t)} :
*(overflow_value_t*)(tuples[i] + sizeof(nodeID_t));
Expand Down Expand Up @@ -160,17 +150,59 @@ void Intersect::populatePayloads(
}
}

bool Intersect::hasNextTuplesToIntersect() {
tupleIdxPerBuildSide[carryBuildSideIdx]++;
if (tupleIdxPerBuildSide[carryBuildSideIdx] == probedFlatTuples[carryBuildSideIdx].size()) {
if (carryBuildSideIdx == 0) {
return false;
}
tupleIdxPerBuildSide[carryBuildSideIdx] = 0;
carryBuildSideIdx--;
if (!hasNextTuplesToIntersect()) {
return false;
}
carryBuildSideIdx++;
}
return true;
}

bool Intersect::getNextTuplesInternal() {
do {
if (!children[0]->getNextTuple()) {
return false;
while (carryBuildSideIdx == -1u) {
if (!children[0]->getNextTuple()) {
return false;
}
// For each build side, probe its HT and return a vector of matched flat tuples.
probeHTs();
auto maxNumTuplesToIntersect = 1u;
for (auto i = 0u; i < getNumBuilds(); i++) {
maxNumTuplesToIntersect *= probedFlatTuples[i].size();
}
if (maxNumTuplesToIntersect == 0) {
// Skip if any build side has no matches.
continue;
}
carryBuildSideIdx = getNumBuilds() - 1;
std::fill(tupleIdxPerBuildSide.begin(), tupleIdxPerBuildSide.end(), 0);
}
auto tuples = probeHTs(getProbeKeys());
auto listsToIntersect = fetchListsToIntersectFromTuples(tuples, isIntersectListAFlatValue);
// Cartesian product of all flat tuples probed from all build sides.
// Notice: when there are large adjacency lists in the build side, which means the list is
// too large to fit in a single ValueVector, we end up chunking the list as multiple tuples
// in FTable. Thus, when performing the intersection, we need to perform cartesian product
// between all flat tuples probed from all build sides.
std::vector<uint8_t*> flatTuplesToIntersect(getNumBuilds());
for (auto i = 0u; i < getNumBuilds(); i++) {
flatTuplesToIntersect[i] = probedFlatTuples[i][tupleIdxPerBuildSide[i]];
}
auto listsToIntersect =
fetchListsToIntersectFromTuples(flatTuplesToIntersect, isIntersectListAFlatValue);
auto listIdxes = swapSmallestListToFront(listsToIntersect);
intersectLists(listsToIntersect);
if (outKeyVector->state->selVector->selectedSize != 0) {
populatePayloads(tuples, listIdxes);
populatePayloads(flatTuplesToIntersect, listIdxes);
}
if (!hasNextTuplesToIntersect()) {
carryBuildSideIdx = -1u;
}
} while (outKeyVector->state->selVector->selectedSize == 0);
return true;
Expand Down
11 changes: 7 additions & 4 deletions src/processor/operator/intersect/intersect_hash_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using namespace kuzu::common;
namespace kuzu {
namespace processor {

static void sortSelectedPos(const std::shared_ptr<ValueVector>& nodeIDVector) {
static void sortSelectedPos(ValueVector* nodeIDVector) {
auto selVector = nodeIDVector->state->selVector.get();
auto size = selVector->selectedSize;
auto selectedPos = selVector->getSelectedPositionsBuffer();
Expand All @@ -23,16 +23,19 @@ void IntersectHashTable::append(const std::vector<std::shared_ptr<ValueVector>>&
// Based on the way we are planning, we assume that the first and second vectors are both
// nodeIDs from extending, while the first one is key, and the second one is payload.
auto keyState = vectorsToAppend[0]->state.get();
auto payloadNodeIDVector = vectorsToAppend[1];
auto payloadNodeIDVector = vectorsToAppend[1].get();
auto payloadsState = payloadNodeIDVector->state.get();
assert(keyState->isFlat());
if (!payloadsState->isFlat()) {
// Sorting is only needed when the payload is unflat (a list of values).
sortSelectedPos(payloadNodeIDVector);
}
// A single appendInfo will return from `allocateFlatTupleBlocks` when numTuplesToAppend is 1.
auto appendInfo = factorizedTable->allocateFlatTupleBlocks(numTuplesToAppend)[0];
auto appendInfos = factorizedTable->allocateFlatTupleBlocks(numTuplesToAppend);
assert(appendInfos.size() == 1);
for (auto i = 0u; i < vectorsToAppend.size(); i++) {
factorizedTable->copyVectorToColumn(*vectorsToAppend[i], appendInfo, numTuplesToAppend, i);
factorizedTable->copyVectorToColumn(
*vectorsToAppend[i], appendInfos[0], numTuplesToAppend, i);
}
if (!payloadsState->isFlat()) {
payloadsState->selVector->resetSelectorToUnselected();
Expand Down

0 comments on commit 7c9ccf9

Please sign in to comment.