Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework hashjoin #1465

Merged
merged 1 commit into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class LogicalHashJoin : public LogicalOperator {
joinNodeIDs, joinType, mark, children[0]->copy(), children[1]->copy());
}

private:
// Flat probe side key group in either of the following two cases:
// 1. there are multiple join nodes;
// 2. if the build side contains more than one group or the build side has projected out data
Expand All @@ -69,6 +68,7 @@ class LogicalHashJoin : public LogicalOperator {
// flattening probe key, instead duplicating keys as in vectorized processing if necessary.
bool requireFlatProbeKeys();

private:
bool isJoinKeyUniqueOnBuildSide(const binder::Expression& joinNodeID);

private:
Expand Down
46 changes: 30 additions & 16 deletions src/include/processor/operator/hash_join/hash_join_probe.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,46 +47,60 @@ struct ProbeDataInfo {
};

// Probe side on left, i.e. children[0] and build side on right, i.e. children[1]
class HashJoinProbe : public PhysicalOperator, SelVectorOverWriter {
class HashJoinProbe : public PhysicalOperator, public SelVectorOverWriter {
public:
HashJoinProbe(std::shared_ptr<HashJoinSharedState> sharedState, common::JoinType joinType,
const ProbeDataInfo& probeDataInfo, std::unique_ptr<PhysicalOperator> probeChild,
std::unique_ptr<PhysicalOperator> buildChild, uint32_t id, const std::string& paramsString)
bool flatProbe, const ProbeDataInfo& probeDataInfo,
std::unique_ptr<PhysicalOperator> probeChild, std::unique_ptr<PhysicalOperator> buildChild,
uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::HASH_JOIN_PROBE, std::move(probeChild),
std::move(buildChild), id, paramsString},
sharedState{std::move(sharedState)}, joinType{joinType}, probeDataInfo{probeDataInfo} {}
sharedState{std::move(sharedState)}, joinType{joinType}, flatProbe{flatProbe},
probeDataInfo{probeDataInfo} {}

// This constructor is used for cloning only.
// HashJoinProbe do not need to clone hashJoinBuild which is on a different pipeline.
HashJoinProbe(std::shared_ptr<HashJoinSharedState> sharedState, common::JoinType joinType,
const ProbeDataInfo& probeDataInfo, std::unique_ptr<PhysicalOperator> probeChild,
uint32_t id, const std::string& paramsString)
bool flatProbe, const ProbeDataInfo& probeDataInfo,
std::unique_ptr<PhysicalOperator> probeChild, uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::HASH_JOIN_PROBE, std::move(probeChild), id,
paramsString},
sharedState{std::move(sharedState)}, joinType{joinType}, probeDataInfo{probeDataInfo} {}
sharedState{std::move(sharedState)}, joinType{joinType}, flatProbe{flatProbe},
probeDataInfo{probeDataInfo} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool getNextTuplesInternal(ExecutionContext* context) override;

inline std::unique_ptr<PhysicalOperator> clone() override {
return make_unique<HashJoinProbe>(
sharedState, joinType, probeDataInfo, children[0]->clone(), id, paramsString);
return make_unique<HashJoinProbe>(sharedState, joinType, flatProbe, probeDataInfo,
children[0]->clone(), id, paramsString);
}

private:
bool hasMoreLeft();
bool getNextBatchOfMatchedTuples(ExecutionContext* context);
uint64_t getNextInnerJoinResult();
uint64_t getNextLeftJoinResult();
uint64_t getNextMarkJoinResult();
void setVectorsToNull();
inline bool getMatchedTuples(ExecutionContext* context) {
return flatProbe ? getMatchedTuplesForFlatKey(context) :
getMatchedTuplesForUnFlatKey(context);
}
bool getMatchedTuplesForFlatKey(ExecutionContext* context);
// We can probe a batch of input tuples if we know they have at most one match.
bool getMatchedTuplesForUnFlatKey(ExecutionContext* context);

uint64_t getNextJoinResult();
inline uint64_t getInnerJoinResult() {
return flatProbe ? getInnerJoinResultForFlatKey() : getInnerJoinResultForUnFlatKey();
}
uint64_t getInnerJoinResultForFlatKey();
uint64_t getInnerJoinResultForUnFlatKey();
uint64_t getLeftJoinResult();
uint64_t getMarkJoinResult();
uint64_t getJoinResult();

void setVectorsToNull();

private:
std::shared_ptr<HashJoinSharedState> sharedState;
common::JoinType joinType;
bool flatProbe;

ProbeDataInfo probeDataInfo;
std::vector<common::ValueVector*> vectorsToReadInto;
Expand Down
4 changes: 2 additions & 2 deletions src/processor/mapper/map_hash_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalHashJoinToPhysical(
probeDataInfo.markDataPos = markOutputPos;
}
auto hashJoinProbe = make_unique<HashJoinProbe>(sharedState, hashJoin->getJoinType(),
probeDataInfo, std::move(probeSidePrevOperator), std::move(hashJoinBuild), getOperatorID(),
paramsString);
hashJoin->requireFlatProbeKeys(), probeDataInfo, std::move(probeSidePrevOperator),
std::move(hashJoinBuild), getOperatorID(), paramsString);
if (hashJoin->getSIP() == planner::SidewaysInfoPassing::PROBE_TO_BUILD) {
mapAccHashJoin(hashJoinProbe.get());
}
Expand Down
159 changes: 89 additions & 70 deletions src/processor/operator/hash_join/hash_join_probe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,15 @@ void HashJoinProbe::initLocalStateInternal(ResultSet* resultSet, ExecutionContex
}
}

