Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic hash join #1919

Merged
merged 1 commit into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@
return result;
}

std::string ExpressionUtil::toString(const std::vector<expression_pair>& expressionPairs) {
if (expressionPairs.empty()) {
return std::string{};
}
auto result = toString(expressionPairs[0]);
for (auto i = 1u; i < expressionPairs.size(); ++i) {
result += "," + toString(expressionPairs[i]);
}
return result;
}

std::string ExpressionUtil::toString(const expression_pair& expressionPair) {
return expressionPair.first->toString() + "=" + expressionPair.second->toString();
}

expression_vector ExpressionUtil::excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude) {
expression_set excludeSet;
Expand All @@ -64,5 +79,15 @@
return result;
}

std::vector<std::unique_ptr<common::LogicalType>> ExpressionUtil::getDataTypes(
const kuzu::binder::expression_vector& expressions) {
std::vector<std::unique_ptr<common::LogicalType>> result;
result.reserve(expressions.size());
for (auto& expression : expressions) {
result.push_back(expression->getDataType().copy());
}
return result;
}

Check warning on line 90 in src/binder/expression/expression.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression/expression.cpp#L90

Added line #L90 was not covered by tests

} // namespace binder
} // namespace kuzu
10 changes: 10 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@
return dataType;
}

std::vector<std::unique_ptr<LogicalType>> LogicalType::copy(
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<std::unique_ptr<LogicalType>>& types) {
std::vector<std::unique_ptr<LogicalType>> typesCopy;
typesCopy.reserve(types.size());
for (auto& type : types) {
typesCopy.push_back(type->copy());
}
return typesCopy;
}

Check warning on line 315 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L315

Added line #L315 was not covered by tests

