Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework hash join build mapper #1659

Merged
merged 1 commit into from
Jun 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,20 @@
return result;
}

expression_vector ExpressionUtil::excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude) {
expression_set excludeSet;
for (auto& expression : expressionsToExclude) {
excludeSet.insert(expression);
}
expression_vector result;
for (auto& expression : expressions) {
if (!excludeSet.contains(expression)) {
result.push_back(expression);
}
}
return result;
}

Check warning on line 125 in src/binder/expression/expression.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression/expression.cpp#L125

Added line #L125 was not covered by tests

} // namespace binder
} // namespace kuzu
6 changes: 4 additions & 2 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,16 @@ struct ExpressionEquality {
}
};

class ExpressionUtil {
public:
struct ExpressionUtil {
static bool allExpressionsHaveDataType(
expression_vector& expressions, common::LogicalTypeID dataTypeID);

static uint32_t find(Expression* target, expression_vector expressions);

static std::string toString(const expression_vector& expressions);

static expression_vector excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude);
};

} // namespace binder
Expand Down
4 changes: 2 additions & 2 deletions src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace kuzu {
namespace processor {

struct BuildDataInfo;
struct HashJoinBuildInfo;
struct AggregateInputInfo;

class PlanMapper {
Expand Down Expand Up @@ -114,7 +114,7 @@ class PlanMapper {

inline uint32_t getOperatorID() { return physicalOperatorID++; }

BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema,
std::unique_ptr<HashJoinBuildInfo> createHashBuildInfo(const planner::Schema& buildSideSchema,
const binder::expression_vector& keys, const binder::expression_vector& payloads);

std::unique_ptr<PhysicalOperator> createHashAggregate(
Expand Down
37 changes: 25 additions & 12 deletions src/include/processor/operator/cross_product.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,53 @@ class CrossProductLocalState {
uint64_t startIdx = 0u;
};

class CrossProductInfo {
friend class CrossProduct;

public:
CrossProductInfo(std::vector<DataPos> outVecPos, std::vector<ft_col_idx_t> colIndicesToScan)
: outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)} {}
CrossProductInfo(const CrossProductInfo& other)
: outVecPos{other.outVecPos}, colIndicesToScan{other.colIndicesToScan} {}

inline std::unique_ptr<CrossProductInfo> copy() const {
return std::make_unique<CrossProductInfo>(*this);
}

private:
std::vector<DataPos> outVecPos;
std::vector<ft_col_idx_t> colIndicesToScan;
};

class CrossProduct : public PhysicalOperator {
public:
CrossProduct(std::vector<DataPos> outVecPos, std::vector<uint32_t> colIndicesToScan,
CrossProduct(std::unique_ptr<CrossProductInfo> info,
std::unique_ptr<CrossProductLocalState> localState,
std::unique_ptr<PhysicalOperator> probeChild, std::unique_ptr<PhysicalOperator> buildChild,
uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::CROSS_PRODUCT, std::move(probeChild),
std::move(buildChild), id, paramsString},
outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)},
localState{std::move(localState)} {}
info{std::move(info)}, localState{std::move(localState)} {}

// Clone only.
CrossProduct(std::vector<DataPos> outVecPos, std::vector<uint32_t> colIndicesToScan,
CrossProduct(std::unique_ptr<CrossProductInfo> info,
std::unique_ptr<CrossProductLocalState> localState, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::CROSS_PRODUCT, std::move(child), id, paramsString},
outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)},
localState{std::move(localState)} {}
info{std::move(info)}, localState{std::move(localState)} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool getNextTuplesInternal(ExecutionContext* context) override;

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<CrossProduct>(outVecPos, colIndicesToScan, localState->copy(),
children[0]->clone(), id, paramsString);
return std::make_unique<CrossProduct>(
info->copy(), localState->copy(), children[0]->clone(), id, paramsString);
}

private:
std::vector<DataPos> outVecPos;
std::vector<uint32_t> colIndicesToScan;

std::unique_ptr<CrossProductInfo> info;
std::unique_ptr<CrossProductLocalState> localState;

std::vector<common::ValueVector*> vectorsToScan;
};

