Skip to content

Commit

Permalink
Merge pull request #1659 from kuzudb/hash-join-map-rework
Browse files Browse the repository at this point in the history
Rework hash join build mapper
  • Loading branch information
andyfengHKU committed Jun 11, 2023
2 parents b2c85b2 + fb68844 commit 139dbe0
Show file tree
Hide file tree
Showing 16 changed files with 170 additions and 194 deletions.
15 changes: 15 additions & 0 deletions src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,20 @@ std::string ExpressionUtil::toString(const expression_vector& expressions) {
return result;
}

expression_vector ExpressionUtil::excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude) {
expression_set excludeSet;
for (auto& expression : expressionsToExclude) {
excludeSet.insert(expression);
}
expression_vector result;
for (auto& expression : expressions) {
if (!excludeSet.contains(expression)) {
result.push_back(expression);
}
}
return result;
}

} // namespace binder
} // namespace kuzu
6 changes: 4 additions & 2 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,16 @@ struct ExpressionEquality {
}
};

class ExpressionUtil {
public:
struct ExpressionUtil {
static bool allExpressionsHaveDataType(
expression_vector& expressions, common::LogicalTypeID dataTypeID);

static uint32_t find(Expression* target, expression_vector expressions);

static std::string toString(const expression_vector& expressions);

static expression_vector excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude);
};

} // namespace binder
Expand Down
4 changes: 2 additions & 2 deletions src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace kuzu {
namespace processor {

struct BuildDataInfo;
struct HashJoinBuildInfo;
struct AggregateInputInfo;

class PlanMapper {
Expand Down Expand Up @@ -114,7 +114,7 @@ class PlanMapper {

inline uint32_t getOperatorID() { return physicalOperatorID++; }

BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema,
std::unique_ptr<HashJoinBuildInfo> createHashBuildInfo(const planner::Schema& buildSideSchema,
const binder::expression_vector& keys, const binder::expression_vector& payloads);

std::unique_ptr<PhysicalOperator> createHashAggregate(
Expand Down
37 changes: 25 additions & 12 deletions src/include/processor/operator/cross_product.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,53 @@ class CrossProductLocalState {
uint64_t startIdx = 0u;
};

class CrossProductInfo {
friend class CrossProduct;

public:
CrossProductInfo(std::vector<DataPos> outVecPos, std::vector<ft_col_idx_t> colIndicesToScan)
: outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)} {}
CrossProductInfo(const CrossProductInfo& other)
: outVecPos{other.outVecPos}, colIndicesToScan{other.colIndicesToScan} {}

inline std::unique_ptr<CrossProductInfo> copy() const {
return std::make_unique<CrossProductInfo>(*this);
}

private:
std::vector<DataPos> outVecPos;
std::vector<ft_col_idx_t> colIndicesToScan;
};

class CrossProduct : public PhysicalOperator {
public:
CrossProduct(std::vector<DataPos> outVecPos, std::vector<uint32_t> colIndicesToScan,
CrossProduct(std::unique_ptr<CrossProductInfo> info,
std::unique_ptr<CrossProductLocalState> localState,
std::unique_ptr<PhysicalOperator> probeChild, std::unique_ptr<PhysicalOperator> buildChild,
uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::CROSS_PRODUCT, std::move(probeChild),
std::move(buildChild), id, paramsString},
outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)},
localState{std::move(localState)} {}
info{std::move(info)}, localState{std::move(localState)} {}

// Clone only.
CrossProduct(std::vector<DataPos> outVecPos, std::vector<uint32_t> colIndicesToScan,
CrossProduct(std::unique_ptr<CrossProductInfo> info,
std::unique_ptr<CrossProductLocalState> localState, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::CROSS_PRODUCT, std::move(child), id, paramsString},
outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)},
localState{std::move(localState)} {}
info{std::move(info)}, localState{std::move(localState)} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool getNextTuplesInternal(ExecutionContext* context) override;

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<CrossProduct>(outVecPos, colIndicesToScan, localState->copy(),
children[0]->clone(), id, paramsString);
return std::make_unique<CrossProduct>(
info->copy(), localState->copy(), children[0]->clone(), id, paramsString);
}

private:
std::vector<DataPos> outVecPos;
std::vector<uint32_t> colIndicesToScan;

std::unique_ptr<CrossProductInfo> info;
std::unique_ptr<CrossProductLocalState> localState;

std::vector<common::ValueVector*> vectorsToScan;
};

