Skip to content

Commit

Permalink
Merge pull request #1342 from kuzudb/agg-func-fix
Browse files Browse the repository at this point in the history
Fix min/max agg function on string column BUG.
  • Loading branch information
acquamarin committed Mar 5, 2023
2 parents c6b5317 + 1d884f9 commit 539f6e4
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 32 deletions.
2 changes: 1 addition & 1 deletion dataset/tinysnb/schema.cypher
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
create node table person (ID INt64, fName StRING, gender INT64, isStudent BoOLEAN, isWorker BOOLEAN, age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration interval, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], grades INT64[4], height float, PRIMARY KEY (ID));
create node table organisation (ID INT64, name STRING, orgCode INT64, mark DOUBLE, score INT64, history STRING, licenseValidInterval INTERVAL, rating DOUBLE, PRIMARY KEY (ID));
create node table movies (name STRING, length INT32, PRIMARY KEY (name));
create node table movies (name STRING, length INT32, note STRING, PRIMARY KEY (name));
create rel table knows (FROM person TO person, date DATE, meetTime TIMESTAMP, validInterval INTERVAL, comments STRING[], MANY_MANY);
create rel table studyAt (FROM person TO organisation, year INT64, places STRING[], length INT16,MANY_ONE);
create rel table workAt (FROM person TO organisation, year INT64, grading DOUBLE[2], rating float, MANY_ONE);
Expand Down
6 changes: 3 additions & 3 deletions dataset/tinysnb/vMovies.csv
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Sóló cón tu párejâ,126
The 😂😃🧘🏻‍♂️🌍🌦️🍞🚗 movie,2544
Roma,298
Sóló cón tu párejâ,126, this is a very very good movie
The 😂😃🧘🏻‍♂️🌍🌦️🍞🚗 movie,2544, the movie is very very good
Roma,298,the movie is very interesting and funny
8 changes: 5 additions & 3 deletions src/include/function/aggregate/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ using aggr_update_all_function_t = std::function<void(uint8_t* state, common::Va
uint64_t multiplicity, storage::MemoryManager* memoryManager)>;
using aggr_update_pos_function_t = std::function<void(uint8_t* state, common::ValueVector* input,
uint64_t multiplicity, uint32_t pos, storage::MemoryManager* memoryManager)>;
using aggr_combine_function_t = std::function<void(uint8_t* state, uint8_t* otherState)>;
using aggr_combine_function_t =
std::function<void(uint8_t* state, uint8_t* otherState, storage::MemoryManager* memoryManager)>;
using aggr_finalize_function_t = std::function<void(uint8_t* state)>;

