Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Enable UMAP to build knn graph using NN Descent #5910

Merged
merged 55 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
3a0cc67
enable umap nndescent
jinsolp May 31, 2024
bf39ef2
change return_distances to bool
jinsolp May 31, 2024
94eb579
add python test
jinsolp May 31, 2024
8e45779
fix styling
jinsolp May 31, 2024
7838388
fix comment
jinsolp May 31, 2024
a7bf6ba
fix styling
jinsolp May 31, 2024
7b0bb19
remove comment
jinsolp Jun 1, 2024
588a430
fix typo
jinsolp Jun 1, 2024
e88f183
use cuml logger
jinsolp Jun 3, 2024
8fc0f2d
change arg to dict and add documentation
jinsolp Jun 3, 2024
127bea5
change to RAFT_EXPECTS
jinsolp Jun 4, 2024
d42fd84
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jun 4, 2024
fd6793b
remove logger header
jinsolp Jun 4, 2024
9ba5092
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jun 7, 2024
d772098
enable l2sqrtexpanded dist + fix errors
jinsolp Jun 14, 2024
325d462
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jun 14, 2024
98645f9
change sqrt -> pow(0.5)
jinsolp Jun 17, 2024
76c7f31
refine distances due to precision issues
jinsolp Jun 19, 2024
5f0b8f6
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jun 19, 2024
e2b98c3
add to param names
jinsolp Jun 22, 2024
841f8ba
change threshold for test
jinsolp Jun 22, 2024
2b391f0
threshold for iris dataset
jinsolp Jun 24, 2024
e2a36a6
add warning for small dataset + nnd
jinsolp Jun 24, 2024
2759663
revert back to not refining
jinsolp Jun 24, 2024
4fafc31
fix tests
jinsolp Jun 24, 2024
8f85478
remove print
jinsolp Jun 24, 2024
7db8d38
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jul 11, 2024
a4705e0
change copy as kernel
jinsolp Jul 11, 2024
0fd6951
add detailed doc and warning
jinsolp Jul 11, 2024
8585538
newline
jinsolp Jul 11, 2024
7445922
raise error for sparse
jinsolp Jul 11, 2024
520a017
fix typo
jinsolp Jul 11, 2024
7f7179b
change tests
jinsolp Jul 11, 2024
f8f24b9
cleanup algo.cuh and test
jinsolp Jul 12, 2024
6330a70
change to using raft slicing kernel
jinsolp Jul 18, 2024
f27de4c
Merge branch 'branch-24.08' into add-umap-nndescent
jinsolp Jul 18, 2024
3e955ac
add header and change namespace
jinsolp Jul 18, 2024
3ce4016
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jul 23, 2024
a6a0b69
add auto option and enable host on device
jinsolp Jul 24, 2024
8de651f
change nrows for auto to run bfk
jinsolp Jul 24, 2024
e91c2a7
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jul 25, 2024
96cc875
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jul 25, 2024
592ad3c
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jul 26, 2024
750b207
address comments and add warning+link to issue
jinsolp Jul 26, 2024
4688aa3
change to RAFT_EXPECTS
jinsolp Jul 28, 2024
39b1b0a
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jul 28, 2024
744efec
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jul 30, 2024
ded9485
change warn to log
jinsolp Jul 30, 2024
17c4874
set logger level
jinsolp Jul 30, 2024
5133eaf
Merge branch 'rapidsai:branch-24.08' into add-umap-nndescent
jinsolp Jul 30, 2024
eed52a1
change warn to info
jinsolp Jul 30, 2024
9b2a72c
move logger set_level
jinsolp Jul 31, 2024
6e53c28
styling
jinsolp Jul 31, 2024
e04ff99
styling
jinsolp Jul 31, 2024
4a31ebd
styling
jinsolp Jul 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cpp/include/cuml/manifold/umapparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
#include <cuml/common/logger.hpp>

#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/nn_descent_types.hpp>

