diff --git a/src/binder/expression/expression.cpp b/src/binder/expression/expression.cpp index c83050becb..26409e2cc9 100644 --- a/src/binder/expression/expression.cpp +++ b/src/binder/expression/expression.cpp @@ -109,5 +109,20 @@ std::string ExpressionUtil::toString(const expression_vector& expressions) { return result; } +expression_vector ExpressionUtil::excludeExpressions( + const expression_vector& expressions, const expression_vector& expressionsToExclude) { + expression_set excludeSet; + for (auto& expression : expressionsToExclude) { + excludeSet.insert(expression); + } + expression_vector result; + for (auto& expression : expressions) { + if (!excludeSet.contains(expression)) { + result.push_back(expression); + } + } + return result; +} + } // namespace binder } // namespace kuzu diff --git a/src/include/binder/expression/expression.h b/src/include/binder/expression/expression.h index 31c86d1151..204064cf1f 100644 --- a/src/include/binder/expression/expression.h +++ b/src/include/binder/expression/expression.h @@ -128,14 +128,16 @@ struct ExpressionEquality { } }; -class ExpressionUtil { -public: +struct ExpressionUtil { static bool allExpressionsHaveDataType( expression_vector& expressions, common::LogicalTypeID dataTypeID); static uint32_t find(Expression* target, expression_vector expressions); static std::string toString(const expression_vector& expressions); + + static expression_vector excludeExpressions( + const expression_vector& expressions, const expression_vector& expressionsToExclude); }; } // namespace binder diff --git a/src/include/processor/mapper/plan_mapper.h b/src/include/processor/mapper/plan_mapper.h index e8c334e49e..9a6eaa3f14 100644 --- a/src/include/processor/mapper/plan_mapper.h +++ b/src/include/processor/mapper/plan_mapper.h @@ -13,7 +13,7 @@ namespace kuzu { namespace processor { -struct BuildDataInfo; +struct HashJoinBuildInfo; struct AggregateInputInfo; class PlanMapper { @@ -114,7 +114,7 @@ class PlanMapper { inline uint32_t getOperatorID() { return physicalOperatorID++; } - BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema, + std::unique_ptr createHashBuildInfo(const planner::Schema& buildSideSchema, const binder::expression_vector& keys, const binder::expression_vector& payloads); std::unique_ptr createHashAggregate( diff --git a/src/include/processor/operator/cross_product.h b/src/include/processor/operator/cross_product.h index 9dbaa8ecad..9d94c318c0 100644 --- a/src/include/processor/operator/cross_product.h +++ b/src/include/processor/operator/cross_product.h @@ -25,40 +25,53 @@ class CrossProductLocalState { uint64_t startIdx = 0u; }; +class CrossProductInfo { + friend class CrossProduct; + +public: + CrossProductInfo(std::vector outVecPos, std::vector colIndicesToScan) + : outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)} {} + CrossProductInfo(const CrossProductInfo& other) + : outVecPos{other.outVecPos}, colIndicesToScan{other.colIndicesToScan} {} + + inline std::unique_ptr copy() const { + return std::make_unique(*this); + } + +private: + std::vector outVecPos; + std::vector colIndicesToScan; +}; + class CrossProduct : public PhysicalOperator { public: - CrossProduct(std::vector outVecPos, std::vector colIndicesToScan, + CrossProduct(std::unique_ptr info, std::unique_ptr localState, std::unique_ptr probeChild, std::unique_ptr buildChild, uint32_t id, const std::string& paramsString) : PhysicalOperator{PhysicalOperatorType::CROSS_PRODUCT, std::move(probeChild), std::move(buildChild), id, paramsString}, - outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)}, - localState{std::move(localState)} {} + info{std::move(info)}, localState{std::move(localState)} {} // Clone only. - CrossProduct(std::vector outVecPos, std::vector colIndicesToScan, + CrossProduct(std::unique_ptr info, std::unique_ptr localState, std::unique_ptr child, uint32_t id, const std::string& paramsString) : PhysicalOperator{PhysicalOperatorType::CROSS_PRODUCT, std::move(child), id, paramsString}, - outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)}, - localState{std::move(localState)} {} + info{std::move(info)}, localState{std::move(localState)} {} void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; bool getNextTuplesInternal(ExecutionContext* context) override; std::unique_ptr clone() override { - return std::make_unique(outVecPos, colIndicesToScan, localState->copy(), - children[0]->clone(), id, paramsString); + return std::make_unique( + info->copy(), localState->copy(), children[0]->clone(), id, paramsString); } private: - std::vector outVecPos; - std::vector colIndicesToScan; - + std::unique_ptr info; std::unique_ptr localState; - std::vector vectorsToScan; }; diff --git a/src/include/processor/operator/hash_join/hash_join_build.h b/src/include/processor/operator/hash_join/hash_join_build.h index 24e4cc2103..ec16627518 100644 --- a/src/include/processor/operator/hash_join/hash_join_build.h +++ b/src/include/processor/operator/hash_join/hash_join_build.h @@ -10,6 +10,8 @@ namespace kuzu { namespace processor { +class HashJoinBuild; + // This is a shared state between HashJoinBuild and HashJoinProbe operators. // Each clone of these two operators will share the same state. // Inside the state, we keep the materialized tuples in factorizedTable, which are merged by each @@ -18,13 +20,11 @@ namespace processor { // task/pipeline, and probed by the HashJoinProbe operators. class HashJoinSharedState { public: - HashJoinSharedState() = default; + explicit HashJoinSharedState(std::unique_ptr hashTable) + : hashTable{std::move(hashTable)} {}; virtual ~HashJoinSharedState() = default; - virtual void initEmptyHashTable(storage::MemoryManager& memoryManager, uint64_t numKeyColumns, - std::unique_ptr tableSchema); - void mergeLocalHashTable(JoinHashTable& localHashTable); inline JoinHashTable* getHashTable() { return hashTable.get(); } @@ -34,42 +34,45 @@ class HashJoinSharedState { std::unique_ptr hashTable; }; -struct BuildDataInfo { +class HashJoinBuildInfo { + friend class HashJoinBuild; + public: - BuildDataInfo(std::vector> keysPosAndType, - std::vector> payloadsPosAndType, - std::vector isPayloadsFlat, std::vector isPayloadsInKeyChunk) - : keysPosAndType{std::move(keysPosAndType)}, payloadsPosAndType{std::move( - payloadsPosAndType)}, - isPayloadsFlat{std::move(isPayloadsFlat)}, isPayloadsInKeyChunk{ - std::move(isPayloadsInKeyChunk)} {} + HashJoinBuildInfo(std::vector keysPos, std::vector payloadsPos, + std::unique_ptr tableSchema) + : keysPos{std::move(keysPos)}, payloadsPos{std::move(payloadsPos)}, tableSchema{std::move( + tableSchema)} {} + HashJoinBuildInfo(const HashJoinBuildInfo& other) + : keysPos{other.keysPos}, payloadsPos{other.payloadsPos}, tableSchema{ + other.tableSchema->copy()} {} - BuildDataInfo(const BuildDataInfo& other) - : BuildDataInfo{other.keysPosAndType, other.payloadsPosAndType, other.isPayloadsFlat, - other.isPayloadsInKeyChunk} {} + inline uint32_t getNumKeys() const { return keysPos.size(); } - inline uint32_t getNumKeys() const { return keysPosAndType.size(); } + inline FactorizedTableSchema* getTableSchema() const { return tableSchema.get(); } -public: - std::vector> keysPosAndType; - std::vector> payloadsPosAndType; - std::vector isPayloadsFlat; - std::vector isPayloadsInKeyChunk; + inline std::unique_ptr copy() const { + return std::make_unique(*this); + } + +private: + std::vector keysPos; + std::vector payloadsPos; + std::unique_ptr tableSchema; }; class HashJoinBuild : public Sink { public: HashJoinBuild(std::unique_ptr resultSetDescriptor, - std::shared_ptr sharedState, const BuildDataInfo& buildDataInfo, + std::shared_ptr sharedState, std::unique_ptr info, std::unique_ptr child, uint32_t id, const std::string& paramsString) : HashJoinBuild{std::move(resultSetDescriptor), PhysicalOperatorType::HASH_JOIN_BUILD, - std::move(sharedState), buildDataInfo, std::move(child), id, paramsString} {} + std::move(sharedState), std::move(info), std::move(child), id, paramsString} {} HashJoinBuild(std::unique_ptr resultSetDescriptor, PhysicalOperatorType operatorType, std::shared_ptr sharedState, - const BuildDataInfo& buildDataInfo, std::unique_ptr child, uint32_t id, - const std::string& paramsString) + std::unique_ptr info, std::unique_ptr child, + uint32_t id, const std::string& paramsString) : Sink{std::move(resultSetDescriptor), operatorType, std::move(child), id, paramsString}, - sharedState{std::move(sharedState)}, buildDataInfo{buildDataInfo} {} + sharedState{std::move(sharedState)}, info{std::move(info)} {} ~HashJoinBuild() override = default; void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; @@ -78,24 +81,21 @@ class HashJoinBuild : public Sink { void finalize(ExecutionContext* context) override; inline std::unique_ptr clone() override { - return make_unique(resultSetDescriptor->copy(), sharedState, buildDataInfo, + return make_unique(resultSetDescriptor->copy(), sharedState, info->copy(), children[0]->clone(), id, paramsString); } protected: - // TODO(Guodong/Xiyang): construct schema in mapper. - std::unique_ptr populateTableSchema(); - void initGlobalStateInternal(ExecutionContext* context) override; - - virtual void initLocalHashTable( - storage::MemoryManager& memoryManager, std::unique_ptr tableSchema); - inline void appendVectors() { hashTable->append(vectorsToAppend); } + virtual void initLocalHashTable(storage::MemoryManager& memoryManager) { + hashTable = std::make_unique( + memoryManager, info->getNumKeys(), info->tableSchema->copy()); + } protected: std::shared_ptr sharedState; - BuildDataInfo buildDataInfo; + std::unique_ptr info; std::vector vectorsToAppend; - std::unique_ptr hashTable; + std::unique_ptr hashTable; // local state }; } // namespace processor diff --git a/src/include/processor/operator/hash_join/hash_join_probe.h b/src/include/processor/operator/hash_join/hash_join_probe.h index 748cc2747b..b7c1bc8091 100644 --- a/src/include/processor/operator/hash_join/hash_join_probe.h +++ b/src/include/processor/operator/hash_join/hash_join_probe.h @@ -59,7 +59,6 @@ class HashJoinProbe : public PhysicalOperator, public SelVectorOverWriter { probeDataInfo{probeDataInfo} {} // This constructor is used for cloning only. - // HashJoinProbe do not need to clone hashJoinBuild which is on a different pipeline. HashJoinProbe(std::shared_ptr sharedState, common::JoinType joinType, bool flatProbe, const ProbeDataInfo& probeDataInfo, std::unique_ptr probeChild, uint32_t id, const std::string& paramsString) diff --git a/src/include/processor/operator/intersect/intersect.h b/src/include/processor/operator/intersect/intersect.h index a5850ee827..f5246b3121 100644 --- a/src/include/processor/operator/intersect/intersect.h +++ b/src/include/processor/operator/intersect/intersect.h @@ -8,6 +8,7 @@ namespace processor { struct IntersectDataInfo { DataPos keyDataPos; + // TODO(Xiyang): payload is not an accurate name for intersect. std::vector payloadsDataPos; }; diff --git a/src/include/processor/operator/intersect/intersect_build.h b/src/include/processor/operator/intersect/intersect_build.h index 572b089572..74692d74c5 100644 --- a/src/include/processor/operator/intersect/intersect_build.h +++ b/src/include/processor/operator/intersect/intersect_build.h @@ -8,30 +8,29 @@ namespace processor { class IntersectSharedState : public HashJoinSharedState { public: - IntersectSharedState() = default; - - void initEmptyHashTable(storage::MemoryManager& memoryManager, uint64_t numKeyColumns, - std::unique_ptr tableSchema) override; + explicit IntersectSharedState(std::unique_ptr hashtable) + : HashJoinSharedState{std::move(hashtable)} {} }; class IntersectBuild : public HashJoinBuild { public: IntersectBuild(std::unique_ptr resultSetDescriptor, - std::shared_ptr sharedState, const BuildDataInfo& buildDataInfo, + std::shared_ptr sharedState, std::unique_ptr info, std::unique_ptr child, uint32_t id, const std::string& paramsString) : HashJoinBuild{std::move(resultSetDescriptor), PhysicalOperatorType::INTERSECT_BUILD, - std::move(sharedState), buildDataInfo, std::move(child), id, paramsString} {} + std::move(sharedState), std::move(info), std::move(child), id, paramsString} {} inline std::unique_ptr clone() override { return make_unique(resultSetDescriptor->copy(), common::ku_reinterpret_pointer_cast( sharedState), - buildDataInfo, children[0]->clone(), id, paramsString); + info->copy(), children[0]->clone(), id, paramsString); } protected: - void initLocalHashTable(storage::MemoryManager& memoryManager, - std::unique_ptr tableSchema) override; + inline void initLocalHashTable(storage::MemoryManager& memoryManager) override { + hashTable = make_unique(memoryManager, info->getTableSchema()->copy()); + } }; } // namespace processor diff --git a/src/include/processor/result/factorized_table.h b/src/include/processor/result/factorized_table.h index 2f02f71c9a..c87b6db253 100644 --- a/src/include/processor/result/factorized_table.h +++ b/src/include/processor/result/factorized_table.h @@ -5,6 +5,7 @@ #include "common/in_mem_overflow_buffer.h" #include "common/vector/value_vector.h" +#include "processor/data_pos.h" #include "processor/result/flat_tuple.h" #include "storage/buffer_manager/memory_manager.h" #include "storage/storage_structure/disk_overflow_file.h" @@ -94,7 +95,7 @@ class DataBlockCollection { class ColumnSchema { public: - ColumnSchema(bool isUnflat, uint32_t dataChunksPos, uint32_t numBytes) + ColumnSchema(bool isUnflat, data_chunk_pos_t dataChunksPos, uint32_t numBytes) : isUnflat{isUnflat}, dataChunkPos{dataChunksPos}, numBytes{numBytes}, mayContainNulls{ false} {} @@ -102,7 +103,7 @@ class ColumnSchema { inline bool isFlat() const { return !isUnflat; } - inline uint32_t getDataChunkPos() const { return dataChunkPos; } + inline data_chunk_pos_t getDataChunkPos() const { return dataChunkPos; } inline uint32_t getNumBytes() const { return numBytes; } @@ -119,7 +120,7 @@ class ColumnSchema { private: // We need isUnflat, dataChunkPos to know the factorization structure in the factorizedTable. bool isUnflat; - uint32_t dataChunkPos; + data_chunk_pos_t dataChunkPos; uint32_t numBytes; bool mayContainNulls; }; diff --git a/src/processor/mapper/map_cross_product.cpp b/src/processor/mapper/map_cross_product.cpp index ea1a364ce0..a027285457 100644 --- a/src/processor/mapper/map_cross_product.cpp +++ b/src/processor/mapper/map_cross_product.cpp @@ -26,12 +26,14 @@ std::unique_ptr PlanMapper::mapLogicalCrossProductToPhysical( outVecPos.emplace_back(outSchema->getExpressionPos(*expression)); colIndicesToScan.push_back(i); } + auto info = + std::make_unique(std::move(outVecPos), std::move(colIndicesToScan)); auto sharedState = resultCollector->getSharedState(); auto localState = std::make_unique( sharedState->getTable(), sharedState->getMaxMorselSize()); - return make_unique(std::move(outVecPos), std::move(colIndicesToScan), - std::move(localState), std::move(probeSidePrevOperator), std::move(resultCollector), - getOperatorID(), logicalCrossProduct->getExpressionsForPrinting()); + return make_unique(std::move(info), std::move(localState), + std::move(probeSidePrevOperator), std::move(resultCollector), getOperatorID(), + logicalCrossProduct->getExpressionsForPrinting()); } } // namespace processor diff --git a/src/processor/mapper/map_hash_join.cpp b/src/processor/mapper/map_hash_join.cpp index 86017744ff..00faae140d 100644 --- a/src/processor/mapper/map_hash_join.cpp +++ b/src/processor/mapper/map_hash_join.cpp @@ -9,31 +9,45 @@ using namespace kuzu::planner; namespace kuzu { namespace processor { -BuildDataInfo PlanMapper::generateBuildDataInfo(const Schema& buildSideSchema, - const expression_vector& keys, const expression_vector& payloads) { - std::vector> buildKeysPosAndType, - buildPayloadsPosAndTypes; - std::vector isBuildPayloadsFlat, isBuildPayloadsInKeyChunk; - std::vector isBuildDataChunkContainKeys(buildSideSchema.getNumGroups(), false); - std::unordered_set joinKeyNames; +std::unique_ptr PlanMapper::createHashBuildInfo( + const Schema& buildSchema, const expression_vector& keys, const expression_vector& payloads) { + planner::f_group_pos_set keyGroupPosSet; + std::vector keysPos; + std::vector payloadsPos; + auto tableSchema = std::make_unique(); for (auto& key : keys) { - auto buildSideKeyPos = DataPos(buildSideSchema.getExpressionPos(*key)); - isBuildDataChunkContainKeys[buildSideKeyPos.dataChunkPos] = true; - buildKeysPosAndType.emplace_back(buildSideKeyPos, common::LogicalTypeID::INTERNAL_ID); - joinKeyNames.insert(key->getUniqueName()); + auto pos = DataPos(buildSchema.getExpressionPos(*key)); + keyGroupPosSet.insert(pos.dataChunkPos); + // Keys are always stored in flat column. + auto columnSchema = std::make_unique(false /* isUnFlat */, pos.dataChunkPos, + FactorizedTable::getDataTypeSize(key->dataType)); + tableSchema->appendColumn(std::move(columnSchema)); + keysPos.push_back(pos); } for (auto& payload : payloads) { - if (joinKeyNames.find(payload->getUniqueName()) != joinKeyNames.end()) { - continue; + auto pos = DataPos(buildSchema.getExpressionPos(*payload)); + std::unique_ptr columnSchema; + if (keyGroupPosSet.contains(pos.dataChunkPos) || + buildSchema.getGroup(pos.dataChunkPos)->isFlat()) { + // Payloads need to be stored in flat column in 2 cases + // 1. payload is in the same chunk as a key. Since keys are always stored as flat, + // payloads must also be stored as flat. + // 2. payload is in flat chunk + columnSchema = std::make_unique(false /* isUnFlat */, pos.dataChunkPos, + FactorizedTable::getDataTypeSize(payload->dataType)); + } else { + columnSchema = std::make_unique( + true /* isUnFlat */, pos.dataChunkPos, (uint32_t)sizeof(common::overflow_value_t)); } - auto payloadPos = DataPos(buildSideSchema.getExpressionPos(*payload)); - buildPayloadsPosAndTypes.emplace_back(payloadPos, payload->dataType); - auto payloadGroup = buildSideSchema.getGroup(payloadPos.dataChunkPos); - isBuildPayloadsFlat.push_back(payloadGroup->isFlat()); - isBuildPayloadsInKeyChunk.push_back(isBuildDataChunkContainKeys[payloadPos.dataChunkPos]); + tableSchema->appendColumn(std::move(columnSchema)); + payloadsPos.push_back(pos); } - return BuildDataInfo(buildKeysPosAndType, buildPayloadsPosAndTypes, isBuildPayloadsFlat, - isBuildPayloadsInKeyChunk); + auto pointerType = common::LogicalType(common::LogicalTypeID::INT64); + auto pointerColumn = std::make_unique(false /* isUnFlat */, + INVALID_DATA_CHUNK_POS, FactorizedTable::getDataTypeSize(pointerType)); + tableSchema->appendColumn(std::move(pointerColumn)); + return std::make_unique( + std::move(keysPos), std::move(payloadsPos), std::move(tableSchema)); } std::unique_ptr PlanMapper::mapLogicalHashJoinToPhysical( @@ -51,26 +65,26 @@ std::unique_ptr PlanMapper::mapLogicalHashJoinToPhysical( buildSidePrevOperator = mapLogicalOperatorToPhysical(hashJoin->getChild(1)); probeSidePrevOperator = mapLogicalOperatorToPhysical(hashJoin->getChild(0)); } - // Populate build side and probe side std::vector positions auto paramsString = hashJoin->getExpressionsForPrinting(); - auto buildDataInfo = generateBuildDataInfo( - *buildSchema, hashJoin->getJoinNodeIDs(), hashJoin->getExpressionsToMaterialize()); + auto payloads = ExpressionUtil::excludeExpressions( + hashJoin->getExpressionsToMaterialize(), hashJoin->getJoinNodeIDs()); + // Create build + auto buildInfo = createHashBuildInfo(*buildSchema, hashJoin->getJoinNodeIDs(), payloads); + auto globalHashTable = std::make_unique( + *memoryManager, buildInfo->getNumKeys(), buildInfo->getTableSchema()->copy()); + auto sharedState = std::make_shared(std::move(globalHashTable)); + auto hashJoinBuild = + make_unique(std::make_unique(buildSchema), sharedState, + std::move(buildInfo), std::move(buildSidePrevOperator), getOperatorID(), paramsString); + // Create probe std::vector probeKeysDataPos; for (auto& joinNodeID : hashJoin->getJoinNodeIDs()) { probeKeysDataPos.emplace_back(outSchema->getExpressionPos(*joinNodeID)); } std::vector probePayloadsOutPos; - for (auto& [dataPos, _] : buildDataInfo.payloadsPosAndType) { - auto expression = - buildSchema->getGroup(dataPos.dataChunkPos)->getExpressions()[dataPos.valueVectorPos]; - probePayloadsOutPos.emplace_back(outSchema->getExpressionPos(*expression)); + for (auto& payload : payloads) { + probePayloadsOutPos.emplace_back(outSchema->getExpressionPos(*payload)); } - auto sharedState = std::make_shared(); - // create hashJoin build - auto hashJoinBuild = - make_unique(std::make_unique(buildSchema), sharedState, - buildDataInfo, std::move(buildSidePrevOperator), getOperatorID(), paramsString); - // create hashJoin probe ProbeDataInfo probeDataInfo(probeKeysDataPos, probePayloadsOutPos); if (hashJoin->getJoinType() == common::JoinType::MARK) { auto mark = hashJoin->getMark(); diff --git a/src/processor/mapper/map_intersect.cpp b/src/processor/mapper/map_intersect.cpp index c0d7d79690..16b8a7c688 100644 --- a/src/processor/mapper/map_intersect.cpp +++ b/src/processor/mapper/map_intersect.cpp @@ -21,27 +21,25 @@ std::unique_ptr PlanMapper::mapLogicalIntersectToPhysical( for (auto i = 1u; i < logicalIntersect->getNumChildren(); i++) { auto keyNodeID = logicalIntersect->getKeyNodeID(i - 1); auto buildSchema = logicalIntersect->getChild(i)->getSchema(); - auto buildSidePrevOperator = mapLogicalOperatorToPhysical(logicalIntersect->getChild(i)); - std::vector payloadsDataPos; - binder::expression_vector expressionsToMaterialize; - expressionsToMaterialize.push_back(keyNodeID); - expressionsToMaterialize.push_back(intersectNodeID); - for (auto& expression : buildSchema->getExpressionsInScope()) { - if (expression->getUniqueName() == keyNodeID->getUniqueName() || - expression->getUniqueName() == intersectNodeID->getUniqueName()) { - continue; - } - expressionsToMaterialize.push_back(expression); - payloadsDataPos.emplace_back(outSchema->getExpressionPos(*expression)); - } - auto buildDataInfo = - generateBuildDataInfo(*buildSchema, {keyNodeID}, expressionsToMaterialize); - auto sharedState = std::make_shared(); + auto buildPrevOperator = mapLogicalOperatorToPhysical(logicalIntersect->getChild(i)); + auto payloadExpressions = binder::ExpressionUtil::excludeExpressions( + buildSchema->getExpressionsInScope(), {keyNodeID}); + auto buildInfo = createHashBuildInfo(*buildSchema, {keyNodeID}, payloadExpressions); + auto globalHashTable = std::make_unique( + *memoryManager, buildInfo->getTableSchema()->copy()); + auto sharedState = std::make_shared(std::move(globalHashTable)); sharedStates.push_back(sharedState); children[i] = make_unique( - std::make_unique(buildSchema), sharedState, buildDataInfo, - std::move(buildSidePrevOperator), getOperatorID(), keyNodeID->toString()); - IntersectDataInfo info{DataPos(outSchema->getExpressionPos(*keyNodeID)), payloadsDataPos}; + std::make_unique(buildSchema), sharedState, std::move(buildInfo), + std::move(buildPrevOperator), getOperatorID(), keyNodeID->toString()); + // Collect intersect info + std::vector vectorsToScanPos; + auto expressionsToScan = binder::ExpressionUtil::excludeExpressions( + buildSchema->getExpressionsInScope(), {keyNodeID, intersectNodeID}); + for (auto& expression : expressionsToScan) { + vectorsToScanPos.emplace_back(outSchema->getExpressionPos(*expression)); + } + IntersectDataInfo info{DataPos(outSchema->getExpressionPos(*keyNodeID)), vectorsToScanPos}; intersectDataInfos.push_back(info); } // Map probe side child. diff --git a/src/processor/operator/cross_product.cpp b/src/processor/operator/cross_product.cpp index a039aeca05..e93fa8e792 100644 --- a/src/processor/operator/cross_product.cpp +++ b/src/processor/operator/cross_product.cpp @@ -4,7 +4,7 @@ namespace kuzu { namespace processor { void CrossProduct::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { - for (auto& pos : outVecPos) { + for (auto& pos : info->outVecPos) { vectorsToScan.push_back(resultSet->getValueVector(pos).get()); } localState->init(); @@ -26,7 +26,7 @@ bool CrossProduct::getNextTuplesInternal(ExecutionContext* context) { // scan from right table if there is tuple left auto numTuplesToScan = std::min(localState->maxMorselSize, table->getNumTuples() - localState->startIdx); - table->scan(vectorsToScan, localState->startIdx, numTuplesToScan, colIndicesToScan); + table->scan(vectorsToScan, localState->startIdx, numTuplesToScan, info->colIndicesToScan); localState->startIdx += numTuplesToScan; metrics->numOutputTuple.increase(numTuplesToScan); return true; diff --git a/src/processor/operator/hash_join/hash_join_build.cpp b/src/processor/operator/hash_join/hash_join_build.cpp index 97650aba24..0e324388ae 100644 --- a/src/processor/operator/hash_join/hash_join_build.cpp +++ b/src/processor/operator/hash_join/hash_join_build.cpp @@ -1,71 +1,24 @@ #include "processor/operator/hash_join/hash_join_build.h" -#include "common/utils.h" - using namespace kuzu::common; using namespace kuzu::storage; namespace kuzu { namespace processor { -void HashJoinSharedState::initEmptyHashTable(MemoryManager& memoryManager, uint64_t numKeyColumns, - std::unique_ptr tableSchema) { - assert(hashTable == nullptr); - hashTable = - std::make_unique(memoryManager, numKeyColumns, std::move(tableSchema)); -} - void HashJoinSharedState::mergeLocalHashTable(JoinHashTable& localHashTable) { std::unique_lock lck(mtx); hashTable->merge(localHashTable); } void HashJoinBuild::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { - for (auto& [pos, dataType] : buildDataInfo.keysPosAndType) { + for (auto& pos : info->keysPos) { vectorsToAppend.push_back(resultSet->getValueVector(pos).get()); } - for (auto& [pos, dataType] : buildDataInfo.payloadsPosAndType) { + for (auto& pos : info->payloadsPos) { vectorsToAppend.push_back(resultSet->getValueVector(pos).get()); } - auto tableSchema = populateTableSchema(); - initLocalHashTable(*context->memoryManager, std::move(tableSchema)); -} - -std::unique_ptr HashJoinBuild::populateTableSchema() { - std::unique_ptr tableSchema = std::make_unique(); - for (auto& [pos, dataType] : buildDataInfo.keysPosAndType) { - tableSchema->appendColumn(std::make_unique( - false /* is flat */, pos.dataChunkPos, FactorizedTable::getDataTypeSize(dataType))); - } - for (auto i = 0u; i < buildDataInfo.payloadsPosAndType.size(); ++i) { - auto [pos, dataType] = buildDataInfo.payloadsPosAndType[i]; - if (buildDataInfo.isPayloadsInKeyChunk[i]) { - tableSchema->appendColumn(std::make_unique( - false /* is flat */, pos.dataChunkPos, FactorizedTable::getDataTypeSize(dataType))); - } else { - auto isVectorFlat = buildDataInfo.isPayloadsFlat[i]; - tableSchema->appendColumn( - std::make_unique(!isVectorFlat, pos.dataChunkPos, - isVectorFlat ? FactorizedTable::getDataTypeSize(dataType) : - (uint32_t)sizeof(overflow_value_t))); - } - } - // The prev pointer column. - tableSchema->appendColumn(std::make_unique(false /* is flat */, - UINT32_MAX /* For now, we just put UINT32_MAX for prev pointer */, - FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::INT64}))); - return tableSchema; -} - -void HashJoinBuild::initGlobalStateInternal(ExecutionContext* context) { - sharedState->initEmptyHashTable( - *context->memoryManager, buildDataInfo.getNumKeys(), populateTableSchema()); -} - -void HashJoinBuild::initLocalHashTable( - MemoryManager& memoryManager, std::unique_ptr tableSchema) { - hashTable = std::make_unique(memoryManager, buildDataInfo.getNumKeys(), - std::make_unique(*tableSchema)); + initLocalHashTable(*context->memoryManager); } void HashJoinBuild::finalize(ExecutionContext* context) { @@ -78,7 +31,7 @@ void HashJoinBuild::executeInternal(ExecutionContext* context) { // Append thread-local tuples while (children[0]->getNextTuple(context)) { for (auto i = 0u; i < resultSet->multiplicity; ++i) { - appendVectors(); + hashTable->append(vectorsToAppend); } } // Merge with global hash table once local tuples are all appended. diff --git a/src/processor/operator/intersect/CMakeLists.txt b/src/processor/operator/intersect/CMakeLists.txt index c0d8cb5551..4419fbea0f 100644 --- a/src/processor/operator/intersect/CMakeLists.txt +++ b/src/processor/operator/intersect/CMakeLists.txt @@ -1,7 +1,6 @@ add_library(kuzu_processor_operator_intersect OBJECT intersect.cpp - intersect_build.cpp intersect_hash_table.cpp) set(ALL_OBJECT_FILES diff --git a/src/processor/operator/intersect/intersect_build.cpp b/src/processor/operator/intersect/intersect_build.cpp deleted file mode 100644 index 7fa3b3bc3d..0000000000 --- a/src/processor/operator/intersect/intersect_build.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include "processor/operator/intersect/intersect_build.h" - -using namespace kuzu::storage; -namespace kuzu { -namespace processor { - -void IntersectSharedState::initEmptyHashTable(MemoryManager& memoryManager, uint64_t numKeyColumns, - std::unique_ptr tableSchema) { - assert(hashTable == nullptr && numKeyColumns == 1); - hashTable = make_unique(memoryManager, std::move(tableSchema)); -} - -void IntersectBuild::initLocalHashTable( - MemoryManager& memoryManager, std::unique_ptr tableSchema) { - hashTable = make_unique( - memoryManager, std::make_unique(*tableSchema)); -} - -} // namespace processor -} // namespace kuzu