class AggregateFunction {
Expand Down Expand Up @@ -81,8 +82,9 @@ class AggregateFunction {
return updatePosFunc(state, input, multiplicity, pos, memoryManager);
}

inline void combineState(uint8_t* state, uint8_t* otherState) {
return combineFunc(state, otherState);
inline void combineState(
uint8_t* state, uint8_t* otherState, storage::MemoryManager* memoryManager) {
return combineFunc(state, otherState, memoryManager);
}

inline void finalizeState(uint8_t* state) { return finalizeFunc(state); }
Expand Down
3 changes: 2 additions & 1 deletion src/include/function/aggregate/avg.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ struct AvgFunction {
state->count += multiplicity;
}

static void combine(uint8_t* state_, uint8_t* otherState_) {
static void combine(
uint8_t* state_, uint8_t* otherState_, storage::MemoryManager* memoryManager) {
auto otherState = reinterpret_cast<AvgState*>(otherState_);
if (otherState->isNull) {
return;
Expand Down
3 changes: 2 additions & 1 deletion src/include/function/aggregate/base_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ struct BaseCountFunction {
return state;
}

static void combine(uint8_t* state_, uint8_t* otherState_) {
static void combine(
uint8_t* state_, uint8_t* otherState_, storage::MemoryManager* memoryManager) {
auto state = reinterpret_cast<CountState*>(state_);
auto otherState = reinterpret_cast<CountState*>(otherState_);
state->count += otherState->count;
Expand Down
3 changes: 2 additions & 1 deletion src/include/function/aggregate/collect.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ struct CollectFunction {
}
}

static void combine(uint8_t* state_, uint8_t* otherState_) {
static void combine(
uint8_t* state_, uint8_t* otherState_, storage::MemoryManager* memoryManager) {
auto otherState = reinterpret_cast<CollectState*>(otherState_);
if (otherState->isNull) {
return;
Expand Down
40 changes: 31 additions & 9 deletions src/include/function/aggregate/min_max.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ struct MinMaxFunction {
memcpy(outputVector->getData() + pos * outputVector->getNumBytesPerValue(),
reinterpret_cast<uint8_t*>(&val), outputVector->getNumBytesPerValue());
}
inline void setVal(T& val_, storage::MemoryManager* memoryManager) { val = val_; }

std::unique_ptr<common::InMemOverflowBuffer> overflowBuffer;
T val;
};

Expand All @@ -29,13 +31,13 @@ struct MinMaxFunction {
if (input->hasNoNullsGuarantee()) {
for (auto i = 0u; i < input->state->selVector->selectedSize; ++i) {
auto pos = input->state->selVector->selectedPositions[i];
updateSingleValue<OP>(state, input, pos);
updateSingleValue<OP>(state, input, pos, memoryManager);
}
} else {
for (auto i = 0u; i < input->state->selVector->selectedSize; ++i) {
auto pos = input->state->selVector->selectedPositions[i];
if (!input->isNull(pos)) {
updateSingleValue<OP>(state, input, pos);
updateSingleValue<OP>(state, input, pos, memoryManager);
}
}
}
Expand All @@ -44,41 +46,61 @@ struct MinMaxFunction {
template<class OP>
static inline void updatePos(uint8_t* state_, common::ValueVector* input, uint64_t multiplicity,
uint32_t pos, storage::MemoryManager* memoryManager) {
updateSingleValue<OP>(reinterpret_cast<MinMaxState*>(state_), input, pos);
updateSingleValue<OP>(reinterpret_cast<MinMaxState*>(state_), input, pos, memoryManager);
}

template<class OP>
static void updateSingleValue(MinMaxState* state, common::ValueVector* input, uint32_t pos) {
static void updateSingleValue(MinMaxState* state, common::ValueVector* input, uint32_t pos,
storage::MemoryManager* memoryManager) {
T val = input->getValue<T>(pos);
if (state->isNull) {
state->val = val;
state->setVal(val, memoryManager);
state->isNull = false;
} else {
uint8_t compare_result;
OP::template operation<T, T>(val, state->val, compare_result);
state->val = compare_result ? val : state->val;
if (compare_result) {
state->setVal(val, memoryManager);
}
}
}

template<class OP>
static void combine(uint8_t* state_, uint8_t* otherState_) {
static void combine(
uint8_t* state_, uint8_t* otherState_, storage::MemoryManager* memoryManager) {
auto otherState = reinterpret_cast<MinMaxState*>(otherState_);
if (otherState->isNull) {
return;
}
auto state = reinterpret_cast<MinMaxState*>(state_);
if (state->isNull) {
state->val = otherState->val;
state->setVal(otherState->val, memoryManager);
state->isNull = false;
} else {
uint8_t compareResult;
OP::template operation<T, T>(otherState->val, state->val, compareResult);
state->val = compareResult == 1 ? otherState->val : state->val;
if (compareResult) {
state->setVal(otherState->val, memoryManager);
}
}
}

static void finalize(uint8_t* state_) {}
};

template<>
void MinMaxFunction<common::ku_string_t>::MinMaxState::setVal(
common::ku_string_t& val_, storage::MemoryManager* memoryManager) {
if (overflowBuffer == nullptr) {
overflowBuffer = std::make_unique<common::InMemOverflowBuffer>(memoryManager);
}
// We only need to allocate memory if the new val_ is a long string and is longer
// than the current val.
if (val_.len > common::ku_string_t::SHORT_STR_LENGTH && val_.len > val.len) {
val.overflowPtr = reinterpret_cast<uint64_t>(overflowBuffer->allocateSpace(val_.len));
}
val.set(val_);
}

} // namespace function
} // namespace kuzu
3 changes: 2 additions & 1 deletion src/include/function/aggregate/sum.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ struct SumFunction {
}
}

static void combine(uint8_t* state_, uint8_t* otherState_) {
static void combine(
uint8_t* state_, uint8_t* otherState_, storage::MemoryManager* memoryManager) {
auto otherState = reinterpret_cast<SumState*>(otherState_);
if (otherState->isNull) {
return;
Expand Down
3 changes: 2 additions & 1 deletion src/include/processor/operator/aggregate/simple_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class SimpleAggregateSharedState : public BaseAggregateSharedState {
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions);

void combineAggregateStates(
const std::vector<std::unique_ptr<function::AggregateState>>& localAggregateStates);
const std::vector<std::unique_ptr<function::AggregateState>>& localAggregateStates,
storage::MemoryManager* memoryManager);

void finalizeAggregateStates();

Expand Down
3 changes: 2 additions & 1 deletion src/processor/operator/aggregate/aggregate_hash_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ void AggregateHashTable::merge(AggregateHashTable& other) {
for (auto i = 0u; i < numTuplesToScan; i++) {
aggregateFunction->combineState(
hashSlotsToUpdateAggState[i]->entry + aggregateStateOffset,
other.factorizedTable->getTuple(startTupleIdx + i) + aggregateStateOffset);
other.factorizedTable->getTuple(startTupleIdx + i) + aggregateStateOffset,
&memoryManager);
}
aggregateStateOffset += aggregateFunction->getAggregateStateSize();
}
Expand Down
9 changes: 5 additions & 4 deletions src/processor/operator/aggregate/simple_aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ SimpleAggregateSharedState::SimpleAggregateSharedState(
}

void SimpleAggregateSharedState::combineAggregateStates(
const std::vector<std::unique_ptr<AggregateState>>& localAggregateStates) {
const std::vector<std::unique_ptr<AggregateState>>& localAggregateStates,
storage::MemoryManager* memoryManager) {
assert(localAggregateStates.size() == globalAggregateStates.size());
auto lck = acquireLock();
for (auto i = 0u; i < aggregateFunctions.size(); ++i) {
aggregateFunctions[i]->combineState(
(uint8_t*)globalAggregateStates[i].get(), (uint8_t*)localAggregateStates[i].get());
aggregateFunctions[i]->combineState((uint8_t*)globalAggregateStates[i].get(),
(uint8_t*)localAggregateStates[i].get(), memoryManager);
}
}

Expand Down Expand Up @@ -84,7 +85,7 @@ void SimpleAggregate::executeInternal(ExecutionContext* context) {
}
}
}
sharedState->combineAggregateStates(localAggregateStates);
sharedState->combineAggregateStates(localAggregateStates, context->memoryManager);
}

std::unique_ptr<PhysicalOperator> SimpleAggregate::clone() {
Expand Down
5 changes: 5 additions & 0 deletions test/test_files/tinysnb/agg/simple.test
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ False
---- 1
20|83|False|True|4.500000|5.100000|1900-01-01|1990-11-27

-NAME SimpleMinTest
-QUERY MATCH (m:movies) RETURN MIN(m.note)
---- 1
the movie is very very good

-NAME TwoHopTest
-QUERY MATCH (a:person)-[:knows]->(b:person) RETURN SUM(b.age), MIN(b.ID), AVG(b.eyeSight)
-ENUMERATE
Expand Down
12 changes: 6 additions & 6 deletions test/test_files/tinysnb/projection/multi_label.test
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
-NAME MultiLabelReturnStar
-QUERY MATCH (a:movies:organisation) RETURN *
---- 6
(label:organisation, 1:0, {ID:1, name:ABFsUni, orgCode:325, mark:3.700000, score:-2, history:10 years 5 months 13 hours 24 us, licenseValidInterval:3 years 5 days, rating:1.000000, length:})
(label:organisation, 1:1, {ID:4, name:CsWork, orgCode:934, mark:4.100000, score:-100, history:2 years 4 days 10 hours, licenseValidInterval:26 years 52 days 48:00:00, rating:0.780000, length:})
(label:organisation, 1:2, {ID:6, name:DEsWork, orgCode:824, mark:4.100000, score:7, history:2 years 4 hours 22 us 34 minutes, licenseValidInterval:82:00:00.1, rating:0.520000, length:})
(label:movies, 2:0, {ID:, name:Sóló cón tu párejâ, orgCode:, mark:, score:, history:, licenseValidInterval:, rating:, length:126})
(label:movies, 2:1, {ID:, name:The 😂😃🧘🏻‍♂️🌍🌦️🍞🚗 movie, orgCode:, mark:, score:, history:, licenseValidInterval:, rating:, length:2544})
(label:movies, 2:2, {ID:, name:Roma, orgCode:, mark:, score:, history:, licenseValidInterval:, rating:, length:298})
(label:organisation, 1:0, {ID:1, name:ABFsUni, orgCode:325, mark:3.700000, score:-2, history:10 years 5 months 13 hours 24 us, licenseValidInterval:3 years 5 days, rating:1.000000, length:, note:})
(label:organisation, 1:1, {ID:4, name:CsWork, orgCode:934, mark:4.100000, score:-100, history:2 years 4 days 10 hours, licenseValidInterval:26 years 52 days 48:00:00, rating:0.780000, length:, note:})
(label:organisation, 1:2, {ID:6, name:DEsWork, orgCode:824, mark:4.100000, score:7, history:2 years 4 hours 22 us 34 minutes, licenseValidInterval:82:00:00.1, rating:0.520000, length:, note:})
(label:movies, 2:0, {ID:, name:Sóló cón tu párejâ, orgCode:, mark:, score:, history:, licenseValidInterval:, rating:, length:126, note: this is a very very good movie})
(label:movies, 2:1, {ID:, name:The 😂😃🧘🏻‍♂️🌍🌦️🍞🚗 movie, orgCode:, mark:, score:, history:, licenseValidInterval:, rating:, length:2544, note: the movie is very very good})
(label:movies, 2:2, {ID:, name:Roma, orgCode:, mark:, score:, history:, licenseValidInterval:, rating:, length:298, note:the movie is very interesting and funny})

0 comments on commit 539f6e4

Please sign in to comment.