Skip to content

Commit

Permalink
Switch rmm::get_current_device_resource to use references via `rmm:…
Browse files Browse the repository at this point in the history
…:get_current_device_resource_ref` (#2372)

* Switching get resource to use references

Signed-off-by: Mike Wilson <knobby@burntsheep.com>
  • Loading branch information
hyperbolic2346 authored Sep 3, 2024
1 parent fa67ada commit ad271f0
Show file tree
Hide file tree
Showing 23 changed files with 122 additions and 98 deletions.
8 changes: 4 additions & 4 deletions src/main/cpp/src/bloom_filter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ std::unique_ptr<cudf::list_scalar> bloom_filter_create(
int num_hashes,
int bloom_filter_longs,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Inserts input values into a bloom filter.
Expand Down Expand Up @@ -79,7 +79,7 @@ std::unique_ptr<cudf::column> bloom_filter_probe(
cudf::column_view const& input,
cudf::device_span<uint8_t const> bloom_filter,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Probe a bloom filter with an input column of int64_t values.
Expand All @@ -96,7 +96,7 @@ std::unique_ptr<cudf::column> bloom_filter_probe(
cudf::column_view const& input,
cudf::list_scalar& bloom_filter,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Merge multiple bloom filters into a single output.
Expand All @@ -114,6 +114,6 @@ std::unique_ptr<cudf::column> bloom_filter_probe(
std::unique_ptr<cudf::list_scalar> bloom_filter_merge(
cudf::column_view const& bloom_filters,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

} // namespace spark_rapids_jni
2 changes: 1 addition & 1 deletion src/main/cpp/src/case_when.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ namespace spark_rapids_jni {
std::unique_ptr<cudf::column> select_first_true_index(
cudf::table_view const& when_bool_columns,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

} // namespace spark_rapids_jni
12 changes: 6 additions & 6 deletions src/main/cpp/src/cast_string.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ std::unique_ptr<cudf::column> string_to_integer(
bool ansi_mode,
bool strip,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Convert a string column into an decimal column.
Expand All @@ -97,7 +97,7 @@ std::unique_ptr<cudf::column> string_to_decimal(
bool ansi_mode,
bool strip,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Convert a string column into an float column.
Expand All @@ -115,22 +115,22 @@ std::unique_ptr<cudf::column> string_to_float(
cudf::strings_column_view const& string_col,
bool ansi_mode,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

std::unique_ptr<cudf::column> format_float(
cudf::column_view const& input,
int const digits,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

std::unique_ptr<cudf::column> float_to_string(
cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

std::unique_ptr<cudf::column> decimal_to_non_ansi_string(
cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

} // namespace spark_rapids_jni
14 changes: 7 additions & 7 deletions src/main/cpp/src/decimal_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ std::unique_ptr<cudf::table> multiply_decimal128(cudf::column_view const& a,
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1026,7 +1026,7 @@ std::unique_ptr<cudf::table> divide_decimal128(cudf::column_view const& a,
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1060,7 +1060,7 @@ std::unique_ptr<cudf::table> integer_divide_decimal128(cudf::column_view const&
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1093,7 +1093,7 @@ std::unique_ptr<cudf::table> remainder_decimal128(cudf::column_view const& a,
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1126,7 +1126,7 @@ std::unique_ptr<cudf::table> add_decimal128(cudf::column_view const& a,
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1159,7 +1159,7 @@ std::unique_ptr<cudf::table> sub_decimal128(cudf::column_view const& a,
auto const num_rows = a.size();
CUDF_EXPECTS(num_rows == b.size(), "inputs have mismatched row counts");
auto [result_null_mask, result_null_count] = cudf::detail::bitmask_and(
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource());
cudf::table_view{{a, b}}, stream, rmm::mr::get_current_device_resource_ref());
std::vector<std::unique_ptr<cudf::column>> columns;
// copy the null mask here, as it will be used again later
columns.push_back(cudf::make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8},
Expand Down Expand Up @@ -1410,7 +1410,7 @@ std::pair<std::unique_ptr<cudf::column>, bool> floating_point_to_decimal(
output_type, input.size(), cudf::mask_state::UNALLOCATED, stream, mr);

auto const decimal_places = -output_type.scale();
auto const default_mr = rmm::mr::get_current_device_resource();
auto const default_mr = rmm::mr::get_current_device_resource_ref();

rmm::device_uvector<int8_t> validity(input.size(), stream, default_mr);
rmm::device_scalar<bool> has_failure(false, stream, default_mr);
Expand Down
2 changes: 1 addition & 1 deletion src/main/cpp/src/decimal_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ std::pair<std::unique_ptr<cudf::column>, bool> floating_point_to_decimal(
cudf::data_type output_type,
int32_t precision,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

} // namespace cudf::jni
2 changes: 1 addition & 1 deletion src/main/cpp/src/from_json.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ namespace spark_rapids_jni {
std::unique_ptr<cudf::column> from_json_to_raw_map(
cudf::strings_column_view const& input,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

} // namespace spark_rapids_jni
6 changes: 3 additions & 3 deletions src/main/cpp/src/from_json_to_raw_map.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ rmm::device_uvector<char> unify_json_strings(cudf::strings_column_view const& in
{
if (input.is_empty()) {
return cudf::detail::make_device_uvector_async<char>(
std::vector<char>{'[', ']'}, stream, rmm::mr::get_current_device_resource());
std::vector<char>{'[', ']'}, stream, rmm::mr::get_current_device_resource_ref());
}

auto const d_strings = cudf::column_device_view::create(input.parent(), stream);
Expand All @@ -84,7 +84,7 @@ rmm::device_uvector<char> unify_json_strings(cudf::strings_column_view const& in
cudf::string_scalar(","), // append `,` character between the input rows
cudf::string_scalar("{}"), // replacement for null rows
stream,
rmm::mr::get_current_device_resource());
rmm::mr::get_current_device_resource_ref());
auto const joined_input_scv = cudf::strings_column_view{*joined_input};
auto const joined_input_size_bytes = joined_input_scv.chars_size(stream);
// TODO: This assertion requires a stream synchronization, may want to remove at some point.
Expand Down Expand Up @@ -656,7 +656,7 @@ std::unique_ptr<cudf::column> from_json_to_raw_map(cudf::strings_column_view con
cudf::device_span<char const>{unified_json_buff.data(), unified_json_buff.size()},
cudf::io::json_reader_options{},
stream,
rmm::mr::get_current_device_resource());
rmm::mr::get_current_device_resource_ref());

#ifdef DEBUG_FROM_JSON
print_debug(tokens, "Tokens", ", ", stream);
Expand Down
6 changes: 3 additions & 3 deletions src/main/cpp/src/get_json_object.cu
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ construct_path_commands(
d_path_commands.reserve(h_path_commands->size());
for (auto const& path_commands : *h_path_commands) {
d_path_commands.emplace_back(cudf::detail::make_device_uvector_async(
path_commands, stream, rmm::mr::get_current_device_resource()));
path_commands, stream, rmm::mr::get_current_device_resource_ref()));
}

return {std::move(d_path_commands),
Expand Down Expand Up @@ -1060,7 +1060,7 @@ std::vector<std::unique_ptr<cudf::column>> get_json_object_batch(
d_error_check.data() + idx});
}
auto d_path_data = cudf::detail::make_device_uvector_async(
h_path_data, stream, rmm::mr::get_current_device_resource());
h_path_data, stream, rmm::mr::get_current_device_resource_ref());
thrust::uninitialized_fill(
rmm::exec_policy(stream), d_error_check.begin(), d_error_check.end(), 0);

Expand Down Expand Up @@ -1130,7 +1130,7 @@ std::vector<std::unique_ptr<cudf::column>> get_json_object_batch(

// Push data to the GPU and launch the kernel again.
d_path_data = cudf::detail::make_device_uvector_async(
h_path_data, stream, rmm::mr::get_current_device_resource());
h_path_data, stream, rmm::mr::get_current_device_resource_ref());
thrust::uninitialized_fill(
rmm::exec_policy(stream), d_error_check.begin(), d_error_check.end(), 0);
kernel_launcher::exec(input, d_path_data, d_max_path_depth_exceeded, stream);
Expand Down
4 changes: 2 additions & 2 deletions src/main/cpp/src/get_json_object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ std::unique_ptr<cudf::column> get_json_object(
cudf::strings_column_view const& input,
std::vector<std::tuple<path_instruction_type, std::string, int32_t>> const& instructions,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Extract multiple JSON objects from a JSON string based on the specified JSON paths.
Expand All @@ -67,6 +67,6 @@ std::vector<std::unique_ptr<cudf::column>> get_json_object_multiple_paths(
int64_t memory_budget_bytes,
int32_t parallel_override,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

} // namespace spark_rapids_jni
6 changes: 3 additions & 3 deletions src/main/cpp/src/hash.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ std::unique_ptr<cudf::column> murmur_hash3_32(
cudf::table_view const& input,
uint32_t seed = 0,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Computes the xxhash64 hash value of each row in the input set of columns.
Expand All @@ -56,7 +56,7 @@ std::unique_ptr<cudf::column> xxhash64(
cudf::table_view const& input,
int64_t seed = DEFAULT_XXHASH64_SEED,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Computes the Hive hash value of each row in the input set of columns.
Expand All @@ -70,6 +70,6 @@ std::unique_ptr<cudf::column> xxhash64(
std::unique_ptr<cudf::column> hive_hash(
cudf::table_view const& input,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

} // namespace spark_rapids_jni
8 changes: 4 additions & 4 deletions src/main/cpp/src/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ struct percentile_dispatcher {
// We may always have nulls in the output due to either:
// - Having nulls in the input, and/or,
// - Having empty histograms.
auto out_validities =
rmm::device_uvector<int8_t>(num_histograms, stream, rmm::mr::get_current_device_resource());
auto out_validities = rmm::device_uvector<int8_t>(
num_histograms, stream, rmm::mr::get_current_device_resource_ref());

auto const fill_percentile = [&](auto const sorted_validity_it) {
auto const sorted_input_it =
Expand Down Expand Up @@ -307,7 +307,7 @@ std::unique_ptr<cudf::column> create_histogram_if_valid(cudf::column_view const&
}
}

auto const default_mr = rmm::mr::get_current_device_resource();
auto const default_mr = rmm::mr::get_current_device_resource_ref();

// We only check if there is any row in frequencies that are negative (invalid) or zero.
auto check_invalid_and_zero =
Expand Down Expand Up @@ -439,7 +439,7 @@ std::unique_ptr<cudf::column> percentile_from_histogram(cudf::column_view const&
auto const data_col = cudf::structs_column_view{histograms}.get_sliced_child(0);
auto const counts_col = cudf::structs_column_view{histograms}.get_sliced_child(1);

auto const default_mr = rmm::mr::get_current_device_resource();
auto const default_mr = rmm::mr::get_current_device_resource_ref();
auto const d_data = cudf::column_device_view::create(data_col, stream);
auto const d_percentages =
cudf::detail::make_device_uvector_sync(percentages, stream, default_mr);
Expand Down
4 changes: 2 additions & 2 deletions src/main/cpp/src/histogram.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ std::unique_ptr<cudf::column> create_histogram_if_valid(
cudf::column_view const& frequencies,
bool output_as_lists,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Compute percentiles from the given histograms and percentage values.
Expand All @@ -72,6 +72,6 @@ std::unique_ptr<cudf::column> percentile_from_histogram(
std::vector<double> const& percentage,
bool output_as_lists,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

} // namespace spark_rapids_jni
12 changes: 6 additions & 6 deletions src/main/cpp/src/parse_uri.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace spark_rapids_jni {
std::unique_ptr<cudf::column> parse_uri_to_protocol(
cudf::strings_column_view const& input,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Parse host and copy from the input string column to the output string column.
Expand All @@ -51,7 +51,7 @@ std::unique_ptr<cudf::column> parse_uri_to_protocol(
std::unique_ptr<cudf::column> parse_uri_to_host(
cudf::strings_column_view const& input,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Parse query and copy from the input string column to the output string column.
Expand All @@ -64,7 +64,7 @@ std::unique_ptr<cudf::column> parse_uri_to_host(
std::unique_ptr<cudf::column> parse_uri_to_query(
cudf::strings_column_view const& input,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Parse query and copy from the input string column to the output string column.
Expand All @@ -79,7 +79,7 @@ std::unique_ptr<cudf::column> parse_uri_to_query(
cudf::strings_column_view const& input,
std::string const& query_match,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Parse query and copy from the input string column to the output string column.
Expand All @@ -94,7 +94,7 @@ std::unique_ptr<cudf::column> parse_uri_to_query(
cudf::strings_column_view const& input,
cudf::strings_column_view const& query_match,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

/**
* @brief Parse path and copy from the input string column to the output string column.
Expand All @@ -107,6 +107,6 @@ std::unique_ptr<cudf::column> parse_uri_to_query(
std::unique_ptr<cudf::column> parse_uri_to_path(
cudf::strings_column_view const& input,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());

} // namespace spark_rapids_jni
2 changes: 1 addition & 1 deletion src/main/cpp/src/regex_rewrite_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ std::unique_ptr<cudf::column> literal_range_pattern(
int const start,
int const end,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource_ref());
} // namespace spark_rapids_jni
Loading

0 comments on commit ad271f0

Please sign in to comment.