Skip to content

Commit

Permalink
Add generic hash join
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Aug 11, 2023
1 parent e81f88c commit e8e0283
Show file tree
Hide file tree
Showing 33 changed files with 598 additions and 337 deletions.
25 changes: 25 additions & 0 deletions src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ std::string ExpressionUtil::toString(const expression_vector& expressions) {
return result;
}

std::string ExpressionUtil::toString(const std::vector<expression_pair>& expressionPairs) {
if (expressionPairs.empty()) {
return std::string{};
}
auto result = toString(expressionPairs[0]);
for (auto i = 1u; i < expressionPairs.size(); ++i) {
result += "," + toString(expressionPairs[i]);
}
return result;
}

std::string ExpressionUtil::toString(const expression_pair& expressionPair) {
return expressionPair.first->toString() + "=" + expressionPair.second->toString();
}

expression_vector ExpressionUtil::excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude) {
expression_set excludeSet;
Expand All @@ -64,5 +79,15 @@ expression_vector ExpressionUtil::excludeExpressions(
return result;
}

std::vector<std::unique_ptr<common::LogicalType>> ExpressionUtil::getDataTypes(
const kuzu::binder::expression_vector& expressions) {
std::vector<std::unique_ptr<common::LogicalType>> result;
result.reserve(expressions.size());
for (auto& expression : expressions) {
result.push_back(expression->getDataType().copy());
}
return result;
}

Check warning on line 90 in src/binder/expression/expression.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression/expression.cpp#L90

Added line #L90 was not covered by tests

} // namespace binder
} // namespace kuzu
10 changes: 10 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ std::unique_ptr<LogicalType> LogicalType::copy() const {
return dataType;
}

std::vector<std::unique_ptr<LogicalType>> LogicalType::copy(
const std::vector<std::unique_ptr<LogicalType>>& types) {
std::vector<std::unique_ptr<LogicalType>> typesCopy;
typesCopy.reserve(types.size());
for (auto& type : types) {
typesCopy.push_back(type->copy());
}
return typesCopy;
}

Check warning on line 315 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L315

Added line #L315 was not covered by tests

