Skip to content

Commit

Permalink
Merge pull request #1957 from kuzudb/rework-agg-hash-table
Browse files Browse the repository at this point in the history
Rename keys in agg hash table
  • Loading branch information
andyfengHKU committed Aug 25, 2023
2 parents 9399be0 + 8d40ddd commit dc4266a
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 190 deletions.
88 changes: 44 additions & 44 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,20 @@ class AggregateHashTable : public BaseHashTable {

inline uint64_t getNumEntries() const { return factorizedTable->getNumTuples(); }

inline void append(const std::vector<common::ValueVector*>& groupByFlatKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatKeyVectors,
inline void append(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
common::DataChunkState* leadingState,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
uint64_t resultSetMultiplicity) {
append(groupByFlatKeyVectors, groupByUnFlatKeyVectors, std::vector<common::ValueVector*>(),
append(flatKeyVectors, unFlatKeyVectors, std::vector<common::ValueVector*>(), leadingState,
aggregateInputs, resultSetMultiplicity);
}

//! update aggregate states for an input
void append(const std::vector<common::ValueVector*>& groupByFlatKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatKeyVectors,
const std::vector<common::ValueVector*>& groupByDependentKeyVectors,
void append(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<common::ValueVector*>& dependentKeyVectors,
common::DataChunkState* leadingState,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
uint64_t resultSetMultiplicity);

Expand Down Expand Up @@ -100,14 +102,14 @@ class AggregateHashTable : public BaseHashTable {
const std::vector<common::ValueVector*>& groupByKeyVectors, common::hash_t hash);

void initializeFTEntryWithFlatVec(
common::ValueVector* groupByFlatVector, uint64_t numEntriesToInitialize, uint32_t colIdx);
common::ValueVector* flatVector, uint64_t numEntriesToInitialize, uint32_t colIdx);

void initializeFTEntryWithUnflatVec(
common::ValueVector* groupByUnflatVector, uint64_t numEntriesToInitialize, uint32_t colIdx);
void initializeFTEntryWithUnFlatVec(
common::ValueVector* unFlatVector, uint64_t numEntriesToInitialize, uint32_t colIdx);

void initializeFTEntries(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByDependentKeyVectors,
void initializeFTEntries(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<common::ValueVector*>& dependentKeyVectors,
uint64_t numFTEntriesToInitialize);

uint8_t* createEntryInDistinctHT(
Expand All @@ -119,45 +121,46 @@ class AggregateHashTable : public BaseHashTable {

void increaseHashSlotIdxes(uint64_t numNoMatches);

void findHashSlots(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByDependentKeyVectors);
void findHashSlots(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<common::ValueVector*>& dependentKeyVectors,
common::DataChunkState* leadingState);

void computeAndCombineVecHash(
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors, uint32_t startVecIdx);
void computeVectorHashes(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors);
const std::vector<common::ValueVector*>& unFlatKeyVectors, uint32_t startVecIdx);
void computeVectorHashes(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors);

void updateDistinctAggState(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatHashKeyVectors,
void updateDistinctAggState(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
std::unique_ptr<function::AggregateFunction>& aggregateFunction,
common::ValueVector* aggregateVector, uint64_t multiplicity, uint32_t colIdx,
uint32_t aggStateOffset);

void updateAggState(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatHashKeyVectors,
void updateAggState(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
std::unique_ptr<function::AggregateFunction>& aggregateFunction,
common::ValueVector* aggVector, uint64_t multiplicity, uint32_t colIdx,
uint32_t aggStateOffset);

void updateAggStates(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnFlatHashKeyVectors,
void updateAggStates(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
uint64_t resultSetMultiplicity);

// ! This function will only be used by distinct aggregate, which assumes that all keyVectors
// are flat.
bool matchFlatGroupByKeys(const std::vector<common::ValueVector*>& keyVectors, uint8_t* entry);

uint64_t matchUnflatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches,
uint64_t matchUnFlatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches,
uint64_t& numNoMatches, uint32_t colIdx);

uint64_t matchFlatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches,
uint64_t& numNoMatches, uint32_t colIdx);

uint64_t matchFTEntries(const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
uint64_t numMayMatches, uint64_t numNoMatches);
uint64_t matchFTEntries(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors, uint64_t numMayMatches,
uint64_t numNoMatches);

void fillEntryWithInitialNullAggregateState(uint8_t* entry);

Expand Down Expand Up @@ -189,37 +192,34 @@ class AggregateHashTable : public BaseHashTable {
static void getCompareEntryWithKeysFunc(
common::PhysicalTypeID physicalType, compare_function_t& func);

void updateNullAggVectorState(
const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
void updateNullAggVectorState(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
std::unique_ptr<function::AggregateFunction>& aggregateFunction, uint64_t multiplicity,
uint32_t aggStateOffset);

void updateBothFlatAggVectorState(
const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
void updateBothFlatAggVectorState(const std::vector<common::ValueVector*>& flatKeyVectors,
std::unique_ptr<function::AggregateFunction>& aggregateFunction,
common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset);

void updateFlatUnflatKeyFlatAggVectorState(
const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
void updateFlatUnFlatKeyFlatAggVectorState(
const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
std::unique_ptr<function::AggregateFunction>& aggregateFunction,
common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset);

void updateFlatKeyUnflatAggVectorState(
const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
void updateFlatKeyUnFlatAggVectorState(const std::vector<common::ValueVector*>& flatKeyVectors,
std::unique_ptr<function::AggregateFunction>& aggregateFunction,
common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset);

void updateBothUnflatSameDCAggVectorState(
const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
void updateBothUnFlatSameDCAggVectorState(
const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
std::unique_ptr<function::AggregateFunction>& aggregateFunction,
common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset);

void updateBothUnflatDifferentDCAggVectorState(
const std::vector<common::ValueVector*>& groupByFlatHashKeyVectors,
const std::vector<common::ValueVector*>& groupByUnflatHashKeyVectors,
void updateBothUnFlatDifferentDCAggVectorState(
const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
std::unique_ptr<function::AggregateFunction>& aggregateFunction,
common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset);

Expand Down
1 change: 1 addition & 0 deletions src/include/processor/operator/aggregate/hash_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class HashAggregate : public BaseAggregate {
std::vector<common::ValueVector*> flatKeyVectors;
std::vector<common::ValueVector*> unFlatKeyVectors;
std::vector<common::ValueVector*> dependentKeyVectors;
common::DataChunkState* leadingState;

std::shared_ptr<HashAggregateSharedState> sharedState;
std::unique_ptr<AggregateHashTable> localAggregateHashTable;
Expand Down
Loading

0 comments on commit dc4266a

Please sign in to comment.