Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed May 6, 2023
1 parent a8ce8af commit 1bd8063
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 34 deletions.
1 change: 0 additions & 1 deletion src/include/function/aggregate/sum.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ struct SumFunction {
struct SumState : public AggregateState {
inline uint32_t getStateSize() const override { return sizeof(*this); }
inline void moveResultToVector(common::ValueVector* outputVector, uint64_t pos) override {
auto b = 1;
memcpy(outputVector->getData() + pos * outputVector->getNumBytesPerValue(),
reinterpret_cast<uint8_t*>(&sum), outputVector->getNumBytesPerValue());
}
Expand Down
18 changes: 9 additions & 9 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ class AggregateHashTable : public BaseHashTable {
public:
// Used by distinct aggregate hash table only.
AggregateHashTable(storage::MemoryManager& memoryManager,
const std::vector<common::DataType>& groupByHashKeysDataTypes,
const std::vector<common::DataType>& keysDataTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions,
uint64_t numEntriesToAllocate)
: AggregateHashTable(memoryManager, groupByHashKeysDataTypes,
std::vector<common::DataType>(), aggregateFunctions, numEntriesToAllocate) {}
: AggregateHashTable(memoryManager, keysDataTypes, std::vector<common::DataType>(),
aggregateFunctions, numEntriesToAllocate) {}

AggregateHashTable(storage::MemoryManager& memoryManager,
std::vector<common::DataType> groupByHashKeysDataTypes,
std::vector<common::DataType> groupByNonHashKeysDataTypes,
std::vector<common::DataType> keysDataTypes,
std::vector<common::DataType> payloadsDataTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions,
uint64_t numEntriesToAllocate);

Expand Down Expand Up @@ -219,8 +219,8 @@ class AggregateHashTable : public BaseHashTable {
common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset);

private:
std::vector<common::DataType> groupByHashKeysDataTypes;
std::vector<common::DataType> groupByNonHashKeysDataTypes;
std::vector<common::DataType> keysDataTypes;
std::vector<common::DataType> payloadsDataTypes;
std::vector<std::unique_ptr<function::AggregateFunction>> aggregateFunctions;

//! special handling of distinct aggregate
Expand All @@ -229,8 +229,8 @@ class AggregateHashTable : public BaseHashTable {
uint32_t hashColOffsetInFT;
uint32_t aggStateColOffsetInFT;
uint32_t aggStateColIdxInFT;
uint32_t numBytesForGroupByHashKeys = 0;
uint32_t numBytesForGroupByNonHashKeys = 0;
uint32_t numBytesForHashKeys = 0;
uint32_t numBytesForPayloads = 0;
std::vector<compare_function_t> compareFuncs;
std::vector<update_agg_function_t> updateAggFuncs;
bool hasStrCol = false;
Expand Down
43 changes: 19 additions & 24 deletions src/processor/operator/aggregate/aggregate_hash_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ namespace kuzu {
namespace processor {

AggregateHashTable::AggregateHashTable(MemoryManager& memoryManager,
std::vector<DataType> groupByHashKeysDataTypes,
std::vector<DataType> groupByNonHashKeysDataTypes,
std::vector<DataType> keysDataTypes, std::vector<DataType> payloadsDataTypes,
const std::vector<std::unique_ptr<AggregateFunction>>& aggregateFunctions,
uint64_t numEntriesToAllocate)
: BaseHashTable{memoryManager}, groupByHashKeysDataTypes{std::move(groupByHashKeysDataTypes)},
groupByNonHashKeysDataTypes{std::move(groupByNonHashKeysDataTypes)} {
: BaseHashTable{memoryManager}, keysDataTypes{std::move(keysDataTypes)},
payloadsDataTypes{std::move(payloadsDataTypes)} {
initializeFT(aggregateFunctions);
initializeHashTable(numEntriesToAllocate);
distinctHashTables = AggregateHashTableUtils::createDistinctHashTables(
memoryManager, this->groupByHashKeysDataTypes, this->aggregateFunctions);
memoryManager, this->keysDataTypes, this->aggregateFunctions);
initializeTmpVectors();
}

Expand Down Expand Up @@ -71,25 +70,22 @@ bool AggregateHashTable::isAggregateValueDistinctForGroupByKeys(

void AggregateHashTable::merge(AggregateHashTable& other) {
std::shared_ptr<DataChunkState> vectorsToScanState = std::make_shared<DataChunkState>();
std::vector<ValueVector*> vectorsToScan(
groupByHashKeysDataTypes.size() + groupByNonHashKeysDataTypes.size());
std::vector<ValueVector*> groupByHashVectors(groupByHashKeysDataTypes.size());
std::vector<ValueVector*> groupByNonHashVectors(groupByNonHashKeysDataTypes.size());
std::vector<std::unique_ptr<ValueVector>> hashKeyVectors(groupByHashKeysDataTypes.size());
std::vector<ValueVector*> vectorsToScan(keysDataTypes.size() + payloadsDataTypes.size());
std::vector<ValueVector*> groupByHashVectors(keysDataTypes.size());
std::vector<ValueVector*> groupByNonHashVectors(payloadsDataTypes.size());
std::vector<std::unique_ptr<ValueVector>> hashKeyVectors(keysDataTypes.size());
std::vector<std::unique_ptr<ValueVector>> nonHashKeyVectors(groupByNonHashVectors.size());
for (auto i = 0u; i < groupByHashKeysDataTypes.size(); i++) {
auto hashKeyVec =
std::make_unique<ValueVector>(groupByHashKeysDataTypes[i], &memoryManager);
for (auto i = 0u; i < keysDataTypes.size(); i++) {
auto hashKeyVec = std::make_unique<ValueVector>(keysDataTypes[i], &memoryManager);
hashKeyVec->state = vectorsToScanState;
vectorsToScan[i] = hashKeyVec.get();
groupByHashVectors[i] = hashKeyVec.get();
hashKeyVectors[i] = std::move(hashKeyVec);
}
for (auto i = 0u; i < groupByNonHashKeysDataTypes.size(); i++) {
auto nonHashKeyVec =
std::make_unique<ValueVector>(groupByNonHashKeysDataTypes[i], &memoryManager);
for (auto i = 0u; i < payloadsDataTypes.size(); i++) {
auto nonHashKeyVec = std::make_unique<ValueVector>(payloadsDataTypes[i], &memoryManager);
nonHashKeyVec->state = vectorsToScanState;
vectorsToScan[i + groupByHashKeysDataTypes.size()] = nonHashKeyVec.get();
vectorsToScan[i + keysDataTypes.size()] = nonHashKeyVec.get();
groupByNonHashVectors[i] = nonHashKeyVec.get();
nonHashKeyVectors[i] = std::move(nonHashKeyVec);
}
Expand Down Expand Up @@ -137,27 +133,26 @@ void AggregateHashTable::initializeFT(
auto isUnflat = false;
auto dataChunkPos = 0u;
std::unique_ptr<FactorizedTableSchema> tableSchema = std::make_unique<FactorizedTableSchema>();
aggStateColIdxInFT =
this->groupByHashKeysDataTypes.size() + this->groupByNonHashKeysDataTypes.size();
aggStateColIdxInFT = keysDataTypes.size() + payloadsDataTypes.size();
compareFuncs.resize(aggStateColIdxInFT);
auto colIdx = 0u;
for (auto& dataType : this->groupByHashKeysDataTypes) {
for (auto& dataType : keysDataTypes) {
auto size = Types::getDataTypeSize(dataType);
tableSchema->appendColumn(std::make_unique<ColumnSchema>(isUnflat, dataChunkPos, size));
hasStrCol = hasStrCol || dataType.typeID == STRING;
compareFuncs[colIdx] = getCompareEntryWithKeysFunc(dataType.typeID);
numBytesForGroupByHashKeys += size;
numBytesForHashKeys += size;
colIdx++;
}
for (auto& dataType : this->groupByNonHashKeysDataTypes) {
for (auto& dataType : payloadsDataTypes) {
auto size = Types::getDataTypeSize(dataType);
tableSchema->appendColumn(std::make_unique<ColumnSchema>(isUnflat, dataChunkPos, size));
hasStrCol = hasStrCol || dataType.typeID == STRING;
compareFuncs[colIdx] = getCompareEntryWithKeysFunc(dataType.typeID);
numBytesForGroupByNonHashKeys += size;
numBytesForPayloads += size;
colIdx++;
}
aggStateColOffsetInFT = numBytesForGroupByHashKeys + numBytesForGroupByNonHashKeys;
aggStateColOffsetInFT = numBytesForHashKeys + numBytesForPayloads;

aggregateFunctions.resize(aggFuncs.size());
updateAggFuncs.resize(aggFuncs.size());
Expand Down

0 comments on commit 1bd8063

Please sign in to comment.