Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aggreate key dependency optimizer #1517

Merged
merged 1 commit into from
May 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/include/optimizer/agg_key_dependency_optimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

// This optimizer analyzes the dependency between group by keys. If key2 depends on key1 (e.g. key1
// is a primary key column) we only hash on key1 and saves key2 as a payload.
class AggKeyDependencyOptimizer : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);

void visitAggregate(planner::LogicalOperator* op) override;
void visitDistinct(planner::LogicalOperator* op) override;

std::pair<binder::expression_vector, binder::expression_vector> resolveKeysAndDependentKeys(
const binder::expression_vector& keys);
};

} // namespace optimizer
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@ namespace planner {

class LogicalAggregate : public LogicalOperator {
public:
LogicalAggregate(binder::expression_vector expressionsToGroupBy,
binder::expression_vector expressionsToAggregate, std::shared_ptr<LogicalOperator> child)
LogicalAggregate(binder::expression_vector keyExpressions,
binder::expression_vector aggregateExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)},
expressionsToGroupBy{std::move(expressionsToGroupBy)}, expressionsToAggregate{std::move(
expressionsToAggregate)} {}
keyExpressions{std::move(keyExpressions)}, aggregateExpressions{
std::move(aggregateExpressions)} {}
LogicalAggregate(binder::expression_vector keyExpressions,
binder::expression_vector dependentKeyExpressions,
binder::expression_vector aggregateExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)},
keyExpressions{std::move(keyExpressions)}, dependentKeyExpressions{std::move(
dependentKeyExpressions)},
aggregateExpressions{std::move(aggregateExpressions)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;
Expand All @@ -21,30 +28,42 @@ class LogicalAggregate : public LogicalOperator {

std::string getExpressionsForPrinting() const override;

inline bool hasExpressionsToGroupBy() const { return !expressionsToGroupBy.empty(); }
inline binder::expression_vector getExpressionsToGroupBy() const {
return expressionsToGroupBy;
inline bool hasKeyExpressions() const { return !keyExpressions.empty(); }
inline binder::expression_vector getKeyExpressions() const { return keyExpressions; }
inline void setKeyExpressions(binder::expression_vector expressions) {
keyExpressions = std::move(expressions);
}
inline binder::expression_vector getDependentKeyExpressions() const {
return dependentKeyExpressions;
}
inline void setDependentKeyExpressions(binder::expression_vector expressions) {
dependentKeyExpressions = std::move(expressions);
}
inline size_t getNumAggregateExpressions() const { return expressionsToAggregate.size(); }
inline std::shared_ptr<binder::Expression> getAggregateExpression(
common::vector_idx_t idx) const {
return expressionsToAggregate[idx];
inline binder::expression_vector getAllKeyExpressions() const {
binder::expression_vector result;
result.insert(result.end(), keyExpressions.begin(), keyExpressions.end());
result.insert(result.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end());
return result;
}
inline binder::expression_vector getExpressionsToAggregate() const {
return expressionsToAggregate;
inline binder::expression_vector getAggregateExpressions() const {
return aggregateExpressions;
}

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAggregate>(
expressionsToGroupBy, expressionsToAggregate, children[0]->copy());
keyExpressions, dependentKeyExpressions, aggregateExpressions, children[0]->copy());
}

private:
bool hasDistinctAggregate();
void insertAllExpressionsToGroupAndScope(f_group_pos groupPos);

private:
binder::expression_vector expressionsToGroupBy;
binder::expression_vector expressionsToAggregate;
binder::expression_vector keyExpressions;
// A dependentKeyExpression depend on a keyExpression (e.g. a.age depends on a.ID) and will not
// be treated as a hash key during hash aggregation.
binder::expression_vector dependentKeyExpressions;
binder::expression_vector aggregateExpressions;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ namespace planner {
class LogicalDistinct : public LogicalOperator {
public:
LogicalDistinct(
binder::expression_vector expressionsToDistinct, std::shared_ptr<LogicalOperator> child)
binder::expression_vector keyExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)},
expressionsToDistinct{std::move(expressionsToDistinct)} {}
keyExpressions{std::move(keyExpressions)} {}
LogicalDistinct(binder::expression_vector keyExpressions,
binder::expression_vector dependentKeyExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)},
keyExpressions{std::move(keyExpressions)}, dependentKeyExpressions{
std::move(dependentKeyExpressions)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;
Expand All @@ -20,17 +25,32 @@ class LogicalDistinct : public LogicalOperator {

std::string getExpressionsForPrinting() const override;

inline binder::expression_vector getExpressionsToDistinct() const {
return expressionsToDistinct;
inline binder::expression_vector getKeyExpressions() const { return keyExpressions; }
inline void setKeyExpressions(binder::expression_vector expressions) {
keyExpressions = std::move(expressions);
}
inline binder::expression_vector getDependentKeyExpressions() const {
return dependentKeyExpressions;
}
inline void setDependentKeyExpressions(binder::expression_vector expressions) {
dependentKeyExpressions = std::move(expressions);
}
inline binder::expression_vector getAllDistinctExpressions() const {
binder::expression_vector result;
result.insert(result.end(), keyExpressions.begin(), keyExpressions.end());
result.insert(result.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end());
return result;
}
inline Schema* getSchemaBeforeDistinct() const { return children[0]->getSchema(); }

std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalDistinct>(expressionsToDistinct, children[0]->copy());
return make_unique<LogicalDistinct>(
keyExpressions, dependentKeyExpressions, children[0]->copy());
}

private:
binder::expression_vector expressionsToDistinct;
binder::expression_vector keyExpressions;
// See logical_aggregate.h for details.
binder::expression_vector dependentKeyExpressions;
};

} // namespace planner
Expand Down
25 changes: 12 additions & 13 deletions src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,22 @@ class PlanMapper {

inline uint32_t getOperatorID() { return physicalOperatorID++; }

BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema,
const binder::expression_vector& keys, const binder::expression_vector& payloads);

std::unique_ptr<PhysicalOperator> createHashAggregate(
const binder::expression_vector& keyExpressions,
const binder::expression_vector& dependentKeyExpressions,
std::vector<std::unique_ptr<function::AggregateFunction>> aggregateFunctions,
std::vector<std::unique_ptr<AggregateInputInfo>> inputAggregateInfo,
std::vector<DataPos> outputAggVectorsPos,
const binder::expression_vector& groupByExpressions,
std::unique_ptr<PhysicalOperator> prevOperator, const planner::Schema& inSchema,
const planner::Schema& outSchema, const std::string& paramsString);
std::vector<std::unique_ptr<AggregateInputInfo>> aggregateInputInfos,
std::vector<DataPos> aggregatesOutputPos, const planner::Schema& inSchema,
const planner::Schema& outSchema, std::unique_ptr<PhysicalOperator> prevOperator,
const std::string& paramsString);

void appendGroupByExpressions(const binder::expression_vector& groupByExpressions,
std::vector<DataPos>& inputGroupByHashKeyVectorsPos,
std::vector<DataPos>& outputGroupByKeyVectorsPos, const planner::Schema& inSchema,
const planner::Schema& outSchema, std::vector<bool>& isInputGroupByHashKeyVectorFlat);

BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema,
const binder::expression_vector& keys, const binder::expression_vector& payloads);
static void mapAccHashJoin(PhysicalOperator* probe);

