Skip to content

Commit

Permalink
Fixing a few compile errors in new APIs (#874)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #874
  • Loading branch information
cjnolet authored Oct 3, 2022
1 parent da9da83 commit e9959ac
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
4 changes: 2 additions & 2 deletions cpp/include/raft/cluster/single_linkage_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ class linkage_output {
}
};

class linkage_output_int_float : public linkage_output<int> {
class linkage_output_int : public linkage_output<int> {
};
class linkage_output__int64_float : public linkage_output<int64_t> {
class linkage_output_int64 : public linkage_output<int64_t> {
};

}; // namespace raft::cluster
7 changes: 5 additions & 2 deletions cpp/include/raft/linalg/axpy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ void axpy(const raft::handle_t& handle,
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
void axpy(const raft::handle_t& handle,
Expand All @@ -79,7 +80,7 @@ void axpy(const raft::handle_t& handle,
const int incx,
const int incy)
{
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input")
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input");

axpy<typename InType::value_type, true>(handle,
y.size(),
Expand All @@ -105,6 +106,8 @@ void axpy(const raft::handle_t& handle,
* @param [in] incy stride between consecutive elements of y
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
void axpy(const raft::handle_t& handle,
Expand All @@ -114,7 +117,7 @@ void axpy(const raft::handle_t& handle,
const int incx,
const int incy)
{
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input")
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input");

axpy<typename InType::value_type, false>(handle,
y.size(),
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/linalg/mean_squared_error.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ namespace linalg {
* @param weight weight to apply to every term in the mean squared error calculation
* @param stream cuda-stream where to launch this kernel
*/
template <typename in_t, typename out_t = in_t, typname idx_t = std::uint32_t, tyint TPB = 256>
template <typename in_t, typename out_t, typename idx_t>
void meanSquaredError(
math_t* out, const math_t* A, const math_t* B, size_t len, math_t weight, cudaStream_t stream)
out_t* out, const in_t* A, const in_t* B, size_t len, in_t weight, cudaStream_t stream)
{
detail::meanSquaredError(out, A, B, len, weight, stream);
}
Expand All @@ -58,7 +58,7 @@ void meanSquaredError(
* @param[out] out the output mean squared error value of type raft::device_scalar_view
* @param[in] weight weight to apply to every term in the mean squared error calculation
*/
template <typename InValueType, typename IndexType, typename OutValueType, int TPB = 256>
template <typename InValueType, typename IndexType, typename OutValueType>
void mean_squared_error(const raft::handle_t& handle,
raft::device_vector_view<const InValueType, IndexType> A,
raft::device_vector_view<const InValueType, IndexType> B,
Expand All @@ -68,7 +68,7 @@ void mean_squared_error(const raft::handle_t& handle,
RAFT_EXPECTS(A.size() == B.size(), "Size mismatch between inputs");

meanSquaredError(
out.data_handle(), A.data_handle(), B.data_handle(), A.extent(0), weight, stream);
out.data_handle(), A.data_handle(), B.data_handle(), A.extent(0), weight, handle.get_stream());
}

/** @} */ // end of group mean_squared_error
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/sparse/hierarchy/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

namespace raft::hierarchy {
using raft::cluster::linkage_output;
using raft::cluster::linkage_output__int64_float;
using raft::cluster::linkage_output_int_float;
using raft::cluster::linkage_output_int;
using raft::cluster::linkage_output_int64;
using raft::cluster::LinkageDistance;
} // namespace raft::hierarchy

0 comments on commit e9959ac

Please sign in to comment.