From 7673880b9f8cd019dfe87264a48b337299cc11df Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 1 Sep 2021 00:32:07 -0400 Subject: [PATCH] Fixing remaining hdbscan bug (#4179) Closes #4054 Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/4179 --- cpp/include/cuml/cluster/hdbscan.hpp | 1 - cpp/src/hdbscan/detail/reachability.cuh | 28 +++++------ cpp/src/hdbscan/detail/reachability_faiss.cuh | 26 +++++----- cpp/src/hdbscan/runner.h | 5 +- cpp/test/sg/hdbscan_test.cu | 3 +- docs/source/api.rst | 14 +++--- python/cuml/__init__.py | 1 + python/cuml/cluster/__init__.py | 2 + .../{experimental => }/cluster/hdbscan.pyx | 16 ++---- python/cuml/experimental/cluster/__init__.py | 19 ------- python/cuml/test/test_hdbscan.py | 50 ++++++++++++++++--- python/cuml/test/test_pickle.py | 6 ++- 12 files changed, 92 insertions(+), 79 deletions(-) rename python/cuml/{experimental => }/cluster/hdbscan.pyx (98%) delete mode 100644 python/cuml/experimental/cluster/__init__.py diff --git a/cpp/include/cuml/cluster/hdbscan.hpp b/cpp/include/cuml/cluster/hdbscan.hpp index 2df1d72772..57ab0facd5 100644 --- a/cpp/include/cuml/cluster/hdbscan.hpp +++ b/cpp/include/cuml/cluster/hdbscan.hpp @@ -137,7 +137,6 @@ enum CLUSTER_SELECTION_METHOD { EOM = 0, LEAF = 1 }; class RobustSingleLinkageParams { public: - int k = 5; int min_samples = 5; int min_cluster_size = 5; int max_cluster_size = 0; diff --git a/cpp/src/hdbscan/detail/reachability.cuh b/cpp/src/hdbscan/detail/reachability.cuh index 2449cd4196..4397cb91a2 100644 --- a/cpp/src/hdbscan/detail/reachability.cuh +++ b/cpp/src/hdbscan/detail/reachability.cuh @@ -50,7 +50,6 @@ namespace Reachability { * @tparam value_t data type for distance * @tparam tpb block size for kernel * @param[in] knn_dists knn distance array (size n * k) - * @param[in] k neighborhood size * @param[in] min_samples this neighbor will be selected for core distances * @param[in] n number of samples * @param[out] out output array (size n) @@ -58,7 +57,7 @@ namespace Reachability { */ template void core_distances( - value_t* knn_dists, int k, int min_samples, size_t n, value_t* out, cudaStream_t stream) + value_t* knn_dists, int min_samples, size_t n, value_t* out, cudaStream_t stream) { int blocks = raft::ceildiv(n, (size_t)tpb); @@ -67,7 +66,7 @@ void core_distances( auto indices = thrust::make_counting_iterator(0); thrust::transform(exec_policy, indices, indices + n, out, [=] __device__(value_idx row) { - return knn_dists[row * k + (min_samples - 1)]; + return knn_dists[row * min_samples + (min_samples - 1)]; }); } @@ -118,7 +117,6 @@ void mutual_reachability_graph(const raft::handle_t& handle, size_t m, size_t n, raft::distance::DistanceType metric, - int k, int min_samples, value_t alpha, value_idx* indptr, @@ -139,10 +137,10 @@ void mutual_reachability_graph(const raft::handle_t& handle, // This is temporary. Once faiss is updated, we should be able to // pass value_idx through to knn. - rmm::device_uvector coo_rows(k * m, stream); - rmm::device_uvector int64_indices(k * m, stream); - rmm::device_uvector inds(k * m, stream); - rmm::device_uvector dists(k * m, stream); + rmm::device_uvector coo_rows(min_samples * m, stream); + rmm::device_uvector int64_indices(min_samples * m, stream); + rmm::device_uvector inds(min_samples * m, stream); + rmm::device_uvector dists(min_samples * m, stream); // perform knn brute_force_knn(handle, @@ -153,7 +151,7 @@ void mutual_reachability_graph(const raft::handle_t& handle, m, int64_indices.data(), dists.data(), - k, + min_samples, true, true, metric); @@ -166,24 +164,24 @@ void mutual_reachability_graph(const raft::handle_t& handle, [] __device__(int64_t in) -> value_idx { return in; }); // Slice core distances (distances to kth nearest neighbor) - core_distances(dists.data(), k, min_samples, m, core_dists, stream); + core_distances(dists.data(), min_samples, m, core_dists, stream); /** * Compute L2 norm */ mutual_reachability_knn_l2( - handle, inds.data(), dists.data(), X, m, n, k, core_dists, (value_t)1.0 / alpha); + handle, inds.data(), dists.data(), X, m, n, min_samples, core_dists, (value_t)1.0 / alpha); // self-loops get max distance auto coo_rows_counting_itr = thrust::make_counting_iterator(0); thrust::transform(exec_policy, coo_rows_counting_itr, - coo_rows_counting_itr + (m * k), + coo_rows_counting_itr + (m * min_samples), coo_rows.data(), - [k] __device__(value_idx c) -> value_idx { return c / k; }); + [min_samples] __device__(value_idx c) -> value_idx { return c / min_samples; }); raft::sparse::linalg::symmetrize( - handle, coo_rows.data(), inds.data(), dists.data(), m, m, k * m, out); + handle, coo_rows.data(), inds.data(), dists.data(), m, m, min_samples * m, out); raft::sparse::convert::sorted_coo_to_csr(out.rows(), out.nnz, indptr, m + 1, stream); @@ -205,4 +203,4 @@ void mutual_reachability_graph(const raft::handle_t& handle, }; // end namespace Reachability }; // end namespace detail }; // end namespace HDBSCAN -}; // end namespace ML \ No newline at end of file +}; // end namespace ML diff --git a/cpp/src/hdbscan/detail/reachability_faiss.cuh b/cpp/src/hdbscan/detail/reachability_faiss.cuh index a70d797b97..2caedddf48 100644 --- a/cpp/src/hdbscan/detail/reachability_faiss.cuh +++ b/cpp/src/hdbscan/detail/reachability_faiss.cuh @@ -57,7 +57,8 @@ __global__ void l2SelectMinK(faiss::gpu::Tensor inner_products faiss::gpu::Tensor core_dists, faiss::gpu::Tensor out_dists, faiss::gpu::Tensor out_inds, - int batch_offset, + int batch_offset_i, + int batch_offset_j, int k, value_t initK, value_t alpha) @@ -85,19 +86,19 @@ __global__ void l2SelectMinK(faiss::gpu::Tensor inner_products for (; i < limit; i += blockDim.x) { value_t v = sqrt(faiss::gpu::Math::add( - sq_norms[row + batch_offset], - faiss::gpu::Math::add(sq_norms[i], inner_products[row][i]))); + sq_norms[row + batch_offset_i], + faiss::gpu::Math::add(sq_norms[i + batch_offset_j], inner_products[row][i]))); - v = max(core_dists[i], max(core_dists[row + batch_offset], alpha * v)); + v = max(core_dists[i + batch_offset_j], max(core_dists[row + batch_offset_i], alpha * v)); heap.add(v, i); } if (i < inner_products.getSize(1)) { value_t v = sqrt(faiss::gpu::Math::add( - sq_norms[row + batch_offset], - faiss::gpu::Math::add(sq_norms[i], inner_products[row][i]))); + sq_norms[row + batch_offset_i], + faiss::gpu::Math::add(sq_norms[i + batch_offset_j], inner_products[row][i]))); - v = max(core_dists[i], max(core_dists[row + batch_offset], alpha * v)); + v = max(core_dists[i + batch_offset_j], max(core_dists[row + batch_offset_i], alpha * v)); heap.addThreadQ(v, i); } @@ -127,7 +128,8 @@ void runL2SelectMin(faiss::gpu::Tensor& productDistances, faiss::gpu::Tensor& coreDistances, faiss::gpu::Tensor& outDistances, faiss::gpu::Tensor& outIndices, - int batch_offset, + int batch_offset_i, + int batch_offset_j, int k, value_t alpha, cudaStream_t stream) @@ -149,7 +151,8 @@ void runL2SelectMin(faiss::gpu::Tensor& productDistances, coreDistances, \ outDistances, \ outIndices, \ - batch_offset, \ + batch_offset_i, \ + batch_offset_j, \ k, \ faiss::gpu::Limits::getMax(), \ alpha); \ @@ -323,7 +326,6 @@ void mutual_reachability_knn_l2(const raft::handle_t& handle, auto outIndexView = out_inds_tensor.narrow(0, i, curQuerySize); auto queryView = x_tensor.narrow(0, i, curQuerySize); - norms_tensor.narrow(0, i, curQuerySize); auto outDistanceBufRowView = outDistanceBufs[curStream]->narrow(0, 0, curQuerySize); auto outIndexBufRowView = outIndexBufs[curStream]->narrow(0, 0, curQuerySize); @@ -365,12 +367,11 @@ void mutual_reachability_knn_l2(const raft::handle_t& handle, outDistanceView, outIndexView, i, + j, k, alpha, streams[curStream]); } else { - norms_tensor.narrow(0, j, curCentroidSize); - // Write into our intermediate output runL2SelectMin(distanceBufView, norms_tensor, @@ -378,6 +379,7 @@ void mutual_reachability_knn_l2(const raft::handle_t& handle, outDistanceBufColView, outIndexBufColView, i, + j, k, alpha, streams[curStream]); diff --git a/cpp/src/hdbscan/runner.h b/cpp/src/hdbscan/runner.h index c7600bdcff..c14861e834 100644 --- a/cpp/src/hdbscan/runner.h +++ b/cpp/src/hdbscan/runner.h @@ -125,13 +125,11 @@ void build_linkage(const raft::handle_t& handle, { auto stream = handle.get_stream(); - int k = params.k + 1; - /** * Mutual reachability graph */ rmm::device_uvector mutual_reachability_indptr(m + 1, stream); - raft::sparse::COO mutual_reachability_coo(stream, k * m * 2); + raft::sparse::COO mutual_reachability_coo(stream, params.min_samples * m * 2); rmm::device_uvector core_dists(m, stream); detail::Reachability::mutual_reachability_graph(handle, @@ -139,7 +137,6 @@ void build_linkage(const raft::handle_t& handle, (size_t)m, (size_t)n, metric, - k, params.min_samples, params.alpha, mutual_reachability_indptr.data(), diff --git a/cpp/test/sg/hdbscan_test.cu b/cpp/test/sg/hdbscan_test.cu index 8be7ee3988..a9299cb1d7 100644 --- a/cpp/test/sg/hdbscan_test.cu +++ b/cpp/test/sg/hdbscan_test.cu @@ -94,7 +94,6 @@ class HDBSCANTest : public ::testing::TestWithParam> { mst_weights.data()); HDBSCAN::Common::HDBSCANParams hdbscan_params; - hdbscan_params.k = params.k; hdbscan_params.min_cluster_size = params.min_cluster_size; hdbscan_params.min_samples = params.min_pts; @@ -116,6 +115,7 @@ class HDBSCANTest : public ::testing::TestWithParam> { protected: HDBSCANInputs params; + IdxT* labels_ref; int k; double score; @@ -218,7 +218,6 @@ class ClusterCondensingTest : public ::testing::TestWithParam params; - int k; double score; }; diff --git a/docs/source/api.rst b/docs/source/api.rst index 889c307b3c..6874b72f34 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -347,6 +347,14 @@ Agglomerative Clustering .. autoclass:: cuml.AgglomerativeClustering :members: + +HDBSCAN +------- + +.. autoclass:: cuml.cluster.HDBSCAN + :members: + + Dimensionality Reduction and Manifold Learning ============================================== @@ -563,12 +571,6 @@ Preprocessing add_dummy_feature, binarize, minmax_scale, normalize, PolynomialFeatures, robust_scale, scale -Clustering ----------- - -.. autoclass:: cuml.experimental.cluster.HDBSCAN - :members: - Linear Models ------------- .. autoclass:: cuml.experimental.linear_model.Lars diff --git a/python/cuml/__init__.py b/python/cuml/__init__.py index 4379faa1fa..b68a6922df 100644 --- a/python/cuml/__init__.py +++ b/python/cuml/__init__.py @@ -21,6 +21,7 @@ from cuml.cluster.dbscan import DBSCAN from cuml.cluster.kmeans import KMeans from cuml.cluster.agglomerative import AgglomerativeClustering +from cuml.cluster.hdbscan import HDBSCAN from cuml.datasets.arima import make_arima from cuml.datasets.blobs import make_blobs diff --git a/python/cuml/cluster/__init__.py b/python/cuml/cluster/__init__.py index 40b99974c6..269641c3cc 100644 --- a/python/cuml/cluster/__init__.py +++ b/python/cuml/cluster/__init__.py @@ -17,3 +17,5 @@ from cuml.cluster.dbscan import DBSCAN from cuml.cluster.kmeans import KMeans from cuml.cluster.agglomerative import AgglomerativeClustering +from cuml.cluster.hdbscan import HDBSCAN +from cuml.cluster.hdbscan import condense_hierarchy diff --git a/python/cuml/experimental/cluster/hdbscan.pyx b/python/cuml/cluster/hdbscan.pyx similarity index 98% rename from python/cuml/experimental/cluster/hdbscan.pyx rename to python/cuml/cluster/hdbscan.pyx index 2dd7820833..e5c32d41f6 100644 --- a/python/cuml/experimental/cluster/hdbscan.pyx +++ b/python/cuml/cluster/hdbscan.pyx @@ -72,7 +72,6 @@ cdef extern from "cuml/cluster/hdbscan.hpp" namespace "ML::HDBSCAN::Common": CondensedHierarchy[int, float] &get_condensed_tree() cdef cppclass HDBSCANParams: - int k int min_samples int min_cluster_size int max_cluster_size, @@ -435,12 +434,11 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): handle=None, verbose=False, connectivity='knn', - n_neighbors=10, output_type=None): - super().__init__(handle, - verbose, - output_type) + super().__init__(handle=handle, + verbose=verbose, + output_type=output_type) if min_samples is None: min_samples = min_cluster_size @@ -449,8 +447,8 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): raise ValueError("'connectivity' can only be one of " "{'knn', 'pairwise'}") - if n_neighbors > 1023 or n_neighbors < 2: - raise ValueError("'n_neighbors' must be a positive number " + if 2 < min_samples and min_samples > 1023: + raise ValueError("'min_samples' must be a positive number " "between 2 and 1023") self.min_cluster_size = min_cluster_size @@ -462,7 +460,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): self.alpha = alpha self.cluster_selection_method = cluster_selection_method self.allow_single_cluster = allow_single_cluster - self.n_neighbors = n_neighbors self.connectivity = connectivity self.fit_called_ = False @@ -619,7 +616,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): self.hdbscan_output_ = linkage_output cdef HDBSCANParams params - params.k = self.n_neighbors params.min_samples = self.min_samples # params.alpha = self.alpha params.min_cluster_size = self.min_cluster_size @@ -730,7 +726,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): def get_param_names(self): return super().get_param_names() + [ - "n_neighbors", "metric", "min_cluster_size", "max_cluster_size", @@ -740,7 +735,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): "p", "allow_single_cluster", "connectivity", - "n_neighbors", "alpha", "gen_min_span_tree", ] diff --git a/python/cuml/experimental/cluster/__init__.py b/python/cuml/experimental/cluster/__init__.py deleted file mode 100644 index ccebf13054..0000000000 --- a/python/cuml/experimental/cluster/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# -# Copyright (c) 2021, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from cuml.experimental.cluster.hdbscan import HDBSCAN -from cuml.experimental.cluster.hdbscan import condense_hierarchy diff --git a/python/cuml/test/test_hdbscan.py b/python/cuml/test/test_hdbscan.py index ef9111fe49..6d898333d5 100644 --- a/python/cuml/test/test_hdbscan.py +++ b/python/cuml/test/test_hdbscan.py @@ -16,11 +16,10 @@ import pytest -from cuml.experimental.cluster import HDBSCAN -from cuml.experimental.cluster import condense_hierarchy +from cuml.cluster import HDBSCAN +from cuml.cluster import condense_hierarchy from sklearn.datasets import make_blobs - from cuml.metrics import adjusted_rand_score from cuml.test.utils import get_pattern @@ -29,6 +28,7 @@ from cuml.common import logger import hdbscan +from hdbscan.plots import CondensedTree from sklearn import datasets @@ -159,7 +159,6 @@ def test_hdbscan_blobs(nrows, ncols, nclusters, cuml_agg = HDBSCAN(verbose=logger.level_info, allow_single_cluster=allow_single_cluster, - n_neighbors=min_samples+1, min_samples=min_samples, max_cluster_size=max_cluster_size, min_cluster_size=min_cluster_size, @@ -210,7 +209,6 @@ def test_hdbscan_sklearn_datasets(dataset, cuml_agg = HDBSCAN(verbose=logger.level_info, allow_single_cluster=allow_single_cluster, - n_neighbors=min_samples, gen_min_span_tree=True, min_samples=min_samples, max_cluster_size=max_cluster_size, @@ -263,7 +261,6 @@ def test_hdbscan_sklearn_extract_clusters(dataset, cuml_agg = HDBSCAN(verbose=logger.level_info, allow_single_cluster=allow_single_cluster, - n_neighbors=min_samples, gen_min_span_tree=True, min_samples=min_samples, max_cluster_size=max_cluster_size, @@ -313,7 +310,6 @@ def test_hdbscan_cluster_patterns(dataset, nrows, cuml_agg = HDBSCAN(verbose=logger.level_info, allow_single_cluster=allow_single_cluster, - n_neighbors=min_samples, min_samples=min_samples, max_cluster_size=max_cluster_size, min_cluster_size=min_cluster_size, @@ -367,7 +363,6 @@ def test_hdbscan_cluster_patterns_extract_clusters(dataset, nrows, cuml_agg = HDBSCAN(verbose=logger.level_info, allow_single_cluster=allow_single_cluster, - n_neighbors=min_samples, min_samples=min_samples, max_cluster_size=max_cluster_size, min_cluster_size=min_cluster_size, @@ -393,6 +388,45 @@ def test_hdbscan_cluster_patterns_extract_clusters(dataset, nrows, sk_agg.probabilities_) +def test_hdbscan_core_dists_bug_4054(): + """ + This test explicitly verifies that the MRE from + https://github.com/rapidsai/cuml/issues/4054 + matches the reference impl + """ + + X, y = datasets.make_moons(n_samples=10000, noise=0.12, random_state=0) + + cu_labels_ = HDBSCAN(min_samples=25, min_cluster_size=25).fit_predict(X) + sk_labels_ = hdbscan.HDBSCAN(min_samples=25, + min_cluster_size=25, + approx_min_span_tree=False).fit_predict(X) + + assert adjusted_rand_score(cu_labels_, sk_labels_) > 0.99 + + +def test_hdbscan_empty_cluster_tree(): + + raw_tree = np.recarray(shape=(5,), + formats=[np.intp, np.intp, float, np.intp], + names=('parent', 'child', 'lambda_val', + 'child_size')) + + raw_tree['parent'] = np.asarray([5, 5, 5, 5, 5]) + raw_tree['child'] = [0, 1, 2, 3, 4] + raw_tree['lambda_val'] = [1.0, 1.0, 1.0, 1.0, 1.0] + raw_tree['child_size'] = [1, 1, 1, 1, 1] + + condensed_tree = CondensedTree(raw_tree, 0.0, True) + + cuml_agg = HDBSCAN(allow_single_cluster=True, + cluster_selection_method="eom") + cuml_agg._extract_clusters(condensed_tree) + + # We just care that all points are assigned to the root cluster + assert np.sum(cuml_agg.labels_test.to_output("numpy")) == 0 + + def test_hdbscan_plots(): X, y = make_blobs(int(100), diff --git a/python/cuml/test/test_pickle.py b/python/cuml/test/test_pickle.py index f1da3b6c0d..adf912f6f1 100644 --- a/python/cuml/test/test_pickle.py +++ b/python/cuml/test/test_pickle.py @@ -42,7 +42,8 @@ cluster_config = ClassEnumerator( module=cuml.cluster, exclude_classes=[cuml.DBSCAN, - cuml.AgglomerativeClustering] + cuml.AgglomerativeClustering, + cuml.HDBSCAN] ) cluster_models = cluster_config.get_models() @@ -59,6 +60,8 @@ agglomerative_model = {"AgglomerativeClustering": cuml.AgglomerativeClustering} +hdbscan_model = {"HDBSCAN": cuml.HDBSCAN} + umap_model = {"UMAP": cuml.UMAP} rf_module = ClassEnumerator(module=cuml.ensemble) @@ -98,6 +101,7 @@ **decomposition_models_xfail, **neighbor_models, **dbscan_model, + **hdbscan_model, **agglomerative_model, **umap_model, **rf_models,