void mapAccHashJoin(PhysicalOperator* probe);
static std::vector<DataPos> getExpressionsDataPos(
const binder::expression_vector& expressions, const planner::Schema& schema);

public:
storage::StorageManager& storageManager;
Expand Down
35 changes: 17 additions & 18 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ class AggregateHashTable : public BaseHashTable {
public:
// Used by distinct aggregate hash table only.
AggregateHashTable(storage::MemoryManager& memoryManager,
const std::vector<common::DataType>& groupByHashKeysDataTypes,
const std::vector<common::DataType>& keysDataTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions,
uint64_t numEntriesToAllocate)
: AggregateHashTable(memoryManager, groupByHashKeysDataTypes,
std::vector<common::DataType>(), aggregateFunctions, numEntriesToAllocate) {}
: AggregateHashTable(memoryManager, keysDataTypes, std::vector<common::DataType>(),
aggregateFunctions, numEntriesToAllocate) {}

AggregateHashTable(storage::MemoryManager& memoryManager,
std::vector<common::DataType> groupByHashKeysDataTypes,
std::vector<common::DataType> groupByNonHashKeysDataTypes,
std::vector<common::DataType> keysDataTypes,
std::vector<common::DataType> payloadsDataTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions,
uint64_t numEntriesToAllocate);

Expand All @@ -61,17 +61,17 @@ class AggregateHashTable : public BaseHashTable {
inline uint64_t getNumEntries() const { return factorizedTable->getNumTuples(); }

inline void append(const std::vector<common::ValueVector*>& groupByFlatKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatKeyVectors,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
uint64_t resultSetMultiplicity) {
append(groupByFlatKeyVectors, groupByUnFlatHashKeyVectors,
std::vector<common::ValueVector*>(), aggregateInputs, resultSetMultiplicity);
append(groupByFlatKeyVectors, groupByUnFlatKeyVectors, std::vector<common::ValueVector*>(),
aggregateInputs, resultSetMultiplicity);
}

//! update aggregate states for an input
void append(const std::vector<common::ValueVector*>& groupByFlatKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByNonHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatKeyVectors,
const std::vector<common::ValueVector*>& groupByDependentKeyVectors,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
uint64_t resultSetMultiplicity);

Expand Down Expand Up @@ -107,7 +107,7 @@ class AggregateHashTable : public BaseHashTable {

void initializeFTEntries(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByNonHashKeyVectors,
const std::vector<common::ValueVector*>& groupByDependentKeyVectors,
uint64_t numFTEntriesToInitialize);

uint8_t* createEntryInDistinctHT(
Expand All @@ -121,7 +121,7 @@ class AggregateHashTable : public BaseHashTable {

void findHashSlots(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByNonHashKeyVectors);
const std::vector<common::ValueVector*>& groupByDependentKeyVectors);

void computeAndCombineVecHash(
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors, uint32_t startVecIdx);
Expand Down Expand Up @@ -157,8 +157,7 @@ class AggregateHashTable : public BaseHashTable {

uint64_t matchFTEntries(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByNonHashKeyVectors, uint64_t numMayMatches,
uint64_t numNoMatches);
uint64_t numMayMatches, uint64_t numNoMatches);

void fillEntryWithInitialNullAggregateState(uint8_t* entry);

Expand Down Expand Up @@ -219,8 +218,8 @@ class AggregateHashTable : public BaseHashTable {
common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset);

private:
std::vector<common::DataType> groupByHashKeysDataTypes;
std::vector<common::DataType> groupByNonHashKeysDataTypes;
std::vector<common::DataType> keyDataTypes;
std::vector<common::DataType> dependentKeyDataTypes;
std::vector<std::unique_ptr<function::AggregateFunction>> aggregateFunctions;

//! special handling of distinct aggregate
Expand All @@ -229,8 +228,8 @@ class AggregateHashTable : public BaseHashTable {
uint32_t hashColOffsetInFT;
uint32_t aggStateColOffsetInFT;
uint32_t aggStateColIdxInFT;
uint32_t numBytesForGroupByHashKeys = 0;
uint32_t numBytesForGroupByNonHashKeys = 0;
uint32_t numBytesForKeys = 0;
uint32_t numBytesForDependentKeys = 0;
std::vector<compare_function_t> compareFuncs;
std::vector<update_agg_function_t> updateAggFuncs;
bool hasStrCol = false;
Expand Down
31 changes: 14 additions & 17 deletions src/include/processor/operator/aggregate/hash_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,15 @@ class HashAggregateSharedState : public BaseAggregateSharedState {
class HashAggregate : public BaseAggregate {
public:
HashAggregate(std::unique_ptr<ResultSetDescriptor> resultSetDescriptor,
std::shared_ptr<HashAggregateSharedState> sharedState,
std::vector<DataPos> inputGroupByHashKeyVectorsPos,
std::vector<DataPos> inputGroupByNonHashKeyVectorsPos,
std::vector<bool> isInputGroupByHashKeyVectorFlat,
std::shared_ptr<HashAggregateSharedState> sharedState, std::vector<DataPos> flatKeysPos,
std::vector<DataPos> unFlatKeysPos, std::vector<DataPos> dependentKeysPos,
std::vector<std::unique_ptr<function::AggregateFunction>> aggregateFunctions,
std::vector<std::unique_ptr<AggregateInputInfo>> aggregateInputInfos,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: BaseAggregate{std::move(resultSetDescriptor), std::move(aggregateFunctions),
std::move(aggregateInputInfos), std::move(child), id, paramsString},
groupByHashKeyVectorsPos{std::move(inputGroupByHashKeyVectorsPos)},
groupByNonHashKeyVectorsPos{std::move(inputGroupByNonHashKeyVectorsPos)},
isGroupByHashKeyVectorFlat{std::move(isInputGroupByHashKeyVectorFlat)},
sharedState{std::move(sharedState)} {}
flatKeysPos{std::move(flatKeysPos)}, unFlatKeysPos{std::move(unFlatKeysPos)},
dependentKeysPos{std::move(dependentKeysPos)}, sharedState{std::move(sharedState)} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

Expand All @@ -54,18 +50,19 @@ class HashAggregate : public BaseAggregate {
void finalize(ExecutionContext* context) override;

inline std::unique_ptr<PhysicalOperator> clone() override {
return make_unique<HashAggregate>(resultSetDescriptor->copy(), sharedState,
groupByHashKeyVectorsPos, groupByNonHashKeyVectorsPos, isGroupByHashKeyVectorFlat,
cloneAggFunctions(), cloneAggInputInfos(), children[0]->clone(), id, paramsString);
return make_unique<HashAggregate>(resultSetDescriptor->copy(), sharedState, flatKeysPos,
unFlatKeysPos, dependentKeysPos, cloneAggFunctions(), cloneAggInputInfos(),
children[0]->clone(), id, paramsString);
}

private:
std::vector<DataPos> groupByHashKeyVectorsPos;
std::vector<DataPos> groupByNonHashKeyVectorsPos;
std::vector<bool> isGroupByHashKeyVectorFlat;
std::vector<common::ValueVector*> groupByFlatHashKeyVectors;
std::vector<common::ValueVector*> groupByUnflatHashKeyVectors;
std::vector<common::ValueVector*> groupByNonHashKeyVectors;
std::vector<DataPos> flatKeysPos;
std::vector<DataPos> unFlatKeysPos;
std::vector<DataPos> dependentKeysPos;

std::vector<common::ValueVector*> flatKeyVectors;
std::vector<common::ValueVector*> unFlatKeyVectors;
std::vector<common::ValueVector*> dependentKeyVectors;

std::shared_ptr<HashAggregateSharedState> sharedState;
std::unique_ptr<AggregateHashTable> localAggregateHashTable;
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(kuzu_optimizer
OBJECT
acc_hash_join_optimizer.cpp
agg_key_dependency_optimizer.cpp
factorization_rewriter.cpp
filter_push_down_optimizer.cpp
logical_operator_collector.cpp
Expand Down
Loading