From 3a0cc673b65b513f0ab23cd890262234ab450236 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Fri, 31 May 2024 17:36:30 +0000 Subject: [PATCH 01/42] enable umap nndescent --- cpp/include/cuml/manifold/umapparams.h | 13 +++ cpp/src/umap/knn_graph/algo.cuh | 122 +++++++++++++++++-------- python/cuml/manifold/umap.pyx | 29 +++++- python/cuml/manifold/umap_utils.pxd | 16 +++- 4 files changed, 139 insertions(+), 41 deletions(-) diff --git a/cpp/include/cuml/manifold/umapparams.h b/cpp/include/cuml/manifold/umapparams.h index 227f46982e..45a13a1ff1 100644 --- a/cpp/include/cuml/manifold/umapparams.h +++ b/cpp/include/cuml/manifold/umapparams.h @@ -20,12 +20,18 @@ #include #include +#include + +// namespace NNDescent = raft::neighbors::experimental::nn_descent; namespace ML { +using nn_index_params = raft::neighbors::experimental::nn_descent::index_params; + class UMAPParams { public: enum MetricType { EUCLIDEAN, CATEGORICAL }; + enum graph_build_algo {BRUTE_FORCE_KNN, NN_DESCENT}; /** * The number of neighbors to use to approximate geodesic distance. @@ -140,6 +146,13 @@ class UMAPParams { */ int init = 1; + /** + * KNN graph build algorithm + */ + graph_build_algo build_algo = graph_build_algo::BRUTE_FORCE_KNN; + + nn_index_params nn_descent_params = {}; + /** * The number of nearest neighbors to use to construct the target simplicial * set. If set to -1, use the n_neighbors value. diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index ab474122d8..53e09d720f 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -22,14 +22,20 @@ #include #include +#include #include #include #include #include #include +#include +#include + #include +namespace NNDescent = raft::neighbors::experimental::nn_descent; + namespace UMAPAlgo { namespace kNNGraph { namespace Algo { @@ -57,25 +63,45 @@ inline void launcher(const raft::handle_t& handle, const ML::UMAPParams* params, cudaStream_t stream) { - std::vector ptrs(1); - std::vector sizes(1); - ptrs[0] = inputsA.X; - sizes[0] = inputsA.n; - - raft::spatial::knn::brute_force_knn(handle, - ptrs, - sizes, - inputsA.d, - inputsB.X, - inputsB.n, - out.knn_indices, - out.knn_dists, - n_neighbors, - true, - true, - static_cast*>(nullptr), - params->metric, - params->p); + + if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) { + std::vector ptrs(1); + std::vector sizes(1); + ptrs[0] = inputsA.X; + sizes[0] = inputsA.n; + + raft::spatial::knn::brute_force_knn(handle, + ptrs, + sizes, + inputsA.d, + inputsB.X, + inputsB.n, + out.knn_indices, + out.knn_dists, + n_neighbors, + true, + true, + static_cast*>(nullptr), + params->metric, + params->p); + } else { // nn_descent + // number of columns (n_neightbors) should be smaller than the graph degree computed by nn descent + assert(n_neighbors <= params->nn_descent_params.graph_degree); + auto dataset = raft::make_host_matrix_view(inputsA.X, inputsA.n, inputsA.d); + auto graph = NNDescent::detail::build(handle, params->nn_descent_params, dataset); + + for(int i = 0; i < inputsA.n; i++) { + raft::copy(out.knn_dists + i * n_neighbors, + graph.distances().data_handle() + i * params->nn_descent_params.graph_degree, + n_neighbors, + handle.get_stream()); + raft::copy(out.knn_indices + i * n_neighbors, + graph.graph().data_handle() + i * params->nn_descent_params.graph_degree, + n_neighbors, + handle.get_stream()); + } + } + } // Instantiation for dense inputs, int indices @@ -100,26 +126,44 @@ inline void launcher(const raft::handle_t& handle, const ML::UMAPParams* params, cudaStream_t stream) { - raft::sparse::selection::brute_force_knn(inputsA.indptr, - inputsA.indices, - inputsA.data, - inputsA.nnz, - inputsA.n, - inputsA.d, - inputsB.indptr, - inputsB.indices, - inputsB.data, - inputsB.nnz, - inputsB.n, - inputsB.d, - out.knn_indices, - out.knn_dists, - n_neighbors, - handle, - ML::Sparse::DEFAULT_BATCH_SIZE, - ML::Sparse::DEFAULT_BATCH_SIZE, - params->metric, - params->p); + if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) { // brute_force_knn + raft::sparse::selection::brute_force_knn(inputsA.indptr, + inputsA.indices, + inputsA.data, + inputsA.nnz, + inputsA.n, + inputsA.d, + inputsB.indptr, + inputsB.indices, + inputsB.data, + inputsB.nnz, + inputsB.n, + inputsB.d, + out.knn_indices, + out.knn_dists, + n_neighbors, + handle, + ML::Sparse::DEFAULT_BATCH_SIZE, + ML::Sparse::DEFAULT_BATCH_SIZE, + params->metric, + params->p); + } else { // nn_descent + // number of columns (n_neightbors) should be smaller than the graph degree computed by nn descent + assert(n_neighbors <= params->nn_descent_params.graph_degree); + auto dataset = raft::make_host_matrix_view(inputsA.data, inputsA.n, inputsA.d); + auto graph = NNDescent::detail::build(handle, params->nn_descent_params, dataset); + + for(int i = 0; i < inputsA.n; i++) { + raft::copy(out.knn_dists + i * n_neighbors, + graph.distances().data_handle() + i * params->nn_descent_params.graph_degree, + n_neighbors, + handle.get_stream()); + raft::copy(out.knn_indices + i * n_neighbors, + graph.graph().data_handle() + i * params->nn_descent_params.graph_degree, + n_neighbors, + handle.get_stream()); + } + } } template <> diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index cc9c492fba..62d989628a 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -348,7 +348,13 @@ class UMAP(UniversalBase, callback=None, handle=None, verbose=False, - output_type=None): + output_type=None, + build_algo="brute_force_knn", + nnd_graph_degree=64, + nnd_intermediate_graph_degree=128, + nnd_max_iterations=20, + nnd_termination_threshold=0.0001, + nnd_return_distances=0): super().__init__(handle=handle, verbose=verbose, @@ -419,6 +425,17 @@ class UMAP(UniversalBase, self.precomputed_knn = extract_knn_infos(precomputed_knn, n_neighbors) + + if build_algo == "brute_force_knn" or build_algo == "nn_descent": + self.build_algo = build_algo + else: + raise Exception("Invalid build algo: {}. Only support brute_force_knn and nn_descent" % build_algo) + + self.nnd_graph_degree = nnd_graph_degree + self.nnd_intermediate_graph_degree = nnd_intermediate_graph_degree + self.nnd_max_iterations = nnd_max_iterations + self.nnd_termination_threshold = nnd_termination_threshold + self.nnd_return_distances = nnd_return_distances def validate_hyperparams(self): @@ -452,6 +469,15 @@ class UMAP(UniversalBase, umap_params.target_metric = MetricType.EUCLIDEAN else: # self.target_metric == "categorical" umap_params.target_metric = MetricType.CATEGORICAL + if cls.build_algo == "brute_force_knn": + umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN + else: # self.init == "nn_descent" + umap_params.build_algo = graph_build_algo.NN_DESCENT + umap_params.nn_descent_params.graph_degree = cls.nnd_graph_degree + umap_params.nn_descent_params.intermediate_graph_degree = cls.nnd_intermediate_graph_degree + umap_params.nn_descent_params.max_iterations = cls.nnd_max_iterations + umap_params.nn_descent_params.termination_threshold = cls.nnd_termination_threshold + umap_params.nn_descent_params.return_distances = 1 umap_params.target_weight = cls.target_weight umap_params.random_state = cls.random_state umap_params.deterministic = cls.deterministic @@ -804,3 +830,4 @@ class UMAP(UniversalBase, def get_attr_names(self): return ['_raw_data', 'embedding_', '_input_hash', '_small_data'] + \ No newline at end of file diff --git a/python/cuml/manifold/umap_utils.pxd b/python/cuml/manifold/umap_utils.pxd index abf4698b75..9e18614574 100644 --- a/python/cuml/manifold/umap_utils.pxd +++ b/python/cuml/manifold/umap_utils.pxd @@ -30,11 +30,22 @@ cdef extern from "cuml/manifold/umapparams.h" namespace "ML::UMAPParams": enum MetricType: EUCLIDEAN = 0, CATEGORICAL = 1 + enum graph_build_algo: + BRUTE_FORCE_KNN = 0, + NN_DESCENT = 1 cdef extern from "cuml/common/callback.hpp" namespace "ML::Internals": cdef cppclass GraphBasedDimRedCallback +cdef extern from "raft/neighbors/nn_descent_types.hpp" namespace "raft::neighbors::experimental::nn_descent": + cdef struct index_params: + int64_t graph_degree, + int64_t intermediate_graph_degree, + int64_t max_iterations, + float termination_threshold, + int return_distances + cdef extern from "cuml/manifold/umapparams.h" namespace "ML": cdef cppclass UMAPParams: @@ -54,6 +65,7 @@ cdef extern from "cuml/manifold/umapparams.h" namespace "ML": float b, float initial_alpha, int init, + graph_build_algo build_algo, int target_n_neighbors, MetricType target_metric, float target_weight, @@ -61,7 +73,8 @@ cdef extern from "cuml/manifold/umapparams.h" namespace "ML": bool deterministic, DistanceType metric, float p, - GraphBasedDimRedCallback * callback + GraphBasedDimRedCallback * callback, + index_params nn_descent_params cdef extern from "raft/sparse/coo.hpp": cdef cppclass COO "raft::sparse::COO": @@ -90,3 +103,4 @@ cdef class GraphHolder: cdef uintptr_t rows(GraphHolder self) cdef uintptr_t cols(GraphHolder self) cdef uint64_t get_nnz(GraphHolder self) + \ No newline at end of file From bf39ef23cf47ebe72fd256293ef2f42137259bec Mon Sep 17 00:00:00 2001 From: jinsolp Date: Fri, 31 May 2024 22:17:13 +0000 Subject: [PATCH 02/42] change return_distances to bool --- python/cuml/manifold/umap.pyx | 4 ++-- python/cuml/manifold/umap_utils.pxd | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 62d989628a..6541db645d 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -354,7 +354,7 @@ class UMAP(UniversalBase, nnd_intermediate_graph_degree=128, nnd_max_iterations=20, nnd_termination_threshold=0.0001, - nnd_return_distances=0): + nnd_return_distances=False): super().__init__(handle=handle, verbose=verbose, @@ -477,7 +477,7 @@ class UMAP(UniversalBase, umap_params.nn_descent_params.intermediate_graph_degree = cls.nnd_intermediate_graph_degree umap_params.nn_descent_params.max_iterations = cls.nnd_max_iterations umap_params.nn_descent_params.termination_threshold = cls.nnd_termination_threshold - umap_params.nn_descent_params.return_distances = 1 + umap_params.nn_descent_params.return_distances = True umap_params.target_weight = cls.target_weight umap_params.random_state = cls.random_state umap_params.deterministic = cls.deterministic diff --git a/python/cuml/manifold/umap_utils.pxd b/python/cuml/manifold/umap_utils.pxd index 9e18614574..cba663e56f 100644 --- a/python/cuml/manifold/umap_utils.pxd +++ b/python/cuml/manifold/umap_utils.pxd @@ -44,7 +44,7 @@ cdef extern from "raft/neighbors/nn_descent_types.hpp" namespace "raft::neighbor int64_t intermediate_graph_degree, int64_t max_iterations, float termination_threshold, - int return_distances + bool return_distances cdef extern from "cuml/manifold/umapparams.h" namespace "ML": From 94eb5793f64b1532c14d6f1173b572c24f846956 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Fri, 31 May 2024 22:44:14 +0000 Subject: [PATCH 03/42] add python test --- python/cuml/tests/test_umap.py | 36 ++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 6faa4ad8d3..8258289af3 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -27,8 +27,10 @@ from sklearn.cluster import KMeans from sklearn.neighbors import NearestNeighbors from sklearn import datasets +from cuml.datasets import make_blobs as cu_make_blobs from cuml.internals import logger from cuml.metrics import pairwise_distances +from cuml.metrics import trustworthiness as cu_trustworthiness from cuml.testing.utils import ( array_equal, unit_param, @@ -790,3 +792,37 @@ def test_umap_distance_metrics_fit_transform_trust_on_sparse_input( if umap_learn_supported: assert array_equal(umap_trust, cuml_trust, 0.05, with_sign=True) + + +def test_umap_nn_descent(): + X_blobs, y_blobs = cu_make_blobs(n_samples = 1000, + cluster_std = 0.1, + n_features = 100, + random_state = 0, + dtype=np.float32) + + # Dense + trained_UMAP_bf = cuUMAP(n_neighbors = 16).fit(X_blobs) + X_embedded_bf = trained_UMAP_bf.transform(X_blobs) + cu_score_bf = cu_trustworthiness(X_blobs, X_embedded_bf) + + trained_UMAP_nnd = cuUMAP(n_neighbors = 16, build_algo="nn_descent").fit(X_blobs) + X_embedded_nnd = trained_UMAP_nnd.transform(X_blobs) + cu_score_nnd = cu_trustworthiness(X_blobs, X_embedded_nnd) + + score_diff = np.abs(cu_score_bf - cu_score_nnd) + assert score_diff < 0.1 + + # Sparse + X_blobs_sparse = cupyx.scipy.sparse.csr_matrix(X_blobs) + + trained_UMAP_bf = cuUMAP(n_neighbors = 10).fit(X_blobs_sparse) + X_embedded_bf = trained_UMAP_bf.transform(X_blobs_sparse) + cu_score_bf = cu_trustworthiness(X_blobs, X_embedded_bf) + + trained_UMAP_nnd = cuUMAP(n_neighbors = 16, build_algo="nn_descent").fit(X_blobs_sparse) + X_embedded_nnd = trained_UMAP_nnd.transform(X_blobs_sparse) + cu_score_nnd = cu_trustworthiness(X_blobs, X_embedded_nnd) + + score_diff = np.abs(cu_score_bf - cu_score_nnd) + assert score_diff < 0.1 From 8e45779a491d1d87596e799bb86e0896fdc6d787 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Fri, 31 May 2024 23:01:32 +0000 Subject: [PATCH 04/42] fix styling --- cpp/include/cuml/manifold/umapparams.h | 2 +- cpp/src/umap/knn_graph/algo.cuh | 101 +++++++++++++------------ python/cuml/manifold/umap.pyx | 2 +- python/cuml/manifold/umap_utils.pxd | 2 +- python/cuml/tests/test_umap.py | 38 ++++++---- 5 files changed, 76 insertions(+), 69 deletions(-) diff --git a/cpp/include/cuml/manifold/umapparams.h b/cpp/include/cuml/manifold/umapparams.h index 45a13a1ff1..c5535f6603 100644 --- a/cpp/include/cuml/manifold/umapparams.h +++ b/cpp/include/cuml/manifold/umapparams.h @@ -31,7 +31,7 @@ using nn_index_params = raft::neighbors::experimental::nn_descent::index_params; class UMAPParams { public: enum MetricType { EUCLIDEAN, CATEGORICAL }; - enum graph_build_algo {BRUTE_FORCE_KNN, NN_DESCENT}; + enum graph_build_algo { BRUTE_FORCE_KNN, NN_DESCENT }; /** * The number of neighbors to use to approximate geodesic distance. diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 53e09d720f..eb8a166e03 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -25,13 +25,12 @@ #include #include #include +#include +#include #include #include #include -#include -#include - #include namespace NNDescent = raft::neighbors::experimental::nn_descent; @@ -63,7 +62,6 @@ inline void launcher(const raft::handle_t& handle, const ML::UMAPParams* params, cudaStream_t stream) { - if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) { std::vector ptrs(1); std::vector sizes(1); @@ -84,24 +82,26 @@ inline void launcher(const raft::handle_t& handle, static_cast*>(nullptr), params->metric, params->p); - } else { // nn_descent - // number of columns (n_neightbors) should be smaller than the graph degree computed by nn descent + } else { // nn_descent + // number of columns (n_neightbors) should be smaller than the graph degree computed by nn + // descent assert(n_neighbors <= params->nn_descent_params.graph_degree); - auto dataset = raft::make_host_matrix_view(inputsA.X, inputsA.n, inputsA.d); - auto graph = NNDescent::detail::build(handle, params->nn_descent_params, dataset); - - for(int i = 0; i < inputsA.n; i++) { - raft::copy(out.knn_dists + i * n_neighbors, - graph.distances().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors, - handle.get_stream()); - raft::copy(out.knn_indices + i * n_neighbors, - graph.graph().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors, - handle.get_stream()); + auto dataset = + raft::make_host_matrix_view(inputsA.X, inputsA.n, inputsA.d); + auto graph = + NNDescent::detail::build(handle, params->nn_descent_params, dataset); + + for (int i = 0; i < inputsA.n; i++) { + raft::copy(out.knn_dists + i * n_neighbors, + graph.distances().data_handle() + i * params->nn_descent_params.graph_degree, + n_neighbors, + handle.get_stream()); + raft::copy(out.knn_indices + i * n_neighbors, + graph.graph().data_handle() + i * params->nn_descent_params.graph_degree, + n_neighbors, + handle.get_stream()); } } - } // Instantiation for dense inputs, int indices @@ -128,40 +128,41 @@ inline void launcher(const raft::handle_t& handle, { if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) { // brute_force_knn raft::sparse::selection::brute_force_knn(inputsA.indptr, - inputsA.indices, - inputsA.data, - inputsA.nnz, - inputsA.n, - inputsA.d, - inputsB.indptr, - inputsB.indices, - inputsB.data, - inputsB.nnz, - inputsB.n, - inputsB.d, - out.knn_indices, - out.knn_dists, - n_neighbors, - handle, - ML::Sparse::DEFAULT_BATCH_SIZE, - ML::Sparse::DEFAULT_BATCH_SIZE, - params->metric, - params->p); + inputsA.indices, + inputsA.data, + inputsA.nnz, + inputsA.n, + inputsA.d, + inputsB.indptr, + inputsB.indices, + inputsB.data, + inputsB.nnz, + inputsB.n, + inputsB.d, + out.knn_indices, + out.knn_dists, + n_neighbors, + handle, + ML::Sparse::DEFAULT_BATCH_SIZE, + ML::Sparse::DEFAULT_BATCH_SIZE, + params->metric, + params->p); } else { // nn_descent - // number of columns (n_neightbors) should be smaller than the graph degree computed by nn descent + // n_neightbors should be smaller than the graph degree computed by nn descent assert(n_neighbors <= params->nn_descent_params.graph_degree); - auto dataset = raft::make_host_matrix_view(inputsA.data, inputsA.n, inputsA.d); + auto dataset = + raft::make_host_matrix_view(inputsA.data, inputsA.n, inputsA.d); auto graph = NNDescent::detail::build(handle, params->nn_descent_params, dataset); - - for(int i = 0; i < inputsA.n; i++) { - raft::copy(out.knn_dists + i * n_neighbors, - graph.distances().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors, - handle.get_stream()); - raft::copy(out.knn_indices + i * n_neighbors, - graph.graph().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors, - handle.get_stream()); + + for (int i = 0; i < inputsA.n; i++) { + raft::copy(out.knn_dists + i * n_neighbors, + graph.distances().data_handle() + i * params->nn_descent_params.graph_degree, + n_neighbors, + handle.get_stream()); + raft::copy(out.knn_indices + i * n_neighbors, + graph.graph().data_handle() + i * params->nn_descent_params.graph_degree, + n_neighbors, + handle.get_stream()); } } } diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 6541db645d..d5645c946f 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/cuml/manifold/umap_utils.pxd b/python/cuml/manifold/umap_utils.pxd index cba663e56f..0a9b899269 100644 --- a/python/cuml/manifold/umap_utils.pxd +++ b/python/cuml/manifold/umap_utils.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 8258289af3..7d323b7797 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -795,34 +795,40 @@ def test_umap_distance_metrics_fit_transform_trust_on_sparse_input( def test_umap_nn_descent(): - X_blobs, y_blobs = cu_make_blobs(n_samples = 1000, - cluster_std = 0.1, - n_features = 100, - random_state = 0, - dtype=np.float32) - + X_blobs, y_blobs = cu_make_blobs( + n_samples=1000, + cluster_std=0.1, + n_features=100, + random_state=0, + dtype=np.float32, + ) + # Dense - trained_UMAP_bf = cuUMAP(n_neighbors = 16).fit(X_blobs) + trained_UMAP_bf = cuUMAP(n_neighbors=16).fit(X_blobs) X_embedded_bf = trained_UMAP_bf.transform(X_blobs) cu_score_bf = cu_trustworthiness(X_blobs, X_embedded_bf) - - trained_UMAP_nnd = cuUMAP(n_neighbors = 16, build_algo="nn_descent").fit(X_blobs) + + trained_UMAP_nnd = cuUMAP(n_neighbors=16, build_algo="nn_descent").fit( + X_blobs + ) X_embedded_nnd = trained_UMAP_nnd.transform(X_blobs) cu_score_nnd = cu_trustworthiness(X_blobs, X_embedded_nnd) - + score_diff = np.abs(cu_score_bf - cu_score_nnd) assert score_diff < 0.1 - + # Sparse X_blobs_sparse = cupyx.scipy.sparse.csr_matrix(X_blobs) - - trained_UMAP_bf = cuUMAP(n_neighbors = 10).fit(X_blobs_sparse) + + trained_UMAP_bf = cuUMAP(n_neighbors=10).fit(X_blobs_sparse) X_embedded_bf = trained_UMAP_bf.transform(X_blobs_sparse) cu_score_bf = cu_trustworthiness(X_blobs, X_embedded_bf) - - trained_UMAP_nnd = cuUMAP(n_neighbors = 16, build_algo="nn_descent").fit(X_blobs_sparse) + + trained_UMAP_nnd = cuUMAP(n_neighbors=16, build_algo="nn_descent").fit( + X_blobs_sparse + ) X_embedded_nnd = trained_UMAP_nnd.transform(X_blobs_sparse) cu_score_nnd = cu_trustworthiness(X_blobs, X_embedded_nnd) - + score_diff = np.abs(cu_score_bf - cu_score_nnd) assert score_diff < 0.1 From 783838871c1d552678e37969be739e375c158e69 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Fri, 31 May 2024 23:02:44 +0000 Subject: [PATCH 05/42] fix comment --- cpp/src/umap/knn_graph/algo.cuh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index eb8a166e03..5106347418 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -83,8 +83,7 @@ inline void launcher(const raft::handle_t& handle, params->metric, params->p); } else { // nn_descent - // number of columns (n_neightbors) should be smaller than the graph degree computed by nn - // descent + // n_neightbors should be smaller than the graph degree computed by nn descent assert(n_neighbors <= params->nn_descent_params.graph_degree); auto dataset = raft::make_host_matrix_view(inputsA.X, inputsA.n, inputsA.d); From a7bf6bafc487062ab01d86272ae360a6b9483645 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Fri, 31 May 2024 23:03:58 +0000 Subject: [PATCH 06/42] fix styling --- python/cuml/manifold/umap.pyx | 1 - python/cuml/manifold/umap_utils.pxd | 1 - 2 files changed, 2 deletions(-) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index d5645c946f..4e95ba84d3 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -830,4 +830,3 @@ class UMAP(UniversalBase, def get_attr_names(self): return ['_raw_data', 'embedding_', '_input_hash', '_small_data'] - \ No newline at end of file diff --git a/python/cuml/manifold/umap_utils.pxd b/python/cuml/manifold/umap_utils.pxd index 0a9b899269..b57b25a25a 100644 --- a/python/cuml/manifold/umap_utils.pxd +++ b/python/cuml/manifold/umap_utils.pxd @@ -103,4 +103,3 @@ cdef class GraphHolder: cdef uintptr_t rows(GraphHolder self) cdef uintptr_t cols(GraphHolder self) cdef uint64_t get_nnz(GraphHolder self) - \ No newline at end of file From 7b0bb1903a17774b65c01dbc01a047da3060f1d2 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Sat, 1 Jun 2024 00:34:45 +0000 Subject: [PATCH 07/42] remove comment --- cpp/include/cuml/manifold/umapparams.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/include/cuml/manifold/umapparams.h b/cpp/include/cuml/manifold/umapparams.h index c5535f6603..71418198cf 100644 --- a/cpp/include/cuml/manifold/umapparams.h +++ b/cpp/include/cuml/manifold/umapparams.h @@ -22,8 +22,6 @@ #include #include -// namespace NNDescent = raft::neighbors::experimental::nn_descent; - namespace ML { using nn_index_params = raft::neighbors::experimental::nn_descent::index_params; From 588a430871b7fc59e6086c2666a8dcb723ccf7f5 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Sat, 1 Jun 2024 00:35:59 +0000 Subject: [PATCH 08/42] fix typo --- cpp/src/umap/knn_graph/algo.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 5106347418..49594128b9 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -83,7 +83,7 @@ inline void launcher(const raft::handle_t& handle, params->metric, params->p); } else { // nn_descent - // n_neightbors should be smaller than the graph degree computed by nn descent + // n_neighbors should be smaller than the graph degree computed by nn descent assert(n_neighbors <= params->nn_descent_params.graph_degree); auto dataset = raft::make_host_matrix_view(inputsA.X, inputsA.n, inputsA.d); @@ -147,7 +147,7 @@ inline void launcher(const raft::handle_t& handle, params->metric, params->p); } else { // nn_descent - // n_neightbors should be smaller than the graph degree computed by nn descent + // n_neighbors should be smaller than the graph degree computed by nn descent assert(n_neighbors <= params->nn_descent_params.graph_degree); auto dataset = raft::make_host_matrix_view(inputsA.data, inputsA.n, inputsA.d); From e88f183c9158b9581a484cc55f485a1b8f210990 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Mon, 3 Jun 2024 16:44:49 +0000 Subject: [PATCH 09/42] use cuml logger --- cpp/src/umap/knn_graph/algo.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 49594128b9..fd90740b82 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -16,6 +16,7 @@ #pragma once +#include "cuml/common/logger.hpp" #include #include #include @@ -83,7 +84,7 @@ inline void launcher(const raft::handle_t& handle, params->metric, params->p); } else { // nn_descent - // n_neighbors should be smaller than the graph degree computed by nn descent + CUML_LOG_DEBUG("n_neighbors should be smaller than the graph degree computed by nn descent"); assert(n_neighbors <= params->nn_descent_params.graph_degree); auto dataset = raft::make_host_matrix_view(inputsA.X, inputsA.n, inputsA.d); @@ -147,7 +148,7 @@ inline void launcher(const raft::handle_t& handle, params->metric, params->p); } else { // nn_descent - // n_neighbors should be smaller than the graph degree computed by nn descent + CUML_LOG_DEBUG("n_neighbors should be smaller than the graph degree computed by nn descent"); assert(n_neighbors <= params->nn_descent_params.graph_degree); auto dataset = raft::make_host_matrix_view(inputsA.data, inputsA.n, inputsA.d); From 8fc0f2d57ec35dea309234b2840e8564a8656fb9 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Mon, 3 Jun 2024 18:54:01 +0000 Subject: [PATCH 10/42] change arg to dict and add documentation --- python/cuml/manifold/umap.pyx | 37 ++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 4e95ba84d3..663e7b7e45 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -289,6 +289,11 @@ class UMAP(UniversalBase, type. If None, the output type set at the module level (`cuml.global_settings.output_type`) will be used. See :ref:`output-data-type-configuration` for more info. + build_algo: string (default='brute_force_knn') + How to build the knn graph. Supported build algorithms are ['brute_force_knn', + 'nn_descent'] + metric_kwds: dict (optional, default=None) + Build algorithm argument Notes ----- @@ -350,11 +355,8 @@ class UMAP(UniversalBase, verbose=False, output_type=None, build_algo="brute_force_knn", - nnd_graph_degree=64, - nnd_intermediate_graph_degree=128, - nnd_max_iterations=20, - nnd_termination_threshold=0.0001, - nnd_return_distances=False): + build_kwds=None): + super().__init__(handle=handle, verbose=verbose, @@ -430,12 +432,8 @@ class UMAP(UniversalBase, self.build_algo = build_algo else: raise Exception("Invalid build algo: {}. Only support brute_force_knn and nn_descent" % build_algo) - - self.nnd_graph_degree = nnd_graph_degree - self.nnd_intermediate_graph_degree = nnd_intermediate_graph_degree - self.nnd_max_iterations = nnd_max_iterations - self.nnd_termination_threshold = nnd_termination_threshold - self.nnd_return_distances = nnd_return_distances + + self.build_kwds = build_kwds def validate_hyperparams(self): @@ -473,11 +471,18 @@ class UMAP(UniversalBase, umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN else: # self.init == "nn_descent" umap_params.build_algo = graph_build_algo.NN_DESCENT - umap_params.nn_descent_params.graph_degree = cls.nnd_graph_degree - umap_params.nn_descent_params.intermediate_graph_degree = cls.nnd_intermediate_graph_degree - umap_params.nn_descent_params.max_iterations = cls.nnd_max_iterations - umap_params.nn_descent_params.termination_threshold = cls.nnd_termination_threshold - umap_params.nn_descent_params.return_distances = True + if cls.build_kwds is None: + umap_params.nn_descent_params.graph_degree = 64 + umap_params.nn_descent_params.intermediate_graph_degree = 128 + umap_params.nn_descent_params.max_iterations = 20 + umap_params.nn_descent_params.termination_threshold = 0.0001 + umap_params.nn_descent_params.return_distances = True + else: + umap_params.nn_descent_params.graph_degree = cls.build_kwds.get("nnd_graph_degree", 64) + umap_params.nn_descent_params.intermediate_graph_degree = cls.build_kwds.get("nnd_intermediate_graph_degree", 128) + umap_params.nn_descent_params.max_iterations = cls.build_kwds.get("nnd_max_iterations", 20) + umap_params.nn_descent_params.termination_threshold = cls.build_kwds.get("nnd_termination_threshold", 0.0001) + umap_params.nn_descent_params.return_distances = True umap_params.target_weight = cls.target_weight umap_params.random_state = cls.random_state umap_params.deterministic = cls.deterministic From 127bea55859c70688d63c2590181433901f1b2e4 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Tue, 4 Jun 2024 17:53:30 +0000 Subject: [PATCH 11/42] change to RAFT_EXPECTS --- cpp/src/umap/knn_graph/algo.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index fd90740b82..23d442d604 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -84,8 +84,8 @@ inline void launcher(const raft::handle_t& handle, params->metric, params->p); } else { // nn_descent - CUML_LOG_DEBUG("n_neighbors should be smaller than the graph degree computed by nn descent"); - assert(n_neighbors <= params->nn_descent_params.graph_degree); + RAFT_EXPECTS(static_cast(n_neighbors) <= params->nn_descent_params.graph_degree, "n_neighbors should be smaller than the graph degree computed by nn descent"); + auto dataset = raft::make_host_matrix_view(inputsA.X, inputsA.n, inputsA.d); auto graph = @@ -148,8 +148,8 @@ inline void launcher(const raft::handle_t& handle, params->metric, params->p); } else { // nn_descent - CUML_LOG_DEBUG("n_neighbors should be smaller than the graph degree computed by nn descent"); - assert(n_neighbors <= params->nn_descent_params.graph_degree); + RAFT_EXPECTS(static_cast(n_neighbors) <= params->nn_descent_params.graph_degree, "n_neighbors should be smaller than the graph degree computed by nn descent"); + auto dataset = raft::make_host_matrix_view(inputsA.data, inputsA.n, inputsA.d); auto graph = NNDescent::detail::build(handle, params->nn_descent_params, dataset); From fd6793b57c0c1a68120d457569282b2adac8982a Mon Sep 17 00:00:00 2001 From: jinsolp Date: Tue, 4 Jun 2024 17:57:10 +0000 Subject: [PATCH 12/42] remove logger header --- cpp/src/umap/knn_graph/algo.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 23d442d604..7eadbf5f64 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -16,7 +16,6 @@ #pragma once -#include "cuml/common/logger.hpp" #include #include #include From d772098cfb8b627a158268c9ba521fd44387c5aa Mon Sep 17 00:00:00 2001 From: jinsolp Date: Fri, 14 Jun 2024 20:49:25 +0000 Subject: [PATCH 13/42] enable l2sqrtexpanded dist + fix errors --- cpp/src/umap/knn_graph/algo.cuh | 103 ++++++++++---------- cpp/src/umap/umap.cu | 4 + python/cuml/manifold/umap.pyx | 6 +- python/cuml/tests/test_umap.py | 162 +++++++++++++++++++------------- 4 files changed, 157 insertions(+), 118 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 7eadbf5f64..dd9ba66644 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -35,6 +35,12 @@ namespace NNDescent = raft::neighbors::experimental::nn_descent; +// Functor to post-process distances as L2Sqrt* +template +struct DistancePostProcessSqrt { + DI value_t operator()(value_t value, value_idx row, value_idx col) const { return sqrtf(value); } +}; + namespace UMAPAlgo { namespace kNNGraph { namespace Algo { @@ -83,22 +89,36 @@ inline void launcher(const raft::handle_t& handle, params->metric, params->p); } else { // nn_descent - RAFT_EXPECTS(static_cast(n_neighbors) <= params->nn_descent_params.graph_degree, "n_neighbors should be smaller than the graph degree computed by nn descent"); + RAFT_EXPECTS(static_cast(n_neighbors) <= params->nn_descent_params.graph_degree, + "n_neighbors should be smaller than the graph degree computed by nn descent"); + + auto epilogue = DistancePostProcessSqrt{}; auto dataset = raft::make_host_matrix_view(inputsA.X, inputsA.n, inputsA.d); - auto graph = - NNDescent::detail::build(handle, params->nn_descent_params, dataset); - - for (int i = 0; i < inputsA.n; i++) { - raft::copy(out.knn_dists + i * n_neighbors, - graph.distances().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors, - handle.get_stream()); - raft::copy(out.knn_indices + i * n_neighbors, + auto graph = NNDescent::detail::build( + handle, params->nn_descent_params, dataset, epilogue); + + for (int i = 0; i < inputsB.n; i++) { + if (graph.distances().has_value()) { + raft::copy( + out.knn_dists + i * n_neighbors + 1, + graph.distances().value().data_handle() + i * params->nn_descent_params.graph_degree, + n_neighbors - 1, + handle.get_stream()); + thrust::fill(thrust::device.on(stream), + out.knn_dists + i * n_neighbors, + out.knn_dists + i * n_neighbors + 1, + 0.0); + } + raft::copy(out.knn_indices + i * n_neighbors + 1, graph.graph().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors, + n_neighbors - 1, handle.get_stream()); + thrust::fill(thrust::device.on(stream), + out.knn_indices + i * n_neighbors, + out.knn_indices + i * n_neighbors + 1, + i); } } } @@ -125,45 +145,28 @@ inline void launcher(const raft::handle_t& handle, const ML::UMAPParams* params, cudaStream_t stream) { - if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) { // brute_force_knn - raft::sparse::selection::brute_force_knn(inputsA.indptr, - inputsA.indices, - inputsA.data, - inputsA.nnz, - inputsA.n, - inputsA.d, - inputsB.indptr, - inputsB.indices, - inputsB.data, - inputsB.nnz, - inputsB.n, - inputsB.d, - out.knn_indices, - out.knn_dists, - n_neighbors, - handle, - ML::Sparse::DEFAULT_BATCH_SIZE, - ML::Sparse::DEFAULT_BATCH_SIZE, - params->metric, - params->p); - } else { // nn_descent - RAFT_EXPECTS(static_cast(n_neighbors) <= params->nn_descent_params.graph_degree, "n_neighbors should be smaller than the graph degree computed by nn descent"); - - auto dataset = - raft::make_host_matrix_view(inputsA.data, inputsA.n, inputsA.d); - auto graph = NNDescent::detail::build(handle, params->nn_descent_params, dataset); - - for (int i = 0; i < inputsA.n; i++) { - raft::copy(out.knn_dists + i * n_neighbors, - graph.distances().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors, - handle.get_stream()); - raft::copy(out.knn_indices + i * n_neighbors, - graph.graph().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors, - handle.get_stream()); - } - } + RAFT_EXPECTS(params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN, + "nn_descent does not support sparse inputs"); + raft::sparse::selection::brute_force_knn(inputsA.indptr, + inputsA.indices, + inputsA.data, + inputsA.nnz, + inputsA.n, + inputsA.d, + inputsB.indptr, + inputsB.indices, + inputsB.data, + inputsB.nnz, + inputsB.n, + inputsB.d, + out.knn_indices, + out.knn_dists, + n_neighbors, + handle, + ML::Sparse::DEFAULT_BATCH_SIZE, + ML::Sparse::DEFAULT_BATCH_SIZE, + params->metric, + params->p); } template <> diff --git a/cpp/src/umap/umap.cu b/cpp/src/umap/umap.cu index 016a79d2d4..86799ae6bc 100644 --- a/cpp/src/umap/umap.cu +++ b/cpp/src/umap/umap.cu @@ -187,6 +187,8 @@ void transform(const raft::handle_t& handle, UMAPParams* params, float* transformed) { + RAFT_EXPECTS(params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN, + "build algo nn_descent not supported for transform()"); manifold_dense_inputs_t inputs(X, nullptr, n, d); manifold_dense_inputs_t orig_inputs(orig_X, nullptr, orig_n, d); UMAPAlgo::_transform, TPB_X>( @@ -210,6 +212,8 @@ void transform_sparse(const raft::handle_t& handle, UMAPParams* params, float* transformed) { + RAFT_EXPECTS(params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN, + "build algo nn_descent not supported for transform()"); manifold_sparse_inputs_t inputs( indptr, indices, data, nullptr, nnz, n, d); manifold_sparse_inputs_t orig_x_inputs( diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 663e7b7e45..2088bf0923 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -356,7 +356,6 @@ class UMAP(UniversalBase, output_type=None, build_algo="brute_force_knn", build_kwds=None): - super().__init__(handle=handle, verbose=verbose, @@ -427,7 +426,7 @@ class UMAP(UniversalBase, self.precomputed_knn = extract_knn_infos(precomputed_knn, n_neighbors) - + if build_algo == "brute_force_knn" or build_algo == "nn_descent": self.build_algo = build_algo else: @@ -769,6 +768,9 @@ class UMAP(UniversalBase, cdef UMAPParams* umap_params = \ UMAP._build_umap_params(self, self.sparse_fit) + # NN Descent doesn't support transform yet + umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN + cdef handle_t * handle_ = \ self.handle.getHandle() if self.sparse_fit: diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 7d323b7797..8a167a4038 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -62,12 +62,15 @@ @pytest.mark.parametrize( "n_feats", [unit_param(20), quality_param(100), stress_param(1000)] ) -def test_blobs_cluster(nrows, n_feats): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_blobs_cluster(nrows, n_feats, build_algo): data, labels = datasets.make_blobs( n_samples=nrows, n_features=n_feats, centers=5, random_state=0 ) - embedding = cuUMAP().fit_transform(data, convert_dtype=True) + embedding = cuUMAP(build_algo=build_algo).fit_transform( + data, convert_dtype=True + ) if nrows < 500000: score = adjusted_rand_score(labels, KMeans(5).fit_predict(embedding)) @@ -83,7 +86,8 @@ def test_blobs_cluster(nrows, n_feats): @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) -def test_umap_fit_transform_score(nrows, n_feats): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_umap_fit_transform_score(nrows, n_feats, build_algo): n_samples = nrows n_features = n_feats @@ -93,7 +97,7 @@ def test_umap_fit_transform_score(nrows, n_feats): ) model = umap.UMAP(n_neighbors=10, min_dist=0.1) - cuml_model = cuUMAP(n_neighbors=10, min_dist=0.01) + cuml_model = cuUMAP(n_neighbors=10, min_dist=0.01, build_algo=build_algo) embedding = model.fit_transform(data) cuml_embedding = cuml_model.fit_transform(data, convert_dtype=True) @@ -223,7 +227,8 @@ def test_umap_transform_on_digits_sparse( @pytest.mark.parametrize("target_metric", ["categorical", "euclidean"]) -def test_umap_transform_on_digits(target_metric): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_umap_transform_on_digits(target_metric, build_algo): digits = datasets.load_digits() @@ -240,6 +245,7 @@ def test_umap_transform_on_digits(target_metric): min_dist=0.01, random_state=42, target_metric=target_metric, + build_algo=build_algo, ) fitter.fit(data, convert_dtype=True) @@ -253,11 +259,22 @@ def test_umap_transform_on_digits(target_metric): @pytest.mark.parametrize("target_metric", ["categorical", "euclidean"]) -@pytest.mark.parametrize("name", dataset_names) +@pytest.mark.parametrize( + "name,build_algo", + [ + ("iris", "brute_force_knn"), + ("digits", "brute_force_knn"), + ("wine", "brute_force_knn"), + ("blobs", "brute_force_knn"), + ("digits", "nn_descent"), + ("wine", "nn_descent"), + ("blobs", "nn_descent"), + ], +) @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) -def test_umap_fit_transform_trust(name, target_metric): +def test_umap_fit_transform_trust(name, target_metric, build_algo): if name == "iris": iris = datasets.load_iris() @@ -282,7 +299,10 @@ def test_umap_fit_transform_trust(name, target_metric): n_neighbors=10, min_dist=0.01, target_metric=target_metric ) cuml_model = cuUMAP( - n_neighbors=10, min_dist=0.01, target_metric=target_metric + n_neighbors=10, + min_dist=0.01, + target_metric=target_metric, + build_algo=build_algo, ) embedding = model.fit_transform(data) cuml_embedding = cuml_model.fit_transform(data, convert_dtype=True) @@ -299,11 +319,18 @@ def test_umap_fit_transform_trust(name, target_metric): @pytest.mark.parametrize("n_feats", [quality_param(100), stress_param(1000)]) @pytest.mark.parametrize("should_downcast", [True]) @pytest.mark.parametrize("input_type", ["dataframe", "ndarray"]) +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) def test_umap_data_formats( - input_type, should_downcast, nrows, n_feats, name, target_metric + input_type, + should_downcast, + nrows, + n_feats, + name, + target_metric, + build_algo, ): dtype = np.float32 if not should_downcast else np.float64 @@ -320,7 +347,12 @@ def test_umap_data_formats( n_samples=n_samples, n_features=n_feats, random_state=0 ) - umap = cuUMAP(n_neighbors=3, n_components=2, target_metric=target_metric) + umap = cuUMAP( + n_neighbors=3, + n_components=2, + target_metric=target_metric, + build_algo=build_algo, + ) embeds = umap.fit_transform(X) assert type(embeds) == np.ndarray @@ -328,10 +360,11 @@ def test_umap_data_formats( @pytest.mark.parametrize("target_metric", ["categorical", "euclidean"]) @pytest.mark.filterwarnings("ignore:(.*)connected(.*):UserWarning:sklearn[.*]") +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) -def test_umap_fit_transform_score_default(target_metric): +def test_umap_fit_transform_score_default(target_metric, build_algo): n_samples = 500 n_features = 20 @@ -341,7 +374,7 @@ def test_umap_fit_transform_score_default(target_metric): ) model = umap.UMAP(target_metric=target_metric) - cuml_model = cuUMAP(target_metric=target_metric) + cuml_model = cuUMAP(target_metric=target_metric, build_algo=build_algo) embedding = model.fit_transform(data) cuml_embedding = cuml_model.fit_transform(data, convert_dtype=True) @@ -354,7 +387,8 @@ def test_umap_fit_transform_score_default(target_metric): assert array_equal(score, cuml_score, 1e-2, with_sign=True) -def test_umap_fit_transform_against_fit_and_transform(): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_umap_fit_transform_against_fit_and_transform(build_algo): n_samples = 500 n_features = 20 @@ -367,7 +401,7 @@ def test_umap_fit_transform_against_fit_and_transform(): First test the default option does not hash the input """ - cuml_model = cuUMAP() + cuml_model = cuUMAP(build_algo=build_algo) ft_embedding = cuml_model.fit_transform(data, convert_dtype=True) fit_embedding_same_input = cuml_model.transform(data, convert_dtype=True) @@ -526,9 +560,12 @@ def test_umap_transform_trustworthiness_with_consistency_enabled(): @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) -def test_exp_decay_params(): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_exp_decay_params(build_algo): def compare_exp_decay_params(a=None, b=None, min_dist=0.1, spread=1.0): - cuml_model = cuUMAP(a=a, b=b, min_dist=min_dist, spread=spread) + cuml_model = cuUMAP( + a=a, b=b, min_dist=min_dist, spread=spread, build_algo=build_algo + ) state = cuml_model.__getstate__() cuml_a, cuml_b = state["a"], state["b"] skl_model = umap.UMAP(a=a, b=b, min_dist=min_dist, spread=spread) @@ -546,20 +583,31 @@ def compare_exp_decay_params(a=None, b=None, min_dist=0.1, spread=1.0): @pytest.mark.parametrize("n_neighbors", [5, 15]) -def test_umap_knn_graph(n_neighbors): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_umap_knn_graph(n_neighbors, build_algo): data, labels = datasets.make_blobs( n_samples=2000, n_features=10, centers=5, random_state=0 ) data = data.astype(np.float32) def fit_transform_embed(knn_graph=None): - model = cuUMAP(random_state=42, init="random", n_neighbors=n_neighbors) + model = cuUMAP( + random_state=42, + init="random", + n_neighbors=n_neighbors, + build_algo=build_algo, + ) return model.fit_transform( data, knn_graph=knn_graph, convert_dtype=True ) def transform_embed(knn_graph=None): - model = cuUMAP(random_state=42, init="random", n_neighbors=n_neighbors) + model = cuUMAP( + random_state=42, + init="random", + n_neighbors=n_neighbors, + build_algo=build_algo, + ) model.fit(data, knn_graph=knn_graph, convert_dtype=True) return model.transform(data, convert_dtype=True) @@ -601,8 +649,15 @@ def test_equality(e1, e2): @pytest.mark.parametrize( "precomputed_type", ["knn_graph", "tuple", "pairwise"] ) -@pytest.mark.parametrize("sparse_input", [False, True]) -def test_umap_precomputed_knn(precomputed_type, sparse_input): +@pytest.mark.parametrize( + "sparse_input,build_algo", + [ + (False, "brute_force_knn"), + (True, "brute_force_knn"), + (False, "nn_descent"), + ], +) +def test_umap_precomputed_knn(precomputed_type, sparse_input, build_algo): data, labels = make_blobs( n_samples=2000, n_features=10, centers=5, random_state=0 ) @@ -629,7 +684,11 @@ def test_umap_precomputed_knn(precomputed_type, sparse_input): elif precomputed_type == "pairwise": precomputed_knn = pairwise_distances(data) - model = cuUMAP(n_neighbors=n_neighbors, precomputed_knn=precomputed_knn) + model = cuUMAP( + n_neighbors=n_neighbors, + precomputed_knn=precomputed_knn, + build_algo=build_algo, + ) embedding = model.fit_transform(data) trust = trustworthiness(data, embedding, n_neighbors=n_neighbors) assert trust >= 0.92 @@ -649,7 +708,8 @@ def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.95): @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) -def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors, build_algo): n_clusters = 30 random_state = 42 @@ -660,7 +720,10 @@ def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors): random_state=random_state, ) - model = cuUMAP(n_neighbors=n_neighbors) + model = cuUMAP( + n_neighbors=n_neighbors, + build_algo=build_algo, + ) model.fit(X) cu_fss_graph = model.graph_ @@ -692,10 +755,13 @@ def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors): ("canberra", True), ], ) +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) -def test_umap_distance_metrics_fit_transform_trust(metric, supported): +def test_umap_distance_metrics_fit_transform_trust( + metric, supported, build_algo +): data, labels = make_blobs( n_samples=1000, n_features=64, centers=5, random_state=42 ) @@ -707,7 +773,11 @@ def test_umap_distance_metrics_fit_transform_trust(metric, supported): n_neighbors=10, min_dist=0.01, metric=metric, init="random" ) cuml_model = cuUMAP( - n_neighbors=10, min_dist=0.01, metric=metric, init="random" + n_neighbors=10, + min_dist=0.01, + metric=metric, + init="random", + build_algo=build_algo, ) if not supported: with pytest.raises(NotImplementedError): @@ -792,43 +862,3 @@ def test_umap_distance_metrics_fit_transform_trust_on_sparse_input( if umap_learn_supported: assert array_equal(umap_trust, cuml_trust, 0.05, with_sign=True) - - -def test_umap_nn_descent(): - X_blobs, y_blobs = cu_make_blobs( - n_samples=1000, - cluster_std=0.1, - n_features=100, - random_state=0, - dtype=np.float32, - ) - - # Dense - trained_UMAP_bf = cuUMAP(n_neighbors=16).fit(X_blobs) - X_embedded_bf = trained_UMAP_bf.transform(X_blobs) - cu_score_bf = cu_trustworthiness(X_blobs, X_embedded_bf) - - trained_UMAP_nnd = cuUMAP(n_neighbors=16, build_algo="nn_descent").fit( - X_blobs - ) - X_embedded_nnd = trained_UMAP_nnd.transform(X_blobs) - cu_score_nnd = cu_trustworthiness(X_blobs, X_embedded_nnd) - - score_diff = np.abs(cu_score_bf - cu_score_nnd) - assert score_diff < 0.1 - - # Sparse - X_blobs_sparse = cupyx.scipy.sparse.csr_matrix(X_blobs) - - trained_UMAP_bf = cuUMAP(n_neighbors=10).fit(X_blobs_sparse) - X_embedded_bf = trained_UMAP_bf.transform(X_blobs_sparse) - cu_score_bf = cu_trustworthiness(X_blobs, X_embedded_bf) - - trained_UMAP_nnd = cuUMAP(n_neighbors=16, build_algo="nn_descent").fit( - X_blobs_sparse - ) - X_embedded_nnd = trained_UMAP_nnd.transform(X_blobs_sparse) - cu_score_nnd = cu_trustworthiness(X_blobs, X_embedded_nnd) - - score_diff = np.abs(cu_score_bf - cu_score_nnd) - assert score_diff < 0.1 From 98645f99e43577e287949d0aee74755a00ff64a9 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Mon, 17 Jun 2024 16:38:18 +0000 Subject: [PATCH 14/42] change sqrt -> pow(0.5) --- cpp/src/umap/knn_graph/algo.cuh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index dd9ba66644..d245b7937e 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -35,10 +35,14 @@ namespace NNDescent = raft::neighbors::experimental::nn_descent; -// Functor to post-process distances as L2Sqrt* +// Functor to post-process distances by sqrt +// For usage with NN Descent which internally supports L2Expanded only template struct DistancePostProcessSqrt { - DI value_t operator()(value_t value, value_idx row, value_idx col) const { return sqrtf(value); } + DI value_t operator()(value_t value, value_idx row, value_idx col) const + { + return powf(fabsf(value), 0.5); + } }; namespace UMAPAlgo { From 76c7f3162f451c6e8a3cdc917e52f1fc79bd636b Mon Sep 17 00:00:00 2001 From: jinsolp Date: Wed, 19 Jun 2024 01:11:34 +0000 Subject: [PATCH 15/42] refine distances due to precision issues --- cpp/src/umap/knn_graph/algo.cuh | 70 ++++++++++++++++++--------------- python/cuml/manifold/umap.pyx | 4 +- python/cuml/tests/test_umap.py | 46 ++++++++++++++-------- 3 files changed, 70 insertions(+), 50 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index d245b7937e..ae72c3b182 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -1,3 +1,5 @@ + + /* * Copyright (c) 2019-2024, NVIDIA CORPORATION. * @@ -27,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -35,16 +38,6 @@ namespace NNDescent = raft::neighbors::experimental::nn_descent; -// Functor to post-process distances by sqrt -// For usage with NN Descent which internally supports L2Expanded only -template -struct DistancePostProcessSqrt { - DI value_t operator()(value_t value, value_idx row, value_idx col) const - { - return powf(fabsf(value), 0.5); - } -}; - namespace UMAPAlgo { namespace kNNGraph { namespace Algo { @@ -96,34 +89,47 @@ inline void launcher(const raft::handle_t& handle, RAFT_EXPECTS(static_cast(n_neighbors) <= params->nn_descent_params.graph_degree, "n_neighbors should be smaller than the graph degree computed by nn descent"); - auto epilogue = DistancePostProcessSqrt{}; - auto dataset = raft::make_host_matrix_view(inputsA.X, inputsA.n, inputsA.d); - auto graph = NNDescent::detail::build( - handle, params->nn_descent_params, dataset, epilogue); + auto graph = + NNDescent::detail::build(handle, params->nn_descent_params, dataset); for (int i = 0; i < inputsB.n; i++) { - if (graph.distances().has_value()) { - raft::copy( - out.knn_dists + i * n_neighbors + 1, - graph.distances().value().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors - 1, - handle.get_stream()); - thrust::fill(thrust::device.on(stream), - out.knn_dists + i * n_neighbors, - out.knn_dists + i * n_neighbors + 1, - 0.0); + for (size_t j = n_neighbors - 1; j >= 1; j--) { + graph.graph().data_handle()[i * params->nn_descent_params.graph_degree + j] = + graph.graph().data_handle()[i * params->nn_descent_params.graph_degree + j - 1]; } - raft::copy(out.knn_indices + i * n_neighbors + 1, - graph.graph().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors - 1, - handle.get_stream()); - thrust::fill(thrust::device.on(stream), - out.knn_indices + i * n_neighbors, - out.knn_indices + i * n_neighbors + 1, - i); + graph.graph().data_handle()[i * params->nn_descent_params.graph_degree] = i; } + + auto dataset_dev = + raft::make_device_matrix(handle, inputsB.n, inputsA.d); + raft::copy( + dataset_dev.data_handle(), dataset.data_handle(), inputsB.n * inputsA.d, handle.get_stream()); + auto dataset_dev_view = raft::make_device_matrix_view( + dataset_dev.data_handle(), inputsB.n, inputsA.d); + + auto neighbor_candidates = raft::make_device_matrix( + handle, inputsB.n, params->nn_descent_params.graph_degree); + raft::copy(neighbor_candidates.data_handle(), + graph.graph().data_handle(), + inputsB.n * params->nn_descent_params.graph_degree, + handle.get_stream()); + auto neighbor_candidates_view = + raft::make_device_matrix_view( + neighbor_candidates.data_handle(), inputsB.n, params->nn_descent_params.graph_degree); + + auto indices = + raft::make_device_matrix_view(out.knn_indices, inputsB.n, n_neighbors); + auto distances = + raft::make_device_matrix_view(out.knn_dists, inputsB.n, n_neighbors); + raft::neighbors::refine(handle, + dataset_dev_view, + dataset_dev_view, + neighbor_candidates_view, + indices, + distances, + params->metric); } } diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 2088bf0923..678a6b8936 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -475,13 +475,13 @@ class UMAP(UniversalBase, umap_params.nn_descent_params.intermediate_graph_degree = 128 umap_params.nn_descent_params.max_iterations = 20 umap_params.nn_descent_params.termination_threshold = 0.0001 - umap_params.nn_descent_params.return_distances = True + umap_params.nn_descent_params.return_distances = False else: umap_params.nn_descent_params.graph_degree = cls.build_kwds.get("nnd_graph_degree", 64) umap_params.nn_descent_params.intermediate_graph_degree = cls.build_kwds.get("nnd_intermediate_graph_degree", 128) umap_params.nn_descent_params.max_iterations = cls.build_kwds.get("nnd_max_iterations", 20) umap_params.nn_descent_params.termination_threshold = cls.build_kwds.get("nnd_termination_threshold", 0.0001) - umap_params.nn_descent_params.return_distances = True + umap_params.nn_descent_params.return_distances = cls.build_kwds.get("nnd_return_distances", False) umap_params.target_weight = cls.target_weight umap_params.random_state = cls.random_state umap_params.deterministic = cls.deterministic diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 8a167a4038..00d8f54cd1 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -114,41 +114,45 @@ def test_umap_fit_transform_score(nrows, n_feats, build_algo): assert array_equal(score, cuml_score, 1e-2, with_sign=True) -def test_supervised_umap_trustworthiness_on_iris(): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_supervised_umap_trustworthiness_on_iris(build_algo): iris = datasets.load_iris() data = iris.data embedding = cuUMAP( - n_neighbors=10, random_state=0, min_dist=0.01 + n_neighbors=10, random_state=0, min_dist=0.01, build_algo=build_algo ).fit_transform(data, iris.target, convert_dtype=True) trust = trustworthiness(iris.data, embedding, n_neighbors=10) assert trust >= 0.97 -def test_semisupervised_umap_trustworthiness_on_iris(): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_semisupervised_umap_trustworthiness_on_iris(build_algo): iris = datasets.load_iris() data = iris.data target = iris.target.copy() target[25:75] = -1 embedding = cuUMAP( - n_neighbors=10, random_state=0, min_dist=0.01 + n_neighbors=10, random_state=0, min_dist=0.01, build_algo=build_algo ).fit_transform(data, target, convert_dtype=True) trust = trustworthiness(iris.data, embedding, n_neighbors=10) assert trust >= 0.97 -def test_umap_trustworthiness_on_iris(): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_umap_trustworthiness_on_iris(build_algo): iris = datasets.load_iris() data = iris.data embedding = cuUMAP( - n_neighbors=10, min_dist=0.01, random_state=0 + n_neighbors=10, min_dist=0.01, random_state=0, build_algo=build_algo ).fit_transform(data, convert_dtype=True) trust = trustworthiness(iris.data, embedding, n_neighbors=10) assert trust >= 0.97 @pytest.mark.parametrize("target_metric", ["categorical", "euclidean"]) -def test_umap_transform_on_iris(target_metric): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_umap_transform_on_iris(target_metric, build_algo): iris = datasets.load_iris() @@ -164,6 +168,7 @@ def test_umap_transform_on_iris(target_metric): min_dist=0.01, random_state=42, target_metric=target_metric, + build_algo=build_algo, ) fitter.fit(data, convert_dtype=True) new_data = iris.data[~iris_selection] @@ -266,6 +271,7 @@ def test_umap_transform_on_digits(target_metric, build_algo): ("digits", "brute_force_knn"), ("wine", "brute_force_knn"), ("blobs", "brute_force_knn"), + ("iris", "nn_descent"), ("digits", "nn_descent"), ("wine", "nn_descent"), ("blobs", "nn_descent"), @@ -528,18 +534,26 @@ def get_embedding(n_components, random_state): assert mean_diff > 0.5 -def test_umap_fit_transform_trustworthiness_with_consistency_enabled(): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_umap_fit_transform_trustworthiness_with_consistency_enabled( + build_algo, +): iris = datasets.load_iris() data = iris.data algo = cuUMAP( - n_neighbors=10, min_dist=0.01, init="random", random_state=42 + n_neighbors=10, + min_dist=0.01, + init="random", + random_state=42, + build_algo=build_algo, ) embedding = algo.fit_transform(data, convert_dtype=True) trust = trustworthiness(iris.data, embedding, n_neighbors=10) assert trust >= 0.97 -def test_umap_transform_trustworthiness_with_consistency_enabled(): +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) +def test_umap_transform_trustworthiness_with_consistency_enabled(build_algo): iris = datasets.load_iris() data = iris.data selection = np.random.RandomState(42).choice( @@ -548,7 +562,11 @@ def test_umap_transform_trustworthiness_with_consistency_enabled(): fit_data = data[selection] transform_data = data[~selection] model = cuUMAP( - n_neighbors=10, min_dist=0.01, init="random", random_state=42 + n_neighbors=10, + min_dist=0.01, + init="random", + random_state=42, + build_algo=build_algo, ) model.fit(fit_data, convert_dtype=True) embedding = model.transform(transform_data, convert_dtype=True) @@ -755,13 +773,10 @@ def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors, build_algo): ("canberra", True), ], ) -@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) -def test_umap_distance_metrics_fit_transform_trust( - metric, supported, build_algo -): +def test_umap_distance_metrics_fit_transform_trust(metric, supported): data, labels = make_blobs( n_samples=1000, n_features=64, centers=5, random_state=42 ) @@ -777,7 +792,6 @@ def test_umap_distance_metrics_fit_transform_trust( min_dist=0.01, metric=metric, init="random", - build_algo=build_algo, ) if not supported: with pytest.raises(NotImplementedError): From e2b98c390f483497a8124de308200b9e07514d35 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Sat, 22 Jun 2024 20:15:37 +0000 Subject: [PATCH 16/42] add to param names --- python/cuml/manifold/umap.pyx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 678a6b8936..0e9fe24407 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -832,7 +832,9 @@ class UMAP(UniversalBase, "callback", "metric", "metric_kwds", - "precomputed_knn" + "precomputed_knn", + "build_algo", + "build_kwds" ] def get_attr_names(self): From 841f8baee3091eeb03b9e10ec7959b36910e7876 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Sat, 22 Jun 2024 23:00:07 +0000 Subject: [PATCH 17/42] change threshold for test --- python/cuml/tests/test_umap.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 00d8f54cd1..fd921ad79e 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -751,9 +751,14 @@ def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors, build_algo): cu_fss_graph = cu_fss_graph.todense() ref_fss_graph = cupyx.scipy.sparse.coo_matrix(ref_fss_graph).todense() - assert correctness_sparse( - ref_fss_graph, cu_fss_graph, atol=0.1, rtol=0.2, threshold=0.95 - ) + if build_algo == "brute_force_knn": + assert correctness_sparse( + ref_fss_graph, cu_fss_graph, atol=0.1, rtol=0.2, threshold=0.95 + ) + else: + assert correctness_sparse( + ref_fss_graph, cu_fss_graph, atol=0.1, rtol=0.2, threshold=0.92 + ) @pytest.mark.parametrize( From 2b391f007e6409e01bfd7bdc3d922623e8d83015 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Mon, 24 Jun 2024 17:00:50 +0000 Subject: [PATCH 18/42] threshold for iris dataset --- python/cuml/tests/test_umap.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index fd921ad79e..5604c996b7 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -177,7 +177,10 @@ def test_umap_transform_on_iris(target_metric, build_algo): assert not np.isnan(embedding).any() trust = trustworthiness(new_data, embedding, n_neighbors=10) - assert trust >= 0.85 + if build_algo == "brute_force_knn": + assert trust >= 0.85 + else: + assert trust >= 0.82 @pytest.mark.parametrize("input_type", ["cupy", "scipy"]) From e2a36a6c04f8a13c40dd8ec3e89d2329230962d6 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Mon, 24 Jun 2024 22:42:34 +0000 Subject: [PATCH 19/42] add warning for small dataset + nnd --- python/cuml/manifold/umap.pyx | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 0e9fe24407..c4d2722d20 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -21,6 +21,7 @@ np = cpu_only_import('numpy') pd = cpu_only_import('pandas') import joblib +import warnings from cuml.internals.safe_imports import gpu_only_import cupy = gpu_only_import('cupy') @@ -475,13 +476,13 @@ class UMAP(UniversalBase, umap_params.nn_descent_params.intermediate_graph_degree = 128 umap_params.nn_descent_params.max_iterations = 20 umap_params.nn_descent_params.termination_threshold = 0.0001 - umap_params.nn_descent_params.return_distances = False + umap_params.nn_descent_params.return_distances = True else: umap_params.nn_descent_params.graph_degree = cls.build_kwds.get("nnd_graph_degree", 64) umap_params.nn_descent_params.intermediate_graph_degree = cls.build_kwds.get("nnd_intermediate_graph_degree", 128) umap_params.nn_descent_params.max_iterations = cls.build_kwds.get("nnd_max_iterations", 20) umap_params.nn_descent_params.termination_threshold = cls.build_kwds.get("nnd_termination_threshold", 0.0001) - umap_params.nn_descent_params.return_distances = cls.build_kwds.get("nnd_return_distances", False) + umap_params.nn_descent_params.return_distances = cls.build_kwds.get("nnd_return_distances", True) umap_params.target_weight = cls.target_weight umap_params.random_state = cls.random_state umap_params.deterministic = cls.deterministic @@ -568,6 +569,9 @@ class UMAP(UniversalBase, if self.n_rows <= 1: raise ValueError("There needs to be more than 1 sample to " "build nearest the neighbors graph") + if self.build_algo == "nn_descent" and self.n_rows < 150: + # https://github.com/rapidsai/cuvs/issues/184 + warnings.warn("using nn_descent as build_algo on a small dataset (< 150 samples) is unstable") cdef uintptr_t _knn_dists_ptr = 0 cdef uintptr_t _knn_indices_ptr = 0 From 2759663c979db93d62f83184d6c4a45623971818 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Mon, 24 Jun 2024 22:43:07 +0000 Subject: [PATCH 20/42] revert back to not refining --- cpp/src/umap/knn_graph/algo.cuh | 63 +++++++++++++++------------------ 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index ae72c3b182..d505295084 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -55,6 +55,12 @@ void launcher(const raft::handle_t& handle, const ML::UMAPParams* params, cudaStream_t stream); +// Functor to post-process distances as L2Sqrt* +template +struct DistancePostProcessSqrt { + DI value_t operator()(value_t value, value_idx row, value_idx col) const { return sqrtf(value); } +}; + // Instantiation for dense inputs, int64_t indices template <> inline void launcher(const raft::handle_t& handle, @@ -88,48 +94,35 @@ inline void launcher(const raft::handle_t& handle, } else { // nn_descent RAFT_EXPECTS(static_cast(n_neighbors) <= params->nn_descent_params.graph_degree, "n_neighbors should be smaller than the graph degree computed by nn descent"); + auto epilogue = DistancePostProcessSqrt{}; auto dataset = raft::make_host_matrix_view(inputsA.X, inputsA.n, inputsA.d); - auto graph = - NNDescent::detail::build(handle, params->nn_descent_params, dataset); + auto graph = NNDescent::detail::build( + handle, params->nn_descent_params, dataset, epilogue); + // nn descent does not include itself as its closest neighbor for (int i = 0; i < inputsB.n; i++) { - for (size_t j = n_neighbors - 1; j >= 1; j--) { - graph.graph().data_handle()[i * params->nn_descent_params.graph_degree + j] = - graph.graph().data_handle()[i * params->nn_descent_params.graph_degree + j - 1]; + if (graph.distances().has_value()) { + raft::copy( + out.knn_dists + i * n_neighbors + 1, + graph.distances().value().data_handle() + i * params->nn_descent_params.graph_degree, + n_neighbors - 1, + handle.get_stream()); + thrust::fill(thrust::device.on(stream), + out.knn_dists + i * n_neighbors, + out.knn_dists + i * n_neighbors + 1, + 0.0); } - graph.graph().data_handle()[i * params->nn_descent_params.graph_degree] = i; + raft::copy(out.knn_indices + i * n_neighbors + 1, + graph.graph().data_handle() + i * params->nn_descent_params.graph_degree, + n_neighbors - 1, + handle.get_stream()); + thrust::fill(thrust::device.on(stream), + out.knn_indices + i * n_neighbors, + out.knn_indices + i * n_neighbors + 1, + i); } - - auto dataset_dev = - raft::make_device_matrix(handle, inputsB.n, inputsA.d); - raft::copy( - dataset_dev.data_handle(), dataset.data_handle(), inputsB.n * inputsA.d, handle.get_stream()); - auto dataset_dev_view = raft::make_device_matrix_view( - dataset_dev.data_handle(), inputsB.n, inputsA.d); - - auto neighbor_candidates = raft::make_device_matrix( - handle, inputsB.n, params->nn_descent_params.graph_degree); - raft::copy(neighbor_candidates.data_handle(), - graph.graph().data_handle(), - inputsB.n * params->nn_descent_params.graph_degree, - handle.get_stream()); - auto neighbor_candidates_view = - raft::make_device_matrix_view( - neighbor_candidates.data_handle(), inputsB.n, params->nn_descent_params.graph_degree); - - auto indices = - raft::make_device_matrix_view(out.knn_indices, inputsB.n, n_neighbors); - auto distances = - raft::make_device_matrix_view(out.knn_dists, inputsB.n, n_neighbors); - raft::neighbors::refine(handle, - dataset_dev_view, - dataset_dev_view, - neighbor_candidates_view, - indices, - distances, - params->metric); } } From 4fafc312a28d86c8ccdb82376d1e69f0f9c47944 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Mon, 24 Jun 2024 23:04:25 +0000 Subject: [PATCH 21/42] fix tests --- python/cuml/tests/test_umap.py | 66 ++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 5604c996b7..0464cc9900 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -150,9 +150,9 @@ def test_umap_trustworthiness_on_iris(build_algo): assert trust >= 0.97 +# nn descent is unstable with small datasets @pytest.mark.parametrize("target_metric", ["categorical", "euclidean"]) -@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) -def test_umap_transform_on_iris(target_metric, build_algo): +def test_umap_transform_on_iris(target_metric): iris = datasets.load_iris() @@ -168,7 +168,6 @@ def test_umap_transform_on_iris(target_metric, build_algo): min_dist=0.01, random_state=42, target_metric=target_metric, - build_algo=build_algo, ) fitter.fit(data, convert_dtype=True) new_data = iris.data[~iris_selection] @@ -177,10 +176,7 @@ def test_umap_transform_on_iris(target_metric, build_algo): assert not np.isnan(embedding).any() trust = trustworthiness(new_data, embedding, n_neighbors=10) - if build_algo == "brute_force_knn": - assert trust >= 0.85 - else: - assert trust >= 0.82 + assert trust >= 0.85 @pytest.mark.parametrize("input_type", ["cupy", "scipy"]) @@ -267,19 +263,8 @@ def test_umap_transform_on_digits(target_metric, build_algo): @pytest.mark.parametrize("target_metric", ["categorical", "euclidean"]) -@pytest.mark.parametrize( - "name,build_algo", - [ - ("iris", "brute_force_knn"), - ("digits", "brute_force_knn"), - ("wine", "brute_force_knn"), - ("blobs", "brute_force_knn"), - ("iris", "nn_descent"), - ("digits", "nn_descent"), - ("wine", "nn_descent"), - ("blobs", "nn_descent"), - ], -) +@pytest.mark.parametrize("name", dataset_names) +@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) @@ -555,8 +540,8 @@ def test_umap_fit_transform_trustworthiness_with_consistency_enabled( assert trust >= 0.97 -@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) -def test_umap_transform_trustworthiness_with_consistency_enabled(build_algo): +# nn descent is unstable with small datasets +def test_umap_transform_trustworthiness_with_consistency_enabled(): iris = datasets.load_iris() data = iris.data selection = np.random.RandomState(42).choice( @@ -569,7 +554,6 @@ def test_umap_transform_trustworthiness_with_consistency_enabled(build_algo): min_dist=0.01, init="random", random_state=42, - build_algo=build_algo, ) model.fit(fit_data, convert_dtype=True) embedding = model.transform(transform_data, convert_dtype=True) @@ -577,6 +561,31 @@ def test_umap_transform_trustworthiness_with_consistency_enabled(build_algo): assert trust >= 0.92 +@pytest.mark.parametrize("build_algo", ["nn_descent"]) +def test_umap_transform_trustworthiness_with_consistency_enabled_digits( + build_algo, +): + digits = datasets.load_digits() + data = digits.data + digits_selection = np.random.RandomState(42).choice( + [True, False], 1797, replace=True, p=[0.75, 0.25] + ) + fit_data = digits.data[digits_selection] + transform_data = data[~digits_selection] + model = cuUMAP( + n_neighbors=10, + min_dist=0.01, + init="random", + random_state=42, + build_algo=build_algo, + ) + print(fit_data.shape) + model.fit(fit_data, convert_dtype=True) + embedding = model.transform(transform_data, convert_dtype=True) + trust = trustworthiness(transform_data, embedding, n_neighbors=10) + assert trust >= 0.95 + + @pytest.mark.filterwarnings("ignore:(.*)zero(.*)::scipy[.*]|umap[.*]") @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" @@ -754,14 +763,9 @@ def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors, build_algo): cu_fss_graph = cu_fss_graph.todense() ref_fss_graph = cupyx.scipy.sparse.coo_matrix(ref_fss_graph).todense() - if build_algo == "brute_force_knn": - assert correctness_sparse( - ref_fss_graph, cu_fss_graph, atol=0.1, rtol=0.2, threshold=0.95 - ) - else: - assert correctness_sparse( - ref_fss_graph, cu_fss_graph, atol=0.1, rtol=0.2, threshold=0.92 - ) + assert correctness_sparse( + ref_fss_graph, cu_fss_graph, atol=0.1, rtol=0.2, threshold=0.93 + ) @pytest.mark.parametrize( From 8f8547873468f410a7701c6636bedf84ee6ab5f3 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Mon, 24 Jun 2024 23:09:07 +0000 Subject: [PATCH 22/42] remove print --- python/cuml/tests/test_umap.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 0464cc9900..c8d725560a 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -579,7 +579,6 @@ def test_umap_transform_trustworthiness_with_consistency_enabled_digits( random_state=42, build_algo=build_algo, ) - print(fit_data.shape) model.fit(fit_data, convert_dtype=True) embedding = model.transform(transform_data, convert_dtype=True) trust = trustworthiness(transform_data, embedding, n_neighbors=10) @@ -647,7 +646,6 @@ def test_trustworthiness(embedding): def test_equality(e1, e2): mean_diff = np.mean(np.abs(e1 - e2)) - print("mean diff: %s" % mean_diff) assert mean_diff < 1.0 neigh = NearestNeighbors(n_neighbors=n_neighbors) From a4705e0aa3ad905bf4cb8424b6026af19241d85f Mon Sep 17 00:00:00 2001 From: jinsolp Date: Thu, 11 Jul 2024 22:12:14 +0000 Subject: [PATCH 23/42] change copy as kernel --- cpp/src/umap/knn_graph/algo.cuh | 56 ++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index d505295084..3b374d4465 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -18,6 +18,7 @@ #pragma once +#include #include #include #include @@ -61,6 +62,17 @@ struct DistancePostProcessSqrt { DI value_t operator()(value_t value, value_idx row, value_idx col) const { return sqrtf(value); } }; +template +CUML_KERNEL void copy_first_k_cols(T* out, T* in, size_t out_k, size_t in_k, size_t nrows) +{ + size_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row < nrows) { + for (size_t i = 0; i < out_k; i++) { + out[row * out_k + i] = in[row * in_k + i]; + } + } +} + // Instantiation for dense inputs, int64_t indices template <> inline void launcher(const raft::handle_t& handle, @@ -101,28 +113,30 @@ inline void launcher(const raft::handle_t& handle, auto graph = NNDescent::detail::build( handle, params->nn_descent_params, dataset, epilogue); - // nn descent does not include itself as its closest neighbor - for (int i = 0; i < inputsB.n; i++) { - if (graph.distances().has_value()) { - raft::copy( - out.knn_dists + i * n_neighbors + 1, - graph.distances().value().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors - 1, - handle.get_stream()); - thrust::fill(thrust::device.on(stream), - out.knn_dists + i * n_neighbors, - out.knn_dists + i * n_neighbors + 1, - 0.0); - } - raft::copy(out.knn_indices + i * n_neighbors + 1, - graph.graph().data_handle() + i * params->nn_descent_params.graph_degree, - n_neighbors - 1, - handle.get_stream()); - thrust::fill(thrust::device.on(stream), - out.knn_indices + i * n_neighbors, - out.knn_indices + i * n_neighbors + 1, - i); + auto indices_d = raft::make_device_matrix( + handle, inputsA.n, params->nn_descent_params.graph_degree); + + raft::copy(indices_d.data_handle(), + graph.graph().data_handle(), + inputsA.n * params->nn_descent_params.graph_degree, + stream); + + size_t TPB = 256; + size_t num_blocks = static_cast((inputsA.n + TPB) / TPB); + if (graph.distances().has_value()) { + copy_first_k_cols + <<>>(out.knn_dists, + graph.distances().value().data_handle(), + static_cast(n_neighbors), + params->nn_descent_params.graph_degree, + inputsA.n); } + copy_first_k_cols + <<>>(out.knn_indices, + indices_d.data_handle(), + static_cast(n_neighbors), + params->nn_descent_params.graph_degree, + inputsA.n); } } From 0fd69516cb1086cc5c4e7761d759217fcb5c6708 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Thu, 11 Jul 2024 22:28:57 +0000 Subject: [PATCH 24/42] add detailed doc and warning --- python/cuml/manifold/umap.pyx | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index c4d2722d20..f36bb80ea3 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -293,8 +293,8 @@ class UMAP(UniversalBase, build_algo: string (default='brute_force_knn') How to build the knn graph. Supported build algorithms are ['brute_force_knn', 'nn_descent'] - metric_kwds: dict (optional, default=None) - Build algorithm argument + build_kwds: dict (optional, default=None) + Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128, 'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True} Notes ----- @@ -354,9 +354,9 @@ class UMAP(UniversalBase, callback=None, handle=None, verbose=False, - output_type=None, build_algo="brute_force_knn", - build_kwds=None): + build_kwds=None + output_type=None,): super().__init__(handle=handle, verbose=verbose, @@ -774,6 +774,7 @@ class UMAP(UniversalBase, self.sparse_fit) # NN Descent doesn't support transform yet umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN + logger.warn("NN Descent does not support transform. Using Brute force instead.") cdef handle_t * handle_ = \ self.handle.getHandle() From 8585538e430de923e82d8bcbb13a137a0d869c19 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Thu, 11 Jul 2024 22:29:51 +0000 Subject: [PATCH 25/42] newline --- python/cuml/manifold/umap.pyx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index f36bb80ea3..3684f93f7f 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -294,7 +294,8 @@ class UMAP(UniversalBase, How to build the knn graph. Supported build algorithms are ['brute_force_knn', 'nn_descent'] build_kwds: dict (optional, default=None) - Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128, 'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True} + Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128, + 'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True} Notes ----- From 7445922c56b42a55d4a0d5a8b9d2be09f2b84c24 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Thu, 11 Jul 2024 22:38:09 +0000 Subject: [PATCH 26/42] raise error for sparse --- python/cuml/manifold/umap.pyx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 3684f93f7f..b5ffbda7bb 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -558,6 +558,8 @@ class UMAP(UniversalBase, convert_format=False) self.n_rows, self.n_dims = self._raw_data.shape self.sparse_fit = True + if self.build_algo == "nn_descent": + raise ValueError("NN Descent does not support sparse inputs") # Handle dense inputs else: From 520a017d12e9a683c05ca826d98bc1d6eae0f071 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Thu, 11 Jul 2024 22:51:47 +0000 Subject: [PATCH 27/42] fix typo --- python/cuml/manifold/umap.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index b5ffbda7bb..59f3ed500d 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -356,8 +356,8 @@ class UMAP(UniversalBase, handle=None, verbose=False, build_algo="brute_force_knn", - build_kwds=None - output_type=None,): + build_kwds=None, + output_type=None): super().__init__(handle=handle, verbose=verbose, From 7f7179bf13efd3d0fc14390e44831aaf1174b4ce Mon Sep 17 00:00:00 2001 From: jinsolp Date: Thu, 11 Jul 2024 22:52:57 +0000 Subject: [PATCH 28/42] change tests --- python/cuml/tests/test_umap.py | 58 ++++++---------------------------- 1 file changed, 10 insertions(+), 48 deletions(-) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index c8d725560a..703a1b2c36 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -114,37 +114,34 @@ def test_umap_fit_transform_score(nrows, n_feats, build_algo): assert array_equal(score, cuml_score, 1e-2, with_sign=True) -@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) -def test_supervised_umap_trustworthiness_on_iris(build_algo): +def test_supervised_umap_trustworthiness_on_iris(): iris = datasets.load_iris() data = iris.data embedding = cuUMAP( - n_neighbors=10, random_state=0, min_dist=0.01, build_algo=build_algo + n_neighbors=10, random_state=0, min_dist=0.01 ).fit_transform(data, iris.target, convert_dtype=True) trust = trustworthiness(iris.data, embedding, n_neighbors=10) assert trust >= 0.97 -@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) -def test_semisupervised_umap_trustworthiness_on_iris(build_algo): +def test_semisupervised_umap_trustworthiness_on_iris(): iris = datasets.load_iris() data = iris.data target = iris.target.copy() target[25:75] = -1 embedding = cuUMAP( - n_neighbors=10, random_state=0, min_dist=0.01, build_algo=build_algo + n_neighbors=10, random_state=0, min_dist=0.01 ).fit_transform(data, target, convert_dtype=True) trust = trustworthiness(iris.data, embedding, n_neighbors=10) assert trust >= 0.97 -@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) -def test_umap_trustworthiness_on_iris(build_algo): +def test_umap_trustworthiness_on_iris(): iris = datasets.load_iris() data = iris.data embedding = cuUMAP( - n_neighbors=10, min_dist=0.01, random_state=0, build_algo=build_algo + n_neighbors=10, min_dist=0.01, random_state=0 ).fit_transform(data, convert_dtype=True) trust = trustworthiness(iris.data, embedding, n_neighbors=10) assert trust >= 0.97 @@ -231,8 +228,7 @@ def test_umap_transform_on_digits_sparse( @pytest.mark.parametrize("target_metric", ["categorical", "euclidean"]) -@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) -def test_umap_transform_on_digits(target_metric, build_algo): +def test_umap_transform_on_digits(target_metric): digits = datasets.load_digits() @@ -249,7 +245,6 @@ def test_umap_transform_on_digits(target_metric, build_algo): min_dist=0.01, random_state=42, target_metric=target_metric, - build_algo=build_algo, ) fitter.fit(data, convert_dtype=True) @@ -264,11 +259,10 @@ def test_umap_transform_on_digits(target_metric, build_algo): @pytest.mark.parametrize("target_metric", ["categorical", "euclidean"]) @pytest.mark.parametrize("name", dataset_names) -@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) -def test_umap_fit_transform_trust(name, target_metric, build_algo): +def test_umap_fit_transform_trust(name, target_metric): if name == "iris": iris = datasets.load_iris() @@ -296,7 +290,6 @@ def test_umap_fit_transform_trust(name, target_metric, build_algo): n_neighbors=10, min_dist=0.01, target_metric=target_metric, - build_algo=build_algo, ) embedding = model.fit_transform(data) cuml_embedding = cuml_model.fit_transform(data, convert_dtype=True) @@ -522,10 +515,7 @@ def get_embedding(n_components, random_state): assert mean_diff > 0.5 -@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) -def test_umap_fit_transform_trustworthiness_with_consistency_enabled( - build_algo, -): +def test_umap_fit_transform_trustworthiness_with_consistency_enabled(): iris = datasets.load_iris() data = iris.data algo = cuUMAP( @@ -533,14 +523,12 @@ def test_umap_fit_transform_trustworthiness_with_consistency_enabled( min_dist=0.01, init="random", random_state=42, - build_algo=build_algo, ) embedding = algo.fit_transform(data, convert_dtype=True) trust = trustworthiness(iris.data, embedding, n_neighbors=10) assert trust >= 0.97 -# nn descent is unstable with small datasets def test_umap_transform_trustworthiness_with_consistency_enabled(): iris = datasets.load_iris() data = iris.data @@ -561,30 +549,6 @@ def test_umap_transform_trustworthiness_with_consistency_enabled(): assert trust >= 0.92 -@pytest.mark.parametrize("build_algo", ["nn_descent"]) -def test_umap_transform_trustworthiness_with_consistency_enabled_digits( - build_algo, -): - digits = datasets.load_digits() - data = digits.data - digits_selection = np.random.RandomState(42).choice( - [True, False], 1797, replace=True, p=[0.75, 0.25] - ) - fit_data = digits.data[digits_selection] - transform_data = data[~digits_selection] - model = cuUMAP( - n_neighbors=10, - min_dist=0.01, - init="random", - random_state=42, - build_algo=build_algo, - ) - model.fit(fit_data, convert_dtype=True) - embedding = model.transform(transform_data, convert_dtype=True) - trust = trustworthiness(transform_data, embedding, n_neighbors=10) - assert trust >= 0.95 - - @pytest.mark.filterwarnings("ignore:(.*)zero(.*)::scipy[.*]|umap[.*]") @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" @@ -736,8 +700,7 @@ def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.95): @pytest.mark.skipif( IS_ARM, reason="https://github.com/rapidsai/cuml/issues/5441" ) -@pytest.mark.parametrize("build_algo", ["brute_force_knn", "nn_descent"]) -def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors, build_algo): +def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors): n_clusters = 30 random_state = 42 @@ -750,7 +713,6 @@ def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors, build_algo): model = cuUMAP( n_neighbors=n_neighbors, - build_algo=build_algo, ) model.fit(X) cu_fss_graph = model.graph_ From f8f24b990eb14c249f91f27f0eeb0b6440f35b9b Mon Sep 17 00:00:00 2001 From: jinsolp Date: Fri, 12 Jul 2024 22:31:44 +0000 Subject: [PATCH 29/42] cleanup algo.cuh and test --- cpp/src/umap/knn_graph/algo.cuh | 3 --- python/cuml/tests/test_umap.py | 18 ++++-------------- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 3b374d4465..9684af325a 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -1,5 +1,3 @@ - - /* * Copyright (c) 2019-2024, NVIDIA CORPORATION. * @@ -30,7 +28,6 @@ #include #include #include -#include #include #include #include diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 703a1b2c36..24b88cc3ff 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -27,10 +27,8 @@ from sklearn.cluster import KMeans from sklearn.neighbors import NearestNeighbors from sklearn import datasets -from cuml.datasets import make_blobs as cu_make_blobs from cuml.internals import logger from cuml.metrics import pairwise_distances -from cuml.metrics import trustworthiness as cu_trustworthiness from cuml.testing.utils import ( array_equal, unit_param, @@ -147,7 +145,6 @@ def test_umap_trustworthiness_on_iris(): assert trust >= 0.97 -# nn descent is unstable with small datasets @pytest.mark.parametrize("target_metric", ["categorical", "euclidean"]) def test_umap_transform_on_iris(target_metric): @@ -287,9 +284,7 @@ def test_umap_fit_transform_trust(name, target_metric): n_neighbors=10, min_dist=0.01, target_metric=target_metric ) cuml_model = cuUMAP( - n_neighbors=10, - min_dist=0.01, - target_metric=target_metric, + n_neighbors=10, min_dist=0.01, target_metric=target_metric ) embedding = model.fit_transform(data) cuml_embedding = cuml_model.fit_transform(data, convert_dtype=True) @@ -711,9 +706,7 @@ def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors): random_state=random_state, ) - model = cuUMAP( - n_neighbors=n_neighbors, - ) + model = cuUMAP(n_neighbors=n_neighbors) model.fit(X) cu_fss_graph = model.graph_ @@ -724,7 +717,7 @@ def test_fuzzy_simplicial_set(n_rows, n_features, n_neighbors): cu_fss_graph = cu_fss_graph.todense() ref_fss_graph = cupyx.scipy.sparse.coo_matrix(ref_fss_graph).todense() assert correctness_sparse( - ref_fss_graph, cu_fss_graph, atol=0.1, rtol=0.2, threshold=0.93 + ref_fss_graph, cu_fss_graph, atol=0.1, rtol=0.2, threshold=0.95 ) @@ -760,10 +753,7 @@ def test_umap_distance_metrics_fit_transform_trust(metric, supported): n_neighbors=10, min_dist=0.01, metric=metric, init="random" ) cuml_model = cuUMAP( - n_neighbors=10, - min_dist=0.01, - metric=metric, - init="random", + n_neighbors=10, min_dist=0.01, metric=metric, init="random" ) if not supported: with pytest.raises(NotImplementedError): From 6330a7060823895347cf0332f454edca196e3d80 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Thu, 18 Jul 2024 01:02:34 +0000 Subject: [PATCH 30/42] change to using raft slicing kernel --- cpp/src/umap/knn_graph/algo.cuh | 41 +++++++++++++-------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 9684af325a..54c2e51e10 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -21,9 +21,12 @@ #include #include +#include #include #include #include +#include +#include #include #include #include @@ -59,17 +62,6 @@ struct DistancePostProcessSqrt { DI value_t operator()(value_t value, value_idx row, value_idx col) const { return sqrtf(value); } }; -template -CUML_KERNEL void copy_first_k_cols(T* out, T* in, size_t out_k, size_t in_k, size_t nrows) -{ - size_t row = blockIdx.x * blockDim.x + threadIdx.x; - if (row < nrows) { - for (size_t i = 0; i < out_k; i++) { - out[row * out_k + i] = in[row * in_k + i]; - } - } -} - // Instantiation for dense inputs, int64_t indices template <> inline void launcher(const raft::handle_t& handle, @@ -118,22 +110,21 @@ inline void launcher(const raft::handle_t& handle, inputsA.n * params->nn_descent_params.graph_degree, stream); - size_t TPB = 256; - size_t num_blocks = static_cast((inputsA.n + TPB) / TPB); + raft::matrix::slice_coordinates coords{static_cast(0), + static_cast(0), + static_cast(inputsA.n), + static_cast(n_neighbors)}; + if (graph.distances().has_value()) { - copy_first_k_cols - <<>>(out.knn_dists, - graph.distances().value().data_handle(), - static_cast(n_neighbors), - params->nn_descent_params.graph_degree, - inputsA.n); + auto out_knn_dists_view = + raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors); + raft::matrix::slice( + handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); } - copy_first_k_cols - <<>>(out.knn_indices, - indices_d.data_handle(), - static_cast(n_neighbors), - params->nn_descent_params.graph_degree, - inputsA.n); + auto out_knn_indices_view = + raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors); + raft::matrix::slice( + handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords); } } From 3e955aca2d11a5b964e9ce81887fe72f91f95240 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Thu, 18 Jul 2024 15:31:46 +0000 Subject: [PATCH 31/42] add header and change namespace --- cpp/src/umap/knn_graph/algo.cuh | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 54c2e51e10..33f7033815 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -29,7 +29,8 @@ #include #include #include -#include +#include +#include #include #include #include @@ -98,9 +99,9 @@ inline void launcher(const raft::handle_t& handle, auto epilogue = DistancePostProcessSqrt{}; auto dataset = - raft::make_host_matrix_view(inputsA.X, inputsA.n, inputsA.d); - auto graph = NNDescent::detail::build( - handle, params->nn_descent_params, dataset, epilogue); + raft::make_device_matrix_view(inputsA.X, inputsA.n, inputsA.d); + auto graph = + NNDescent::build(handle, params->nn_descent_params, dataset, epilogue); auto indices_d = raft::make_device_matrix( handle, inputsA.n, params->nn_descent_params.graph_degree); From a6a0b69e394a7234c1daa9d2a54a2638e0c7722b Mon Sep 17 00:00:00 2001 From: jinsolp Date: Wed, 24 Jul 2024 00:33:25 +0000 Subject: [PATCH 32/42] add auto option and enable host on device --- cpp/src/umap/knn_graph/algo.cuh | 24 +++++++++++---- python/cuml/cuml/manifold/umap.pyx | 48 +++++++++++++++++++++--------- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 33f7033815..ec99e6d82c 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -63,6 +63,24 @@ struct DistancePostProcessSqrt { DI value_t operator()(value_t value, value_idx row, value_idx col) const { return sqrtf(value); } }; +auto get_graph_nnd(const raft::handle_t& handle, + const ML::manifold_dense_inputs_t& inputs, + const ML::UMAPParams* params) +{ + auto epilogue = DistancePostProcessSqrt{}; + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputs.X)); + float* ptr = reinterpret_cast(attr.devicePointer); + if (ptr != nullptr) { + auto dataset = + raft::make_device_matrix_view(inputs.X, inputs.n, inputs.d); + return NNDescent::build(handle, params->nn_descent_params, dataset, epilogue); + } else { + auto dataset = raft::make_host_matrix_view(inputs.X, inputs.n, inputs.d); + return NNDescent::build(handle, params->nn_descent_params, dataset, epilogue); + } +} + // Instantiation for dense inputs, int64_t indices template <> inline void launcher(const raft::handle_t& handle, @@ -96,12 +114,8 @@ inline void launcher(const raft::handle_t& handle, } else { // nn_descent RAFT_EXPECTS(static_cast(n_neighbors) <= params->nn_descent_params.graph_degree, "n_neighbors should be smaller than the graph degree computed by nn descent"); - auto epilogue = DistancePostProcessSqrt{}; - auto dataset = - raft::make_device_matrix_view(inputsA.X, inputsA.n, inputsA.d); - auto graph = - NNDescent::build(handle, params->nn_descent_params, dataset, epilogue); + auto graph = get_graph_nnd(handle, inputsA, params); auto indices_d = raft::make_device_matrix( handle, inputsA.n, params->nn_descent_params.graph_degree); diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 59f3ed500d..116f394280 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -41,6 +41,7 @@ from cuml.internals.available_devices import is_cuda_available from cuml.internals.input_utils import input_to_cuml_array from cuml.internals.array import CumlArray from cuml.internals.array_sparse import SparseCumlArray +from cuml.internals.mem_type import MemoryType from cuml.internals.mixins import CMajorInputTagMixin from cuml.common.sparse_utils import is_sparse @@ -290,9 +291,9 @@ class UMAP(UniversalBase, type. If None, the output type set at the module level (`cuml.global_settings.output_type`) will be used. See :ref:`output-data-type-configuration` for more info. - build_algo: string (default='brute_force_knn') - How to build the knn graph. Supported build algorithms are ['brute_force_knn', - 'nn_descent'] + build_algo: string (default='auto') + How to build the knn graph. Supported build algorithms are ['auto', 'brute_force_knn', + 'nn_descent']. 'auto' chooses to run with brute force knn or nn descent based on the dataset size. build_kwds: dict (optional, default=None) Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128, 'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True} @@ -355,7 +356,7 @@ class UMAP(UniversalBase, callback=None, handle=None, verbose=False, - build_algo="brute_force_knn", + build_algo="auto", build_kwds=None, output_type=None): @@ -429,10 +430,10 @@ class UMAP(UniversalBase, self.precomputed_knn = extract_knn_infos(precomputed_knn, n_neighbors) - if build_algo == "brute_force_knn" or build_algo == "nn_descent": + if build_algo == "auto" or build_algo == "brute_force_knn" or build_algo == "nn_descent": self.build_algo = build_algo else: - raise Exception("Invalid build algo: {}. Only support brute_force_knn and nn_descent" % build_algo) + raise Exception("Invalid build algo: {}. Only support auto, brute_force_knn and nn_descent" % build_algo) self.build_kwds = build_kwds @@ -527,7 +528,7 @@ class UMAP(UniversalBase, skip_parameters_heading=True) @enable_device_interop def fit(self, X, y=None, convert_dtype=True, - knn_graph=None) -> "UMAP": + knn_graph=None, data_on_host=False) -> "UMAP": """ Fit X into an embedded space. @@ -563,11 +564,29 @@ class UMAP(UniversalBase, # Handle dense inputs else: + if data_on_host: + convert_to_mem_type = MemoryType.host + else: + convert_to_mem_type = MemoryType.device + self._raw_data, self.n_rows, self.n_dims, _ = \ input_to_cuml_array(X, order='C', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype - else None)) + else None), + convert_to_mem_type=convert_to_mem_type) + + if self.build_algo == "auto": + if self.n_rows * self.n_dims < 45000000 or self.sparse_fit: + # brute force is faster for small datasets + logger.warn("Building knn graph using brute force") + self.build_algo = "brute_force_knn" + else: + logger.warn("Building knn graph using nn descent") + self.build_algo = "nn_descent" + + if self.build_algo == "brute_force_knn" and data_on_host: + raise ValueError("Data cannot be on host for building with brute force knn") if self.n_rows <= 1: raise ValueError("There needs to be more than 1 sample to " @@ -667,7 +686,7 @@ class UMAP(UniversalBase, @cuml.internals.api_base_fit_transform() @enable_device_interop def fit_transform(self, X, y=None, convert_dtype=True, - knn_graph=None) -> CumlArray: + knn_graph=None, data_on_host=False) -> CumlArray: """ Fit X into an embedded space and return that transformed output. @@ -700,7 +719,7 @@ class UMAP(UniversalBase, CSR/COO preferred other formats will go through conversion to CSR """ - self.fit(X, y, convert_dtype=convert_dtype, knn_graph=knn_graph) + self.fit(X, y, convert_dtype=convert_dtype, knn_graph=knn_graph, data_on_host=data_on_host) return self.embedding_ @@ -771,14 +790,15 @@ class UMAP(UniversalBase, cdef uintptr_t _embed_ptr = self.embedding_.ptr + # NN Descent doesn't support transform yet + if self.build_algo == "nn_descent" or self.build_algo == "auto": + self.build_algo = "brute_force_knn" + logger.warn("Transform can only be run with brute force. Using brute force.") + IF GPUBUILD == 1: cdef UMAPParams* umap_params = \ UMAP._build_umap_params(self, self.sparse_fit) - # NN Descent doesn't support transform yet - umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN - logger.warn("NN Descent does not support transform. Using Brute force instead.") - cdef handle_t * handle_ = \ self.handle.getHandle() if self.sparse_fit: From 8de651f5b5c2345455b1aafcb9d82e9ae96febd0 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Wed, 24 Jul 2024 21:06:51 +0000 Subject: [PATCH 33/42] change nrows for auto to run bfk --- python/cuml/cuml/manifold/umap.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 116f394280..fb78ed1ef8 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -577,7 +577,7 @@ class UMAP(UniversalBase, convert_to_mem_type=convert_to_mem_type) if self.build_algo == "auto": - if self.n_rows * self.n_dims < 45000000 or self.sparse_fit: + if self.n_rows <= 50000 or self.sparse_fit: # brute force is faster for small datasets logger.warn("Building knn graph using brute force") self.build_algo = "brute_force_knn" From 750b207fc4c0fa7e1e0aa564390ea840cd381efe Mon Sep 17 00:00:00 2001 From: jinsolp Date: Fri, 26 Jul 2024 21:56:02 +0000 Subject: [PATCH 34/42] address comments and add warning+link to issue --- python/cuml/cuml/manifold/umap.pyx | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index fb78ed1ef8..49d0ab8a94 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -293,7 +293,8 @@ class UMAP(UniversalBase, :ref:`output-data-type-configuration` for more info. build_algo: string (default='auto') How to build the knn graph. Supported build algorithms are ['auto', 'brute_force_knn', - 'nn_descent']. 'auto' chooses to run with brute force knn or nn descent based on the dataset size. + 'nn_descent']. 'auto' chooses to run with brute force knn if number of data rows is + smaller than or equal to 50K. Otherwise, runs with nn descent. build_kwds: dict (optional, default=None) Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128, 'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True} @@ -431,6 +432,12 @@ class UMAP(UniversalBase, n_neighbors) if build_algo == "auto" or build_algo == "brute_force_knn" or build_algo == "nn_descent": + if self.deterministic and build_algo == "auto": + # TODO: for now, users should be able to see the same results as previous version + # (i.e. running brute force knn) when they explicitly pass random_state + # https://github.com/rapidsai/cuml/issues/5985 + logger.warn("build_algo set to brute_force_knn because random_state is given") + self.build_algo ="brute_force_knn" self.build_algo = build_algo else: raise Exception("Invalid build algo: {}. Only support auto, brute_force_knn and nn_descent" % build_algo) From 4688aa3e1b477a6f32eb2324a6c01759da0b0519 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Sun, 28 Jul 2024 21:07:56 +0000 Subject: [PATCH 35/42] change to RAFT_EXPECTS --- cpp/src/umap/knn_graph/algo.cuh | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index ec99e6d82c..f6284e6a91 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -130,12 +130,11 @@ inline void launcher(const raft::handle_t& handle, static_cast(inputsA.n), static_cast(n_neighbors)}; - if (graph.distances().has_value()) { - auto out_knn_dists_view = - raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors); - raft::matrix::slice( - handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); - } + RAFT_EXPECTS(graph.distances().has_value(), + "return_distances for nn descent should be set to true to be used for UMAP"); + auto out_knn_dists_view = raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors); + raft::matrix::slice( + handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); auto out_knn_indices_view = raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors); raft::matrix::slice( From ded94856ef9f81eeee2b463a733e5db8eabca9b6 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Tue, 30 Jul 2024 16:46:39 +0000 Subject: [PATCH 36/42] change warn to log --- python/cuml/cuml/manifold/umap.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 49d0ab8a94..ae9f65f1c8 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -586,10 +586,10 @@ class UMAP(UniversalBase, if self.build_algo == "auto": if self.n_rows <= 50000 or self.sparse_fit: # brute force is faster for small datasets - logger.warn("Building knn graph using brute force") + logger.info("Building knn graph using brute force") self.build_algo = "brute_force_knn" else: - logger.warn("Building knn graph using nn descent") + logger.info("Building knn graph using nn descent") self.build_algo = "nn_descent" if self.build_algo == "brute_force_knn" and data_on_host: From 17c4874f1c3623f2b3bd932eb4f125d2cd605a15 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Tue, 30 Jul 2024 19:06:19 +0000 Subject: [PATCH 37/42] set logger level --- python/cuml/cuml/manifold/umap.pyx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index ae9f65f1c8..e1b60bfefe 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -444,6 +444,8 @@ class UMAP(UniversalBase, self.build_kwds = build_kwds + logger.set_level(verbose) + def validate_hyperparams(self): if self.min_dist > self.spread: From eed52a1a2815aad3a7bf02313eb3af800215cd8d Mon Sep 17 00:00:00 2001 From: Jinsol Park Date: Tue, 30 Jul 2024 14:17:05 -0700 Subject: [PATCH 38/42] change warn to info --- python/cuml/cuml/manifold/umap.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index e1b60bfefe..34978fe5f6 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -436,7 +436,7 @@ class UMAP(UniversalBase, # TODO: for now, users should be able to see the same results as previous version # (i.e. running brute force knn) when they explicitly pass random_state # https://github.com/rapidsai/cuml/issues/5985 - logger.warn("build_algo set to brute_force_knn because random_state is given") + logger.info("build_algo set to brute_force_knn because random_state is given") self.build_algo ="brute_force_knn" self.build_algo = build_algo else: @@ -802,7 +802,7 @@ class UMAP(UniversalBase, # NN Descent doesn't support transform yet if self.build_algo == "nn_descent" or self.build_algo == "auto": self.build_algo = "brute_force_knn" - logger.warn("Transform can only be run with brute force. Using brute force.") + logger.info("Transform can only be run with brute force. Using brute force.") IF GPUBUILD == 1: cdef UMAPParams* umap_params = \ From 9b2a72c33f81b2933f304dcf31d88de7502e7a16 Mon Sep 17 00:00:00 2001 From: Jinsol Park Date: Tue, 30 Jul 2024 17:05:11 -0700 Subject: [PATCH 39/42] move logger set_level --- python/cuml/cuml/manifold/umap.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 34978fe5f6..0a7eb4a6ce 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -431,6 +431,8 @@ class UMAP(UniversalBase, self.precomputed_knn = extract_knn_infos(precomputed_knn, n_neighbors) + logger.set_level(verbose) + if build_algo == "auto" or build_algo == "brute_force_knn" or build_algo == "nn_descent": if self.deterministic and build_algo == "auto": # TODO: for now, users should be able to see the same results as previous version @@ -444,8 +446,6 @@ class UMAP(UniversalBase, self.build_kwds = build_kwds - logger.set_level(verbose) - def validate_hyperparams(self): if self.min_dist > self.spread: From 6e53c28de608e482e3c107155732d5fc23542600 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Wed, 31 Jul 2024 00:16:22 +0000 Subject: [PATCH 40/42] styling --- python/cuml/cuml/manifold/umap.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 0a7eb4a6ce..2b1b11c597 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -432,7 +432,7 @@ class UMAP(UniversalBase, n_neighbors) logger.set_level(verbose) - + if build_algo == "auto" or build_algo == "brute_force_knn" or build_algo == "nn_descent": if self.deterministic and build_algo == "auto": # TODO: for now, users should be able to see the same results as previous version From e04ff99374f6985de042882da15e2d678a3fe095 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Wed, 31 Jul 2024 00:18:51 +0000 Subject: [PATCH 41/42] styling --- python/cuml/cuml/manifold/umap.pyx | 1 + 1 file changed, 1 insertion(+) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index 2b1b11c597..e37d935cd9 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -433,6 +433,7 @@ class UMAP(UniversalBase, logger.set_level(verbose) + if build_algo == "auto" or build_algo == "brute_force_knn" or build_algo == "nn_descent": if self.deterministic and build_algo == "auto": # TODO: for now, users should be able to see the same results as previous version From 4a31ebddc6352eeadb0b6a8fd22d4746f6a16a96 Mon Sep 17 00:00:00 2001 From: jinsolp Date: Wed, 31 Jul 2024 00:26:10 +0000 Subject: [PATCH 42/42] styling --- python/cuml/cuml/manifold/umap.pyx | 1 - 1 file changed, 1 deletion(-) diff --git a/python/cuml/cuml/manifold/umap.pyx b/python/cuml/cuml/manifold/umap.pyx index e37d935cd9..2b1b11c597 100644 --- a/python/cuml/cuml/manifold/umap.pyx +++ b/python/cuml/cuml/manifold/umap.pyx @@ -433,7 +433,6 @@ class UMAP(UniversalBase, logger.set_level(verbose) - if build_algo == "auto" or build_algo == "brute_force_knn" or build_algo == "nn_descent": if self.deterministic and build_algo == "auto": # TODO: for now, users should be able to see the same results as previous version