Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #998 #1315

Merged
merged 1 commit into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/include/processor/operator/intersect/intersect.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ class Intersect : public PhysicalOperator {
const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::INTERSECT, std::move(children), id, paramsString},
outputDataPos{outputDataPos},
intersectDataInfos{std::move(intersectDataInfos)}, sharedHTs{std::move(sharedHTs)} {}
intersectDataInfos{std::move(intersectDataInfos)}, sharedHTs{std::move(sharedHTs)} {
tupleIdxPerBuildSide.resize(this->sharedHTs.size(), 0);
carryBuildSideIdx = -1u;
probedFlatTuples.resize(this->sharedHTs.size());
}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

Expand All @@ -33,14 +37,17 @@ class Intersect : public PhysicalOperator {
}

private:
std::vector<common::nodeID_t> getProbeKeys();
std::vector<uint8_t*> probeHTs(const std::vector<common::nodeID_t>& keys);
// For each build side, probe its HT and return a vector of matched flat tuples.
void probeHTs();
// Left is always the one with less num of values.
static void twoWayIntersect(common::nodeID_t* leftNodeIDs, common::SelectionVector& lSelVector,
common::nodeID_t* rightNodeIDs, common::SelectionVector& rSelVector);
void intersectLists(const std::vector<common::overflow_value_t>& listsToIntersect);
void populatePayloads(
const std::vector<uint8_t*>& tuples, const std::vector<uint32_t>& listIdxes);
bool hasNextTuplesToIntersect();

inline uint32_t getNumBuilds() { return sharedHTs.size(); }

private:
DataPos outputDataPos;
Expand All @@ -53,6 +60,11 @@ class Intersect : public PhysicalOperator {
std::vector<std::unique_ptr<common::SelectionVector>> intersectSelVectors;
std::vector<std::shared_ptr<IntersectSharedState>> sharedHTs;
std::vector<bool> isIntersectListAFlatValue;
std::vector<std::vector<uint8_t*>> probedFlatTuples;
// Keep track of the tuple to intersect for each build side.
std::vector<uint32_t> tupleIdxPerBuildSide;
// This is used to indicate which build side to increment the tuple idx for.
uint32_t carryBuildSideIdx;
};

} // namespace processor
Expand Down
90 changes: 61 additions & 29 deletions src/processor/operator/intersect/intersect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,23 @@ void Intersect::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* c
}
}

