Skip to content

Commit

Permalink
Add aggregate key dependency optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed May 7, 2023
1 parent bfc1b0d commit 1aef377
Show file tree
Hide file tree
Showing 17 changed files with 369 additions and 239 deletions.
26 changes: 26 additions & 0 deletions src/include/optimizer/agg_key_dependency_optimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

// This optimizer analyzes the dependency between group by keys. If key2 depends on key1 (e.g. key1
// is a primary key column) we only hash on key1 and saves key2 as a payload.
class AggKeyDependencyOptimizer : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);

void visitAggregate(planner::LogicalOperator* op) override;
void visitDistinct(planner::LogicalOperator* op) override;

std::pair<binder::expression_vector, binder::expression_vector> resolveKeysAndDependentKeys(
const binder::expression_vector& keys);
};

} // namespace optimizer
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@ namespace planner {

class LogicalAggregate : public LogicalOperator {
public:
LogicalAggregate(binder::expression_vector expressionsToGroupBy,
binder::expression_vector expressionsToAggregate, std::shared_ptr<LogicalOperator> child)
LogicalAggregate(binder::expression_vector keyExpressions,
binder::expression_vector aggregateExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)},
expressionsToGroupBy{std::move(expressionsToGroupBy)}, expressionsToAggregate{std::move(
expressionsToAggregate)} {}
keyExpressions{std::move(keyExpressions)}, aggregateExpressions{
std::move(aggregateExpressions)} {}
LogicalAggregate(binder::expression_vector keyExpressions,
binder::expression_vector dependentKeyExpressions,
binder::expression_vector aggregateExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)},
keyExpressions{std::move(keyExpressions)}, dependentKeyExpressions{std::move(
dependentKeyExpressions)},
aggregateExpressions{std::move(aggregateExpressions)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;
Expand All @@ -21,30 +28,42 @@ class LogicalAggregate : public LogicalOperator {

std::string getExpressionsForPrinting() const override;

inline bool hasExpressionsToGroupBy() const { return !expressionsToGroupBy.empty(); }
inline binder::expression_vector getExpressionsToGroupBy() const {
return expressionsToGroupBy;
inline bool hasKeyExpressions() const { return !keyExpressions.empty(); }
inline binder::expression_vector getKeyExpressions() const { return keyExpressions; }
inline void setKeyExpressions(binder::expression_vector expressions) {
keyExpressions = std::move(expressions);
}
inline binder::expression_vector getDependentKeyExpressions() const {
return dependentKeyExpressions;
}
inline void setDependentKeyExpressions(binder::expression_vector expressions) {
dependentKeyExpressions = std::move(expressions);
}
inline size_t getNumAggregateExpressions() const { return expressionsToAggregate.size(); }
inline std::shared_ptr<binder::Expression> getAggregateExpression(
common::vector_idx_t idx) const {
return expressionsToAggregate[idx];
inline binder::expression_vector getAllKeyExpressions() const {
binder::expression_vector result;
result.insert(result.end(), keyExpressions.begin(), keyExpressions.end());
result.insert(result.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end());
return result;
}
inline binder::expression_vector getExpressionsToAggregate() const {
return expressionsToAggregate;
inline binder::expression_vector getAggregateExpressions() const {
return aggregateExpressions;
}

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAggregate>(
expressionsToGroupBy, expressionsToAggregate, children[0]->copy());
keyExpressions, dependentKeyExpressions, aggregateExpressions, children[0]->copy());
}

private:
bool hasDistinctAggregate();
void insertAllExpressionsToGroupAndScope(f_group_pos groupPos);

private:
binder::expression_vector expressionsToGroupBy;
binder::expression_vector expressionsToAggregate;
binder::expression_vector keyExpressions;
// A dependentKeyExpression depend on a keyExpression (e.g. a.age depends on a.ID) and will not
// be treated as a hash key during hash aggregation.
binder::expression_vector dependentKeyExpressions;
binder::expression_vector aggregateExpressions;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ namespace planner {
class LogicalDistinct : public LogicalOperator {
public:
LogicalDistinct(
binder::expression_vector expressionsToDistinct, std::shared_ptr<LogicalOperator> child)
binder::expression_vector keyExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)},
expressionsToDistinct{std::move(expressionsToDistinct)} {}
keyExpressions{std::move(keyExpressions)} {}
LogicalDistinct(binder::expression_vector keyExpressions,
binder::expression_vector dependentKeyExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)},
keyExpressions{std::move(keyExpressions)}, dependentKeyExpressions{
std::move(dependentKeyExpressions)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;
Expand All @@ -20,17 +25,32 @@ class LogicalDistinct : public LogicalOperator {

std::string getExpressionsForPrinting() const override;

inline binder::expression_vector getExpressionsToDistinct() const {
return expressionsToDistinct;
inline binder::expression_vector getKeyExpressions() const { return keyExpressions; }
inline void setKeyExpressions(binder::expression_vector expressions) {
keyExpressions = std::move(expressions);
}
inline binder::expression_vector getDependentKeyExpressions() const {
return dependentKeyExpressions;
}
inline void setDependentKeyExpressions(binder::expression_vector expressions) {
dependentKeyExpressions = std::move(expressions);
}
inline binder::expression_vector getAllDistinctExpressions() const {
binder::expression_vector result;
result.insert(result.end(), keyExpressions.begin(), keyExpressions.end());
result.insert(result.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end());
return result;
}
inline Schema* getSchemaBeforeDistinct() const { return children[0]->getSchema(); }

std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalDistinct>(expressionsToDistinct, children[0]->copy());
return make_unique<LogicalDistinct>(
keyExpressions, dependentKeyExpressions, children[0]->copy());
}

private:
binder::expression_vector expressionsToDistinct;
binder::expression_vector keyExpressions;
// See logical_aggregate.h for details.
binder::expression_vector dependentKeyExpressions;
};

} // namespace planner
Expand Down
25 changes: 12 additions & 13 deletions src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,22 @@ class PlanMapper {

inline uint32_t getOperatorID() { return physicalOperatorID++; }

BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema,
const binder::expression_vector& keys, const binder::expression_vector& payloads);

std::unique_ptr<PhysicalOperator> createHashAggregate(
const binder::expression_vector& keyExpressions,
const binder::expression_vector& dependentKeyExpressions,
std::vector<std::unique_ptr<function::AggregateFunction>> aggregateFunctions,
std::vector<std::unique_ptr<AggregateInputInfo>> inputAggregateInfo,
std::vector<DataPos> outputAggVectorsPos,
const binder::expression_vector& groupByExpressions,
std::unique_ptr<PhysicalOperator> prevOperator, const planner::Schema& inSchema,
const planner::Schema& outSchema, const std::string& paramsString);
std::vector<std::unique_ptr<AggregateInputInfo>> aggregateInputInfos,
std::vector<DataPos> aggregatesOutputPos, const planner::Schema& inSchema,
const planner::Schema& outSchema, std::unique_ptr<PhysicalOperator> prevOperator,
const std::string& paramsString);

void appendGroupByExpressions(const binder::expression_vector& groupByExpressions,
std::vector<DataPos>& inputGroupByHashKeyVectorsPos,
std::vector<DataPos>& outputGroupByKeyVectorsPos, const planner::Schema& inSchema,
const planner::Schema& outSchema, std::vector<bool>& isInputGroupByHashKeyVectorFlat);

BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema,
const binder::expression_vector& keys, const binder::expression_vector& payloads);
static void mapAccHashJoin(PhysicalOperator* probe);

