Skip to content

Commit

Permalink
Merge pull request #1410 from kuzudb/agg-func
Browse files Browse the repository at this point in the history
Add min/max agg function support for more types
  • Loading branch information
acquamarin committed Mar 25, 2023
2 parents 4dc00f0 + ccba576 commit 73f3f91
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 5 deletions.
5 changes: 5 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ std::vector<DataTypeID> DataType::getNumericalTypeIDs() {
return std::vector<DataTypeID>{INT64, INT32, INT16, DOUBLE, FLOAT};
}

std::vector<DataTypeID> DataType::getAllValidComparableTypes() {
return std::vector<DataTypeID>{
BOOL, INT64, INT32, INT16, DOUBLE, FLOAT, DATE, TIMESTAMP, INTERVAL, STRING};
}

std::vector<DataTypeID> DataType::getAllValidTypeIDs() {
// TODO(Ziyi): Add FIX_LIST type to allValidTypeID when we support functions on VAR_LIST.
return std::vector<DataTypeID>{INTERNAL_ID, BOOL, INT64, INT32, INT16, DOUBLE, STRING, DATE,
Expand Down
11 changes: 11 additions & 0 deletions src/function/aggregate_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getMinMaxFunction(
MinMaxFunction<date_t>::updateAll<FUNC>, MinMaxFunction<date_t>::updatePos<FUNC>,
MinMaxFunction<date_t>::combine<FUNC>, MinMaxFunction<date_t>::finalize, inputType,
isDistinct);
case TIMESTAMP:
return std::make_unique<AggregateFunction>(MinMaxFunction<timestamp_t>::initialize,
MinMaxFunction<timestamp_t>::updateAll<FUNC>,
MinMaxFunction<timestamp_t>::updatePos<FUNC>,
MinMaxFunction<timestamp_t>::combine<FUNC>, MinMaxFunction<timestamp_t>::finalize,
inputType, isDistinct);
case INTERVAL:
return std::make_unique<AggregateFunction>(MinMaxFunction<interval_t>::initialize,
MinMaxFunction<interval_t>::updateAll<FUNC>,
MinMaxFunction<interval_t>::updatePos<FUNC>, MinMaxFunction<interval_t>::combine<FUNC>,
MinMaxFunction<interval_t>::finalize, inputType, isDistinct);
case STRING:
return std::make_unique<AggregateFunction>(MinMaxFunction<ku_string_t>::initialize,
MinMaxFunction<ku_string_t>::updateAll<FUNC>,
Expand Down
4 changes: 2 additions & 2 deletions src/function/built_in_aggregate_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void BuiltInAggregateFunctions::registerAvg() {

void BuiltInAggregateFunctions::registerMin() {
std::vector<std::unique_ptr<AggregateFunctionDefinition>> definitions;
for (auto typeID : std::vector<DataTypeID>{BOOL, INT64, DOUBLE, DATE, STRING}) {
for (auto typeID : DataType::getAllValidComparableTypes()) {
for (auto isDistinct : std::vector<bool>{true, false}) {
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(MIN_FUNC_NAME,
std::vector<DataTypeID>{typeID}, typeID,
Expand All @@ -134,7 +134,7 @@ void BuiltInAggregateFunctions::registerMin() {

void BuiltInAggregateFunctions::registerMax() {
std::vector<std::unique_ptr<AggregateFunctionDefinition>> definitions;
for (auto typeID : std::vector<DataTypeID>{BOOL, INT64, DOUBLE, DATE, STRING}) {
for (auto typeID : DataType::getAllValidComparableTypes()) {
for (auto isDistinct : std::vector<bool>{true, false}) {
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(MAX_FUNC_NAME,
std::vector<DataTypeID>{typeID}, typeID,
Expand Down
1 change: 1 addition & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class DataType {
KUZU_API DataType(DataType&& other) noexcept;

static std::vector<DataTypeID> getNumericalTypeIDs();
static std::vector<DataTypeID> getAllValidComparableTypes();
static std::vector<DataTypeID> getAllValidTypeIDs();

KUZU_API DataType& operator=(const DataType& other);
Expand Down
9 changes: 6 additions & 3 deletions test/binder/binder_error_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,12 @@ TEST_F(BinderErrorTest, MaxNodeID) {
std::string expectedException =
"Binder exception: Cannot match a built-in function for given function MIN(INTERNAL_ID). "
"Supported inputs are\nDISTINCT (BOOL) -> BOOL\n(BOOL) -> BOOL\nDISTINCT (INT64) -> "
"INT64\n(INT64) -> INT64\nDISTINCT (DOUBLE) -> DOUBLE\n(DOUBLE) -> DOUBLE\nDISTINCT "
"(DATE) -> DATE\n(DATE) -> DATE\nDISTINCT (STRING) -> STRING\n(STRING) -> "
"STRING\n";
"INT64\n(INT64) -> INT64\nDISTINCT (INT32) -> INT32\n(INT32) -> INT32\nDISTINCT (INT16) -> "
"INT16\n(INT16) -> INT16\nDISTINCT (DOUBLE) -> DOUBLE\n(DOUBLE) -> DOUBLE\nDISTINCT "
"(FLOAT) -> FLOAT\n(FLOAT) -> FLOAT\nDISTINCT "
"(DATE) -> DATE\n(DATE) -> DATE\nDISTINCT (TIMESTAMP) -> TIMESTAMP\n(TIMESTAMP) -> "
"TIMESTAMP\nDISTINCT (INTERVAL) -> INTERVAL\n(INTERVAL) -> INTERVAL\nDISTINCT (STRING) -> "
"STRING\n(STRING) -> STRING\n";
auto input = "MATCH (a:person) RETURN MIN(a);";
ASSERT_STREQ(expectedException.c_str(), getBindingError(input).c_str());
}
Expand Down
49 changes: 49 additions & 0 deletions test/test_files/tinysnb/agg/simple.test
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,52 @@ False
-PARALLELISM 8
---- 1
[[10,5],[12,8],[4,5],[1,9],[2],[3,4,5,6,7],[1],[10,11,12,3,4,5,6,7]]

-NAME SimpleAggMinTimestampTest
-QUERY MATCH (a:person) RETURN MIN(a.registerTime)
-PARALLELISM 8
-ENUMERATE
---- 1
1911-08-20 02:32:21

-NAME SimpleAggMinDateTest
-QUERY MATCH (a:person) RETURN MIN(a.birthdate)
-PARALLELISM 7
-ENUMERATE
---- 1
1900-01-01

-NAME SimpleAggMinIntervalTest
-QUERY MATCH (a:person) RETURN MIN(a.lastJobDuration)
-PARALLELISM 4
-ENUMERATE
---- 1
00:18:00.024

-NAME SimpleAggMaxFloatTest
-QUERY MATCH (:person)-[w:workAt]->(:organisation) RETURN MAX(w.rating)
-PARALLELISM 3
-ENUMERATE
---- 1
9.200000

-NAME SimpleAggMaxInt16Test
-QUERY MATCH (:person)-[s:studyAt]->(:organisation) RETURN MAX(s.length)
-PARALLELISM 4
-ENUMERATE
---- 1
55

-NAME SimpleAggSumInt16Test
-QUERY MATCH (:person)-[s:studyAt]->(:organisation) RETURN SUM(s.length)
-PARALLELISM 2
-ENUMERATE
---- 1
82

-NAME SimpleAggAvgInt16Test
-QUERY MATCH (m:movies) RETURN AVG(m.length)
-PARALLELISM 7
-ENUMERATE
---- 1
989.333333

0 comments on commit 73f3f91

Please sign in to comment.