namespace ML {

using nn_index_params = raft::neighbors::experimental::nn_descent::index_params;

class UMAPParams {
public:
enum MetricType { EUCLIDEAN, CATEGORICAL };
enum graph_build_algo { BRUTE_FORCE_KNN, NN_DESCENT };

/**
* The number of neighbors to use to approximate geodesic distance.
Expand Down Expand Up @@ -140,6 +144,13 @@ class UMAPParams {
*/
int init = 1;

/**
* KNN graph build algorithm
*/
graph_build_algo build_algo = graph_build_algo::BRUTE_FORCE_KNN;

nn_index_params nn_descent_params = {};

/**
* The number of nearest neighbors to use to construct the target simplicial
* set. If set to -1, use the n_neighbors value.
Expand Down
104 changes: 85 additions & 19 deletions cpp/src/umap/knn_graph/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,30 @@

#pragma once

#include <cuml/common/utils.hpp>
#include <cuml/manifold/common.hpp>
#include <cuml/manifold/umapparams.h>
#include <cuml/neighbors/knn_sparse.hpp>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/error.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/slice.cuh>
#include <raft/neighbors/nn_descent.cuh>
#include <raft/neighbors/nn_descent_types.hpp>
#include <raft/sparse/selection/knn.cuh>
#include <raft/spatial/knn/knn.cuh>
#include <raft/util/cudart_utils.hpp>

#include <iostream>

namespace NNDescent = raft::neighbors::experimental::nn_descent;

namespace UMAPAlgo {
namespace kNNGraph {
namespace Algo {
Expand All @@ -47,6 +57,30 @@ void launcher(const raft::handle_t& handle,
const ML::UMAPParams* params,
cudaStream_t stream);

// Functor to post-process distances as L2Sqrt*
template <typename value_idx, typename value_t = float>
struct DistancePostProcessSqrt {
DI value_t operator()(value_t value, value_idx row, value_idx col) const { return sqrtf(value); }
};

auto get_graph_nnd(const raft::handle_t& handle,
const ML::manifold_dense_inputs_t<float>& inputs,
const ML::UMAPParams* params)
{
auto epilogue = DistancePostProcessSqrt<int64_t, float>{};
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputs.X));
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
if (ptr != nullptr) {
auto dataset =
raft::make_device_matrix_view<const float, int64_t>(inputs.X, inputs.n, inputs.d);
return NNDescent::build<float, int64_t>(handle, params->nn_descent_params, dataset, epilogue);
} else {
auto dataset = raft::make_host_matrix_view<const float, int64_t>(inputs.X, inputs.n, inputs.d);
return NNDescent::build<float, int64_t>(handle, params->nn_descent_params, dataset, epilogue);
}
}

// Instantiation for dense inputs, int64_t indices
template <>
inline void launcher(const raft::handle_t& handle,
Expand All @@ -57,25 +91,55 @@ inline void launcher(const raft::handle_t& handle,
const ML::UMAPParams* params,
cudaStream_t stream)
{
std::vector<float*> ptrs(1);
std::vector<int> sizes(1);
ptrs[0] = inputsA.X;
sizes[0] = inputsA.n;

raft::spatial::knn::brute_force_knn(handle,
ptrs,
sizes,
inputsA.d,
inputsB.X,
inputsB.n,
out.knn_indices,
out.knn_dists,
n_neighbors,
true,
true,
static_cast<std::vector<int64_t>*>(nullptr),
params->metric,
params->p);
if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) {
std::vector<float*> ptrs(1);
std::vector<int> sizes(1);
ptrs[0] = inputsA.X;
sizes[0] = inputsA.n;

raft::spatial::knn::brute_force_knn(handle,
ptrs,
sizes,
inputsA.d,
inputsB.X,
inputsB.n,
out.knn_indices,
out.knn_dists,
n_neighbors,
true,
true,
static_cast<std::vector<int64_t>*>(nullptr),
params->metric,
params->p);
} else { // nn_descent
RAFT_EXPECTS(static_cast<size_t>(n_neighbors) <= params->nn_descent_params.graph_degree,
"n_neighbors should be smaller than the graph degree computed by nn descent");

auto graph = get_graph_nnd(handle, inputsA, params);

auto indices_d = raft::make_device_matrix<int64_t, int64_t>(
handle, inputsA.n, params->nn_descent_params.graph_degree);

raft::copy(indices_d.data_handle(),
graph.graph().data_handle(),
inputsA.n * params->nn_descent_params.graph_degree,
stream);

raft::matrix::slice_coordinates coords{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(inputsA.n),
static_cast<int64_t>(n_neighbors)};

RAFT_EXPECTS(graph.distances().has_value(),
"return_distances for nn descent should be set to true to be used for UMAP");
auto out_knn_dists_view = raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors);
raft::matrix::slice<float, int64_t, raft::row_major>(
handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords);
auto out_knn_indices_view =
raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors);
raft::matrix::slice<int64_t, int64_t, raft::row_major>(
handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords);
}
}

// Instantiation for dense inputs, int indices
Expand All @@ -100,6 +164,8 @@ inline void launcher(const raft::handle_t& handle,
const ML::UMAPParams* params,
cudaStream_t stream)
{
RAFT_EXPECTS(params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN,
"nn_descent does not support sparse inputs");
raft::sparse::selection::brute_force_knn(inputsA.indptr,
inputsA.indices,
inputsA.data,
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/umap/umap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ void transform(const raft::handle_t& handle,
UMAPParams* params,
float* transformed)
{
RAFT_EXPECTS(params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN,
"build algo nn_descent not supported for transform()");
manifold_dense_inputs_t<float> inputs(X, nullptr, n, d);
manifold_dense_inputs_t<float> orig_inputs(orig_X, nullptr, orig_n, d);
UMAPAlgo::_transform<knn_indices_dense_t, float, manifold_dense_inputs_t<float>, TPB_X>(
Expand All @@ -210,6 +212,8 @@ void transform_sparse(const raft::handle_t& handle,
UMAPParams* params,
float* transformed)
{
RAFT_EXPECTS(params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN,
"build algo nn_descent not supported for transform()");
manifold_sparse_inputs_t<knn_indices_sparse_t, float> inputs(
indptr, indices, data, nullptr, nnz, n, d);
manifold_sparse_inputs_t<knn_indices_sparse_t, float> orig_x_inputs(
Expand Down
84 changes: 78 additions & 6 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@ np = cpu_only_import('numpy')
pd = cpu_only_import('pandas')

import joblib
import warnings

from cuml.internals.safe_imports import gpu_only_import
cupy = gpu_only_import('cupy')
Expand All @@ -40,6 +41,7 @@ from cuml.internals.available_devices import is_cuda_available
from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals.array import CumlArray
from cuml.internals.array_sparse import SparseCumlArray
from cuml.internals.mem_type import MemoryType
from cuml.internals.mixins import CMajorInputTagMixin
from cuml.common.sparse_utils import is_sparse

Expand Down Expand Up @@ -289,6 +291,13 @@ class UMAP(UniversalBase,
type. If None, the output type set at the module level
(`cuml.global_settings.output_type`) will be used. See
:ref:`output-data-type-configuration` for more info.
build_algo: string (default='auto')
How to build the knn graph. Supported build algorithms are ['auto', 'brute_force_knn',
'nn_descent']. 'auto' chooses to run with brute force knn if number of data rows is
smaller than or equal to 50K. Otherwise, runs with nn descent.
build_kwds: dict (optional, default=None)
Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128,
'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True}

Notes
-----
Expand Down Expand Up @@ -348,6 +357,8 @@ class UMAP(UniversalBase,
callback=None,
handle=None,
verbose=False,
build_algo="auto",
build_kwds=None,
output_type=None):

super().__init__(handle=handle,
Expand Down Expand Up @@ -420,6 +431,21 @@ class UMAP(UniversalBase,
self.precomputed_knn = extract_knn_infos(precomputed_knn,
n_neighbors)

logger.set_level(verbose)

if build_algo == "auto" or build_algo == "brute_force_knn" or build_algo == "nn_descent":
if self.deterministic and build_algo == "auto":
# TODO: for now, users should be able to see the same results as previous version
# (i.e. running brute force knn) when they explicitly pass random_state
# https://github.com/rapidsai/cuml/issues/5985
logger.info("build_algo set to brute_force_knn because random_state is given")
self.build_algo ="brute_force_knn"
self.build_algo = build_algo
else:
raise Exception("Invalid build algo: {}. Only support auto, brute_force_knn and nn_descent" % build_algo)

self.build_kwds = build_kwds

def validate_hyperparams(self):

if self.min_dist > self.spread:
Expand Down Expand Up @@ -452,6 +478,22 @@ class UMAP(UniversalBase,
umap_params.target_metric = MetricType.EUCLIDEAN
else: # self.target_metric == "categorical"
umap_params.target_metric = MetricType.CATEGORICAL
if cls.build_algo == "brute_force_knn":
umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN
else: # self.init == "nn_descent"
umap_params.build_algo = graph_build_algo.NN_DESCENT
if cls.build_kwds is None:
umap_params.nn_descent_params.graph_degree = <uint64_t> 64
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> 128
umap_params.nn_descent_params.max_iterations = <uint64_t> 20
umap_params.nn_descent_params.termination_threshold = <float> 0.0001
umap_params.nn_descent_params.return_distances = <bool> True
else:
umap_params.nn_descent_params.graph_degree = <uint64_t> cls.build_kwds.get("nnd_graph_degree", 64)
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> cls.build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.nn_descent_params.max_iterations = <uint64_t> cls.build_kwds.get("nnd_max_iterations", 20)
umap_params.nn_descent_params.termination_threshold = <float> cls.build_kwds.get("nnd_termination_threshold", 0.0001)
umap_params.nn_descent_params.return_distances = <bool> cls.build_kwds.get("nnd_return_distances", True)
umap_params.target_weight = <float> cls.target_weight
umap_params.random_state = <uint64_t> cls.random_state
umap_params.deterministic = <bool> cls.deterministic
Expand Down Expand Up @@ -495,7 +537,7 @@ class UMAP(UniversalBase,
skip_parameters_heading=True)
@enable_device_interop
def fit(self, X, y=None, convert_dtype=True,
knn_graph=None) -> "UMAP":
knn_graph=None, data_on_host=False) -> "UMAP":
"""
Fit X into an embedded space.

Expand Down Expand Up @@ -526,18 +568,41 @@ class UMAP(UniversalBase,
convert_format=False)
self.n_rows, self.n_dims = self._raw_data.shape
self.sparse_fit = True
if self.build_algo == "nn_descent":
raise ValueError("NN Descent does not support sparse inputs")

# Handle dense inputs
else:
if data_on_host:
convert_to_mem_type = MemoryType.host
else:
convert_to_mem_type = MemoryType.device

self._raw_data, self.n_rows, self.n_dims, _ = \
input_to_cuml_array(X, order='C', check_dtype=np.float32,
convert_to_dtype=(np.float32
if convert_dtype
else None))
else None),
convert_to_mem_type=convert_to_mem_type)

if self.build_algo == "auto":
if self.n_rows <= 50000 or self.sparse_fit:
# brute force is faster for small datasets
logger.info("Building knn graph using brute force")
self.build_algo = "brute_force_knn"
else:
logger.info("Building knn graph using nn descent")
self.build_algo = "nn_descent"

if self.build_algo == "brute_force_knn" and data_on_host:
raise ValueError("Data cannot be on host for building with brute force knn")

if self.n_rows <= 1:
raise ValueError("There needs to be more than 1 sample to "
"build nearest the neighbors graph")
if self.build_algo == "nn_descent" and self.n_rows < 150:
# https://github.com/rapidsai/cuvs/issues/184
warnings.warn("using nn_descent as build_algo on a small dataset (< 150 samples) is unstable")

cdef uintptr_t _knn_dists_ptr = 0
cdef uintptr_t _knn_indices_ptr = 0
Expand Down Expand Up @@ -630,7 +695,7 @@ class UMAP(UniversalBase,
@cuml.internals.api_base_fit_transform()
@enable_device_interop
def fit_transform(self, X, y=None, convert_dtype=True,
knn_graph=None) -> CumlArray:
knn_graph=None, data_on_host=False) -> CumlArray:
"""
Fit X into an embedded space and return that transformed
output.
Expand Down Expand Up @@ -663,7 +728,7 @@ class UMAP(UniversalBase,
CSR/COO preferred other formats will go through conversion to CSR

"""
self.fit(X, y, convert_dtype=convert_dtype, knn_graph=knn_graph)
self.fit(X, y, convert_dtype=convert_dtype, knn_graph=knn_graph, data_on_host=data_on_host)

return self.embedding_

Expand Down Expand Up @@ -734,6 +799,11 @@ class UMAP(UniversalBase,

cdef uintptr_t _embed_ptr = self.embedding_.ptr

# NN Descent doesn't support transform yet
if self.build_algo == "nn_descent" or self.build_algo == "auto":
self.build_algo = "brute_force_knn"
logger.info("Transform can only be run with brute force. Using brute force.")

IF GPUBUILD == 1:
cdef UMAPParams* umap_params = \
<UMAPParams*> <size_t> UMAP._build_umap_params(self,
Expand Down Expand Up @@ -799,7 +869,9 @@ class UMAP(UniversalBase,
"callback",
"metric",
"metric_kwds",
"precomputed_knn"
"precomputed_knn",
"build_algo",
"build_kwds"
]

def get_attr_names(self):
Expand Down
Loading
Loading