diff --git a/src/include/optimizer/agg_key_dependency_optimizer.h b/src/include/optimizer/agg_key_dependency_optimizer.h new file mode 100644 index 0000000000..7cf22ff809 --- /dev/null +++ b/src/include/optimizer/agg_key_dependency_optimizer.h @@ -0,0 +1,26 @@ +#pragma once + +#include "logical_operator_visitor.h" +#include "planner/logical_plan/logical_plan.h" + +namespace kuzu { +namespace optimizer { + +// This optimizer analyzes the dependency between group by keys. If key2 depends on key1 (e.g. key1 +// is a primary key column) we only hash on key1 and saves key2 as a payload. +class AggKeyDependencyOptimizer : public LogicalOperatorVisitor { +public: + void rewrite(planner::LogicalPlan* plan); + +private: + void visitOperator(planner::LogicalOperator* op); + + void visitAggregate(planner::LogicalOperator* op) override; + void visitDistinct(planner::LogicalOperator* op) override; + + std::pair resolveKeysAndDependentKeys( + const binder::expression_vector& keys); +}; + +} // namespace optimizer +} // namespace kuzu diff --git a/src/include/planner/logical_plan/logical_operator/logical_aggregate.h b/src/include/planner/logical_plan/logical_operator/logical_aggregate.h index 7de64c8745..498ded9dc9 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_aggregate.h +++ b/src/include/planner/logical_plan/logical_operator/logical_aggregate.h @@ -7,11 +7,18 @@ namespace planner { class LogicalAggregate : public LogicalOperator { public: - LogicalAggregate(binder::expression_vector expressionsToGroupBy, - binder::expression_vector expressionsToAggregate, std::shared_ptr child) + LogicalAggregate(binder::expression_vector keyExpressions, + binder::expression_vector aggregateExpressions, std::shared_ptr child) : LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)}, - expressionsToGroupBy{std::move(expressionsToGroupBy)}, expressionsToAggregate{std::move( - expressionsToAggregate)} {} + keyExpressions{std::move(keyExpressions)}, aggregateExpressions{ + std::move(aggregateExpressions)} {} + LogicalAggregate(binder::expression_vector keyExpressions, + binder::expression_vector dependentKeyExpressions, + binder::expression_vector aggregateExpressions, std::shared_ptr child) + : LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)}, + keyExpressions{std::move(keyExpressions)}, dependentKeyExpressions{std::move( + dependentKeyExpressions)}, + aggregateExpressions{std::move(aggregateExpressions)} {} void computeFactorizedSchema() override; void computeFlatSchema() override; @@ -21,30 +28,42 @@ class LogicalAggregate : public LogicalOperator { std::string getExpressionsForPrinting() const override; - inline bool hasExpressionsToGroupBy() const { return !expressionsToGroupBy.empty(); } - inline binder::expression_vector getExpressionsToGroupBy() const { - return expressionsToGroupBy; + inline bool hasKeyExpressions() const { return !keyExpressions.empty(); } + inline binder::expression_vector getKeyExpressions() const { return keyExpressions; } + inline void setKeyExpressions(binder::expression_vector expressions) { + keyExpressions = std::move(expressions); + } + inline binder::expression_vector getDependentKeyExpressions() const { + return dependentKeyExpressions; + } + inline void setDependentKeyExpressions(binder::expression_vector expressions) { + dependentKeyExpressions = std::move(expressions); } - inline size_t getNumAggregateExpressions() const { return expressionsToAggregate.size(); } - inline std::shared_ptr getAggregateExpression( - common::vector_idx_t idx) const { - return expressionsToAggregate[idx]; + inline binder::expression_vector getAllKeyExpressions() const { + binder::expression_vector result; + result.insert(result.end(), keyExpressions.begin(), keyExpressions.end()); + result.insert(result.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end()); + return result; } - inline binder::expression_vector getExpressionsToAggregate() const { - return expressionsToAggregate; + inline binder::expression_vector getAggregateExpressions() const { + return aggregateExpressions; } inline std::unique_ptr copy() override { return make_unique( - expressionsToGroupBy, expressionsToAggregate, children[0]->copy()); + keyExpressions, dependentKeyExpressions, aggregateExpressions, children[0]->copy()); } private: bool hasDistinctAggregate(); + void insertAllExpressionsToGroupAndScope(f_group_pos groupPos); private: - binder::expression_vector expressionsToGroupBy; - binder::expression_vector expressionsToAggregate; + binder::expression_vector keyExpressions; + // A dependentKeyExpression depend on a keyExpression (e.g. a.age depends on a.ID) and will not + // be treated as a hash key during hash aggregation. + binder::expression_vector dependentKeyExpressions; + binder::expression_vector aggregateExpressions; }; } // namespace planner diff --git a/src/include/planner/logical_plan/logical_operator/logical_distinct.h b/src/include/planner/logical_plan/logical_operator/logical_distinct.h index 5ad3e462e3..6e1a09c4e6 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_distinct.h +++ b/src/include/planner/logical_plan/logical_operator/logical_distinct.h @@ -9,9 +9,14 @@ namespace planner { class LogicalDistinct : public LogicalOperator { public: LogicalDistinct( - binder::expression_vector expressionsToDistinct, std::shared_ptr child) + binder::expression_vector keyExpressions, std::shared_ptr child) : LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)}, - expressionsToDistinct{std::move(expressionsToDistinct)} {} + keyExpressions{std::move(keyExpressions)} {} + LogicalDistinct(binder::expression_vector keyExpressions, + binder::expression_vector dependentKeyExpressions, std::shared_ptr child) + : LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)}, + keyExpressions{std::move(keyExpressions)}, dependentKeyExpressions{ + std::move(dependentKeyExpressions)} {} void computeFactorizedSchema() override; void computeFlatSchema() override; @@ -20,17 +25,32 @@ class LogicalDistinct : public LogicalOperator { std::string getExpressionsForPrinting() const override; - inline binder::expression_vector getExpressionsToDistinct() const { - return expressionsToDistinct; + inline binder::expression_vector getKeyExpressions() const { return keyExpressions; } + inline void setKeyExpressions(binder::expression_vector expressions) { + keyExpressions = std::move(expressions); + } + inline binder::expression_vector getDependentKeyExpressions() const { + return dependentKeyExpressions; + } + inline void setDependentKeyExpressions(binder::expression_vector expressions) { + dependentKeyExpressions = std::move(expressions); + } + inline binder::expression_vector getAllDistinctExpressions() const { + binder::expression_vector result; + result.insert(result.end(), keyExpressions.begin(), keyExpressions.end()); + result.insert(result.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end()); + return result; } - inline Schema* getSchemaBeforeDistinct() const { return children[0]->getSchema(); } std::unique_ptr copy() override { - return make_unique(expressionsToDistinct, children[0]->copy()); + return make_unique( + keyExpressions, dependentKeyExpressions, children[0]->copy()); } private: - binder::expression_vector expressionsToDistinct; + binder::expression_vector keyExpressions; + // See logical_aggregate.h for details. + binder::expression_vector dependentKeyExpressions; }; } // namespace planner diff --git a/src/include/processor/mapper/plan_mapper.h b/src/include/processor/mapper/plan_mapper.h index 65c92059d5..15f5708d7f 100644 --- a/src/include/processor/mapper/plan_mapper.h +++ b/src/include/processor/mapper/plan_mapper.h @@ -109,23 +109,22 @@ class PlanMapper { inline uint32_t getOperatorID() { return physicalOperatorID++; } + BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema, + const binder::expression_vector& keys, const binder::expression_vector& payloads); + std::unique_ptr createHashAggregate( + const binder::expression_vector& keyExpressions, + const binder::expression_vector& dependentKeyExpressions, std::vector> aggregateFunctions, - std::vector> inputAggregateInfo, - std::vector outputAggVectorsPos, - const binder::expression_vector& groupByExpressions, - std::unique_ptr prevOperator, const planner::Schema& inSchema, - const planner::Schema& outSchema, const std::string& paramsString); + std::vector> aggregateInputInfos, + std::vector aggregatesOutputPos, const planner::Schema& inSchema, + const planner::Schema& outSchema, std::unique_ptr prevOperator, + const std::string& paramsString); - void appendGroupByExpressions(const binder::expression_vector& groupByExpressions, - std::vector& inputGroupByHashKeyVectorsPos, - std::vector& outputGroupByKeyVectorsPos, const planner::Schema& inSchema, - const planner::Schema& outSchema, std::vector& isInputGroupByHashKeyVectorFlat); - - BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema, - const binder::expression_vector& keys, const binder::expression_vector& payloads); + static void mapAccHashJoin(PhysicalOperator* probe); - void mapAccHashJoin(PhysicalOperator* probe); + static std::vector getExpressionsDataPos( + const binder::expression_vector& expressions, const planner::Schema& schema); public: storage::StorageManager& storageManager; diff --git a/src/include/processor/operator/aggregate/aggregate_hash_table.h b/src/include/processor/operator/aggregate/aggregate_hash_table.h index 1b5e408d2d..f4cf5b8482 100644 --- a/src/include/processor/operator/aggregate/aggregate_hash_table.h +++ b/src/include/processor/operator/aggregate/aggregate_hash_table.h @@ -42,15 +42,15 @@ class AggregateHashTable : public BaseHashTable { public: // Used by distinct aggregate hash table only. AggregateHashTable(storage::MemoryManager& memoryManager, - const std::vector& groupByHashKeysDataTypes, + const std::vector& keysDataTypes, const std::vector>& aggregateFunctions, uint64_t numEntriesToAllocate) - : AggregateHashTable(memoryManager, groupByHashKeysDataTypes, - std::vector(), aggregateFunctions, numEntriesToAllocate) {} + : AggregateHashTable(memoryManager, keysDataTypes, std::vector(), + aggregateFunctions, numEntriesToAllocate) {} AggregateHashTable(storage::MemoryManager& memoryManager, - std::vector groupByHashKeysDataTypes, - std::vector groupByNonHashKeysDataTypes, + std::vector keysDataTypes, + std::vector payloadsDataTypes, const std::vector>& aggregateFunctions, uint64_t numEntriesToAllocate); @@ -61,17 +61,17 @@ class AggregateHashTable : public BaseHashTable { inline uint64_t getNumEntries() const { return factorizedTable->getNumTuples(); } inline void append(const std::vector& groupByFlatKeyVectors, - const std::vector& groupByUnFlatHashKeyVectors, + const std::vector& groupByUnFlatKeyVectors, const std::vector>& aggregateInputs, uint64_t resultSetMultiplicity) { - append(groupByFlatKeyVectors, groupByUnFlatHashKeyVectors, - std::vector(), aggregateInputs, resultSetMultiplicity); + append(groupByFlatKeyVectors, groupByUnFlatKeyVectors, std::vector(), + aggregateInputs, resultSetMultiplicity); } //! update aggregate states for an input void append(const std::vector& groupByFlatKeyVectors, - const std::vector& groupByUnFlatHashKeyVectors, - const std::vector& groupByNonHashKeyVectors, + const std::vector& groupByUnFlatKeyVectors, + const std::vector& groupByDependentKeyVectors, const std::vector>& aggregateInputs, uint64_t resultSetMultiplicity); @@ -107,7 +107,7 @@ class AggregateHashTable : public BaseHashTable { void initializeFTEntries(const std::vector& groupByFlatHashKeyVectors, const std::vector& groupByUnflatHashKeyVectors, - const std::vector& groupByNonHashKeyVectors, + const std::vector& groupByDependentKeyVectors, uint64_t numFTEntriesToInitialize); uint8_t* createEntryInDistinctHT( @@ -121,7 +121,7 @@ class AggregateHashTable : public BaseHashTable { void findHashSlots(const std::vector& groupByFlatHashKeyVectors, const std::vector& groupByUnflatHashKeyVectors, - const std::vector& groupByNonHashKeyVectors); + const std::vector& groupByDependentKeyVectors); void computeAndCombineVecHash( const std::vector& groupByUnflatHashKeyVectors, uint32_t startVecIdx); @@ -157,8 +157,7 @@ class AggregateHashTable : public BaseHashTable { uint64_t matchFTEntries(const std::vector& groupByFlatHashKeyVectors, const std::vector& groupByUnflatHashKeyVectors, - const std::vector& groupByNonHashKeyVectors, uint64_t numMayMatches, - uint64_t numNoMatches); + uint64_t numMayMatches, uint64_t numNoMatches); void fillEntryWithInitialNullAggregateState(uint8_t* entry); @@ -219,8 +218,8 @@ class AggregateHashTable : public BaseHashTable { common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset); private: - std::vector groupByHashKeysDataTypes; - std::vector groupByNonHashKeysDataTypes; + std::vector keyDataTypes; + std::vector dependentKeyDataTypes; std::vector> aggregateFunctions; //! special handling of distinct aggregate @@ -229,8 +228,8 @@ class AggregateHashTable : public BaseHashTable { uint32_t hashColOffsetInFT; uint32_t aggStateColOffsetInFT; uint32_t aggStateColIdxInFT; - uint32_t numBytesForGroupByHashKeys = 0; - uint32_t numBytesForGroupByNonHashKeys = 0; + uint32_t numBytesForKeys = 0; + uint32_t numBytesForDependentKeys = 0; std::vector compareFuncs; std::vector updateAggFuncs; bool hasStrCol = false; diff --git a/src/include/processor/operator/aggregate/hash_aggregate.h b/src/include/processor/operator/aggregate/hash_aggregate.h index 60f6e6347e..3dc473bc2a 100644 --- a/src/include/processor/operator/aggregate/hash_aggregate.h +++ b/src/include/processor/operator/aggregate/hash_aggregate.h @@ -33,19 +33,15 @@ class HashAggregateSharedState : public BaseAggregateSharedState { class HashAggregate : public BaseAggregate { public: HashAggregate(std::unique_ptr resultSetDescriptor, - std::shared_ptr sharedState, - std::vector inputGroupByHashKeyVectorsPos, - std::vector inputGroupByNonHashKeyVectorsPos, - std::vector isInputGroupByHashKeyVectorFlat, + std::shared_ptr sharedState, std::vector flatKeysPos, + std::vector unFlatKeysPos, std::vector dependentKeysPos, std::vector> aggregateFunctions, std::vector> aggregateInputInfos, std::unique_ptr child, uint32_t id, const std::string& paramsString) : BaseAggregate{std::move(resultSetDescriptor), std::move(aggregateFunctions), std::move(aggregateInputInfos), std::move(child), id, paramsString}, - groupByHashKeyVectorsPos{std::move(inputGroupByHashKeyVectorsPos)}, - groupByNonHashKeyVectorsPos{std::move(inputGroupByNonHashKeyVectorsPos)}, - isGroupByHashKeyVectorFlat{std::move(isInputGroupByHashKeyVectorFlat)}, - sharedState{std::move(sharedState)} {} + flatKeysPos{std::move(flatKeysPos)}, unFlatKeysPos{std::move(unFlatKeysPos)}, + dependentKeysPos{std::move(dependentKeysPos)}, sharedState{std::move(sharedState)} {} void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; @@ -54,18 +50,19 @@ class HashAggregate : public BaseAggregate { void finalize(ExecutionContext* context) override; inline std::unique_ptr clone() override { - return make_unique(resultSetDescriptor->copy(), sharedState, - groupByHashKeyVectorsPos, groupByNonHashKeyVectorsPos, isGroupByHashKeyVectorFlat, - cloneAggFunctions(), cloneAggInputInfos(), children[0]->clone(), id, paramsString); + return make_unique(resultSetDescriptor->copy(), sharedState, flatKeysPos, + unFlatKeysPos, dependentKeysPos, cloneAggFunctions(), cloneAggInputInfos(), + children[0]->clone(), id, paramsString); } private: - std::vector groupByHashKeyVectorsPos; - std::vector groupByNonHashKeyVectorsPos; - std::vector isGroupByHashKeyVectorFlat; - std::vector groupByFlatHashKeyVectors; - std::vector groupByUnflatHashKeyVectors; - std::vector groupByNonHashKeyVectors; + std::vector flatKeysPos; + std::vector unFlatKeysPos; + std::vector dependentKeysPos; + + std::vector flatKeyVectors; + std::vector unFlatKeyVectors; + std::vector dependentKeyVectors; std::shared_ptr sharedState; std::unique_ptr localAggregateHashTable; diff --git a/src/optimizer/CMakeLists.txt b/src/optimizer/CMakeLists.txt index ce4ffd08b9..2311d67c71 100644 --- a/src/optimizer/CMakeLists.txt +++ b/src/optimizer/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(kuzu_optimizer OBJECT acc_hash_join_optimizer.cpp + agg_key_dependency_optimizer.cpp factorization_rewriter.cpp filter_push_down_optimizer.cpp logical_operator_collector.cpp diff --git a/src/optimizer/agg_key_dependency_optimizer.cpp b/src/optimizer/agg_key_dependency_optimizer.cpp new file mode 100644 index 0000000000..98749fa720 --- /dev/null +++ b/src/optimizer/agg_key_dependency_optimizer.cpp @@ -0,0 +1,82 @@ +#include "optimizer/agg_key_dependency_optimizer.h" + +#include "binder/expression/property_expression.h" +#include "planner/logical_plan/logical_operator/logical_aggregate.h" +#include "planner/logical_plan/logical_operator/logical_distinct.h" + +using namespace kuzu::planner; + +namespace kuzu { +namespace optimizer { + +void AggKeyDependencyOptimizer::rewrite(planner::LogicalPlan* plan) { + visitOperator(plan->getLastOperator().get()); +} + +void AggKeyDependencyOptimizer::visitOperator(planner::LogicalOperator* op) { + // bottom up traversal + for (auto i = 0u; i < op->getNumChildren(); ++i) { + visitOperator(op->getChild(i).get()); + } + visitOperatorSwitch(op); +} + +void AggKeyDependencyOptimizer::visitAggregate(planner::LogicalOperator* op) { + auto agg = (LogicalAggregate*)op; + auto [keyExpressions, payloadExpressions] = + resolveKeysAndDependentKeys(agg->getKeyExpressions()); + agg->setKeyExpressions(keyExpressions); + agg->setDependentKeyExpressions(payloadExpressions); +} + +void AggKeyDependencyOptimizer::visitDistinct(planner::LogicalOperator* op) { + auto distinct = (LogicalDistinct*)op; + auto [keyExpressions, payloadExpressions] = + resolveKeysAndDependentKeys(distinct->getKeyExpressions()); + distinct->setKeyExpressions(keyExpressions); + distinct->setDependentKeyExpressions(payloadExpressions); +} + +std::pair +AggKeyDependencyOptimizer::resolveKeysAndDependentKeys(const binder::expression_vector& keys) { + // Consider example RETURN a.ID, a.age, COUNT(*). + // We first collect a.ID into primaryKeys. Then collect "a" into primaryVarNames. + // Finally, we loop through all group by keys to put non-primary key properties under name "a" + // into dependentKeyExpressions. + + // Collect primary keys from group keys. + std::vector primaryKeys; + for (auto& expression : keys) { + if (expression->expressionType == common::PROPERTY) { + auto propertyExpression = (binder::PropertyExpression*)expression.get(); + if (propertyExpression->isPrimaryKey() || propertyExpression->isInternalID()) { + primaryKeys.push_back(propertyExpression); + } + } + } + // Collect variable names whose primary key is part of group keys. + std::unordered_set primaryVarNames; + for (auto& primaryKey : primaryKeys) { + primaryVarNames.insert(primaryKey->getVariableName()); + } + binder::expression_vector groupExpressions; + binder::expression_vector dependentExpressions; + for (auto& expression : keys) { + if (expression->expressionType == common::PROPERTY) { + auto propertyExpression = (binder::PropertyExpression*)expression.get(); + if (propertyExpression->isPrimaryKey() || propertyExpression->isInternalID()) { + groupExpressions.push_back(expression); + } else if (primaryVarNames.contains(propertyExpression->getVariableName())) { + dependentExpressions.push_back(expression); + } else { + groupExpressions.push_back(expression); + } + } else { + groupExpressions.push_back(expression); + } + } + return std::make_pair(std::move(groupExpressions), std::move(dependentExpressions)); +} + +} // namespace optimizer +} // namespace kuzu diff --git a/src/optimizer/optimizer.cpp b/src/optimizer/optimizer.cpp index f85d49b54b..3a626fee38 100644 --- a/src/optimizer/optimizer.cpp +++ b/src/optimizer/optimizer.cpp @@ -1,6 +1,7 @@ #include "optimizer/optimizer.h" #include "optimizer/acc_hash_join_optimizer.h" +#include "optimizer/agg_key_dependency_optimizer.h" #include "optimizer/factorization_rewriter.h" #include "optimizer/filter_push_down_optimizer.h" #include "optimizer/projection_push_down_optimizer.h" @@ -21,7 +22,7 @@ void Optimizer::optimize(planner::LogicalPlan* plan) { auto filterPushDownOptimizer = FilterPushDownOptimizer(); filterPushDownOptimizer.rewrite(plan); - // ASP optimizer should be applied after optimizers that manipulate hash join. + // HashJoinSIPOptimizer should be applied after optimizers that manipulate hash join. auto hashJoinSIPOptimizer = HashJoinSIPOptimizer(); hashJoinSIPOptimizer.rewrite(plan); @@ -30,6 +31,11 @@ void Optimizer::optimize(planner::LogicalPlan* plan) { auto factorizationRewriter = FactorizationRewriter(); factorizationRewriter.rewrite(plan); + + // AggKeyDependencyOptimizer doesn't change factorization structure and thus can be put after + // FactorizationRewriter. + auto aggKeyDependencyOptimizer = AggKeyDependencyOptimizer(); + aggKeyDependencyOptimizer.rewrite(plan); } } // namespace optimizer diff --git a/src/planner/operator/logical_aggregate.cpp b/src/planner/operator/logical_aggregate.cpp index 1105ff0278..e9e1fca111 100644 --- a/src/planner/operator/logical_aggregate.cpp +++ b/src/planner/operator/logical_aggregate.cpp @@ -11,28 +11,18 @@ using namespace factorization; void LogicalAggregate::computeFactorizedSchema() { createEmptySchema(); auto groupPos = schema->createGroup(); - for (auto& expression : expressionsToGroupBy) { - schema->insertToGroupAndScope(expression, groupPos); - } - for (auto& expression : expressionsToAggregate) { - schema->insertToGroupAndScope(expression, groupPos); - } + insertAllExpressionsToGroupAndScope(groupPos); } void LogicalAggregate::computeFlatSchema() { createEmptySchema(); schema->createGroup(); - for (auto& expression : expressionsToGroupBy) { - schema->insertToGroupAndScope(expression, 0); - } - for (auto& expression : expressionsToAggregate) { - schema->insertToGroupAndScope(expression, 0); - } + insertAllExpressionsToGroupAndScope(0 /* groupPos */); } f_group_pos_set LogicalAggregate::getGroupsPosToFlattenForGroupBy() { f_group_pos_set dependentGroupsPos; - for (auto& expression : expressionsToGroupBy) { + for (auto& expression : getAllKeyExpressions()) { for (auto groupPos : children[0]->getSchema()->getDependentGroupsPos(expression)) { dependentGroupsPos.insert(groupPos); } @@ -48,7 +38,7 @@ f_group_pos_set LogicalAggregate::getGroupsPosToFlattenForGroupBy() { f_group_pos_set LogicalAggregate::getGroupsPosToFlattenForAggregate() { if (hasDistinctAggregate()) { f_group_pos_set dependentGroupsPos; - for (auto& expression : expressionsToAggregate) { + for (auto& expression : aggregateExpressions) { for (auto groupPos : children[0]->getSchema()->getDependentGroupsPos(expression)) { dependentGroupsPos.insert(groupPos); } @@ -60,11 +50,14 @@ f_group_pos_set LogicalAggregate::getGroupsPosToFlattenForAggregate() { std::string LogicalAggregate::getExpressionsForPrinting() const { std::string result = "Group By ["; - for (auto& expression : expressionsToGroupBy) { + for (auto& expression : keyExpressions) { + result += expression->toString() + ", "; + } + for (auto& expression : dependentKeyExpressions) { result += expression->toString() + ", "; } result += "], Aggregate ["; - for (auto& expression : expressionsToAggregate) { + for (auto& expression : aggregateExpressions) { result += expression->toString() + ", "; } result += "]"; @@ -72,8 +65,8 @@ std::string LogicalAggregate::getExpressionsForPrinting() const { } bool LogicalAggregate::hasDistinctAggregate() { - for (auto& expressionToAggregate : expressionsToAggregate) { - auto& functionExpression = (binder::AggregateFunctionExpression&)*expressionToAggregate; + for (auto& expression : aggregateExpressions) { + auto& functionExpression = (binder::AggregateFunctionExpression&)*expression; if (functionExpression.isDistinct()) { return true; } @@ -81,5 +74,17 @@ bool LogicalAggregate::hasDistinctAggregate() { return false; } +void LogicalAggregate::insertAllExpressionsToGroupAndScope(f_group_pos groupPos) { + for (auto& expression : keyExpressions) { + schema->insertToGroupAndScope(expression, groupPos); + } + for (auto& expression : dependentKeyExpressions) { + schema->insertToGroupAndScope(expression, groupPos); + } + for (auto& expression : aggregateExpressions) { + schema->insertToGroupAndScope(expression, groupPos); + } +} + } // namespace planner } // namespace kuzu diff --git a/src/planner/operator/logical_distinct.cpp b/src/planner/operator/logical_distinct.cpp index 1fd6446ce7..f7bbc1a54c 100644 --- a/src/planner/operator/logical_distinct.cpp +++ b/src/planner/operator/logical_distinct.cpp @@ -8,7 +8,7 @@ namespace planner { void LogicalDistinct::computeFactorizedSchema() { createEmptySchema(); auto groupPos = schema->createGroup(); - for (auto& expression : expressionsToDistinct) { + for (auto& expression : getAllDistinctExpressions()) { schema->insertToGroupAndScope(expression, groupPos); } } @@ -16,7 +16,7 @@ void LogicalDistinct::computeFactorizedSchema() { void LogicalDistinct::computeFlatSchema() { createEmptySchema(); schema->createGroup(); - for (auto& expression : expressionsToDistinct) { + for (auto& expression : getAllDistinctExpressions()) { schema->insertToGroupAndScope(expression, 0); } } @@ -24,7 +24,7 @@ void LogicalDistinct::computeFlatSchema() { f_group_pos_set LogicalDistinct::getGroupsPosToFlatten() { f_group_pos_set dependentGroupsPos; auto childSchema = children[0]->getSchema(); - for (auto& expression : expressionsToDistinct) { + for (auto& expression : getAllDistinctExpressions()) { for (auto groupPos : childSchema->getDependentGroupsPos(expression)) { dependentGroupsPos.insert(groupPos); } @@ -34,7 +34,7 @@ f_group_pos_set LogicalDistinct::getGroupsPosToFlatten() { std::string LogicalDistinct::getExpressionsForPrinting() const { std::string result; - for (auto& expression : expressionsToDistinct) { + for (auto& expression : getAllDistinctExpressions()) { result += expression->getUniqueName() + ", "; } return result; diff --git a/src/processor/mapper/map_aggregate.cpp b/src/processor/mapper/map_aggregate.cpp index 43a43e1538..0d253f2293 100644 --- a/src/processor/mapper/map_aggregate.cpp +++ b/src/processor/mapper/map_aggregate.cpp @@ -51,6 +51,17 @@ static std::vector> getAggregateInputInfos( return result; } +static binder::expression_vector getKeyExpressions( + const binder::expression_vector& expressions, const Schema& schema, bool isFlat) { + binder::expression_vector result; + for (auto& expression : expressions) { + if (schema.getGroup(schema.getGroupPos(*expression))->isFlat() == isFlat) { + result.emplace_back(expression); + } + } + return result; +} + std::unique_ptr PlanMapper::mapLogicalAggregateToPhysical( LogicalOperator* logicalOperator) { auto& logicalAggregate = (const LogicalAggregate&)*logicalOperator; @@ -58,21 +69,20 @@ std::unique_ptr PlanMapper::mapLogicalAggregateToPhysical( auto inSchema = logicalAggregate.getChild(0)->getSchema(); auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0)); auto paramsString = logicalAggregate.getExpressionsForPrinting(); - std::vector outputAggVectorsPos{logicalAggregate.getNumAggregateExpressions()}; - std::vector> aggregateFunctions{ - logicalAggregate.getNumAggregateExpressions()}; - for (auto i = 0u; i < logicalAggregate.getNumAggregateExpressions(); ++i) { - auto expression = logicalAggregate.getAggregateExpression(i); - aggregateFunctions[i] = - ((AggregateFunctionExpression&)*expression).aggregateFunction->clone(); - outputAggVectorsPos[i] = DataPos(outSchema->getExpressionPos(*expression)); + std::vector> aggregateFunctions; + for (auto& expression : logicalAggregate.getAggregateExpressions()) { + aggregateFunctions.push_back( + ((AggregateFunctionExpression&)*expression).aggregateFunction->clone()); } - auto aggregateInputInfos = getAggregateInputInfos(logicalAggregate.getExpressionsToGroupBy(), - logicalAggregate.getExpressionsToAggregate(), *inSchema); - if (logicalAggregate.hasExpressionsToGroupBy()) { - return createHashAggregate(std::move(aggregateFunctions), std::move(aggregateInputInfos), - outputAggVectorsPos, logicalAggregate.getExpressionsToGroupBy(), - std::move(prevOperator), *inSchema, *outSchema, paramsString); + auto aggregatesOutputPos = + getExpressionsDataPos(logicalAggregate.getAggregateExpressions(), *outSchema); + auto aggregateInputInfos = getAggregateInputInfos(logicalAggregate.getAllKeyExpressions(), + logicalAggregate.getAggregateExpressions(), *inSchema); + if (logicalAggregate.hasKeyExpressions()) { + return createHashAggregate(logicalAggregate.getKeyExpressions(), + logicalAggregate.getDependentKeyExpressions(), std::move(aggregateFunctions), + std::move(aggregateInputInfos), std::move(aggregatesOutputPos), *inSchema, *outSchema, + std::move(prevOperator), paramsString); } else { auto sharedState = make_shared(aggregateFunctions); auto aggregate = @@ -80,72 +90,36 @@ std::unique_ptr PlanMapper::mapLogicalAggregateToPhysical( sharedState, std::move(aggregateFunctions), std::move(aggregateInputInfos), std::move(prevOperator), getOperatorID(), paramsString); return make_unique( - sharedState, outputAggVectorsPos, std::move(aggregate), getOperatorID(), paramsString); + sharedState, aggregatesOutputPos, std::move(aggregate), getOperatorID(), paramsString); } } std::unique_ptr PlanMapper::createHashAggregate( - std::vector> aggregateFunctions, - std::vector> inputAggregateInfo, - std::vector outputAggVectorsPos, const expression_vector& groupByExpressions, - std::unique_ptr prevOperator, const Schema& inSchema, const Schema& outSchema, + const binder::expression_vector& keyExpressions, + const binder::expression_vector& dependentKeyExpressions, + std::vector> aggregateFunctions, + std::vector> aggregateInputInfos, + std::vector aggregatesOutputPos, const planner::Schema& inSchema, + const planner::Schema& outSchema, std::unique_ptr prevOperator, const std::string& paramsString) { - expression_vector groupByHashExpressions; - expression_vector groupByNonHashExpressions; - std::unordered_set HashPrimaryKeysNodeId; - for (auto& expressionToGroupBy : groupByExpressions) { - if (expressionToGroupBy->expressionType == PROPERTY) { - auto& propertyExpression = (const PropertyExpression&)(*expressionToGroupBy); - if (propertyExpression.isInternalID()) { - groupByHashExpressions.push_back(expressionToGroupBy); - HashPrimaryKeysNodeId.insert(propertyExpression.getVariableName()); - } else if (HashPrimaryKeysNodeId.contains(propertyExpression.getVariableName())) { - groupByNonHashExpressions.push_back(expressionToGroupBy); - } else { - groupByHashExpressions.push_back(expressionToGroupBy); - } - } else { - groupByHashExpressions.push_back(expressionToGroupBy); - } - } - std::vector inputGroupByHashKeyVectorsPos; - std::vector inputGroupByNonHashKeyVectorsPos; - std::vector isInputGroupByHashKeyVectorFlat; - std::vector outputGroupByKeyVectorsPos; - appendGroupByExpressions(groupByHashExpressions, inputGroupByHashKeyVectorsPos, - outputGroupByKeyVectorsPos, inSchema, outSchema, isInputGroupByHashKeyVectorFlat); - appendGroupByExpressions(groupByNonHashExpressions, inputGroupByNonHashKeyVectorsPos, - outputGroupByKeyVectorsPos, inSchema, outSchema, isInputGroupByHashKeyVectorFlat); auto sharedState = make_shared(aggregateFunctions); + auto flatKeyExpressions = getKeyExpressions(keyExpressions, inSchema, true /* isFlat */); + auto unFlatKeyExpressions = getKeyExpressions(keyExpressions, inSchema, false /* isFlat */); auto aggregate = make_unique(std::make_unique(inSchema), - sharedState, inputGroupByHashKeyVectorsPos, inputGroupByNonHashKeyVectorsPos, - isInputGroupByHashKeyVectorFlat, std::move(aggregateFunctions), - std::move(inputAggregateInfo), std::move(prevOperator), getOperatorID(), paramsString); - auto aggregateScan = - std::make_unique(sharedState, outputGroupByKeyVectorsPos, - std::move(outputAggVectorsPos), std::move(aggregate), getOperatorID(), paramsString); - return aggregateScan; -} - -void PlanMapper::appendGroupByExpressions(const expression_vector& groupByExpressions, - std::vector& inputGroupByHashKeyVectorsPos, - std::vector& outputGroupByKeyVectorsPos, const Schema& inSchema, - const Schema& outSchema, std::vector& isInputGroupByHashKeyVectorFlat) { - for (auto& expression : groupByExpressions) { - if (inSchema.getGroup(expression->getUniqueName())->isFlat()) { - inputGroupByHashKeyVectorsPos.emplace_back(inSchema.getExpressionPos(*expression)); - outputGroupByKeyVectorsPos.emplace_back(outSchema.getExpressionPos(*expression)); - isInputGroupByHashKeyVectorFlat.push_back(true); - } - } - - for (auto& expression : groupByExpressions) { - if (!inSchema.getGroup(expression->getUniqueName())->isFlat()) { - inputGroupByHashKeyVectorsPos.emplace_back(inSchema.getExpressionPos(*expression)); - outputGroupByKeyVectorsPos.emplace_back(outSchema.getExpressionPos(*expression)); - isInputGroupByHashKeyVectorFlat.push_back(false); - } - } + sharedState, getExpressionsDataPos(flatKeyExpressions, inSchema), + getExpressionsDataPos(unFlatKeyExpressions, inSchema), + getExpressionsDataPos(dependentKeyExpressions, inSchema), std::move(aggregateFunctions), + std::move(aggregateInputInfos), std::move(prevOperator), getOperatorID(), paramsString); + binder::expression_vector outputExpressions; + outputExpressions.insert( + outputExpressions.end(), flatKeyExpressions.begin(), flatKeyExpressions.end()); + outputExpressions.insert( + outputExpressions.end(), unFlatKeyExpressions.begin(), unFlatKeyExpressions.end()); + outputExpressions.insert( + outputExpressions.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end()); + return std::make_unique(sharedState, + getExpressionsDataPos(outputExpressions, outSchema), std::move(aggregatesOutputPos), + std::move(aggregate), getOperatorID(), paramsString); } } // namespace processor diff --git a/src/processor/mapper/map_distinct.cpp b/src/processor/mapper/map_distinct.cpp index 0c36dfa344..448b1e1b87 100644 --- a/src/processor/mapper/map_distinct.cpp +++ b/src/processor/mapper/map_distinct.cpp @@ -12,15 +12,15 @@ std::unique_ptr PlanMapper::mapLogicalDistinctToPhysical( LogicalOperator* logicalOperator) { auto& logicalDistinct = (const LogicalDistinct&)*logicalOperator; auto outSchema = logicalDistinct.getSchema(); - auto inSchema = logicalDistinct.getSchemaBeforeDistinct(); + auto inSchema = logicalDistinct.getChild(0)->getSchema(); auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0)); - std::vector> emptyAggregateFunctions; + std::vector> emptyAggFunctions; std::vector> emptyAggInputInfos; - std::vector emptyOutputAggVectorsPos; - return createHashAggregate(std::move(emptyAggregateFunctions), std::move(emptyAggInputInfos), - emptyOutputAggVectorsPos, logicalDistinct.getExpressionsToDistinct(), - std::move(prevOperator), *inSchema, *outSchema, - logicalDistinct.getExpressionsForPrinting()); + std::vector emptyAggregatesOutputPos; + return createHashAggregate(logicalDistinct.getKeyExpressions(), + logicalDistinct.getDependentKeyExpressions(), std::move(emptyAggFunctions), + std::move(emptyAggInputInfos), std::move(emptyAggregatesOutputPos), *inSchema, *outSchema, + std::move(prevOperator), logicalDistinct.getExpressionsForPrinting()); } } // namespace processor diff --git a/src/processor/mapper/plan_mapper.cpp b/src/processor/mapper/plan_mapper.cpp index 665b859b6c..6c06988ef0 100644 --- a/src/processor/mapper/plan_mapper.cpp +++ b/src/processor/mapper/plan_mapper.cpp @@ -162,5 +162,14 @@ std::unique_ptr PlanMapper::appendResultCollector( binder::ExpressionUtil::toString(expressionsToCollect)); } +std::vector PlanMapper::getExpressionsDataPos( + const binder::expression_vector& expressions, const planner::Schema& schema) { + std::vector result; + for (auto& expression : expressions) { + result.emplace_back(schema.getExpressionPos(*expression)); + } + return result; +} + } // namespace processor } // namespace kuzu diff --git a/src/processor/operator/aggregate/aggregate_hash_table.cpp b/src/processor/operator/aggregate/aggregate_hash_table.cpp index b5c4fa03d7..5d62949630 100644 --- a/src/processor/operator/aggregate/aggregate_hash_table.cpp +++ b/src/processor/operator/aggregate/aggregate_hash_table.cpp @@ -14,29 +14,29 @@ namespace kuzu { namespace processor { AggregateHashTable::AggregateHashTable(MemoryManager& memoryManager, - std::vector groupByHashKeysDataTypes, - std::vector groupByNonHashKeysDataTypes, + std::vector keyDataTypes, std::vector dependentKeyDataTypes, const std::vector>& aggregateFunctions, uint64_t numEntriesToAllocate) - : BaseHashTable{memoryManager}, groupByHashKeysDataTypes{std::move(groupByHashKeysDataTypes)}, - groupByNonHashKeysDataTypes{std::move(groupByNonHashKeysDataTypes)} { + : BaseHashTable{memoryManager}, keyDataTypes{std::move(keyDataTypes)}, + dependentKeyDataTypes{std::move(dependentKeyDataTypes)} { initializeFT(aggregateFunctions); initializeHashTable(numEntriesToAllocate); distinctHashTables = AggregateHashTableUtils::createDistinctHashTables( - memoryManager, this->groupByHashKeysDataTypes, this->aggregateFunctions); + memoryManager, this->keyDataTypes, this->aggregateFunctions); initializeTmpVectors(); } void AggregateHashTable::append(const std::vector& groupByFlatHashKeyVectors, const std::vector& groupByUnFlatHashKeyVectors, - const std::vector& groupByNonHashKeyVectors, + const std::vector& groupByDependentKeyVectors, const std::vector>& aggregateInputs, uint64_t resultSetMultiplicity) { resizeHashTableIfNecessary(groupByUnFlatHashKeyVectors.empty() ? 1 : groupByUnFlatHashKeyVectors[0]->state->selVector->selectedSize); computeVectorHashes(groupByFlatHashKeyVectors, groupByUnFlatHashKeyVectors); - findHashSlots(groupByFlatHashKeyVectors, groupByUnFlatHashKeyVectors, groupByNonHashKeyVectors); + findHashSlots( + groupByFlatHashKeyVectors, groupByUnFlatHashKeyVectors, groupByDependentKeyVectors); updateAggStates(groupByFlatHashKeyVectors, groupByUnFlatHashKeyVectors, aggregateInputs, resultSetMultiplicity); } @@ -71,25 +71,23 @@ bool AggregateHashTable::isAggregateValueDistinctForGroupByKeys( void AggregateHashTable::merge(AggregateHashTable& other) { std::shared_ptr vectorsToScanState = std::make_shared(); - std::vector vectorsToScan( - groupByHashKeysDataTypes.size() + groupByNonHashKeysDataTypes.size()); - std::vector groupByHashVectors(groupByHashKeysDataTypes.size()); - std::vector groupByNonHashVectors(groupByNonHashKeysDataTypes.size()); - std::vector> hashKeyVectors(groupByHashKeysDataTypes.size()); + std::vector vectorsToScan(keyDataTypes.size() + dependentKeyDataTypes.size()); + std::vector groupByHashVectors(keyDataTypes.size()); + std::vector groupByNonHashVectors(dependentKeyDataTypes.size()); + std::vector> hashKeyVectors(keyDataTypes.size()); std::vector> nonHashKeyVectors(groupByNonHashVectors.size()); - for (auto i = 0u; i < groupByHashKeysDataTypes.size(); i++) { - auto hashKeyVec = - std::make_unique(groupByHashKeysDataTypes[i], &memoryManager); + for (auto i = 0u; i < keyDataTypes.size(); i++) { + auto hashKeyVec = std::make_unique(keyDataTypes[i], &memoryManager); hashKeyVec->state = vectorsToScanState; vectorsToScan[i] = hashKeyVec.get(); groupByHashVectors[i] = hashKeyVec.get(); hashKeyVectors[i] = std::move(hashKeyVec); } - for (auto i = 0u; i < groupByNonHashKeysDataTypes.size(); i++) { + for (auto i = 0u; i < dependentKeyDataTypes.size(); i++) { auto nonHashKeyVec = - std::make_unique(groupByNonHashKeysDataTypes[i], &memoryManager); + std::make_unique(dependentKeyDataTypes[i], &memoryManager); nonHashKeyVec->state = vectorsToScanState; - vectorsToScan[i + groupByHashKeysDataTypes.size()] = nonHashKeyVec.get(); + vectorsToScan[i + keyDataTypes.size()] = nonHashKeyVec.get(); groupByNonHashVectors[i] = nonHashKeyVec.get(); nonHashKeyVectors[i] = std::move(nonHashKeyVec); } @@ -137,27 +135,26 @@ void AggregateHashTable::initializeFT( auto isUnflat = false; auto dataChunkPos = 0u; std::unique_ptr tableSchema = std::make_unique(); - aggStateColIdxInFT = - this->groupByHashKeysDataTypes.size() + this->groupByNonHashKeysDataTypes.size(); + aggStateColIdxInFT = keyDataTypes.size() + dependentKeyDataTypes.size(); compareFuncs.resize(aggStateColIdxInFT); auto colIdx = 0u; - for (auto& dataType : this->groupByHashKeysDataTypes) { + for (auto& dataType : keyDataTypes) { auto size = Types::getDataTypeSize(dataType); tableSchema->appendColumn(std::make_unique(isUnflat, dataChunkPos, size)); hasStrCol = hasStrCol || dataType.typeID == STRING; compareFuncs[colIdx] = getCompareEntryWithKeysFunc(dataType.typeID); - numBytesForGroupByHashKeys += size; + numBytesForKeys += size; colIdx++; } - for (auto& dataType : this->groupByNonHashKeysDataTypes) { + for (auto& dataType : dependentKeyDataTypes) { auto size = Types::getDataTypeSize(dataType); tableSchema->appendColumn(std::make_unique(isUnflat, dataChunkPos, size)); hasStrCol = hasStrCol || dataType.typeID == STRING; compareFuncs[colIdx] = getCompareEntryWithKeysFunc(dataType.typeID); - numBytesForGroupByNonHashKeys += size; + numBytesForDependentKeys += size; colIdx++; } - aggStateColOffsetInFT = numBytesForGroupByHashKeys + numBytesForGroupByNonHashKeys; + aggStateColOffsetInFT = numBytesForKeys + numBytesForDependentKeys; aggregateFunctions.resize(aggFuncs.size()); updateAggFuncs.resize(aggFuncs.size()); @@ -280,7 +277,8 @@ void AggregateHashTable::initializeFTEntryWithUnflatVec( void AggregateHashTable::initializeFTEntries( const std::vector& groupByFlatHashKeyVectors, const std::vector& groupByUnflatHashKeyVectors, - const std::vector& groupByNonHashKeyVectors, uint64_t numFTEntriesToInitialize) { + const std::vector& groupByDependentKeyVectors, + uint64_t numFTEntriesToInitialize) { auto colIdx = 0u; for (auto flatKeyVector : groupByFlatHashKeyVectors) { initializeFTEntryWithFlatVec(flatKeyVector, numFTEntriesToInitialize, colIdx++); @@ -288,11 +286,11 @@ void AggregateHashTable::initializeFTEntries( for (auto unflatKeyVector : groupByUnflatHashKeyVectors) { initializeFTEntryWithUnflatVec(unflatKeyVector, numFTEntriesToInitialize, colIdx++); } - for (auto nonHashKeyVector : groupByNonHashKeyVectors) { - if (nonHashKeyVector->state->isFlat()) { - initializeFTEntryWithFlatVec(nonHashKeyVector, numFTEntriesToInitialize, colIdx++); + for (auto dependentKeyVector : groupByDependentKeyVectors) { + if (dependentKeyVector->state->isFlat()) { + initializeFTEntryWithFlatVec(dependentKeyVector, numFTEntriesToInitialize, colIdx++); } else { - initializeFTEntryWithUnflatVec(nonHashKeyVector, numFTEntriesToInitialize, colIdx++); + initializeFTEntryWithUnflatVec(dependentKeyVector, numFTEntriesToInitialize, colIdx++); } } for (auto i = 0u; i < numFTEntriesToInitialize; i++) { @@ -359,7 +357,7 @@ void AggregateHashTable::increaseHashSlotIdxes(uint64_t numNoMatches) { void AggregateHashTable::findHashSlots(const std::vector& groupByFlatHashKeyVectors, const std::vector& groupByUnflatHashKeyVectors, - const std::vector& groupByNonHashKeyVectors) { + const std::vector& groupByDependentKeyVectors) { initTmpHashSlotsAndIdxes(); auto numEntriesToFindHashSlots = groupByUnflatHashKeyVectors.empty() ? @@ -384,9 +382,9 @@ void AggregateHashTable::findHashSlots(const std::vector& groupByF } } initializeFTEntries(groupByFlatHashKeyVectors, groupByUnflatHashKeyVectors, - groupByNonHashKeyVectors, numFTEntriesToUpdate); - numNoMatches = matchFTEntries(groupByFlatHashKeyVectors, groupByUnflatHashKeyVectors, - groupByNonHashKeyVectors, numMayMatches, numNoMatches); + groupByDependentKeyVectors, numFTEntriesToUpdate); + numNoMatches = matchFTEntries( + groupByFlatHashKeyVectors, groupByUnflatHashKeyVectors, numMayMatches, numNoMatches); increaseHashSlotIdxes(numNoMatches); numEntriesToFindHashSlots = numNoMatches; memcpy(tmpValueIdxes.get(), noMatchIdxes.get(), DEFAULT_VECTOR_CAPACITY * sizeof(uint64_t)); @@ -610,29 +608,17 @@ uint64_t AggregateHashTable::matchFlatVecWithFTColumn( uint64_t AggregateHashTable::matchFTEntries( const std::vector& groupByFlatHashKeyVectors, - const std::vector& groupByUnflatHashKeyVectors, - const std::vector& groupByNonHashKeyVectors, uint64_t numMayMatches, + const std::vector& groupByUnflatHashKeyVectors, uint64_t numMayMatches, uint64_t numNoMatches) { auto colIdx = 0u; for (auto& groupByFlatHashKeyVector : groupByFlatHashKeyVectors) { numMayMatches = matchFlatVecWithFTColumn( groupByFlatHashKeyVector, numMayMatches, numNoMatches, colIdx++); } - for (auto& groupByUnflatHashKeyVector : groupByUnflatHashKeyVectors) { numMayMatches = matchUnflatVecWithFTColumn( groupByUnflatHashKeyVector, numMayMatches, numNoMatches, colIdx++); } - - for (auto& groupByNonHashKeyVector : groupByNonHashKeyVectors) { - if (groupByNonHashKeyVector->state->isFlat()) { - numMayMatches = matchFlatVecWithFTColumn( - groupByNonHashKeyVector, numMayMatches, numNoMatches, colIdx++); - } else { - numMayMatches = matchUnflatVecWithFTColumn( - groupByNonHashKeyVector, numMayMatches, numNoMatches, colIdx++); - } - } return numNoMatches; } diff --git a/src/processor/operator/aggregate/hash_aggregate.cpp b/src/processor/operator/aggregate/hash_aggregate.cpp index ad3363edbd..e6b89ab71d 100644 --- a/src/processor/operator/aggregate/hash_aggregate.cpp +++ b/src/processor/operator/aggregate/hash_aggregate.cpp @@ -49,30 +49,31 @@ std::pair HashAggregateSharedState::getNextRangeToRead() { void HashAggregate::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { BaseAggregate::initLocalStateInternal(resultSet, context); - std::vector groupByHashKeysDataTypes; - for (auto i = 0u; i < groupByHashKeyVectorsPos.size(); i++) { - auto vector = resultSet->getValueVector(groupByHashKeyVectorsPos[i]).get(); - if (isGroupByHashKeyVectorFlat[i]) { - groupByFlatHashKeyVectors.push_back(vector); - } else { - groupByUnflatHashKeyVectors.push_back(vector); - } - groupByHashKeysDataTypes.push_back(vector->dataType); + std::vector keyDataTypes; + for (auto& pos : flatKeysPos) { + auto vector = resultSet->getValueVector(pos).get(); + flatKeyVectors.push_back(vector); + keyDataTypes.push_back(vector->dataType); + } + for (auto& pos : unFlatKeysPos) { + auto vector = resultSet->getValueVector(pos).get(); + unFlatKeyVectors.push_back(vector); + keyDataTypes.push_back(vector->dataType); } - std::vector groupByNonHashKeysDataTypes; - for (auto& dataPos : groupByNonHashKeyVectorsPos) { - auto vector = resultSet->getValueVector(dataPos).get(); - groupByNonHashKeyVectors.push_back(vector); - groupByNonHashKeysDataTypes.push_back(vector->dataType); + std::vector payloadDataTypes; + for (auto& pos : dependentKeysPos) { + auto vector = resultSet->getValueVector(pos).get(); + dependentKeyVectors.push_back(vector); + payloadDataTypes.push_back(vector->dataType); } - localAggregateHashTable = make_unique(*context->memoryManager, - groupByHashKeysDataTypes, groupByNonHashKeysDataTypes, aggregateFunctions, 0); + localAggregateHashTable = make_unique( + *context->memoryManager, keyDataTypes, payloadDataTypes, aggregateFunctions, 0); } void HashAggregate::executeInternal(ExecutionContext* context) { while (children[0]->getNextTuple(context)) { - localAggregateHashTable->append(groupByFlatHashKeyVectors, groupByUnflatHashKeyVectors, - groupByNonHashKeyVectors, aggregateInputs, resultSet->multiplicity); + localAggregateHashTable->append(flatKeyVectors, unFlatKeyVectors, dependentKeyVectors, + aggregateInputs, resultSet->multiplicity); } sharedState->appendAggregateHashTable(std::move(localAggregateHashTable)); } diff --git a/test/test_files/tinysnb/agg/multi_query_part.test b/test/test_files/tinysnb/agg/multi_query_part.test index 61953ccc46..9f443c75bb 100644 --- a/test/test_files/tinysnb/agg/multi_query_part.test +++ b/test/test_files/tinysnb/agg/multi_query_part.test @@ -38,3 +38,9 @@ True 3|5|3 7|8|2 7|9|2 + +-NAME GroupByMultiQueryTest3 +-QUERY MATCH (a:person) WHERE a.ID > 4 WITH a, a.age AS foo MATCH (a)-[:knows]->(b:person) RETURN a.ID, foo, COUNT(*) +---- 2 +5|20|3 +7|20|2