bool HashJoinProbe::hasMoreLeft() {
if (keyVectors[0]->state->isFlat() && probeState->probedTuples[0] != nullptr) {
return true;
}
return false;
}

bool HashJoinProbe::getNextBatchOfMatchedTuples(ExecutionContext* context) {
bool HashJoinProbe::getMatchedTuplesForFlatKey(ExecutionContext* context) {
if (probeState->nextMatchedTupleIdx < probeState->matchedSelVector->selectedSize) {
// Not all matched tuples have been shipped. Continue shipping.
return true;
}
if (!hasMoreLeft()) {
if (probeState->probedTuples[0] == nullptr) { // No more matched tuples on the chain.
// We still need to save and restore for flat input because we are discarding NULL join keys
// which changes the selected position.
// TODO(Guodong): we have potential bugs here because all keys' states should be restored.
restoreSelVector(keyVectors[0]->state->selVector);
if (!children[0]->getNextTuple(context)) {
return false;
Expand All @@ -54,91 +51,113 @@ bool HashJoinProbe::getNextBatchOfMatchedTuples(ExecutionContext* context) {
keyVectors, hashVector.get(), tmpHashVector.get(), probeState->probedTuples.get());
}
auto numMatchedTuples = 0;
auto keyState = keyVectors[0]->state.get();
if (keyState->isFlat()) {
// probe side is flat.
while (probeState->probedTuples[0]) {
if (numMatchedTuples == DEFAULT_VECTOR_CAPACITY) {
while (probeState->probedTuples[0]) {
if (numMatchedTuples == DEFAULT_VECTOR_CAPACITY) {
break;
}
auto currentTuple = probeState->probedTuples[0];
probeState->matchedTuples[numMatchedTuples] = currentTuple;
bool isKeysEqual = true;
for (auto i = 0u; i < keyVectors.size(); i++) {
auto pos = keyVectors[i]->state->selVector->selectedPositions[0];
if (((nodeID_t*)currentTuple)[i] != keyVectors[i]->getValue<nodeID_t>(pos)) {
isKeysEqual = false;
break;
}
auto currentTuple = probeState->probedTuples[0];
probeState->matchedTuples[numMatchedTuples] = currentTuple;
bool isKeysEqual = true;
for (auto i = 0u; i < keyVectors.size(); i++) {
auto pos = keyVectors[i]->state->selVector->selectedPositions[0];
if (((nodeID_t*)currentTuple)[i] != keyVectors[i]->getValue<nodeID_t>(pos)) {
isKeysEqual = false;
break;
}
}
numMatchedTuples += isKeysEqual;
probeState->probedTuples[0] = *sharedState->getHashTable()->getPrevTuple(currentTuple);
}
} else {
assert(keyVectors.size() == 1);
for (auto i = 0u; i < keyState->selVector->selectedSize; i++) {
auto pos = keyState->selVector->selectedPositions[i];
while (probeState->probedTuples[i]) {
assert(numMatchedTuples <= DEFAULT_VECTOR_CAPACITY);
auto currentTuple = probeState->probedTuples[i];
numMatchedTuples += isKeysEqual;
probeState->probedTuples[0] = *sharedState->getHashTable()->getPrevTuple(currentTuple);
}
probeState->matchedSelVector->selectedSize = numMatchedTuples;
probeState->nextMatchedTupleIdx = 0;
return true;
}

bool HashJoinProbe::getMatchedTuplesForUnFlatKey(ExecutionContext* context) {
assert(keyVectors.size() == 1);
auto keyVector = keyVectors[0];
restoreSelVector(keyVector->state->selVector);
if (!children[0]->getNextTuple(context)) {
return false;
}
saveSelVector(keyVector->state->selVector);
sharedState->getHashTable()->probe(
keyVectors, hashVector.get(), tmpHashVector.get(), probeState->probedTuples.get());
auto numMatchedTuples = 0;
auto keySelVector = keyVector->state->selVector.get();
for (auto i = 0u; i < keySelVector->selectedSize; i++) {
auto pos = keySelVector->selectedPositions[i];
while (probeState->probedTuples[i]) {
assert(numMatchedTuples <= DEFAULT_VECTOR_CAPACITY);
auto currentTuple = probeState->probedTuples[i];
if (*(nodeID_t*)currentTuple == keyVectors[0]->getValue<nodeID_t>(pos)) {
// Break if a match has been found.
probeState->matchedTuples[numMatchedTuples] = currentTuple;
probeState->matchedSelVector->selectedPositions[numMatchedTuples] = pos;
numMatchedTuples +=
*(nodeID_t*)currentTuple == keyVectors[0]->getValue<nodeID_t>(pos);
probeState->probedTuples[i] =
*sharedState->getHashTable()->getPrevTuple(currentTuple);
numMatchedTuples++;
break;
}
probeState->probedTuples[i] = *sharedState->getHashTable()->getPrevTuple(currentTuple);
}
}
probeState->matchedSelVector->selectedSize = numMatchedTuples;
probeState->nextMatchedTupleIdx = 0;
return true;
}

void HashJoinProbe::setVectorsToNull() {
for (auto& vector : vectorsToReadInto) {
if (vector->state->isFlat()) {
vector->setNull(vector->state->selVector->selectedPositions[0], true);
} else {
assert(vector->state != keyVectors[0]->state);
auto pos = vector->state->selVector->selectedPositions[0];
vector->setNull(pos, true);
vector->state->selVector->selectedSize = 1;
}
uint64_t HashJoinProbe::getInnerJoinResultForFlatKey() {
if (probeState->matchedSelVector->selectedSize == 0) {
return 0;
}
auto numTuplesToRead = 1;
sharedState->getHashTable()->lookup(vectorsToReadInto, columnIdxsToReadFrom,
probeState->matchedTuples.get(), probeState->nextMatchedTupleIdx, numTuplesToRead);
probeState->nextMatchedTupleIdx += numTuplesToRead;
return numTuplesToRead;
}

uint64_t HashJoinProbe::getNextInnerJoinResult() {
if (probeState->matchedSelVector->selectedSize == 0) {
uint64_t HashJoinProbe::getInnerJoinResultForUnFlatKey() {
auto numTuplesToRead = probeState->matchedSelVector->selectedSize;
if (numTuplesToRead == 0) {
return 0;
}
auto numTuplesToRead =
keyVectors[0]->state->isFlat() ? 1 : probeState->matchedSelVector->selectedSize;
if (!keyVectors[0]->state->isFlat() &&
keyVectors[0]->state->selVector->selectedSize != numTuplesToRead) {
// Update probeSideKeyVector's selectedPositions when the probe side is unflat and its
// selected positions need to change (i.e., some keys has no matched tuples).
auto keySelectedBuffer = keyVectors[0]->state->selVector->getSelectedPositionsBuffer();
auto keySelVector = keyVectors[0]->state->selVector.get();
if (keySelVector->selectedSize != numTuplesToRead) {
// Some keys have no matched tuple. So we modify selected position.
auto keySelectedBuffer = keySelVector->getSelectedPositionsBuffer();
for (auto i = 0u; i < numTuplesToRead; i++) {
keySelectedBuffer[i] = probeState->matchedSelVector->selectedPositions[i];
}
keyVectors[0]->state->selVector->selectedSize = numTuplesToRead;
keyVectors[0]->state->selVector->resetSelectorToValuePosBuffer();
keySelVector->selectedSize = numTuplesToRead;
keySelVector->resetSelectorToValuePosBuffer();
}
sharedState->getHashTable()->lookup(vectorsToReadInto, columnIdxsToReadFrom,
probeState->matchedTuples.get(), probeState->nextMatchedTupleIdx, numTuplesToRead);
probeState->nextMatchedTupleIdx += numTuplesToRead;
return numTuplesToRead;
}

uint64_t HashJoinProbe::getNextLeftJoinResult() {
if (getNextInnerJoinResult() == 0) {
void HashJoinProbe::setVectorsToNull() {
for (auto& vector : vectorsToReadInto) {
if (vector->state->isFlat()) {
vector->setNull(vector->state->selVector->selectedPositions[0], true);
} else {
assert(vector->state != keyVectors[0]->state);
auto pos = vector->state->selVector->selectedPositions[0];
vector->setNull(pos, true);
vector->state->selVector->selectedSize = 1;
}
}
}

uint64_t HashJoinProbe::getLeftJoinResult() {
if (getInnerJoinResult() == 0) {
setVectorsToNull();
}
return 1;
}

uint64_t HashJoinProbe::getNextMarkJoinResult() {
uint64_t HashJoinProbe::getMarkJoinResult() {
auto markValues = (bool*)markVector->getData();
if (markVector->state->isFlat()) {
markValues[markVector->state->selVector->selectedPositions[0]] =
Expand All @@ -153,20 +172,20 @@ uint64_t HashJoinProbe::getNextMarkJoinResult() {
return 1;
}

uint64_t HashJoinProbe::getNextJoinResult() {
uint64_t HashJoinProbe::getJoinResult() {
switch (joinType) {
case JoinType::LEFT: {
return getNextLeftJoinResult();
return getLeftJoinResult();
}
case JoinType::MARK: {
return getNextMarkJoinResult();
return getMarkJoinResult();
}
case JoinType::INNER: {
return getNextInnerJoinResult();
}
default: {
assert(false);
return getInnerJoinResult();
}
default:
throw common::InternalException(
"Unimplemented join type for HashJoinProbe::getJoinResult()");
}
}

Expand All @@ -178,10 +197,10 @@ uint64_t HashJoinProbe::getNextJoinResult() {
bool HashJoinProbe::getNextTuplesInternal(ExecutionContext* context) {
uint64_t numPopulatedTuples;
do {
if (!getNextBatchOfMatchedTuples(context)) {
if (!getMatchedTuples(context)) {
return false;
}
numPopulatedTuples = getNextJoinResult();
numPopulatedTuples = getJoinResult();
} while (numPopulatedTuples == 0);
metrics->numOutputTuple.increase(numPopulatedTuples);
return true;
Expand Down