From 403e955a87b93da656b1a2c50de08f3a858f3100 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 17 Nov 2022 10:22:49 -0800 Subject: [PATCH 1/5] Expose cluster_cost to python Add cython bindings for the cluster_cost function, to allow computing inertia from python. Closes https://github.com/rapidsai/raft/issues/972 --- cpp/CMakeLists.txt | 2 + cpp/include/raft_distance/kmeans.hpp | 17 ++- cpp/src/distance/cluster_cost.cuh | 79 +++++++++++ cpp/src/distance/cluster_cost_double.cu | 34 +++++ cpp/src/distance/cluster_cost_float.cu | 34 +++++ python/pylibraft/pylibraft/cluster/kmeans.pyx | 123 +++++++++++++----- python/pylibraft/pylibraft/cpp/__init__.pxd | 0 python/pylibraft/pylibraft/cpp/__init__.py | 0 python/pylibraft/pylibraft/cpp/kmeans.pxd | 67 ++++++++++ .../pylibraft/pylibraft/test/test_kmeans.py | 18 ++- 10 files changed, 343 insertions(+), 31 deletions(-) create mode 100644 cpp/src/distance/cluster_cost.cuh create mode 100644 cpp/src/distance/cluster_cost_double.cu create mode 100644 cpp/src/distance/cluster_cost_float.cu create mode 100644 python/pylibraft/pylibraft/cpp/__init__.pxd create mode 100644 python/pylibraft/pylibraft/cpp/__init__.py create mode 100644 python/pylibraft/pylibraft/cpp/kmeans.pxd diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 94e693f861..3e692a3b4a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -286,6 +286,8 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/fused_l2_min_arg.cu src/distance/update_centroids_float.cu src/distance/update_centroids_double.cu + src/distance/cluster_cost_float.cu + src/distance/cluster_cost_double.cu src/distance/specializations/detail/canberra.cu src/distance/specializations/detail/chebyshev.cu src/distance/specializations/detail/correlation.cu diff --git a/cpp/include/raft_distance/kmeans.hpp b/cpp/include/raft_distance/kmeans.hpp index 19f92dd977..a56021b110 100644 --- a/cpp/include/raft_distance/kmeans.hpp +++ b/cpp/include/raft_distance/kmeans.hpp @@ -41,4 +41,19 @@ void update_centroids(raft::handle_t const& handle, double* new_centroids, double* weight_per_cluster); -} // namespace raft::cluster::kmeans::runtime \ No newline at end of file +void cluster_cost(raft::handle_t const& handle, + const float* X, + int n_samples, + int n_features, + int n_clusters, + const float* centroids, + float* cost); + +void cluster_cost(raft::handle_t const& handle, + const double* X, + int n_samples, + int n_features, + int n_clusters, + const double* centroids, + double* cost); +} // namespace raft::cluster::kmeans::runtime diff --git a/cpp/src/distance/cluster_cost.cuh b/cpp/src/distance/cluster_cost.cuh new file mode 100644 index 0000000000..907e34d015 --- /dev/null +++ b/cpp/src/distance/cluster_cost.cuh @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +namespace raft::cluster::kmeans::runtime { +template +void cluster_cost(const raft::handle_t& handle, + const ElementType* X, + IndexType n_samples, + IndexType n_features, + IndexType n_clusters, + const ElementType* centroids, + ElementType* cost) +{ + rmm::device_uvector workspace(n_samples * sizeof(IndexType), handle.get_stream()); + + rmm::device_uvector x_norms(n_samples, handle.get_stream()); + rmm::device_uvector centroid_norms(n_clusters, handle.get_stream()); + raft::linalg::rowNorm( + x_norms.data(), X, n_samples, n_features, raft::linalg::L2Norm, true, handle.get_stream()); + raft::linalg::rowNorm(centroid_norms.data(), + centroids, + n_clusters, + n_features, + raft::linalg::L2Norm, + true, + handle.get_stream()); + + auto min_cluster_distance = + raft::make_device_vector>(handle, n_samples); + raft::distance::fusedL2NNMinReduce(min_cluster_distance.data_handle(), + X, + centroids, + x_norms.data(), + centroid_norms.data(), + n_samples, + n_features, + n_clusters, + (void*)workspace.data(), + true, + true, + handle.get_stream()); + + auto distances = raft::make_device_vector(handle, n_samples); + thrust::transform( + handle.get_thrust_policy(), + min_cluster_distance.data_handle(), + min_cluster_distance.data_handle() + n_samples, + distances.data_handle(), + [] __device__(const raft::KeyValuePair& a) { return a.value; }); + + rmm::device_scalar device_cost(0, handle.get_stream()); + raft::cluster::kmeans::cluster_cost( + handle, + distances.view(), + workspace, + make_device_scalar_view(device_cost.data()), + [] __device__(const ElementType& a, const ElementType& b) { return a + b; }); + + raft::update_host(cost, device_cost.data(), 1, handle.get_stream()); +} +} // namespace raft::cluster::kmeans::runtime diff --git a/cpp/src/distance/cluster_cost_double.cu b/cpp/src/distance/cluster_cost_double.cu new file mode 100644 index 0000000000..b811b0bf8d --- /dev/null +++ b/cpp/src/distance/cluster_cost_double.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cluster_cost.cuh" +#include +#include +#include + +namespace raft::cluster::kmeans::runtime { + +void cluster_cost(const raft::handle_t& handle, + const double* X, + int n_samples, + int n_features, + int n_clusters, + const double* centroids, + double* cost) +{ + cluster_cost(handle, X, n_samples, n_features, n_clusters, centroids, cost); +} +} // namespace raft::cluster::kmeans::runtime diff --git a/cpp/src/distance/cluster_cost_float.cu b/cpp/src/distance/cluster_cost_float.cu new file mode 100644 index 0000000000..d78ea446da --- /dev/null +++ b/cpp/src/distance/cluster_cost_float.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cluster_cost.cuh" +#include +#include +#include + +namespace raft::cluster::kmeans::runtime { + +void cluster_cost(const raft::handle_t& handle, + const float* X, + int n_samples, + int n_features, + int n_clusters, + const float* centroids, + float* cost) +{ + cluster_cost(handle, X, n_samples, n_features, n_clusters, centroids, cost); +} +} // namespace raft::cluster::kmeans::runtime diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index 732a78585d..679523cef4 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -26,11 +26,17 @@ from libcpp cimport bool, nullptr from pylibraft.common import Handle from pylibraft.common.handle import auto_sync_handle + from pylibraft.common.handle cimport handle_t from pylibraft.common.input_validation import * from pylibraft.distance import DISTANCE_TYPES +from pylibraft.cpp.kmeans cimport ( + cluster_cost as cpp_cluster_cost, + update_centroids, +) + def is_c_cont(cai, dt): return "strides" not in cai or \ @@ -38,34 +44,6 @@ def is_c_cont(cai, dt): cai["strides"][1] == dt.itemsize -cdef extern from "raft_distance/kmeans.hpp" \ - namespace "raft::cluster::kmeans::runtime": - - cdef void update_centroids( - const handle_t& handle, - const double *X, - int n_samples, - int n_features, - int n_clusters, - const double *sample_weights, - const double *centroids, - const int* labels, - double *new_centroids, - double *weight_per_cluster) except + - - cdef void update_centroids( - const handle_t& handle, - const float *X, - int n_samples, - int n_features, - int n_clusters, - const float *sample_weights, - const float *centroids, - const int* labels, - float *new_centroids, - float *weight_per_cluster) except + - - @auto_sync_handle def compute_new_centroids(X, centroids, @@ -109,7 +87,6 @@ def compute_new_centroids(X, from pylibraft.common import Handle from pylibraft.cluster.kmeans import compute_new_centroids - from pylibraft.distance import fused_l2_nn_argmin # A single RAFT handle can optionally be reused across # pylibraft functions. @@ -220,3 +197,91 @@ def compute_new_centroids(X, weight_per_cluster_ptr) else: raise ValueError("dtype %s not supported" % x_dt) + + +@auto_sync_handle +def cluster_cost(X, centroids, handle=None): + """ + Compute cluster cost given an input matrix and existing centroids + + Parameters + ---------- + X : Input CUDA array interface compliant matrix shape (m, k) + centroids : Input CUDA array interface compliant matrix shape + (n_clusters, k) + {handle_docstring} + + Examples + -------- + + .. code-block:: python + import cupy as cp + + from pylibraft.cluster.kmeans import cluster_cost + + n_samples = 5000 + n_features = 50 + n_clusters = 3 + + X = cp.random.random_sample((n_samples, n_features), + dtype=cp.float32) + + centroids = cp.random.random_sample((n_clusters, n_features), + dtype=cp.float32) + + inertia = cluster_cost(X, centroids) + """ + x_cai = X.__cuda_array_interface__ + centroids_cai = centroids.__cuda_array_interface__ + + m = x_cai["shape"][0] + x_k = x_cai["shape"][1] + n_clusters = centroids_cai["shape"][0] + + centroids_k = centroids_cai["shape"][1] + + x_dt = np.dtype(x_cai["typestr"]) + centroids_dt = np.dtype(centroids_cai["typestr"]) + + if not do_cols_match(X, centroids): + raise ValueError("X and centroids must have same number of columns.") + + x_ptr = x_cai["data"][0] + centroids_ptr = centroids_cai["data"][0] + + handle = handle if handle is not None else Handle() + cdef handle_t *h = handle.getHandle() + + x_c_contiguous = is_c_cont(x_cai, x_dt) + centroids_c_contiguous = is_c_cont(centroids_cai, centroids_dt) + + if not x_c_contiguous or not centroids_c_contiguous: + raise ValueError("Inputs must all be c contiguous") + + if not do_dtypes_match(X, centroids): + raise ValueError("Inputs must all have the same dtypes " + "(float32 or float64)") + + cdef float f_cost = 0 + cdef double d_cost = 0 + + if x_dt == np.float32: + cpp_cluster_cost(deref(h), + x_ptr, + m, + x_k, + n_clusters, + centroids_ptr, + &f_cost) + return f_cost + elif x_dt == np.float64: + cpp_cluster_cost(deref(h), + x_ptr, + m, + x_k, + n_clusters, + centroids_ptr, + &d_cost) + return d_cost + else: + raise ValueError("dtype %s not supported" % x_dt) diff --git a/python/pylibraft/pylibraft/cpp/__init__.pxd b/python/pylibraft/pylibraft/cpp/__init__.pxd new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/cpp/__init__.py b/python/pylibraft/pylibraft/cpp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/cpp/kmeans.pxd b/python/pylibraft/pylibraft/cpp/kmeans.pxd new file mode 100644 index 0000000000..b263952522 --- /dev/null +++ b/python/pylibraft/pylibraft/cpp/kmeans.pxd @@ -0,0 +1,67 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from pylibraft.common.handle cimport handle_t + + +cdef extern from "raft_distance/kmeans.hpp" \ + namespace "raft::cluster::kmeans::runtime": + + cdef void update_centroids( + const handle_t& handle, + const double *X, + int n_samples, + int n_features, + int n_clusters, + const double *sample_weights, + const double *centroids, + const int* labels, + double *new_centroids, + double *weight_per_cluster) except + + + cdef void update_centroids( + const handle_t& handle, + const float *X, + int n_samples, + int n_features, + int n_clusters, + const float *sample_weights, + const float *centroids, + const int* labels, + float *new_centroids, + float *weight_per_cluster) except + + + cdef void cluster_cost( + const handle_t& handle, + const float* X, + int n_samples, + int n_features, + int n_clusters, + const float * centroids, + float * cost) except + + + cdef void cluster_cost( + const handle_t& handle, + const double* X, + int n_samples, + int n_features, + int n_clusters, + const double * centroids, + double * cost) except + diff --git a/python/pylibraft/pylibraft/test/test_kmeans.py b/python/pylibraft/pylibraft/test/test_kmeans.py index 58028e90e8..2088ac529e 100644 --- a/python/pylibraft/pylibraft/test/test_kmeans.py +++ b/python/pylibraft/pylibraft/test/test_kmeans.py @@ -16,7 +16,7 @@ import numpy as np import pytest -from pylibraft.cluster.kmeans import compute_new_centroids +from pylibraft.cluster.kmeans import cluster_cost, compute_new_centroids from pylibraft.common import Handle, device_ndarray from pylibraft.distance import pairwise_distance @@ -88,3 +88,19 @@ def test_compute_new_centroids( actual_centers = new_centroids_device.copy_to_host() assert np.allclose(expected_centers, actual_centers, rtol=1e-6) + + +@pytest.mark.parametrize("n_rows", [100]) +@pytest.mark.parametrize("n_cols", [5, 25]) +@pytest.mark.parametrize("n_clusters", [5, 15]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): + + X = np.random.random_sample((n_rows, n_cols)).astype(dtype) + X_device = device_ndarray(X) + + centroids = X[:n_clusters] + centroids_device = device_ndarray(centroids) + + # TODO: compute inertia naively, make sure is close + inertia = cluster_cost(X_device, centroids_device) # noqa From 89863ef3ff0fdcfeaacfccd05f06bcb41d7b36dd Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 18 Nov 2022 10:34:29 -0800 Subject: [PATCH 2/5] Add pytest for inertia --- cpp/src/distance/cluster_cost.cuh | 6 ++-- .../pylibraft/pylibraft/test/test_kmeans.py | 28 +++++++++++++++---- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/cpp/src/distance/cluster_cost.cuh b/cpp/src/distance/cluster_cost.cuh index 907e34d015..7b9fd049e0 100644 --- a/cpp/src/distance/cluster_cost.cuh +++ b/cpp/src/distance/cluster_cost.cuh @@ -34,11 +34,11 @@ void cluster_cost(const raft::handle_t& handle, rmm::device_uvector x_norms(n_samples, handle.get_stream()); rmm::device_uvector centroid_norms(n_clusters, handle.get_stream()); raft::linalg::rowNorm( - x_norms.data(), X, n_samples, n_features, raft::linalg::L2Norm, true, handle.get_stream()); + x_norms.data(), X, n_features, n_samples, raft::linalg::L2Norm, true, handle.get_stream()); raft::linalg::rowNorm(centroid_norms.data(), centroids, - n_clusters, n_features, + n_clusters, raft::linalg::L2Norm, true, handle.get_stream()); @@ -51,8 +51,8 @@ void cluster_cost(const raft::handle_t& handle, x_norms.data(), centroid_norms.data(), n_samples, - n_features, n_clusters, + n_features, (void*)workspace.data(), true, true, diff --git a/python/pylibraft/pylibraft/test/test_kmeans.py b/python/pylibraft/pylibraft/test/test_kmeans.py index 2088ac529e..faca219a32 100644 --- a/python/pylibraft/pylibraft/test/test_kmeans.py +++ b/python/pylibraft/pylibraft/test/test_kmeans.py @@ -18,7 +18,7 @@ from pylibraft.cluster.kmeans import cluster_cost, compute_new_centroids from pylibraft.common import Handle, device_ndarray -from pylibraft.distance import pairwise_distance +from pylibraft.distance import fused_l2_nn_argmin, pairwise_distance @pytest.mark.parametrize("n_rows", [100]) @@ -90,17 +90,35 @@ def test_compute_new_centroids( assert np.allclose(expected_centers, actual_centers, rtol=1e-6) -@pytest.mark.parametrize("n_rows", [100]) +@pytest.mark.parametrize("n_rows", [8]) @pytest.mark.parametrize("n_cols", [5, 25]) -@pytest.mark.parametrize("n_clusters", [5, 15]) +@pytest.mark.parametrize("n_clusters", [4, 15]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): - X = np.random.random_sample((n_rows, n_cols)).astype(dtype) X_device = device_ndarray(X) centroids = X[:n_clusters] centroids_device = device_ndarray(centroids) - # TODO: compute inertia naively, make sure is close inertia = cluster_cost(X_device, centroids_device) # noqa + + # compute the nearest centroid to each sample + distances = pairwise_distance(X_device, centroids_device).copy_to_host() + cluster_ids = np.argmin(distances, axis=1) + + # TODO: the cluster_ids above don't match whats being computed in + # cluster_cost: https://github.com/rapidsai/raft/issues/1036 + # use the same fused_l2_nn_argmin implementation as used internally + # so we get the same inertia result here + cluster_ids = fused_l2_nn_argmin( + X_device, centroids_device, sqrt=True + ).copy_to_host() + + cluster_distances = np.take_along_axis( + distances, cluster_ids[:, None], axis=1 + ) + + # need reduced tolerance for float32 + rtol = 1e-3 if dtype == np.float32 else 1e-6 + assert np.allclose(inertia, sum(cluster_distances), rtol=rtol) From 05fab11f8f83ef6463ac9d20dff5a81fd79f5eb6 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 18 Nov 2022 10:38:53 -0800 Subject: [PATCH 3/5] remove errant noqa statement --- python/pylibraft/pylibraft/test/test_kmeans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pylibraft/pylibraft/test/test_kmeans.py b/python/pylibraft/pylibraft/test/test_kmeans.py index faca219a32..b8d6e1f712 100644 --- a/python/pylibraft/pylibraft/test/test_kmeans.py +++ b/python/pylibraft/pylibraft/test/test_kmeans.py @@ -101,7 +101,7 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): centroids = X[:n_clusters] centroids_device = device_ndarray(centroids) - inertia = cluster_cost(X_device, centroids_device) # noqa + inertia = cluster_cost(X_device, centroids_device) # compute the nearest centroid to each sample distances = pairwise_distance(X_device, centroids_device).copy_to_host() From aed934d5267a2933b2e837d42e1189d9e3fb9e11 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 18 Nov 2022 15:56:00 -0800 Subject: [PATCH 4/5] Fix inertia calculation inertia is sum of squared distances, we were doing sum of euclidean distances --- cpp/src/distance/cluster_cost.cuh | 2 +- python/pylibraft/pylibraft/test/test_kmeans.py | 18 ++++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/cpp/src/distance/cluster_cost.cuh b/cpp/src/distance/cluster_cost.cuh index 7b9fd049e0..344673830b 100644 --- a/cpp/src/distance/cluster_cost.cuh +++ b/cpp/src/distance/cluster_cost.cuh @@ -54,7 +54,7 @@ void cluster_cost(const raft::handle_t& handle, n_clusters, n_features, (void*)workspace.data(), - true, + false, true, handle.get_stream()); diff --git a/python/pylibraft/pylibraft/test/test_kmeans.py b/python/pylibraft/pylibraft/test/test_kmeans.py index b8d6e1f712..64632e15a7 100644 --- a/python/pylibraft/pylibraft/test/test_kmeans.py +++ b/python/pylibraft/pylibraft/test/test_kmeans.py @@ -18,7 +18,7 @@ from pylibraft.cluster.kmeans import cluster_cost, compute_new_centroids from pylibraft.common import Handle, device_ndarray -from pylibraft.distance import fused_l2_nn_argmin, pairwise_distance +from pylibraft.distance import pairwise_distance @pytest.mark.parametrize("n_rows", [100]) @@ -104,21 +104,15 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): inertia = cluster_cost(X_device, centroids_device) # compute the nearest centroid to each sample - distances = pairwise_distance(X_device, centroids_device).copy_to_host() - cluster_ids = np.argmin(distances, axis=1) - - # TODO: the cluster_ids above don't match whats being computed in - # cluster_cost: https://github.com/rapidsai/raft/issues/1036 - # use the same fused_l2_nn_argmin implementation as used internally - # so we get the same inertia result here - cluster_ids = fused_l2_nn_argmin( - X_device, centroids_device, sqrt=True + distances = pairwise_distance( + X_device, centroids_device, metric="sqeuclidean" ).copy_to_host() + cluster_ids = np.argmin(distances, axis=1) cluster_distances = np.take_along_axis( distances, cluster_ids[:, None], axis=1 ) # need reduced tolerance for float32 - rtol = 1e-3 if dtype == np.float32 else 1e-6 - assert np.allclose(inertia, sum(cluster_distances), rtol=rtol) + tol = 1e-3 if dtype == np.float32 else 1e-6 + assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol) From 3b503728975ed524ec65632b4d7dbc694a749ecc Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 18 Nov 2022 16:19:41 -0800 Subject: [PATCH 5/5] increase # of rows in test --- python/pylibraft/pylibraft/test/test_kmeans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pylibraft/pylibraft/test/test_kmeans.py b/python/pylibraft/pylibraft/test/test_kmeans.py index 64632e15a7..44f60be310 100644 --- a/python/pylibraft/pylibraft/test/test_kmeans.py +++ b/python/pylibraft/pylibraft/test/test_kmeans.py @@ -90,7 +90,7 @@ def test_compute_new_centroids( assert np.allclose(expected_centers, actual_centers, rtol=1e-6) -@pytest.mark.parametrize("n_rows", [8]) +@pytest.mark.parametrize("n_rows", [100]) @pytest.mark.parametrize("n_cols", [5, 25]) @pytest.mark.parametrize("n_clusters", [4, 15]) @pytest.mark.parametrize("dtype", [np.float32, np.float64])