From 54a7ce07957a3db0f61a330a5579814d2157b52e Mon Sep 17 00:00:00 2001 From: Lukas Schneiderbauer Date: Sat, 23 Mar 2024 09:38:53 +0100 Subject: [PATCH 1/7] implement n_distinct() for multiple arguments using duckdb structs --- R/backend-dbplyr__duckdb_connection.R | 32 +++++++++++++++++-- .../test-backend-dbplyr__duckdb_connection.R | 6 ++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/R/backend-dbplyr__duckdb_connection.R b/R/backend-dbplyr__duckdb_connection.R index 82b6b6147..424b251b2 100644 --- a/R/backend-dbplyr__duckdb_connection.R +++ b/R/backend-dbplyr__duckdb_connection.R @@ -77,6 +77,26 @@ duckdb_grepl <- function(pattern, x, ignore.case = FALSE, perl = FALSE, fixed = } } +duckdb_n_distinct <- function(..., na.rm = FALSE) { + check_na_rm <- pkg_method("check_na_rm", "dbplyr") + glue_sql2 <- pkg_method("glue_sql2", "dbplyr") + + check_na_rm(na.rm) + + vars <- list(...) + str_struct <- + paste0("{", paste0( + lapply( + seq_along(vars), + \(i) glue::glue("'v{i}' : {vars[[i]]}") + ), + collapse = ", " + ), "}") + glue_sql2( + sql_current_con(), + "COUNT(DISTINCT {str_struct})" + ) +} # Customized translation functions for DuckDB SQL # @param con A \code{\link{dbConnect}} object, as returned by \code{dbConnect()} @@ -316,7 +336,8 @@ sql_translation.duckdb_connection <- function(con) { any = sql_aggregate("BOOL_OR", "any"), str_flatten = function(x, collapse) sql_expr(STRING_AGG(!!x, !!collapse)), first = sql_prefix("FIRST", 1), - last = sql_prefix("LAST", 1) + last = sql_prefix("LAST", 1), + n_distinct = duckdb_n_distinct ), sql_translator( .parent = base_win, @@ -333,7 +354,14 @@ sql_translation.duckdb_connection <- function(con) { partition = win_current_group(), order = win_current_order() ) - } + }, + n_distinct = + function(..., na.rm = FALSE) { + win_over( + duckdb_n_distinct(..., na.rm = na.rm), + partition = win_current_group() + ) + } ) ) } diff --git a/tests/testthat/test-backend-dbplyr__duckdb_connection.R b/tests/testthat/test-backend-dbplyr__duckdb_connection.R index 21e8074ac..986c87311 100644 --- a/tests/testthat/test-backend-dbplyr__duckdb_connection.R +++ b/tests/testthat/test-backend-dbplyr__duckdb_connection.R @@ -206,6 +206,9 @@ test_that("aggregators translated correctly", { expect_equal(translate(str_flatten(x, ","), window = FALSE), sql(r"{STRING_AGG(x, ',')}")) expect_equal(translate(str_flatten(x, ","), window = TRUE), sql(r"{STRING_AGG(x, ',') OVER ()}")) + + expect_equal(translate(n_distinct(x, na.rm = TRUE), window = FALSE), sql(r"{COUNT(DISTINCT {'v1' : x})}")) + expect_equal(translate(n_distinct(x, na.rm = TRUE), window = TRUE), sql(r"{COUNT(DISTINCT {'v1' : x}) OVER ()}")) }) test_that("two variable aggregates are translated correctly", { @@ -218,6 +221,9 @@ test_that("two variable aggregates are translated correctly", { expect_equal(translate(cor(x, y), window = FALSE), sql(r"{CORR(x, y)}")) expect_equal(translate(cor(x, y), window = TRUE), sql(r"{CORR(x, y) OVER ()}")) + + expect_equal(translate(n_distinct(x, y, na.rm = TRUE), window = FALSE), sql(r"{COUNT(DISTINCT {'v1' : x, 'v2' : y})}")) + expect_equal(translate(n_distinct(x, y, na.rm = TRUE), window = TRUE), sql(r"{COUNT(DISTINCT {'v1' : x, 'v2' : y}) OVER ()}")) }) From dee31f7e9e6e031e923aef0d1367fa932522771b Mon Sep 17 00:00:00 2001 From: Lukas Schneiderbauer Date: Sat, 30 Mar 2024 10:05:16 +0100 Subject: [PATCH 2/7] remove wrong `na.rm = TRUE` check: the implemented bevaviour coincides with `na.rm = FALSE`. --- R/backend-dbplyr__duckdb_connection.R | 5 +++-- tests/testthat/test-backend-dbplyr__duckdb_connection.R | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/R/backend-dbplyr__duckdb_connection.R b/R/backend-dbplyr__duckdb_connection.R index 424b251b2..75861615b 100644 --- a/R/backend-dbplyr__duckdb_connection.R +++ b/R/backend-dbplyr__duckdb_connection.R @@ -78,10 +78,11 @@ duckdb_grepl <- function(pattern, x, ignore.case = FALSE, perl = FALSE, fixed = } duckdb_n_distinct <- function(..., na.rm = FALSE) { - check_na_rm <- pkg_method("check_na_rm", "dbplyr") glue_sql2 <- pkg_method("glue_sql2", "dbplyr") - check_na_rm(na.rm) + if (identical(na.rm, TRUE)) { + cli::cli_abort("`na.rm = TRUE` not implemented.") + } vars <- list(...) str_struct <- diff --git a/tests/testthat/test-backend-dbplyr__duckdb_connection.R b/tests/testthat/test-backend-dbplyr__duckdb_connection.R index 986c87311..7dc0a45ae 100644 --- a/tests/testthat/test-backend-dbplyr__duckdb_connection.R +++ b/tests/testthat/test-backend-dbplyr__duckdb_connection.R @@ -207,8 +207,8 @@ test_that("aggregators translated correctly", { expect_equal(translate(str_flatten(x, ","), window = FALSE), sql(r"{STRING_AGG(x, ',')}")) expect_equal(translate(str_flatten(x, ","), window = TRUE), sql(r"{STRING_AGG(x, ',') OVER ()}")) - expect_equal(translate(n_distinct(x, na.rm = TRUE), window = FALSE), sql(r"{COUNT(DISTINCT {'v1' : x})}")) - expect_equal(translate(n_distinct(x, na.rm = TRUE), window = TRUE), sql(r"{COUNT(DISTINCT {'v1' : x}) OVER ()}")) + expect_equal(translate(n_distinct(x), window = FALSE), sql(r"{COUNT(DISTINCT {'v1' : x})}")) + expect_equal(translate(n_distinct(x), window = TRUE), sql(r"{COUNT(DISTINCT {'v1' : x}) OVER ()}")) }) test_that("two variable aggregates are translated correctly", { @@ -222,8 +222,8 @@ test_that("two variable aggregates are translated correctly", { expect_equal(translate(cor(x, y), window = FALSE), sql(r"{CORR(x, y)}")) expect_equal(translate(cor(x, y), window = TRUE), sql(r"{CORR(x, y) OVER ()}")) - expect_equal(translate(n_distinct(x, y, na.rm = TRUE), window = FALSE), sql(r"{COUNT(DISTINCT {'v1' : x, 'v2' : y})}")) - expect_equal(translate(n_distinct(x, y, na.rm = TRUE), window = TRUE), sql(r"{COUNT(DISTINCT {'v1' : x, 'v2' : y}) OVER ()}")) + expect_equal(translate(n_distinct(x, y), window = FALSE), sql(r"{COUNT(DISTINCT {'v1' : x, 'v2' : y})}")) + expect_equal(translate(n_distinct(x, y), window = TRUE), sql(r"{COUNT(DISTINCT {'v1' : x, 'v2' : y}) OVER ()}")) }) From f7e1f655bdcf31ceb25f725d9e2a6e4d5a5fbafb Mon Sep 17 00:00:00 2001 From: Lukas Schneiderbauer Date: Sat, 30 Mar 2024 10:09:39 +0100 Subject: [PATCH 3/7] more `n_distinct()` tests: check computation results --- .../test-backend-dbplyr__duckdb_connection.R | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/testthat/test-backend-dbplyr__duckdb_connection.R b/tests/testthat/test-backend-dbplyr__duckdb_connection.R index 7dc0a45ae..46f58816d 100644 --- a/tests/testthat/test-backend-dbplyr__duckdb_connection.R +++ b/tests/testthat/test-backend-dbplyr__duckdb_connection.R @@ -226,6 +226,57 @@ test_that("two variable aggregates are translated correctly", { expect_equal(translate(n_distinct(x, y), window = TRUE), sql(r"{COUNT(DISTINCT {'v1' : x, 'v2' : y}) OVER ()}")) }) +test_that("n_distinct() computations are correct", { + skip_if_no_R4() + skip_if_not_installed("dplyr") + skip_if_not_installed("dbplyr") + con <- dbConnect(duckdb()) + on.exit(dbDisconnect(con, shutdown = TRUE)) + tbl <- dplyr::tbl + summarize <- dplyr::summarize + pull <- dplyr::pull + + duckdb_register(con, "df1", data.frame(x = c(1, 1, 2))) + duckdb_register(con, "df1_na", data.frame(x = c(1, 1, 2, NA))) + duckdb_register(con, "df2", data.frame(x = c(1, 1, 2, 2), y = c(1, 2, 2, 2))) + duckdb_register(con, "df2_na", data.frame(x = c(1, 1, 2, NA, NA), y = c(1, 2, NA, 2, NA))) + + expect_error( + tbl(con, "df1") |> + summarize(n = n_distinct(x, na.rm = TRUE)) |> + pull(n) + ) + + # single column is working as usual + expect_equal( + tbl(con, "df1") |> + summarize(n = n_distinct(x)) |> + pull(n), + 2 + ) + expect_equal( + tbl(con, "df1_na") |> + summarize(n = n_distinct(x)) |> + pull(n), + 3 + ) + + # two columns return correct results + expect_equal( + tbl(con, "df2") |> + summarize(n = n_distinct(x, y)) |> + pull(n), + 3 + ) + + # two columns containing NAs return correct results + expect_equal( + tbl(con, "df2_na") |> + summarize(n = n_distinct(x, y)) |> + pull(n), + 5 + ) +}) From 5731d9996fef00778b0af61dee6bd0ce8d59be74 Mon Sep 17 00:00:00 2001 From: Lukas Schneiderbauer Date: Sat, 30 Mar 2024 11:26:03 +0100 Subject: [PATCH 4/7] `n_distinct()`: simplify SQL by using duckdb's row() function and get rid of `dbplyr:::glue_sql2()` usage --- R/backend-dbplyr__duckdb_connection.R | 19 +++++-------------- .../test-backend-dbplyr__duckdb_connection.R | 8 ++++---- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/R/backend-dbplyr__duckdb_connection.R b/R/backend-dbplyr__duckdb_connection.R index 75861615b..307f2d0e1 100644 --- a/R/backend-dbplyr__duckdb_connection.R +++ b/R/backend-dbplyr__duckdb_connection.R @@ -78,25 +78,16 @@ duckdb_grepl <- function(pattern, x, ignore.case = FALSE, perl = FALSE, fixed = } duckdb_n_distinct <- function(..., na.rm = FALSE) { - glue_sql2 <- pkg_method("glue_sql2", "dbplyr") + sql <- pkg_method("sql", "dbplyr") if (identical(na.rm, TRUE)) { cli::cli_abort("`na.rm = TRUE` not implemented.") } - vars <- list(...) - str_struct <- - paste0("{", paste0( - lapply( - seq_along(vars), - \(i) glue::glue("'v{i}' : {vars[[i]]}") - ), - collapse = ", " - ), "}") - glue_sql2( - sql_current_con(), - "COUNT(DISTINCT {str_struct})" - ) + # https://duckdb.org/docs/sql/data_types/struct.html#creating-structs-with-the-row-function + str_struct <- paste0("row(", paste0(list(...), collapse = ", "), ")") + + sql(paste0("COUNT(DISTINCT ", str_struct, ")")) } # Customized translation functions for DuckDB SQL diff --git a/tests/testthat/test-backend-dbplyr__duckdb_connection.R b/tests/testthat/test-backend-dbplyr__duckdb_connection.R index 46f58816d..3301537ba 100644 --- a/tests/testthat/test-backend-dbplyr__duckdb_connection.R +++ b/tests/testthat/test-backend-dbplyr__duckdb_connection.R @@ -207,8 +207,8 @@ test_that("aggregators translated correctly", { expect_equal(translate(str_flatten(x, ","), window = FALSE), sql(r"{STRING_AGG(x, ',')}")) expect_equal(translate(str_flatten(x, ","), window = TRUE), sql(r"{STRING_AGG(x, ',') OVER ()}")) - expect_equal(translate(n_distinct(x), window = FALSE), sql(r"{COUNT(DISTINCT {'v1' : x})}")) - expect_equal(translate(n_distinct(x), window = TRUE), sql(r"{COUNT(DISTINCT {'v1' : x}) OVER ()}")) + expect_equal(translate(n_distinct(x), window = FALSE), sql(r"{COUNT(DISTINCT row(x))}")) + expect_equal(translate(n_distinct(x), window = TRUE), sql(r"{COUNT(DISTINCT row(x)) OVER ()}")) }) test_that("two variable aggregates are translated correctly", { @@ -222,8 +222,8 @@ test_that("two variable aggregates are translated correctly", { expect_equal(translate(cor(x, y), window = FALSE), sql(r"{CORR(x, y)}")) expect_equal(translate(cor(x, y), window = TRUE), sql(r"{CORR(x, y) OVER ()}")) - expect_equal(translate(n_distinct(x, y), window = FALSE), sql(r"{COUNT(DISTINCT {'v1' : x, 'v2' : y})}")) - expect_equal(translate(n_distinct(x, y), window = TRUE), sql(r"{COUNT(DISTINCT {'v1' : x, 'v2' : y}) OVER ()}")) + expect_equal(translate(n_distinct(x, y), window = FALSE), sql(r"{COUNT(DISTINCT row(x, y))}")) + expect_equal(translate(n_distinct(x, y), window = TRUE), sql(r"{COUNT(DISTINCT row(x, y)) OVER ()}")) }) test_that("n_distinct() computations are correct", { From 45ab3d2447149d7f585ee91e561ae374f3913414 Mon Sep 17 00:00:00 2001 From: Lukas Schneiderbauer Date: Sat, 30 Mar 2024 11:26:55 +0100 Subject: [PATCH 5/7] replace `cli::cli_abort()` by `stop()` --- R/backend-dbplyr__duckdb_connection.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/backend-dbplyr__duckdb_connection.R b/R/backend-dbplyr__duckdb_connection.R index 307f2d0e1..9e9cb80b2 100644 --- a/R/backend-dbplyr__duckdb_connection.R +++ b/R/backend-dbplyr__duckdb_connection.R @@ -81,7 +81,7 @@ duckdb_n_distinct <- function(..., na.rm = FALSE) { sql <- pkg_method("sql", "dbplyr") if (identical(na.rm, TRUE)) { - cli::cli_abort("`na.rm = TRUE` not implemented.") + stop("Parameter `na.rm = TRUE` in n_distinct() is currently not supported in DuckDB backend.", call. = FALSE) } # https://duckdb.org/docs/sql/data_types/struct.html#creating-structs-with-the-row-function From 05f20b23459695776f2999c1d4bbed426f466cc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kirill=20M=C3=BCller?= Date: Sat, 30 Mar 2024 13:55:50 +0100 Subject: [PATCH 6/7] Strict --- R/backend-dbplyr__duckdb_connection.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/backend-dbplyr__duckdb_connection.R b/R/backend-dbplyr__duckdb_connection.R index 9e9cb80b2..dcb23b671 100644 --- a/R/backend-dbplyr__duckdb_connection.R +++ b/R/backend-dbplyr__duckdb_connection.R @@ -80,7 +80,7 @@ duckdb_grepl <- function(pattern, x, ignore.case = FALSE, perl = FALSE, fixed = duckdb_n_distinct <- function(..., na.rm = FALSE) { sql <- pkg_method("sql", "dbplyr") - if (identical(na.rm, TRUE)) { + if (!identical(na.rm, FALSE)) { stop("Parameter `na.rm = TRUE` in n_distinct() is currently not supported in DuckDB backend.", call. = FALSE) } From eb0eeb4e868f757d8cad6452cd669a611181bfac Mon Sep 17 00:00:00 2001 From: Lukas Schneiderbauer Date: Sat, 30 Mar 2024 18:04:21 +0100 Subject: [PATCH 7/7] compat: do not use pipe operator in tests --- .../test-backend-dbplyr__duckdb_connection.R | 30 +++++++------------ 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/testthat/test-backend-dbplyr__duckdb_connection.R b/tests/testthat/test-backend-dbplyr__duckdb_connection.R index 3301537ba..13eea17d9 100644 --- a/tests/testthat/test-backend-dbplyr__duckdb_connection.R +++ b/tests/testthat/test-backend-dbplyr__duckdb_connection.R @@ -236,44 +236,36 @@ test_that("n_distinct() computations are correct", { summarize <- dplyr::summarize pull <- dplyr::pull - duckdb_register(con, "df1", data.frame(x = c(1, 1, 2))) - duckdb_register(con, "df1_na", data.frame(x = c(1, 1, 2, NA))) - duckdb_register(con, "df2", data.frame(x = c(1, 1, 2, 2), y = c(1, 2, 2, 2))) - duckdb_register(con, "df2_na", data.frame(x = c(1, 1, 2, NA, NA), y = c(1, 2, NA, 2, NA))) + duckdb_register(con, "df", data.frame(x = c(1, 1, 2, 2), y = c(1, 2, 2, 2))) + duckdb_register(con, "df_na", data.frame(x = c(1, 1, 2, NA, NA), y = c(1, 2, NA, 2, NA))) + + df <- tbl(con, "df") + df_na <- tbl(con, "df_na") expect_error( - tbl(con, "df1") |> - summarize(n = n_distinct(x, na.rm = TRUE)) |> - pull(n) + pull(summarize(df, n = n_distinct(x, na.rm = TRUE)), n) ) # single column is working as usual expect_equal( - tbl(con, "df1") |> - summarize(n = n_distinct(x)) |> - pull(n), + pull(summarize(df, n = n_distinct(x)), n), 2 ) + expect_equal( - tbl(con, "df1_na") |> - summarize(n = n_distinct(x)) |> - pull(n), + pull(summarize(df_na, n = n_distinct(x)), n), 3 ) # two columns return correct results expect_equal( - tbl(con, "df2") |> - summarize(n = n_distinct(x, y)) |> - pull(n), + pull(summarize(df, n = n_distinct(x, y)), n), 3 ) # two columns containing NAs return correct results expect_equal( - tbl(con, "df2_na") |> - summarize(n = n_distinct(x, y)) |> - pull(n), + pull(summarize(df_na, n = n_distinct(x, y)), n), 5 ) })