Skip to content

Commit

Permalink
Provide explicit pool size for pool_memory_resources and clean up inc…
Browse files Browse the repository at this point in the history
…ludes (#2088)

This PR fixes up RAFT to avoid usage that will soon be deprecated in RMM.

Depends on rapidsai/rmm#1417

Fixes #2087

Authors:
  - Mark Harris (https://github.com/harrism)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2088
  • Loading branch information
harrism authored Jan 17, 2024
1 parent b75fbd0 commit dbc33ea
Show file tree
Hide file tree
Showing 18 changed files with 118 additions and 85 deletions.
15 changes: 10 additions & 5 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,15 +14,20 @@
* limitations under the License.
*/
#pragma once
#include <cassert>
#include <fstream>
#include <iostream>
#include <memory>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/logger.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <cassert>
#include <fstream>
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
Expand Down
14 changes: 7 additions & 7 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@
*/

#include "../common/ann_types.hpp"

#include "raft_ann_bench_param_parser.h"

#include <raft/core/logger.hpp>

#include <rmm/mr/device/per_device_resource.hpp>

#define JSON_DIAGNOSTICS 1
#include <nlohmann/json.hpp>

#include <algorithm>
#include <cmath>
#include <memory>
#include <raft/core/logger.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <utility>

#define JSON_DIAGNOSTICS 1
#include <nlohmann/json.hpp>

namespace raft::bench::ann {

template <typename T>
Expand Down
5 changes: 4 additions & 1 deletion cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "raft_ann_bench_param_parser.h"
#include "raft_cagra_hnswlib_wrapper.h"

#include <rmm/cuda_device.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#define JSON_DIAGNOSTICS 1
Expand Down Expand Up @@ -85,7 +86,9 @@ int main(int argc, char** argv)
{
rmm::mr::cuda_memory_resource cuda_mr;
// Construct a resource that uses a coalescing best-fit pool allocator
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{&cuda_mr};
// and is initially sized to half of free device memory.
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{
&cuda_mr, rmm::percent_of_free_device_memory(50)};
rmm::mr::set_current_device_resource(
&pool_mr); // Updates the current device resource pointer to `pool_mr`
rmm::mr::device_memory_resource* mr =
Expand Down
20 changes: 9 additions & 11 deletions cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,10 +15,9 @@
*/
#pragma once

#include <cassert>
#include <fstream>
#include <iostream>
#include <memory>
#include "../common/ann_types.hpp"
#include "raft_ann_bench_utils.h"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/logger.hpp>
Expand All @@ -28,16 +27,15 @@
#include <raft/neighbors/ivf_flat.cuh>
#include <raft/neighbors/ivf_flat_types.hpp>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <cassert>
#include <fstream>
#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>
#include <type_traits>

#include "../common/ann_types.hpp"
#include "raft_ann_bench_utils.h"
#include <raft/util/cudart_utils.hpp>

namespace raft::bench::ann {

template <typename T, typename IdxT>
Expand Down
12 changes: 6 additions & 6 deletions cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,9 @@
*/
#pragma once

#include "../common/ann_types.hpp"
#include "raft_ann_bench_utils.h"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
Expand All @@ -28,14 +31,11 @@
#include <raft/util/cudart_utils.hpp>
#include <raft_runtime/neighbors/ivf_pq.hpp>
#include <raft_runtime/neighbors/refine.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <type_traits>

#include "../common/ann_types.hpp"
#include "raft_ann_bench_utils.h"
#include <raft/util/cudart_utils.hpp>
#include <type_traits>

namespace raft::bench::ann {

Expand Down
12 changes: 8 additions & 4 deletions cpp/bench/prims/common/benchmark.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,7 @@

#include <benchmark/benchmark.h>

#include <rmm/cuda_device.hpp>
#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_buffer.hpp>
Expand All @@ -43,7 +44,7 @@ namespace raft::bench {
struct using_pool_memory_res {
private:
rmm::mr::device_memory_resource* orig_res_;
rmm::mr::cuda_memory_resource cuda_res_;
rmm::mr::cuda_memory_resource cuda_res_{};
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_res_;

public:
Expand All @@ -54,7 +55,9 @@ struct using_pool_memory_res {
rmm::mr::set_current_device_resource(&pool_res_);
}

using_pool_memory_res() : orig_res_(rmm::mr::get_current_device_resource()), pool_res_(&cuda_res_)
using_pool_memory_res()
: orig_res_(rmm::mr::get_current_device_resource()),
pool_res_(&cuda_res_, rmm::percent_of_free_device_memory(50))
{
rmm::mr::set_current_device_resource(&pool_res_);
}
Expand Down Expand Up @@ -114,7 +117,8 @@ class fixture {
raft::device_resources handle;
rmm::cuda_stream_view stream;

fixture(bool use_pool_memory_resource = false) : stream{resource::get_cuda_stream(handle)}
explicit fixture(bool use_pool_memory_resource = false)
: stream{resource::get_cuda_stream(handle)}
{
// Cache memory pool between test runs, since it is expensive to create.
// This speeds up the time required to run the select_k bench by over 3x.
Expand Down
1 change: 0 additions & 1 deletion cpp/bench/prims/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <cstdint>
#include <cstring>
Expand Down
6 changes: 4 additions & 2 deletions cpp/bench/prims/neighbors/refine.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,6 +27,7 @@
#include <raft/neighbors/refine.cuh>
#include <raft/random/rng.cuh>

#include <rmm/cuda_device.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
Expand Down Expand Up @@ -58,7 +59,8 @@ class RefineAnn : public fixture {
state.SetLabel(label_stream.str());

auto old_mr = rmm::mr::get_current_device_resource();
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr(old_mr);
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr(
old_mr, rmm::percent_of_free_device_memory(50));
rmm::mr::set_current_device_resource(&pool_mr);

if (data.p.host_data) {
Expand Down
7 changes: 5 additions & 2 deletions cpp/include/raft/core/device_resources_manager.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@
#include <optional>
#include <raft/core/device_resources.hpp>
#include <raft/core/device_setter.hpp>
#include <rmm/cuda_device.hpp>
#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include <rmm/mr/device/cuda_memory_resource.hpp>
Expand Down Expand Up @@ -170,7 +171,9 @@ struct device_resources_manager {
if (upstream != nullptr) {
result =
std::make_shared<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>(
upstream, params.init_mem_pool_size, params.max_mem_pool_size);
upstream,
params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)),
params.max_mem_pool_size);
rmm::mr::set_current_device_resource(result.get());
} else {
RAFT_LOG_WARN(
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/neighbors/ivf_flat-inl.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,9 +21,9 @@
#include <raft/neighbors/ivf_flat_serialize.cuh>
#include <raft/neighbors/ivf_flat_types.hpp>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>

#include <raft/core/device_mdspan.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

Expand Down
7 changes: 4 additions & 3 deletions cpp/include/raft/util/memory_pool-ext.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,10 +15,11 @@
*/

#pragma once
#include <cstddef> // size_t
#include <memory> // std::unique_ptr
#include <rmm/mr/device/device_memory_resource.hpp> // rmm::mr::device_memory_resource

#include <cstddef> // size_t
#include <memory> // std::unique_ptr

namespace raft {

std::unique_ptr<rmm::mr::device_memory_resource> get_pool_memory_resource(
Expand Down
16 changes: 10 additions & 6 deletions cpp/include/raft/util/memory_pool-inl.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,14 +15,17 @@
*/

#pragma once
#include <cstddef>
#include <memory>

#include <raft/core/detail/macros.hpp> // RAFT_INLINE_CONDITIONAL

#include <rmm/aligned.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <cstddef>
#include <memory>

namespace raft {

/**
Expand Down Expand Up @@ -65,14 +68,15 @@ RAFT_INLINE_CONDITIONAL std::unique_ptr<rmm::mr::device_memory_resource> get_poo
rmm::mr::device_memory_resource*& mr, size_t initial_size)
{
using pool_res_t = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;
std::unique_ptr<pool_res_t> pool_res{};
std::unique_ptr<pool_res_t> pool_res{nullptr};
if (mr) return pool_res;
mr = rmm::mr::get_current_device_resource();
if (!dynamic_cast<pool_res_t*>(mr) &&
!dynamic_cast<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>*>(mr) &&
!dynamic_cast<rmm::mr::pool_memory_resource<rmm::mr::managed_memory_resource>*>(mr)) {
pool_res = std::make_unique<pool_res_t>(mr, (initial_size + 255) & (~255));
mr = pool_res.get();
pool_res = std::make_unique<pool_res_t>(
mr, rmm::align_down(initial_size, rmm::CUDA_ALLOCATION_ALIGNMENT));
mr = pool_res.get();
}
return pool_res;
}
Expand Down
7 changes: 4 additions & 3 deletions cpp/template/src/cagra_example.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,8 @@
* limitations under the License.
*/

#include <cstdint>
#include "common.cuh"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/neighbors/cagra.cuh>
Expand All @@ -23,7 +24,7 @@
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include "common.cuh"
#include <cstdint>

void cagra_build_search_simple(raft::device_resources const& dev_resources,
raft::device_matrix_view<const float, int64_t> dataset,
Expand Down
3 changes: 2 additions & 1 deletion cpp/template/src/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* limitations under the License.
*/

#include <cstdint>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdarray.hpp>
Expand All @@ -28,6 +27,8 @@
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>

#include <cstdint>

// Fill dataset and queries with synthetic data.
void generate_dataset(raft::device_resources const& dev_resources,
raft::device_matrix_view<float, int64_t> dataset,
Expand Down
Loading

0 comments on commit dbc33ea

Please sign in to comment.