Expand Down
72 changes: 36 additions & 36 deletions src/include/processor/operator/hash_join/hash_join_build.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
namespace kuzu {
namespace processor {

class HashJoinBuild;

// This is a shared state between HashJoinBuild and HashJoinProbe operators.
// Each clone of these two operators will share the same state.
// Inside the state, we keep the materialized tuples in factorizedTable, which are merged by each
Expand All @@ -18,13 +20,11 @@ namespace processor {
// task/pipeline, and probed by the HashJoinProbe operators.
class HashJoinSharedState {
public:
HashJoinSharedState() = default;
explicit HashJoinSharedState(std::unique_ptr<JoinHashTable> hashTable)
: hashTable{std::move(hashTable)} {};

virtual ~HashJoinSharedState() = default;

virtual void initEmptyHashTable(storage::MemoryManager& memoryManager, uint64_t numKeyColumns,
std::unique_ptr<FactorizedTableSchema> tableSchema);

void mergeLocalHashTable(JoinHashTable& localHashTable);

inline JoinHashTable* getHashTable() { return hashTable.get(); }
Expand All @@ -34,42 +34,45 @@ class HashJoinSharedState {
std::unique_ptr<JoinHashTable> hashTable;
};

struct BuildDataInfo {
class HashJoinBuildInfo {
friend class HashJoinBuild;

public:
BuildDataInfo(std::vector<std::pair<DataPos, common::LogicalType>> keysPosAndType,
std::vector<std::pair<DataPos, common::LogicalType>> payloadsPosAndType,
std::vector<bool> isPayloadsFlat, std::vector<bool> isPayloadsInKeyChunk)
: keysPosAndType{std::move(keysPosAndType)}, payloadsPosAndType{std::move(
payloadsPosAndType)},
isPayloadsFlat{std::move(isPayloadsFlat)}, isPayloadsInKeyChunk{
std::move(isPayloadsInKeyChunk)} {}
HashJoinBuildInfo(std::vector<DataPos> keysPos, std::vector<DataPos> payloadsPos,
std::unique_ptr<FactorizedTableSchema> tableSchema)
: keysPos{std::move(keysPos)}, payloadsPos{std::move(payloadsPos)}, tableSchema{std::move(
tableSchema)} {}
HashJoinBuildInfo(const HashJoinBuildInfo& other)
: keysPos{other.keysPos}, payloadsPos{other.payloadsPos}, tableSchema{
other.tableSchema->copy()} {}

BuildDataInfo(const BuildDataInfo& other)
: BuildDataInfo{other.keysPosAndType, other.payloadsPosAndType, other.isPayloadsFlat,
other.isPayloadsInKeyChunk} {}
inline uint32_t getNumKeys() const { return keysPos.size(); }

inline uint32_t getNumKeys() const { return keysPosAndType.size(); }
inline FactorizedTableSchema* getTableSchema() const { return tableSchema.get(); }

public:
std::vector<std::pair<DataPos, common::LogicalType>> keysPosAndType;
std::vector<std::pair<DataPos, common::LogicalType>> payloadsPosAndType;
std::vector<bool> isPayloadsFlat;
std::vector<bool> isPayloadsInKeyChunk;
inline std::unique_ptr<HashJoinBuildInfo> copy() const {
return std::make_unique<HashJoinBuildInfo>(*this);
}

private:
std::vector<DataPos> keysPos;
std::vector<DataPos> payloadsPos;
std::unique_ptr<FactorizedTableSchema> tableSchema;
};

class HashJoinBuild : public Sink {
public:
HashJoinBuild(std::unique_ptr<ResultSetDescriptor> resultSetDescriptor,
std::shared_ptr<HashJoinSharedState> sharedState, const BuildDataInfo& buildDataInfo,
std::shared_ptr<HashJoinSharedState> sharedState, std::unique_ptr<HashJoinBuildInfo> info,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: HashJoinBuild{std::move(resultSetDescriptor), PhysicalOperatorType::HASH_JOIN_BUILD,
std::move(sharedState), buildDataInfo, std::move(child), id, paramsString} {}
std::move(sharedState), std::move(info), std::move(child), id, paramsString} {}
HashJoinBuild(std::unique_ptr<ResultSetDescriptor> resultSetDescriptor,
PhysicalOperatorType operatorType, std::shared_ptr<HashJoinSharedState> sharedState,
const BuildDataInfo& buildDataInfo, std::unique_ptr<PhysicalOperator> child, uint32_t id,
const std::string& paramsString)
std::unique_ptr<HashJoinBuildInfo> info, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString)
: Sink{std::move(resultSetDescriptor), operatorType, std::move(child), id, paramsString},
sharedState{std::move(sharedState)}, buildDataInfo{buildDataInfo} {}
sharedState{std::move(sharedState)}, info{std::move(info)} {}
~HashJoinBuild() override = default;

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;
Expand All @@ -78,24 +81,21 @@ class HashJoinBuild : public Sink {
void finalize(ExecutionContext* context) override;

inline std::unique_ptr<PhysicalOperator> clone() override {
return make_unique<HashJoinBuild>(resultSetDescriptor->copy(), sharedState, buildDataInfo,
return make_unique<HashJoinBuild>(resultSetDescriptor->copy(), sharedState, info->copy(),
children[0]->clone(), id, paramsString);
}

protected:
// TODO(Guodong/Xiyang): construct schema in mapper.
std::unique_ptr<FactorizedTableSchema> populateTableSchema();
void initGlobalStateInternal(ExecutionContext* context) override;

virtual void initLocalHashTable(
storage::MemoryManager& memoryManager, std::unique_ptr<FactorizedTableSchema> tableSchema);
inline void appendVectors() { hashTable->append(vectorsToAppend); }
virtual void initLocalHashTable(storage::MemoryManager& memoryManager) {
hashTable = std::make_unique<JoinHashTable>(
memoryManager, info->getNumKeys(), info->tableSchema->copy());
}

protected:
std::shared_ptr<HashJoinSharedState> sharedState;
BuildDataInfo buildDataInfo;
std::unique_ptr<HashJoinBuildInfo> info;
std::vector<common::ValueVector*> vectorsToAppend;
std::unique_ptr<JoinHashTable> hashTable;
std::unique_ptr<JoinHashTable> hashTable; // local state
};

} // namespace processor
Expand Down
1 change: 0 additions & 1 deletion src/include/processor/operator/hash_join/hash_join_probe.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class HashJoinProbe : public PhysicalOperator, public SelVectorOverWriter {
probeDataInfo{probeDataInfo} {}

// This constructor is used for cloning only.
// HashJoinProbe do not need to clone hashJoinBuild which is on a different pipeline.
HashJoinProbe(std::shared_ptr<HashJoinSharedState> sharedState, common::JoinType joinType,
bool flatProbe, const ProbeDataInfo& probeDataInfo,
std::unique_ptr<PhysicalOperator> probeChild, uint32_t id, const std::string& paramsString)
Expand Down
1 change: 1 addition & 0 deletions src/include/processor/operator/intersect/intersect.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace processor {

struct IntersectDataInfo {
DataPos keyDataPos;
// TODO(Xiyang): payload is not an accurate name for intersect.
std::vector<DataPos> payloadsDataPos;
};

Expand Down
17 changes: 8 additions & 9 deletions src/include/processor/operator/intersect/intersect_build.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,29 @@ namespace processor {

class IntersectSharedState : public HashJoinSharedState {
public:
IntersectSharedState() = default;

void initEmptyHashTable(storage::MemoryManager& memoryManager, uint64_t numKeyColumns,
std::unique_ptr<FactorizedTableSchema> tableSchema) override;
explicit IntersectSharedState(std::unique_ptr<IntersectHashTable> hashtable)
: HashJoinSharedState{std::move(hashtable)} {}
};

class IntersectBuild : public HashJoinBuild {
public:
IntersectBuild(std::unique_ptr<ResultSetDescriptor> resultSetDescriptor,
std::shared_ptr<IntersectSharedState> sharedState, const BuildDataInfo& buildDataInfo,
std::shared_ptr<IntersectSharedState> sharedState, std::unique_ptr<HashJoinBuildInfo> info,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: HashJoinBuild{std::move(resultSetDescriptor), PhysicalOperatorType::INTERSECT_BUILD,
std::move(sharedState), buildDataInfo, std::move(child), id, paramsString} {}
std::move(sharedState), std::move(info), std::move(child), id, paramsString} {}

inline std::unique_ptr<PhysicalOperator> clone() override {
return make_unique<IntersectBuild>(resultSetDescriptor->copy(),
common::ku_reinterpret_pointer_cast<HashJoinSharedState, IntersectSharedState>(
sharedState),
buildDataInfo, children[0]->clone(), id, paramsString);
info->copy(), children[0]->clone(), id, paramsString);
}

protected:
void initLocalHashTable(storage::MemoryManager& memoryManager,
std::unique_ptr<FactorizedTableSchema> tableSchema) override;
inline void initLocalHashTable(storage::MemoryManager& memoryManager) override {
hashTable = make_unique<IntersectHashTable>(memoryManager, info->getTableSchema()->copy());
}
};

} // namespace processor
Expand Down
7 changes: 4 additions & 3 deletions src/include/processor/result/factorized_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "common/in_mem_overflow_buffer.h"
#include "common/vector/value_vector.h"
#include "processor/data_pos.h"
#include "processor/result/flat_tuple.h"
#include "storage/buffer_manager/memory_manager.h"
#include "storage/storage_structure/disk_overflow_file.h"
Expand Down Expand Up @@ -94,15 +95,15 @@ class DataBlockCollection {

class ColumnSchema {
public:
ColumnSchema(bool isUnflat, uint32_t dataChunksPos, uint32_t numBytes)
ColumnSchema(bool isUnflat, data_chunk_pos_t dataChunksPos, uint32_t numBytes)
: isUnflat{isUnflat}, dataChunkPos{dataChunksPos}, numBytes{numBytes}, mayContainNulls{
false} {}

ColumnSchema(const ColumnSchema& other);

inline bool isFlat() const { return !isUnflat; }

inline uint32_t getDataChunkPos() const { return dataChunkPos; }
inline data_chunk_pos_t getDataChunkPos() const { return dataChunkPos; }

inline uint32_t getNumBytes() const { return numBytes; }

Expand All @@ -119,7 +120,7 @@ class ColumnSchema {
private:
// We need isUnflat, dataChunkPos to know the factorization structure in the factorizedTable.
bool isUnflat;
uint32_t dataChunkPos;
data_chunk_pos_t dataChunkPos;
uint32_t numBytes;
bool mayContainNulls;
};
Expand Down
8 changes: 5 additions & 3 deletions src/processor/mapper/map_cross_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalCrossProductToPhysical(
outVecPos.emplace_back(outSchema->getExpressionPos(*expression));
colIndicesToScan.push_back(i);
}
auto info =
std::make_unique<CrossProductInfo>(std::move(outVecPos), std::move(colIndicesToScan));
auto sharedState = resultCollector->getSharedState();
auto localState = std::make_unique<CrossProductLocalState>(
sharedState->getTable(), sharedState->getMaxMorselSize());
return make_unique<CrossProduct>(std::move(outVecPos), std::move(colIndicesToScan),
std::move(localState), std::move(probeSidePrevOperator), std::move(resultCollector),
getOperatorID(), logicalCrossProduct->getExpressionsForPrinting());
return make_unique<CrossProduct>(std::move(info), std::move(localState),
std::move(probeSidePrevOperator), std::move(resultCollector), getOperatorID(),
logicalCrossProduct->getExpressionsForPrinting());
}

} // namespace processor
Expand Down
Loading

0 comments on commit 139dbe0

Please sign in to comment.