Skip to content

Commit

Permalink
Merge pull request #1859 from kuzudb/agg-memory-leak
Browse files Browse the repository at this point in the history
Fix agg memory leak
  • Loading branch information
acquamarin committed Jul 24, 2023
2 parents 657a830 + 366187a commit 7ffdc79
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 6 deletions.
5 changes: 4 additions & 1 deletion src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,11 @@ void ArrowColumnVector::setArrowColumn(ValueVector* vector, std::shared_ptr<arro
template void ValueVector::setValue<nodeID_t>(uint32_t pos, nodeID_t val);
template void ValueVector::setValue<bool>(uint32_t pos, bool val);
template void ValueVector::setValue<int64_t>(uint32_t pos, int64_t val);
template void ValueVector::setValue<hash_t>(uint32_t pos, hash_t val);
template void ValueVector::setValue<int32_t>(uint32_t pos, int32_t val);
template void ValueVector::setValue<int16_t>(uint32_t pos, int16_t val);
template void ValueVector::setValue<double_t>(uint32_t pos, double_t val);
template void ValueVector::setValue<float_t>(uint32_t pos, float_t val);
template void ValueVector::setValue<hash_t>(uint32_t pos, hash_t val);
template void ValueVector::setValue<date_t>(uint32_t pos, date_t val);
template void ValueVector::setValue<timestamp_t>(uint32_t pos, timestamp_t val);
template void ValueVector::setValue<interval_t>(uint32_t pos, interval_t val);
Expand Down
1 change: 1 addition & 0 deletions src/include/function/aggregate/collect.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ struct CollectFunction {
} else {
state->factorizedTable->merge(*otherState->factorizedTable);
}
otherState->factorizedTable.reset();
}

static void finalize(uint8_t* state_) {}
Expand Down
5 changes: 3 additions & 2 deletions src/include/function/aggregate/min_max.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ struct MinMaxFunction {
struct MinMaxState : public AggregateState {
inline uint32_t getStateSize() const override { return sizeof(*this); }
inline void moveResultToVector(common::ValueVector* outputVector, uint64_t pos) override {
memcpy(outputVector->getData() + pos * outputVector->getNumBytesPerValue(),
reinterpret_cast<uint8_t*>(&val), outputVector->getNumBytesPerValue());
outputVector->setValue(pos, val);
overflowBuffer.reset();
}
inline void setVal(T& val_, storage::MemoryManager* memoryManager) { val = val_; }

Expand Down Expand Up @@ -85,6 +85,7 @@ struct MinMaxFunction {
state->setVal(otherState->val, memoryManager);
}
}
otherState->overflowBuffer.reset();
}

static void finalize(uint8_t* state_) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ class AggregateHashTable : public BaseHashTable {
uint32_t numBytesForDependentKeys = 0;
std::vector<compare_function_t> compareFuncs;
std::vector<update_agg_function_t> updateAggFuncs;
bool hasStrCol = false;
// Temporary arrays to hold intermediate results.
std::shared_ptr<common::DataChunkState> hashState;
std::unique_ptr<common::ValueVector> hashVector;
Expand Down
2 changes: 2 additions & 0 deletions src/include/processor/operator/base_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class BaseHashTable {
: maxNumHashSlots{0}, bitmask{0}, numSlotsPerBlockLog2{0}, slotIdxInBlockMask{0},
memoryManager{memoryManager} {}

virtual ~BaseHashTable() = default;

inline uint64_t getSlotIdxForHash(common::hash_t hash) const { return hash & bitmask; }

protected:
Expand Down
2 changes: 0 additions & 2 deletions src/processor/operator/aggregate/aggregate_hash_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,13 @@ void AggregateHashTable::initializeFT(
for (auto& dataType : keyDataTypes) {
auto size = LogicalTypeUtils::getRowLayoutSize(dataType);
tableSchema->appendColumn(std::make_unique<ColumnSchema>(isUnflat, dataChunkPos, size));
hasStrCol = hasStrCol || dataType.getLogicalTypeID() == LogicalTypeID::STRING;
getCompareEntryWithKeysFunc(dataType.getPhysicalType(), compareFuncs[colIdx]);
numBytesForKeys += size;
colIdx++;
}
for (auto& dataType : dependentKeyDataTypes) {
auto size = LogicalTypeUtils::getRowLayoutSize(dataType);
tableSchema->appendColumn(std::make_unique<ColumnSchema>(isUnflat, dataChunkPos, size));
hasStrCol = hasStrCol || dataType.getLogicalTypeID() == LogicalTypeID::STRING;
numBytesForDependentKeys += size;
colIdx++;
}
Expand Down

0 comments on commit 7ffdc79

Please sign in to comment.