Skip to content

Commit

Permalink
Add support for nosync thrust exec policy.
Browse files Browse the repository at this point in the history
  • Loading branch information
abc99lr committed May 6, 2024
1 parent fd64c24 commit c6520fa
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 2 deletions.
5 changes: 3 additions & 2 deletions cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resource/thrust_nosync_policy.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
Expand Down Expand Up @@ -102,7 +103,7 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
handle, mr, make_extents<IdxT>(n_rows));
raft::KeyValuePair<IdxT, MathT> initial_value(0, std::numeric_limits<MathT>::max());
thrust::fill(resource::get_thrust_policy(handle),
thrust::fill(resource::get_thrust_nosync_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
initial_value);
Expand All @@ -128,7 +129,7 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(

// todo(lsugy): use KVP + iterator in caller.
// Copy keys to output labels
thrust::transform(resource::get_thrust_policy(handle),
thrust::transform(resource::get_thrust_nosync_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + n_rows,
labels,
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/core/resource/resource_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ enum resource_type {
STREAM_VIEW, // view of a cuda stream or a placeholder in
// CUDA-free builds
THRUST_POLICY, // thrust execution policy
THRUST_NOSYNC_POLICY, // thrust nosync execution policy
WORKSPACE_RESOURCE, // rmm device memory resource
CUBLASLT_HANDLE, // cublasLt handle
CUSTOM, // runtime-shared default-constructible resource
Expand Down
79 changes: 79 additions & 0 deletions cpp/include/raft/core/resource/thrust_nosync_policy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/resource_types.hpp>
#include <raft/core/resources.hpp>

#include <rmm/exec_policy.hpp>
namespace raft::resource {
class thrust_nosync_policy_resource : public resource {
public:
thrust_nosync_policy_resource(rmm::cuda_stream_view stream_view)
: thrust_nosync_policy_(std::make_unique<rmm::exec_policy_nosync>(stream_view))
{
}
void* get_resource() override { return thrust_nosync_policy_.get(); }

~thrust_nosync_policy_resource() override {}

private:
std::unique_ptr<rmm::exec_policy_nosync> thrust_nosync_policy_;
};

/**
* Factory that knows how to construct a
* specific raft::resource to populate
* the res_t.
*/
class thrust_nosync_policy_resource_factory : public resource_factory {
public:
thrust_nosync_policy_resource_factory(rmm::cuda_stream_view stream_view)
: stream_view_(stream_view)
{
}
resource_type get_resource_type() override { return resource_type::THRUST_NOSYNC_POLICY; }
resource* make_resource() override { return new thrust_nosync_policy_resource(stream_view_); }

private:
rmm::cuda_stream_view stream_view_;
};

/**
* @defgroup resource_thrust_nosync_policy Thrust nosync policy resource functions
* @{
*/

/**
* Load a thrust nosync policy from a res (and populate it on the res if needed).
* @param res raft res object for managing resources
* @return thrust execution policy nosync
*/
inline rmm::exec_policy_nosync& get_thrust_nosync_policy(resources const& res)
{
if (!res.has_resource_factory(resource_type::THRUST_NOSYNC_POLICY)) {
rmm::cuda_stream_view stream = get_cuda_stream(res);
res.add_resource_factory(std::make_shared<thrust_nosync_policy_resource_factory>(stream));
}
return *res.get_resource<rmm::exec_policy_nosync>(resource_type::THRUST_NOSYNC_POLICY);
};

/**
* @}
*/

} // namespace raft::resource

0 comments on commit c6520fa

Please sign in to comment.