void LogicalType::setPhysicalType() {
switch (typeID) {
case LogicalTypeID::ANY: {
Expand Down
2 changes: 1 addition & 1 deletion src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void ValueVector::setState(std::shared_ptr<DataChunkState> state) {
}
}

bool NodeIDVector::discardNull(ValueVector& vector) {
bool ValueVector::discardNull(ValueVector& vector) {
if (vector.hasNoNullsGuarantee()) {
return true;
} else {
Expand Down
8 changes: 8 additions & 0 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ struct ExpressionUtil {

static uint32_t find(Expression* target, expression_vector expressions);

// Print as a1,a2,a3,...
static std::string toString(const expression_vector& expressions);
// Print as a1=a2, a3=a4,...
static std::string toString(const std::vector<expression_pair>& expressionPairs);
// Print as a1=a2
static std::string toString(const expression_pair& expressionPair);

static expression_vector excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude);
Expand All @@ -136,6 +141,9 @@ struct ExpressionUtil {
return expression.expressionType == common::ExpressionType::VARIABLE &&
expression.dataType.getLogicalTypeID() == common::LogicalTypeID::RECURSIVE_REL;
}

static std::vector<std::unique_ptr<common::LogicalType>> getDataTypes(
const expression_vector& expressions);
};

} // namespace binder
Expand Down
6 changes: 5 additions & 1 deletion src/include/common/data_chunk/data_chunk_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
namespace kuzu {
namespace common {

class DataChunkState {
enum class FactorizationStateType : uint8_t {
FLAT = 0,
UNFLAT = 1,
};

class DataChunkState {
public:
DataChunkState() : DataChunkState(DEFAULT_VECTOR_CAPACITY) {}
explicit DataChunkState(uint64_t capacity) : currIdx{-1}, originalSize{0} {
Expand Down
3 changes: 3 additions & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ class LogicalType {

std::unique_ptr<LogicalType> copy() const;

static std::vector<std::unique_ptr<LogicalType>> copy(
const std::vector<std::unique_ptr<LogicalType>>& types);

private:
void setPhysicalType();

Expand Down
11 changes: 4 additions & 7 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class ValueVector {

void resetAuxiliaryBuffer();

// If there is still non-null values after discarding, return true. Otherwise, return false.
// For an unflat vector, its selection vector is also updated to the resultSelVector.
static bool discardNull(ValueVector& vector);

private:
uint32_t getDataTypeSize(const LogicalType& type);
void initializeValueBuffer();
Expand Down Expand Up @@ -228,13 +232,6 @@ class ArrowColumnVector {
static void slice(ValueVector* vector, offset_t offset);
};

class NodeIDVector {
public:
// If there is still non-null values after discarding, return true. Otherwise, return false.
// For an unflat vector, its selection vector is also updated to the resultSelVector.
static bool discardNull(ValueVector& vector);
};

class MapVector {
public:
static inline ValueVector* getKeyVector(const ValueVector* vector) {
Expand Down
30 changes: 19 additions & 11 deletions src/include/planner/logical_plan/logical_hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,32 @@
namespace kuzu {
namespace planner {

// We only support equality comparison as join condition
using join_condition_t = binder::expression_pair;

// Probe side on left, i.e. children[0]. Build side on right, i.e. children[1].
class LogicalHashJoin : public LogicalOperator {
public:
// Inner and left join.
LogicalHashJoin(binder::expression_vector joinNodeIDs, common::JoinType joinType,
LogicalHashJoin(std::vector<join_condition_t> joinConditions, common::JoinType joinType,
std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodeIDs), joinType, nullptr, std::move(probeSideChild),
: LogicalHashJoin{std::move(joinConditions), joinType, nullptr, std::move(probeSideChild),
std::move(buildSideChild)} {}

// Mark join.
LogicalHashJoin(binder::expression_vector joinNodeIDs, std::shared_ptr<binder::Expression> mark,
std::shared_ptr<LogicalOperator> probeSideChild,
LogicalHashJoin(std::vector<join_condition_t> joinConditions,
std::shared_ptr<binder::Expression> mark, std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodeIDs), common::JoinType::MARK, std::move(mark),
: LogicalHashJoin{std::move(joinConditions), common::JoinType::MARK, std::move(mark),
std::move(probeSideChild), std::move(buildSideChild)} {}

LogicalHashJoin(binder::expression_vector joinNodeIDs, common::JoinType joinType,
LogicalHashJoin(std::vector<join_condition_t> joinConditions, common::JoinType joinType,
std::shared_ptr<binder::Expression> mark, std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalOperator{LogicalOperatorType::HASH_JOIN, std::move(probeSideChild),
std::move(buildSideChild)},
joinNodeIDs(std::move(joinNodeIDs)), joinType{joinType}, mark{std::move(mark)},
joinConditions(std::move(joinConditions)), joinType{joinType}, mark{std::move(mark)},
sip{SidewaysInfoPassing::NONE} {}

f_group_pos_set getGroupsPosToFlattenOnProbeSide();
Expand All @@ -41,11 +44,15 @@ class LogicalHashJoin : public LogicalOperator {
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(joinNodeIDs);
return isNodeIDOnlyJoin() ? binder::ExpressionUtil::toString(getJoinNodeIDs()) :
binder::ExpressionUtil::toString(joinConditions);
}

binder::expression_vector getExpressionsToMaterialize() const;
inline binder::expression_vector getJoinNodeIDs() const { return joinNodeIDs; }

binder::expression_vector getJoinNodeIDs() const;

inline std::vector<join_condition_t> getJoinConditions() const { return joinConditions; }
inline common::JoinType getJoinType() const { return joinType; }
inline std::shared_ptr<binder::Expression> getMark() const {
assert(joinType == common::JoinType::MARK && mark);
Expand All @@ -56,7 +63,7 @@ class LogicalHashJoin : public LogicalOperator {

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalHashJoin>(
joinNodeIDs, joinType, mark, children[0]->copy(), children[1]->copy());
joinConditions, joinType, mark, children[0]->copy(), children[1]->copy());
}

// Flat probe side key group in either of the following two cases:
Expand All @@ -69,10 +76,11 @@ class LogicalHashJoin : public LogicalOperator {
bool requireFlatProbeKeys();

private:
bool isNodeIDOnlyJoin() const;
bool isJoinKeyUniqueOnBuildSide(const binder::Expression& joinNodeID);

private:
binder::expression_vector joinNodeIDs;
std::vector<join_condition_t> joinConditions;
common::JoinType joinType;
std::shared_ptr<binder::Expression> mark; // when joinType is Mark
SidewaysInfoPassing sip;
Expand Down
2 changes: 0 additions & 2 deletions src/include/planner/logical_plan/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ class Schema {

class SchemaUtils {
public:
static std::vector<binder::expression_vector> getExpressionsPerGroup(
const binder::expression_vector& expressions, const Schema& schema);
// Given a set of factorization group, a leading group is selected as the unFlat group (caller
// should ensure at most one unFlat group which is our general assumption of factorization). If
// all groups are flat, we select any (the first) group as leading group.
Expand Down
9 changes: 5 additions & 4 deletions src/include/planner/logical_plan/sip/side_way_info_passing.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ namespace planner {

enum class SidewaysInfoPassing : uint8_t {
NONE = 0,
PROBE_TO_BUILD = 1,
PROHIBIT_PROBE_TO_BUILD = 2,
BUILD_TO_PROBE = 3,
PROHIBIT_BUILD_TO_PROBE = 4,
PROHIBIT = 1,
PROBE_TO_BUILD = 2,
PROHIBIT_PROBE_TO_BUILD = 3,
BUILD_TO_PROBE = 4,
PROHIBIT_BUILD_TO_PROBE = 5,
};

} // namespace planner
Expand Down
17 changes: 16 additions & 1 deletion src/include/processor/operator/base_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,29 @@ class BaseHashTable {

virtual ~BaseHashTable() = default;

protected:
inline void setMaxNumHashSlots(uint64_t newSize) {
maxNumHashSlots = newSize;
bitmask = maxNumHashSlots - 1;
}

inline void initSlotConstant(uint64_t numSlotsPerBlock_) {
assert(numSlotsPerBlock_ == common::nextPowerOfTwo(numSlotsPerBlock_));
numSlotsPerBlock = numSlotsPerBlock_;
numSlotsPerBlockLog2 = std::log2(numSlotsPerBlock);
slotIdxInBlockMask =
common::BitmaskUtils::all1sMaskForLeastSignificantBits(numSlotsPerBlockLog2);
}

inline uint64_t getSlotIdxForHash(common::hash_t hash) const { return hash & bitmask; }

protected:
uint64_t maxNumHashSlots;
uint64_t bitmask;
std::vector<std::unique_ptr<DataBlock>> hashSlotsBlocks;
uint64_t numSlotsPerBlock;
uint64_t numSlotsPerBlockLog2;
uint64_t slotIdxInBlockMask;
std::vector<std::unique_ptr<DataBlock>> hashSlotsBlocks;
storage::MemoryManager& memoryManager;
std::unique_ptr<FactorizedTable> factorizedTable;
};
Expand Down
30 changes: 19 additions & 11 deletions src/include/processor/operator/hash_join/hash_join_build.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ class HashJoinBuildInfo {
friend class HashJoinBuild;

public:
HashJoinBuildInfo(std::vector<DataPos> keysPos, std::vector<DataPos> payloadsPos,
std::unique_ptr<FactorizedTableSchema> tableSchema)
: keysPos{std::move(keysPos)}, payloadsPos{std::move(payloadsPos)}, tableSchema{std::move(
tableSchema)} {}
HashJoinBuildInfo(std::vector<DataPos> keysPos,
std::vector<common::FactorizationStateType> factorizationStateTypes,
std::vector<DataPos> payloadsPos, std::unique_ptr<FactorizedTableSchema> tableSchema)
: keysPos{std::move(keysPos)}, factorizationStateTypes{std::move(factorizationStateTypes)},
payloadsPos{std::move(payloadsPos)}, tableSchema{std::move(tableSchema)} {}
HashJoinBuildInfo(const HashJoinBuildInfo& other)
: keysPos{other.keysPos}, payloadsPos{other.payloadsPos}, tableSchema{
other.tableSchema->copy()} {}
: keysPos{other.keysPos}, factorizationStateTypes{other.factorizationStateTypes},
payloadsPos{other.payloadsPos}, tableSchema{other.tableSchema->copy()} {}

inline uint32_t getNumKeys() const { return keysPos.size(); }

Expand All @@ -56,6 +57,7 @@ class HashJoinBuildInfo {

private:
std::vector<DataPos> keysPos;
std::vector<common::FactorizationStateType> factorizationStateTypes;
std::vector<DataPos> payloadsPos;
std::unique_ptr<FactorizedTableSchema> tableSchema;
};
Expand All @@ -73,7 +75,6 @@ class HashJoinBuild : public Sink {
uint32_t id, const std::string& paramsString)
: Sink{std::move(resultSetDescriptor), operatorType, std::move(child), id, paramsString},
sharedState{std::move(sharedState)}, info{std::move(info)} {}
~HashJoinBuild() override = default;

inline std::shared_ptr<HashJoinSharedState> getSharedState() const { return sharedState; }

Expand All @@ -88,15 +89,22 @@ class HashJoinBuild : public Sink {
}

protected:
virtual void initLocalHashTable(storage::MemoryManager& memoryManager) {
hashTable = std::make_unique<JoinHashTable>(
memoryManager, info->getNumKeys(), info->tableSchema->copy());
virtual inline void appendVectors() {
hashTable->appendVectors(keyVectors, payloadVectors, keyState);
}

private:
void setKeyState(common::DataChunkState* state);

protected:
std::shared_ptr<HashJoinSharedState> sharedState;
std::unique_ptr<HashJoinBuildInfo> info;
std::vector<common::ValueVector*> vectorsToAppend;

std::vector<common::ValueVector*> keyVectors;
// State of unFlat key(s). If all keys are flat, it points to any flat key state.
common::DataChunkState* keyState = nullptr;
std::vector<common::ValueVector*> payloadVectors;

std::unique_ptr<JoinHashTable> hashTable; // local state
};

Expand Down
Loading

0 comments on commit e8e0283

Please sign in to comment.