Expand Down
72 changes: 36 additions & 36 deletions src/include/processor/operator/hash_join/hash_join_build.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
namespace kuzu {
namespace processor {

class HashJoinBuild;

// This is a shared state between HashJoinBuild and HashJoinProbe operators.
// Each clone of these two operators will share the same state.
// Inside the state, we keep the materialized tuples in factorizedTable, which are merged by each
Expand All @@ -18,13 +20,11 @@ namespace processor {
// task/pipeline, and probed by the HashJoinProbe operators.
class HashJoinSharedState {
public:
HashJoinSharedState() = default;
explicit HashJoinSharedState(std::unique_ptr<JoinHashTable> hashTable)
: hashTable{std::move(hashTable)} {};

virtual ~HashJoinSharedState() = default;

virtual void initEmptyHashTable(storage::MemoryManager& memoryManager, uint64_t numKeyColumns,
std::unique_ptr<FactorizedTableSchema> tableSchema);

void mergeLocalHashTable(JoinHashTable& localHashTable);

inline JoinHashTable* getHashTable() { return hashTable.get(); }
Expand All @@ -34,42 +34,45 @@ class HashJoinSharedState {
std::unique_ptr<JoinHashTable> hashTable;
};

struct BuildDataInfo {
class HashJoinBuildInfo {
friend class HashJoinBuild;

public:
BuildDataInfo(std::vector<std::pair<DataPos, common::LogicalType>> keysPosAndType,
std::vector<std::pair<DataPos, common::LogicalType>> payloadsPosAndType,
std::vector<bool> isPayloadsFlat, std::vector<bool> isPayloadsInKeyChunk)
: keysPosAndType{std::move(keysPosAndType)}, payloadsPosAndType{std::move(
payloadsPosAndType)},
isPayloadsFlat{std::move(isPayloadsFlat)}, isPayloadsInKeyChunk{
std::move(isPayloadsInKeyChunk)} {}
HashJoinBuildInfo(std::vector<DataPos> keysPos, std::vector<DataPos> payloadsPos,
std::unique_ptr<FactorizedTableSchema> tableSchema)
: keysPos{std::move(keysPos)}, payloadsPos{std::move(payloadsPos)}, tableSchema{std::move(
tableSchema)} {}
HashJoinBuildInfo(const HashJoinBuildInfo& other)
: keysPos{other.keysPos}, payloadsPos{other.payloadsPos}, tableSchema{
other.tableSchema->copy()} {}

BuildDataInfo(const BuildDataInfo& other)
: BuildDataInfo{other.keysPosAndType, other.payloadsPosAndType, other.isPayloadsFlat,
other.isPayloadsInKeyChunk} {}
inline uint32_t getNumKeys() const { return keysPos.size(); }

inline uint32_t getNumKeys() const { return keysPosAndType.size(); }
inline FactorizedTableSchema* getTableSchema() const { return tableSchema.get(); }

public:
std::vector<std::pair<DataPos, common::LogicalType>> keysPosAndType;
std::vector<std::pair<DataPos, common::LogicalType>> payloadsPosAndType;
std::vector<bool> isPayloadsFlat;
std::vector<bool> isPayloadsInKeyChunk;
inline std::unique_ptr<HashJoinBuildInfo> copy() const {
return std::make_unique<HashJoinBuildInfo>(*this);
}

private:
std::vector<DataPos> keysPos;
std::vector<DataPos> payloadsPos;
std::unique_ptr<FactorizedTableSchema> tableSchema;
};

class HashJoinBuild : public Sink {
public:
HashJoinBuild(std::unique_ptr<ResultSetDescriptor> resultSetDescriptor,
std::shared_ptr<HashJoinSharedState> sharedState, const BuildDataInfo& buildDataInfo,
std::shared_ptr<HashJoinSharedState> sharedState, std::unique_ptr<HashJoinBuildInfo> info,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: HashJoinBuild{std::move(resultSetDescriptor), PhysicalOperatorType::HASH_JOIN_BUILD,
std::move(sharedState), buildDataInfo, std::move(child), id, paramsString} {}
std::move(sharedState), std::move(info), std::move(child), id, paramsString} {}
HashJoinBuild(std::unique_ptr<ResultSetDescriptor> resultSetDescriptor,
PhysicalOperatorType operatorType, std::shared_ptr<HashJoinSharedState> sharedState,
const BuildDataInfo& buildDataInfo, std::unique_ptr<PhysicalOperator> child, uint32_t id,
const std::string& paramsString)
std::unique_ptr<HashJoinBuildInfo> info, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString)
: Sink{std::move(resultSetDescriptor), operatorType, std::move(child), id, paramsString},
sharedState{std::move(sharedState)}, buildDataInfo{buildDataInfo} {}
sharedState{std::move(sharedState)}, info{std::move(info)} {}
~HashJoinBuild() override = default;

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;
Expand All @@ -78,24 +81,21 @@ class HashJoinBuild : public Sink {
void finalize(ExecutionContext* context) override;

inline std::unique_ptr<PhysicalOperator> clone() override {
return make_unique<HashJoinBuild>(resultSetDescriptor->copy(), sharedState, buildDataInfo,
return make_unique<HashJoinBuild>(resultSetDescriptor->copy(), sharedState, info->copy(),
children[0]->clone(), id, paramsString);
}

protected:
// TODO(Guodong/Xiyang): construct schema in mapper.
std::unique_ptr<FactorizedTableSchema> populateTableSchema();
void initGlobalStateInternal(ExecutionContext* context) override;

virtual void initLocalHashTable(
storage::MemoryManager& memoryManager, std::unique_ptr<FactorizedTableSchema> tableSchema);
inline void appendVectors() { hashTable->append(vectorsToAppend); }
virtual void initLocalHashTable(storage::MemoryManager& memoryManager) {
hashTable = std::make_unique<JoinHashTable>(
memoryManager, info->getNumKeys(), info->tableSchema->copy());
}

protected:
std::shared_ptr<HashJoinSharedState> sharedState;
BuildDataInfo buildDataInfo;
std::unique_ptr<HashJoinBuildInfo> info;
std::vector<common::ValueVector*> vectorsToAppend;
std::unique_ptr<JoinHashTable> hashTable;
std::unique_ptr<JoinHashTable> hashTable; // local state
};

} // namespace processor
Expand Down
1 change: 0 additions & 1 deletion src/include/processor/operator/hash_join/hash_join_probe.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class HashJoinProbe : public PhysicalOperator, public SelVectorOverWriter {
probeDataInfo{probeDataInfo} {}

// This constructor is used for cloning only.
// HashJoinProbe do not need to clone hashJoinBuild which is on a different pipeline.
HashJoinProbe(std::shared_ptr<HashJoinSharedState> sharedState, common::JoinType joinType,
bool flatProbe, const ProbeDataInfo& probeDataInfo,
std::unique_ptr<PhysicalOperator> probeChild, uint32_t id, const std::string& paramsString)
Expand Down
1 change: 1 addition & 0 deletions src/include/processor/operator/intersect/intersect.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace processor {

struct IntersectDataInfo {
DataPos keyDataPos;
// TODO(Xiyang): payload is not an accurate name for intersect.
std::vector<DataPos> payloadsDataPos;
};

Expand Down
17 changes: 8 additions & 9 deletions src/include/processor/operator/intersect/intersect_build.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,29 @@ namespace processor {

class IntersectSharedState : public HashJoinSharedState {
public:
IntersectSharedState() = default;

void initEmptyHashTable(storage::MemoryManager& memoryManager, uint64_t numKeyColumns,
std::unique_ptr<FactorizedTableSchema> tableSchema) override;
explicit IntersectSharedState(std::unique_ptr<IntersectHashTable> hashtable)
: HashJoinSharedState{std::move(hashtable)} {}
};

class IntersectBuild : public HashJoinBuild {
public:
IntersectBuild(std::unique_ptr<ResultSetDescriptor> resultSetDescriptor,
std::shared_ptr<IntersectSharedState> sharedState, const BuildDataInfo& buildDataInfo,
std::shared_ptr<IntersectSharedState> sharedState, std::unique_ptr<HashJoinBuildInfo> info,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: HashJoinBuild{std::move(resultSetDescriptor), PhysicalOperatorType::INTERSECT_BUILD,
std::move(sharedState), buildDataInfo, std::move(child), id, paramsString} {}
std::move(sharedState), std::move(info), std::move(child), id, paramsString} {}

inline std::unique_ptr<PhysicalOperator> clone() override {
return make_unique<IntersectBuild>(resultSetDescriptor->copy(),
common::ku_reinterpret_pointer_cast<HashJoinSharedState, IntersectSharedState>(
sharedState),
buildDataInfo, children[0]->clone(), id, paramsString);
info->copy(), children[0]->clone(), id, paramsString);
}

protected:
void initLocalHashTable(storage::MemoryManager& memoryManager,
std::unique_ptr<FactorizedTableSchema> tableSchema) override;
inline void initLocalHashTable(storage::MemoryManager& memoryManager) override {
hashTable = make_unique<IntersectHashTable>(memoryManager, info->getTableSchema()->copy());
}
};

} // namespace processor
Expand Down
7 changes: 4 additions & 3 deletions src/include/processor/result/factorized_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "common/in_mem_overflow_buffer.h"
#include "common/vector/value_vector.h"
#include "processor/data_pos.h"
#include "processor/result/flat_tuple.h"
#include "storage/buffer_manager/memory_manager.h"
#include "storage/storage_structure/disk_overflow_file.h"
Expand Down Expand Up @@ -94,15 +95,15 @@ class DataBlockCollection {

class ColumnSchema {
public:
ColumnSchema(bool isUnflat, uint32_t dataChunksPos, uint32_t numBytes)
ColumnSchema(bool isUnflat, data_chunk_pos_t dataChunksPos, uint32_t numBytes)
: isUnflat{isUnflat}, dataChunkPos{dataChunksPos}, numBytes{numBytes}, mayContainNulls{
false} {}

ColumnSchema(const ColumnSchema& other);

inline bool isFlat() const { return !isUnflat; }

inline uint32_t getDataChunkPos() const { return dataChunkPos; }
inline data_chunk_pos_t getDataChunkPos() const { return dataChunkPos; }

inline uint32_t getNumBytes() const { return numBytes; }

Expand All @@ -119,7 +120,7 @@ class ColumnSchema {
private:
// We need isUnflat, dataChunkPos to know the factorization structure in the factorizedTable.
bool isUnflat;
uint32_t dataChunkPos;
data_chunk_pos_t dataChunkPos;
uint32_t numBytes;
bool mayContainNulls;
};
Expand Down
8 changes: 5 additions & 3 deletions src/processor/mapper/map_cross_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalCrossProductToPhysical(
outVecPos.emplace_back(outSchema->getExpressionPos(*expression));
colIndicesToScan.push_back(i);
}
auto info =
std::make_unique<CrossProductInfo>(std::move(outVecPos), std::move(colIndicesToScan));
auto sharedState = resultCollector->getSharedState();
auto localState = std::make_unique<CrossProductLocalState>(
sharedState->getTable(), sharedState->getMaxMorselSize());
return make_unique<CrossProduct>(std::move(outVecPos), std::move(colIndicesToScan),
std::move(localState), std::move(probeSidePrevOperator), std::move(resultCollector),
getOperatorID(), logicalCrossProduct->getExpressionsForPrinting());
return make_unique<CrossProduct>(std::move(info), std::move(localState),
std::move(probeSidePrevOperator), std::move(resultCollector), getOperatorID(),
logicalCrossProduct->getExpressionsForPrinting());
}

} // namespace processor
Expand Down
Loading
Loading