void LogicalType::setPhysicalType() {
switch (typeID) {
case LogicalTypeID::ANY: {
Expand Down
2 changes: 1 addition & 1 deletion src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void ValueVector::setState(std::shared_ptr<DataChunkState> state) {
}
}

bool NodeIDVector::discardNull(ValueVector& vector) {
bool ValueVector::discardNull(ValueVector& vector) {
if (vector.hasNoNullsGuarantee()) {
return true;
} else {
Expand Down
8 changes: 8 additions & 0 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ struct ExpressionUtil {

static uint32_t find(Expression* target, expression_vector expressions);

// Print as a1,a2,a3,...
static std::string toString(const expression_vector& expressions);
// Print as a1=a2, a3=a4,...
static std::string toString(const std::vector<expression_pair>& expressionPairs);
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
// Print as a1=a2
static std::string toString(const expression_pair& expressionPair);
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved

static expression_vector excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude);
Expand All @@ -136,6 +141,9 @@ struct ExpressionUtil {
return expression.expressionType == common::ExpressionType::VARIABLE &&
expression.dataType.getLogicalTypeID() == common::LogicalTypeID::RECURSIVE_REL;
}

static std::vector<std::unique_ptr<common::LogicalType>> getDataTypes(
const expression_vector& expressions);
};

} // namespace binder
Expand Down
6 changes: 5 additions & 1 deletion src/include/common/data_chunk/data_chunk_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
namespace kuzu {
namespace common {

class DataChunkState {
enum class FactorizationStateType : uint8_t {
FLAT = 0,
UNFLAT = 1,
};

class DataChunkState {
public:
DataChunkState() : DataChunkState(DEFAULT_VECTOR_CAPACITY) {}
explicit DataChunkState(uint64_t capacity) : currIdx{-1}, originalSize{0} {
Expand Down
3 changes: 3 additions & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ class LogicalType {

std::unique_ptr<LogicalType> copy() const;

static std::vector<std::unique_ptr<LogicalType>> copy(
const std::vector<std::unique_ptr<LogicalType>>& types);

private:
void setPhysicalType();

Expand Down
11 changes: 4 additions & 7 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class ValueVector {

void resetAuxiliaryBuffer();

// If there is still non-null values after discarding, return true. Otherwise, return false.
// For an unflat vector, its selection vector is also updated to the resultSelVector.
static bool discardNull(ValueVector& vector);

private:
uint32_t getDataTypeSize(const LogicalType& type);
void initializeValueBuffer();
Expand Down Expand Up @@ -228,13 +232,6 @@ class ArrowColumnVector {
static void slice(ValueVector* vector, offset_t offset);
};

class NodeIDVector {
public:
// If there is still non-null values after discarding, return true. Otherwise, return false.
// For an unflat vector, its selection vector is also updated to the resultSelVector.
static bool discardNull(ValueVector& vector);
};

class MapVector {
public:
static inline ValueVector* getKeyVector(const ValueVector* vector) {
Expand Down
30 changes: 19 additions & 11 deletions src/include/planner/logical_plan/logical_hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,32 @@
namespace kuzu {
namespace planner {

// We only support equality comparison as join condition
using join_condition_t = binder::expression_pair;
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved

// Probe side on left, i.e. children[0]. Build side on right, i.e. children[1].
class LogicalHashJoin : public LogicalOperator {
public:
// Inner and left join.
LogicalHashJoin(binder::expression_vector joinNodeIDs, common::JoinType joinType,
LogicalHashJoin(std::vector<join_condition_t> joinConditions, common::JoinType joinType,
std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodeIDs), joinType, nullptr, std::move(probeSideChild),
: LogicalHashJoin{std::move(joinConditions), joinType, nullptr, std::move(probeSideChild),
std::move(buildSideChild)} {}

// Mark join.
LogicalHashJoin(binder::expression_vector joinNodeIDs, std::shared_ptr<binder::Expression> mark,
std::shared_ptr<LogicalOperator> probeSideChild,
LogicalHashJoin(std::vector<join_condition_t> joinConditions,
std::shared_ptr<binder::Expression> mark, std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodeIDs), common::JoinType::MARK, std::move(mark),
: LogicalHashJoin{std::move(joinConditions), common::JoinType::MARK, std::move(mark),
std::move(probeSideChild), std::move(buildSideChild)} {}

LogicalHashJoin(binder::expression_vector joinNodeIDs, common::JoinType joinType,
LogicalHashJoin(std::vector<join_condition_t> joinConditions, common::JoinType joinType,
std::shared_ptr<binder::Expression> mark, std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalOperator{LogicalOperatorType::HASH_JOIN, std::move(probeSideChild),
std::move(buildSideChild)},
joinNodeIDs(std::move(joinNodeIDs)), joinType{joinType}, mark{std::move(mark)},
joinConditions(std::move(joinConditions)), joinType{joinType}, mark{std::move(mark)},
sip{SidewaysInfoPassing::NONE} {}

f_group_pos_set getGroupsPosToFlattenOnProbeSide();
Expand All @@ -41,11 +44,15 @@ class LogicalHashJoin : public LogicalOperator {
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(joinNodeIDs);
return isNodeIDOnlyJoin() ? binder::ExpressionUtil::toString(getJoinNodeIDs()) :
binder::ExpressionUtil::toString(joinConditions);
}

binder::expression_vector getExpressionsToMaterialize() const;
inline binder::expression_vector getJoinNodeIDs() const { return joinNodeIDs; }

binder::expression_vector getJoinNodeIDs() const;
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved

inline std::vector<join_condition_t> getJoinConditions() const { return joinConditions; }
inline common::JoinType getJoinType() const { return joinType; }
inline std::shared_ptr<binder::Expression> getMark() const {
assert(joinType == common::JoinType::MARK && mark);
Expand All @@ -56,7 +63,7 @@ class LogicalHashJoin : public LogicalOperator {

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalHashJoin>(
joinNodeIDs, joinType, mark, children[0]->copy(), children[1]->copy());
joinConditions, joinType, mark, children[0]->copy(), children[1]->copy());
}

// Flat probe side key group in either of the following two cases:
Expand All @@ -69,10 +76,11 @@ class LogicalHashJoin : public LogicalOperator {
bool requireFlatProbeKeys();

private:
bool isNodeIDOnlyJoin() const;
bool isJoinKeyUniqueOnBuildSide(const binder::Expression& joinNodeID);

private:
binder::expression_vector joinNodeIDs;
std::vector<join_condition_t> joinConditions;
common::JoinType joinType;
std::shared_ptr<binder::Expression> mark; // when joinType is Mark
SidewaysInfoPassing sip;
Expand Down
2 changes: 0 additions & 2 deletions src/include/planner/logical_plan/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ class Schema {

class SchemaUtils {
public:
static std::vector<binder::expression_vector> getExpressionsPerGroup(
const binder::expression_vector& expressions, const Schema& schema);
// Given a set of factorization group, a leading group is selected as the unFlat group (caller
// should ensure at most one unFlat group which is our general assumption of factorization). If
// all groups are flat, we select any (the first) group as leading group.
Expand Down
9 changes: 5 additions & 4 deletions src/include/planner/logical_plan/sip/side_way_info_passing.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ namespace planner {

enum class SidewaysInfoPassing : uint8_t {
NONE = 0,
PROBE_TO_BUILD = 1,
PROHIBIT_PROBE_TO_BUILD = 2,
BUILD_TO_PROBE = 3,
PROHIBIT_BUILD_TO_PROBE = 4,
PROHIBIT = 1,
PROBE_TO_BUILD = 2,
PROHIBIT_PROBE_TO_BUILD = 3,
BUILD_TO_PROBE = 4,
PROHIBIT_BUILD_TO_PROBE = 5,
};

} // namespace planner
Expand Down
17 changes: 16 additions & 1 deletion src/include/processor/operator/base_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,29 @@ class BaseHashTable {

virtual ~BaseHashTable() = default;

protected:
inline void setMaxNumHashSlots(uint64_t newSize) {
maxNumHashSlots = newSize;
bitmask = maxNumHashSlots - 1;
}

inline void initSlotConstant(uint64_t numSlotsPerBlock_) {
assert(numSlotsPerBlock_ == common::nextPowerOfTwo(numSlotsPerBlock_));
numSlotsPerBlock = numSlotsPerBlock_;
numSlotsPerBlockLog2 = std::log2(numSlotsPerBlock);
slotIdxInBlockMask =
common::BitmaskUtils::all1sMaskForLeastSignificantBits(numSlotsPerBlockLog2);
}

inline uint64_t getSlotIdxForHash(common::hash_t hash) const { return hash & bitmask; }

protected:
uint64_t maxNumHashSlots;
uint64_t bitmask;
std::vector<std::unique_ptr<DataBlock>> hashSlotsBlocks;
uint64_t numSlotsPerBlock;
uint64_t numSlotsPerBlockLog2;
uint64_t slotIdxInBlockMask;
std::vector<std::unique_ptr<DataBlock>> hashSlotsBlocks;
storage::MemoryManager& memoryManager;
std::unique_ptr<FactorizedTable> factorizedTable;
};
Expand Down
30 changes: 19 additions & 11 deletions src/include/processor/operator/hash_join/hash_join_build.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ class HashJoinBuildInfo {
friend class HashJoinBuild;

public:
HashJoinBuildInfo(std::vector<DataPos> keysPos, std::vector<DataPos> payloadsPos,
std::unique_ptr<FactorizedTableSchema> tableSchema)
: keysPos{std::move(keysPos)}, payloadsPos{std::move(payloadsPos)}, tableSchema{std::move(
tableSchema)} {}
HashJoinBuildInfo(std::vector<DataPos> keysPos,
std::vector<common::FactorizationStateType> factorizationStateTypes,
std::vector<DataPos> payloadsPos, std::unique_ptr<FactorizedTableSchema> tableSchema)
: keysPos{std::move(keysPos)}, factorizationStateTypes{std::move(factorizationStateTypes)},
payloadsPos{std::move(payloadsPos)}, tableSchema{std::move(tableSchema)} {}
HashJoinBuildInfo(const HashJoinBuildInfo& other)
: keysPos{other.keysPos}, payloadsPos{other.payloadsPos}, tableSchema{
other.tableSchema->copy()} {}
: keysPos{other.keysPos}, factorizationStateTypes{other.factorizationStateTypes},
payloadsPos{other.payloadsPos}, tableSchema{other.tableSchema->copy()} {}

inline uint32_t getNumKeys() const { return keysPos.size(); }

Expand All @@ -56,6 +57,7 @@ class HashJoinBuildInfo {

private:
std::vector<DataPos> keysPos;
std::vector<common::FactorizationStateType> factorizationStateTypes;
std::vector<DataPos> payloadsPos;
std::unique_ptr<FactorizedTableSchema> tableSchema;
};
Expand All @@ -73,7 +75,6 @@ class HashJoinBuild : public Sink {
uint32_t id, const std::string& paramsString)
: Sink{std::move(resultSetDescriptor), operatorType, std::move(child), id, paramsString},
sharedState{std::move(sharedState)}, info{std::move(info)} {}
~HashJoinBuild() override = default;

inline std::shared_ptr<HashJoinSharedState> getSharedState() const { return sharedState; }

Expand All @@ -88,15 +89,22 @@ class HashJoinBuild : public Sink {
}

protected:
virtual void initLocalHashTable(storage::MemoryManager& memoryManager) {
hashTable = std::make_unique<JoinHashTable>(
memoryManager, info->getNumKeys(), info->tableSchema->copy());
virtual inline void appendVectors() {
hashTable->appendVectors(keyVectors, payloadVectors, keyState);
}

private:
void setKeyState(common::DataChunkState* state);

protected:
std::shared_ptr<HashJoinSharedState> sharedState;
std::unique_ptr<HashJoinBuildInfo> info;
std::vector<common::ValueVector*> vectorsToAppend;

std::vector<common::ValueVector*> keyVectors;
// State of unFlat key(s). If all keys are flat, it points to any flat key state.
common::DataChunkState* keyState = nullptr;
std::vector<common::ValueVector*> payloadVectors;

std::unique_ptr<JoinHashTable> hashTable; // local state
};

Expand Down
Loading
Loading