std::vector<uint8_t*> Intersect::probeHTs(const std::vector<nodeID_t>& keys) {
std::vector<uint8_t*> tuples(keys.size());
hash_t tmpHash;
for (auto i = 0u; i < keys.size(); i++) {
Hash::operation<nodeID_t>(keys[i], false, tmpHash);
tuples[i] = sharedHTs[i]->getHashTable()->getTupleForHash(tmpHash);
while (tuples[i]) {
if (*(nodeID_t*)tuples[i] == keys[i]) {
break; // The build side should guarantee each key only has one matching tuple.
void Intersect::probeHTs() {
std::vector<std::vector<overflow_value_t>> flatTuples(probeKeyVectors.size());
hash_t hashVal;
for (auto i = 0u; i < probeKeyVectors.size(); i++) {
assert(probeKeyVectors[i]->state->isFlat());
probedFlatTuples[i].clear();
auto key = probeKeyVectors[i]->getValue<nodeID_t>(
probeKeyVectors[i]->state->selVector->selectedPositions[0]);
Hash::operation<nodeID_t>(key, false, hashVal);
auto flatTuple = sharedHTs[i]->getHashTable()->getTupleForHash(hashVal);
while (flatTuple) {
if (*(nodeID_t*)flatTuple == key) {
probedFlatTuples[i].push_back(flatTuple);
}
tuples[i] = *sharedHTs[i]->getHashTable()->getPrevTuple(tuples[i]);
flatTuple = *sharedHTs[i]->getHashTable()->getPrevTuple(flatTuple);
}
}
return tuples;
}

void Intersect::twoWayIntersect(nodeID_t* leftNodeIDs, SelectionVector& lSelVector,
Expand Down Expand Up @@ -71,23 +74,10 @@ void Intersect::twoWayIntersect(nodeID_t* leftNodeIDs, SelectionVector& lSelVect
rSelVector.resetSelectorToValuePosBufferWithSize(outputValuePosition);
}

std::vector<nodeID_t> Intersect::getProbeKeys() {
std::vector<nodeID_t> keys(probeKeyVectors.size());
for (auto i = 0u; i < keys.size(); i++) {
assert(probeKeyVectors[i]->state->isFlat());
keys[i] = probeKeyVectors[i]->getValue<nodeID_t>(
probeKeyVectors[i]->state->selVector->selectedPositions[0]);
}
return keys;
}

static std::vector<overflow_value_t> fetchListsToIntersectFromTuples(
const std::vector<uint8_t*>& tuples, const std::vector<bool>& isFlatValue) {
std::vector<overflow_value_t> listsToIntersect(tuples.size());
for (auto i = 0u; i < tuples.size(); i++) {
if (!tuples[i]) {
continue; // overflow_value will be initialized with size 0 for non-matching tuples.
}
listsToIntersect[i] =
isFlatValue[i] ? overflow_value_t{1 /* numElements */, tuples[i] + sizeof(nodeID_t)} :
*(overflow_value_t*)(tuples[i] + sizeof(nodeID_t));
Expand Down Expand Up @@ -160,17 +150,59 @@ void Intersect::populatePayloads(
}
}

bool Intersect::hasNextTuplesToIntersect() {
tupleIdxPerBuildSide[carryBuildSideIdx]++;
if (tupleIdxPerBuildSide[carryBuildSideIdx] == probedFlatTuples[carryBuildSideIdx].size()) {
if (carryBuildSideIdx == 0) {
return false;
}
tupleIdxPerBuildSide[carryBuildSideIdx] = 0;
carryBuildSideIdx--;
if (!hasNextTuplesToIntersect()) {
return false;
}
carryBuildSideIdx++;
}
return true;
}

bool Intersect::getNextTuplesInternal() {
do {
if (!children[0]->getNextTuple()) {
return false;
while (carryBuildSideIdx == -1u) {
if (!children[0]->getNextTuple()) {
return false;
}
// For each build side, probe its HT and return a vector of matched flat tuples.
probeHTs();
auto maxNumTuplesToIntersect = 1u;
for (auto i = 0u; i < getNumBuilds(); i++) {
maxNumTuplesToIntersect *= probedFlatTuples[i].size();
}
if (maxNumTuplesToIntersect == 0) {
// Skip if any build side has no matches.
continue;
}
carryBuildSideIdx = getNumBuilds() - 1;
std::fill(tupleIdxPerBuildSide.begin(), tupleIdxPerBuildSide.end(), 0);
}
auto tuples = probeHTs(getProbeKeys());
auto listsToIntersect = fetchListsToIntersectFromTuples(tuples, isIntersectListAFlatValue);
// Cartesian product of all flat tuples probed from all build sides.
// Notice: when there are large adjacency lists in the build side, which means the list is
// too large to fit in a single ValueVector, we end up chunking the list as multiple tuples
// in FTable. Thus, when performing the intersection, we need to perform cartesian product
// between all flat tuples probed from all build sides.
std::vector<uint8_t*> flatTuplesToIntersect(getNumBuilds());
for (auto i = 0u; i < getNumBuilds(); i++) {
flatTuplesToIntersect[i] = probedFlatTuples[i][tupleIdxPerBuildSide[i]];
}
auto listsToIntersect =
fetchListsToIntersectFromTuples(flatTuplesToIntersect, isIntersectListAFlatValue);
auto listIdxes = swapSmallestListToFront(listsToIntersect);
intersectLists(listsToIntersect);
if (outKeyVector->state->selVector->selectedSize != 0) {
populatePayloads(tuples, listIdxes);
populatePayloads(flatTuplesToIntersect, listIdxes);
}
if (!hasNextTuplesToIntersect()) {
carryBuildSideIdx = -1u;
}
} while (outKeyVector->state->selVector->selectedSize == 0);
return true;
Expand Down
11 changes: 7 additions & 4 deletions src/processor/operator/intersect/intersect_hash_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using namespace kuzu::common;
namespace kuzu {
namespace processor {

static void sortSelectedPos(const std::shared_ptr<ValueVector>& nodeIDVector) {
static void sortSelectedPos(ValueVector* nodeIDVector) {
auto selVector = nodeIDVector->state->selVector.get();
auto size = selVector->selectedSize;
auto selectedPos = selVector->getSelectedPositionsBuffer();
Expand All @@ -23,16 +23,19 @@ void IntersectHashTable::append(const std::vector<std::shared_ptr<ValueVector>>&
// Based on the way we are planning, we assume that the first and second vectors are both
// nodeIDs from extending, while the first one is key, and the second one is payload.
auto keyState = vectorsToAppend[0]->state.get();
auto payloadNodeIDVector = vectorsToAppend[1];
auto payloadNodeIDVector = vectorsToAppend[1].get();
auto payloadsState = payloadNodeIDVector->state.get();
assert(keyState->isFlat());
if (!payloadsState->isFlat()) {
// Sorting is only needed when the payload is unflat (a list of values).
sortSelectedPos(payloadNodeIDVector);
}
// A single appendInfo will return from `allocateFlatTupleBlocks` when numTuplesToAppend is 1.
auto appendInfo = factorizedTable->allocateFlatTupleBlocks(numTuplesToAppend)[0];
auto appendInfos = factorizedTable->allocateFlatTupleBlocks(numTuplesToAppend);
assert(appendInfos.size() == 1);
for (auto i = 0u; i < vectorsToAppend.size(); i++) {
factorizedTable->copyVectorToColumn(*vectorsToAppend[i], appendInfo, numTuplesToAppend, i);
factorizedTable->copyVectorToColumn(
*vectorsToAppend[i], appendInfos[0], numTuplesToAppend, i);
}
if (!payloadsState->isFlat()) {
payloadsState->selVector->resetSelectorToUnselected();
Expand Down