Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support fetching MAP type #165

Merged
merged 2 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ SEXP duckdb_r_allocate(const LogicalType &type, idx_t nrows) {
case LogicalTypeId::INTERVAL:
return NEW_NUMERIC(nrows);
case LogicalTypeId::LIST:
case LogicalTypeId::MAP:
return NEW_LIST(nrows);
case LogicalTypeId::STRUCT: {
cpp11::writable::list dest_list;
Expand Down Expand Up @@ -147,6 +148,7 @@ void duckdb_r_decorate(const LogicalType &type, const SEXP dest, bool integer64)
case LogicalTypeId::BLOB:
case LogicalTypeId::UUID:
case LogicalTypeId::LIST:
case LogicalTypeId::MAP:
break; // no extra decoration required, do nothing
case LogicalTypeId::TIMESTAMP_SEC:
case LogicalTypeId::TIMESTAMP_MS:
Expand Down Expand Up @@ -432,6 +434,56 @@ void duckdb_r_transform(Vector &src_vec, const SEXP dest, idx_t dest_offset, idx

break;
}

case LogicalTypeId::MAP: {
auto src_data = ListVector::GetData(src_vec);

auto &key_type = MapType::KeyType(src_vec.GetType());
auto &value_type = MapType::ValueType(src_vec.GetType());

Vector key_child(key_type, nullptr);
Vector value_child(value_type, nullptr);

for (size_t row_idx = 0; row_idx < n; row_idx++) {
if (!FlatVector::Validity(src_vec).RowIsValid(row_idx)) {
SET_ELEMENT(dest, dest_offset + row_idx, R_NilValue);
} else {
auto offset = src_data[row_idx].offset;
auto length = src_data[row_idx].length;
const auto end = offset + length;

key_child.Slice(MapVector::GetKeys(src_vec), offset, end);
value_child.Slice(MapVector::GetValues(src_vec), offset, end);

cpp11::sexp key_sexp = duckdb_r_allocate(key_type, length);
cpp11::sexp value_sexp = duckdb_r_allocate(value_type, length);

duckdb_r_decorate(key_type, key_sexp, integer64);
duckdb_r_decorate(value_type, value_sexp, integer64);

duckdb_r_transform(key_child, key_sexp, 0, length, integer64);
duckdb_r_transform(value_child, value_sexp, 0, length, integer64);

cpp11::writable::list dest_list;
dest_list.reserve(2);

dest_list.push_back(cpp11::named_arg("key") = std::move(key_sexp));
dest_list.push_back(cpp11::named_arg("value") = std::move(value_sexp));

// convert to SEXP, with potential side effect of truncation
(void)(SEXP)dest_list;

// Note we cannot use cpp11's data frame here as it tries to calculate the number of rows itself,
// but gives the wrong answer if the first column is another data frame or the struct is empty.
dest_list.attr(R_ClassSymbol) = RStrings::get().dataframe_str;
dest_list.attr(R_RowNamesSymbol) = {NA_INTEGER, -static_cast<int>(length)};
// call R's own extract subset method
SET_ELEMENT(dest, dest_offset + row_idx, dest_list);
}
}
break;
}

case LogicalTypeId::BLOB: {
auto src_ptr = FlatVector::GetData<string_t>(src_vec);
auto &mask = FlatVector::Validity(src_vec);
Expand Down
3 changes: 3 additions & 0 deletions src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,15 @@ string RApiTypes::DetectLogicalType(const LogicalType &stype, const char *caller
case LogicalTypeId::LIST:
return "list";
case LogicalTypeId::STRUCT:
case LogicalTypeId::MAP:
return "data.frame";
return "data.frame";
case LogicalTypeId::ENUM:
return "factor";
case LogicalTypeId::UNKNOWN:
case LogicalTypeId::SQLNULL:
return "unknown";

default:
cpp11::stop("%s: Unknown column type for prepare: %s", caller, stype.ToString().c_str());
break;
Expand Down
16 changes: 15 additions & 1 deletion tests/testthat/_snaps/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Code
as.list(dbGetQuery(con,
"SELECT * EXCLUDE (timestamp_tz, time_tz, timestamp_ns, timestamp_array, timestamptz_array, map, bit, \"union\", fixed_int_array, fixed_varchar_array, fixed_nested_int_array, fixed_nested_varchar_array, fixed_struct_array, struct_of_fixed_array, fixed_array_of_int_list, list_of_fixed_int_array) REPLACE(replace(varchar, chr(0), '') AS varchar) FROM test_all_types(use_large_enum=true)"))
"SELECT * EXCLUDE (timestamp_tz, time_tz, timestamp_ns, timestamp_array, timestamptz_array, bit, \"union\", fixed_int_array, fixed_varchar_array, fixed_nested_int_array, fixed_nested_varchar_array, fixed_struct_array, struct_of_fixed_array, fixed_array_of_int_list, list_of_fixed_int_array) REPLACE(replace(varchar, chr(0), '') AS varchar) FROM test_all_types(use_large_enum=true)"))
Output
$bool
[1] FALSE TRUE NA
Expand Down Expand Up @@ -207,4 +207,18 @@
NULL
$map
$map[[1]]
[1] key value
<0 rows> (or 0-length row.names)
$map[[2]]
key value
1 key1 🦆🦆🦆🦆🦆🦆
2 key2 goose
$map[[3]]
NULL

87 changes: 87 additions & 0 deletions tests/testthat/test-map.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
test_that("maps can be read", {
skip_if_not_installed("vctrs")

con <- dbConnect(duckdb())
on.exit(dbDisconnect(con, shutdown = TRUE))

res <- dbGetQuery(
con,
"SELECT map([1,2],['a','b']) AS x"
)
expect_equal(res, vctrs::data_frame(
x = list(
vctrs::data_frame(key = 1:2, value = letters[1:2])
)
))

res <- dbGetQuery(
con,
"SELECT 1 as a, map([1,2],[1.5,2.5]) AS x UNION SELECT 2, map([3,4,5],[5.5,4.5,3.5]) ORDER BY a"
)
expect_equal(res, vctrs::data_frame(
a = 1:2,
x = list(
vctrs::data_frame(key = 1:2, value = 1:2 + 0.5),
vctrs::data_frame(key = 3:5, value = 5:3 + 0.5)
)
))

res <- dbGetQuery(
con,
"SELECT 1 as a, map([1,2],[TRUE,FALSE]) AS x UNION SELECT 2, NULL ORDER BY a"
)
expect_equal(res, vctrs::data_frame(
a = 1:2,
x = list(
vctrs::data_frame(key = 1:2, value = c(TRUE, FALSE)),
NULL
)
))
})

test_that("structs give the same results via Arrow", {
skip_on_cran()
skip_if_not_installed("vctrs")
skip_if_not_installed("tibble")
skip_if_not_installed("arrow", "13.0.0")

con <- dbConnect(duckdb())
on.exit(dbDisconnect(con, shutdown = TRUE))

res <- dbGetQuery(
con,
"SELECT map([1,2],['a','b']) AS x",
arrow = TRUE
)
expect_equal(res, vctrs::data_frame(
x = structure(class = c("arrow_list", class(vctrs::list_of(logical()))), vctrs::list_of(
tibble::tibble(key = 1:2, value = letters[1:2])
))
))

res <- dbGetQuery(
con,
"SELECT 1 as a, map([1,2],[1.5,2.5]) AS x UNION SELECT 2, map([3,4,5],[5.5,4.5,3.5]) ORDER BY a",
arrow = TRUE
)
expect_equal(res, vctrs::data_frame(
a = 1:2,
x = structure(class = c("arrow_list", class(vctrs::list_of(logical()))), vctrs::list_of(
tibble::tibble(key = 1:2, value = 1:2 + 0.5),
tibble::tibble(key = 3:5, value = 5:3 + 0.5)
))
))

res <- dbGetQuery(
con,
"SELECT 1 as a, map([1,2],[TRUE,FALSE]) AS x UNION SELECT 2, NULL ORDER BY a",
arrow = TRUE
)
expect_equal(res, vctrs::data_frame(
a = 1:2,
x = structure(class = c("arrow_list", class(vctrs::list_of(logical()))), vctrs::list_of(
tibble::tibble(key = 1:2, value = c(TRUE, FALSE)),
NULL
))
))
})
2 changes: 1 addition & 1 deletion tests/testthat/test-types.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ test_that("test_all_types() output", {

# Need to omit timestamp columns, likely due to https://bugs.r-project.org/show_bug.cgi?id=16856
expect_snapshot({
as.list(dbGetQuery(con, "SELECT * EXCLUDE (timestamp_tz, time_tz, timestamp_ns, timestamp_array, timestamptz_array, map, bit, \"union\", fixed_int_array, fixed_varchar_array, fixed_nested_int_array, fixed_nested_varchar_array, fixed_struct_array, struct_of_fixed_array, fixed_array_of_int_list, list_of_fixed_int_array) REPLACE(replace(varchar, chr(0), '') AS varchar) FROM test_all_types(use_large_enum=true)"))
as.list(dbGetQuery(con, "SELECT * EXCLUDE (timestamp_tz, time_tz, timestamp_ns, timestamp_array, timestamptz_array, bit, \"union\", fixed_int_array, fixed_varchar_array, fixed_nested_int_array, fixed_nested_varchar_array, fixed_struct_array, struct_of_fixed_array, fixed_array_of_int_list, list_of_fixed_int_array) REPLACE(replace(varchar, chr(0), '') AS varchar) FROM test_all_types(use_large_enum=true)"))
})
})
Loading