Skip to content

Commit

Permalink
Merge pull request #122 from lschneiderbauer/feature/n_distinct
Browse files Browse the repository at this point in the history
  • Loading branch information
krlmlr committed Apr 29, 2024
2 parents 944679a + eb0eeb4 commit 5013908
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
24 changes: 22 additions & 2 deletions R/backend-dbplyr__duckdb_connection.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ duckdb_grepl <- function(pattern, x, ignore.case = FALSE, perl = FALSE, fixed =
}
}

duckdb_n_distinct <- function(..., na.rm = FALSE) {
sql <- pkg_method("sql", "dbplyr")

if (!identical(na.rm, FALSE)) {
stop("Parameter `na.rm = TRUE` in n_distinct() is currently not supported in DuckDB backend.", call. = FALSE)
}

# https://duckdb.org/docs/sql/data_types/struct.html#creating-structs-with-the-row-function
str_struct <- paste0("row(", paste0(list(...), collapse = ", "), ")")

sql(paste0("COUNT(DISTINCT ", str_struct, ")"))
}

# Customized translation functions for DuckDB SQL
# @param con A \code{\link{dbConnect}} object, as returned by \code{dbConnect()}
Expand Down Expand Up @@ -316,7 +328,8 @@ sql_translation.duckdb_connection <- function(con) {
any = sql_aggregate("BOOL_OR", "any"),
str_flatten = function(x, collapse) sql_expr(STRING_AGG(!!x, !!collapse)),
first = sql_prefix("FIRST", 1),
last = sql_prefix("LAST", 1)
last = sql_prefix("LAST", 1),
n_distinct = duckdb_n_distinct
),
sql_translator(
.parent = base_win,
Expand All @@ -333,7 +346,14 @@ sql_translation.duckdb_connection <- function(con) {
partition = win_current_group(),
order = win_current_order()
)
}
},
n_distinct =
function(..., na.rm = FALSE) {
win_over(
duckdb_n_distinct(..., na.rm = na.rm),
partition = win_current_group()
)
}
)
)
}
Expand Down
49 changes: 49 additions & 0 deletions tests/testthat/test-backend-dbplyr__duckdb_connection.R
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ test_that("aggregators translated correctly", {

expect_equal(translate(str_flatten(x, ","), window = FALSE), sql(r"{STRING_AGG(x, ',')}"))
expect_equal(translate(str_flatten(x, ","), window = TRUE), sql(r"{STRING_AGG(x, ',') OVER ()}"))

expect_equal(translate(n_distinct(x), window = FALSE), sql(r"{COUNT(DISTINCT row(x))}"))
expect_equal(translate(n_distinct(x), window = TRUE), sql(r"{COUNT(DISTINCT row(x)) OVER ()}"))
})

test_that("two variable aggregates are translated correctly", {
Expand All @@ -218,8 +221,54 @@ test_that("two variable aggregates are translated correctly", {

expect_equal(translate(cor(x, y), window = FALSE), sql(r"{CORR(x, y)}"))
expect_equal(translate(cor(x, y), window = TRUE), sql(r"{CORR(x, y) OVER ()}"))

expect_equal(translate(n_distinct(x, y), window = FALSE), sql(r"{COUNT(DISTINCT row(x, y))}"))
expect_equal(translate(n_distinct(x, y), window = TRUE), sql(r"{COUNT(DISTINCT row(x, y)) OVER ()}"))
})

test_that("n_distinct() computations are correct", {
skip_if_no_R4()
skip_if_not_installed("dplyr")
skip_if_not_installed("dbplyr")
con <- dbConnect(duckdb())
on.exit(dbDisconnect(con, shutdown = TRUE))
tbl <- dplyr::tbl
summarize <- dplyr::summarize
pull <- dplyr::pull

duckdb_register(con, "df", data.frame(x = c(1, 1, 2, 2), y = c(1, 2, 2, 2)))
duckdb_register(con, "df_na", data.frame(x = c(1, 1, 2, NA, NA), y = c(1, 2, NA, 2, NA)))

df <- tbl(con, "df")
df_na <- tbl(con, "df_na")

expect_error(
pull(summarize(df, n = n_distinct(x, na.rm = TRUE)), n)
)

# single column is working as usual
expect_equal(
pull(summarize(df, n = n_distinct(x)), n),
2
)

expect_equal(
pull(summarize(df_na, n = n_distinct(x)), n),
3
)

# two columns return correct results
expect_equal(
pull(summarize(df, n = n_distinct(x, y)), n),
3
)

# two columns containing NAs return correct results
expect_equal(
pull(summarize(df_na, n = n_distinct(x, y)), n),
5
)
})



Expand Down

0 comments on commit 5013908

Please sign in to comment.