void mapAccHashJoin(PhysicalOperator* probe);
static std::vector<DataPos> getExpressionsDataPos(
const binder::expression_vector& expressions, const planner::Schema& schema);

public:
storage::StorageManager& storageManager;
Expand Down
35 changes: 17 additions & 18 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ class AggregateHashTable : public BaseHashTable {
public:
// Used by distinct aggregate hash table only.
AggregateHashTable(storage::MemoryManager& memoryManager,
const std::vector<common::DataType>& groupByHashKeysDataTypes,
const std::vector<common::DataType>& keysDataTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions,
uint64_t numEntriesToAllocate)
: AggregateHashTable(memoryManager, groupByHashKeysDataTypes,
std::vector<common::DataType>(), aggregateFunctions, numEntriesToAllocate) {}
: AggregateHashTable(memoryManager, keysDataTypes, std::vector<common::DataType>(),
aggregateFunctions, numEntriesToAllocate) {}

AggregateHashTable(storage::MemoryManager& memoryManager,
std::vector<common::DataType> groupByHashKeysDataTypes,
std::vector<common::DataType> groupByNonHashKeysDataTypes,
std::vector<common::DataType> keysDataTypes,
std::vector<common::DataType> payloadsDataTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions,
uint64_t numEntriesToAllocate);

Expand All @@ -61,17 +61,17 @@ class AggregateHashTable : public BaseHashTable {
inline uint64_t getNumEntries() const { return factorizedTable->getNumTuples(); }

inline void append(const std::vector<common::ValueVector*>& groupByFlatKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatKeyVectors,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
uint64_t resultSetMultiplicity) {
append(groupByFlatKeyVectors, groupByUnFlatHashKeyVectors,
std::vector<common::ValueVector*>(), aggregateInputs, resultSetMultiplicity);
append(groupByFlatKeyVectors, groupByUnFlatKeyVectors, std::vector<common::ValueVector*>(),
aggregateInputs, resultSetMultiplicity);
}

//! update aggregate states for an input
void append(const std::vector<common::ValueVector*>& groupByFlatKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByNonHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatKeyVectors,
const std::vector<common::ValueVector*>& groupByDependentKeyVectors,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
uint64_t resultSetMultiplicity);

Expand Down Expand Up @@ -107,7 +107,7 @@ class AggregateHashTable : public BaseHashTable {

void initializeFTEntries(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByNonHashKeyVectors,
const std::vector<common::ValueVector*>& groupByDependentKeyVectors,
uint64_t numFTEntriesToInitialize);

uint8_t* createEntryInDistinctHT(
Expand All @@ -121,7 +121,7 @@ class AggregateHashTable : public BaseHashTable {

void findHashSlots(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByNonHashKeyVectors);
const std::vector<common::ValueVector*>& groupByDependentKeyVectors);

void computeAndCombineVecHash(
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors, uint32_t startVecIdx);
Expand Down Expand Up @@ -157,8 +157,7 @@ class AggregateHashTable : public BaseHashTable {

uint64_t matchFTEntries(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByNonHashKeyVectors, uint64_t numMayMatches,
uint64_t numNoMatches);
uint64_t numMayMatches, uint64_t numNoMatches);

void fillEntryWithInitialNullAggregateState(uint8_t* entry);

Expand Down Expand Up @@ -219,8 +218,8 @@ class AggregateHashTable : public BaseHashTable {
common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset);

private:
std::vector<common::DataType> groupByHashKeysDataTypes;
std::vector<common::DataType> groupByNonHashKeysDataTypes;
std::vector<common::DataType> keyDataTypes;
std::vector<common::DataType> dependentKeyDataTypes;
std::vector<std::unique_ptr<function::AggregateFunction>> aggregateFunctions;

//! special handling of distinct aggregate
Expand All @@ -229,8 +228,8 @@ class AggregateHashTable : public BaseHashTable {
uint32_t hashColOffsetInFT;
uint32_t aggStateColOffsetInFT;
uint32_t aggStateColIdxInFT;
uint32_t numBytesForGroupByHashKeys = 0;
uint32_t numBytesForGroupByNonHashKeys = 0;
uint32_t numBytesForKeys = 0;
uint32_t numBytesForDependentKeys = 0;
std::vector<compare_function_t> compareFuncs;
std::vector<update_agg_function_t> updateAggFuncs;
bool hasStrCol = false;
Expand Down
31 changes: 14 additions & 17 deletions src/include/processor/operator/aggregate/hash_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,15 @@ class HashAggregateSharedState : public BaseAggregateSharedState {
class HashAggregate : public BaseAggregate {
public:
HashAggregate(std::unique_ptr<ResultSetDescriptor> resultSetDescriptor,
std::shared_ptr<HashAggregateSharedState> sharedState,
std::vector<DataPos> inputGroupByHashKeyVectorsPos,
std::vector<DataPos> inputGroupByNonHashKeyVectorsPos,
std::vector<bool> isInputGroupByHashKeyVectorFlat,
std::shared_ptr<HashAggregateSharedState> sharedState, std::vector<DataPos> flatKeysPos,
std::vector<DataPos> unFlatKeysPos, std::vector<DataPos> dependentKeysPos,
std::vector<std::unique_ptr<function::AggregateFunction>> aggregateFunctions,
std::vector<std::unique_ptr<AggregateInputInfo>> aggregateInputInfos,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: BaseAggregate{std::move(resultSetDescriptor), std::move(aggregateFunctions),
std::move(aggregateInputInfos), std::move(child), id, paramsString},
groupByHashKeyVectorsPos{std::move(inputGroupByHashKeyVectorsPos)},
groupByNonHashKeyVectorsPos{std::move(inputGroupByNonHashKeyVectorsPos)},
isGroupByHashKeyVectorFlat{std::move(isInputGroupByHashKeyVectorFlat)},
sharedState{std::move(sharedState)} {}
flatKeysPos{std::move(flatKeysPos)}, unFlatKeysPos{std::move(unFlatKeysPos)},
dependentKeysPos{std::move(dependentKeysPos)}, sharedState{std::move(sharedState)} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

Expand All @@ -54,18 +50,19 @@ class HashAggregate : public BaseAggregate {
void finalize(ExecutionContext* context) override;

inline std::unique_ptr<PhysicalOperator> clone() override {
return make_unique<HashAggregate>(resultSetDescriptor->copy(), sharedState,
groupByHashKeyVectorsPos, groupByNonHashKeyVectorsPos, isGroupByHashKeyVectorFlat,
cloneAggFunctions(), cloneAggInputInfos(), children[0]->clone(), id, paramsString);
return make_unique<HashAggregate>(resultSetDescriptor->copy(), sharedState, flatKeysPos,
unFlatKeysPos, dependentKeysPos, cloneAggFunctions(), cloneAggInputInfos(),
children[0]->clone(), id, paramsString);
}

private:
std::vector<DataPos> groupByHashKeyVectorsPos;
std::vector<DataPos> groupByNonHashKeyVectorsPos;
std::vector<bool> isGroupByHashKeyVectorFlat;
std::vector<common::ValueVector*> groupByFlatHashKeyVectors;
std::vector<common::ValueVector*> groupByUnflatHashKeyVectors;
std::vector<common::ValueVector*> groupByNonHashKeyVectors;
std::vector<DataPos> flatKeysPos;
std::vector<DataPos> unFlatKeysPos;
std::vector<DataPos> dependentKeysPos;

std::vector<common::ValueVector*> flatKeyVectors;
std::vector<common::ValueVector*> unFlatKeyVectors;
std::vector<common::ValueVector*> dependentKeyVectors;

std::shared_ptr<HashAggregateSharedState> sharedState;
std::unique_ptr<AggregateHashTable> localAggregateHashTable;
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(kuzu_optimizer
OBJECT
acc_hash_join_optimizer.cpp
agg_key_dependency_optimizer.cpp
factorization_rewriter.cpp
filter_push_down_optimizer.cpp
logical_operator_collector.cpp
Expand Down
Loading

0 comments on commit 1aef377

Please sign in to comment.