From 7a2f8286d17f2bba98208c6880c8a7d36bac061d Mon Sep 17 00:00:00 2001 From: Jiacheng Yang <92543367+jiachengdb@users.noreply.github.com> Date: Thu, 11 Jul 2024 16:26:50 -0700 Subject: [PATCH] feat: allow passing in precomputed centroids to lance.util.KMeans (#2586) --- python/python/lance/util.py | 6 +++++- python/python/tests/test_kmeans.py | 17 +++++++++++++++ python/src/utils.rs | 33 +++++++++++++++++++++++++++--- 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/python/python/lance/util.py b/python/python/lance/util.py index 208d8b41ff..1ddc6ffdcd 100644 --- a/python/python/lance/util.py +++ b/python/python/lance/util.py @@ -78,6 +78,7 @@ def __init__( k: int, metric_type: Literal["l2", "dot", "cosine"] = "l2", max_iters: int = 50, + centroids: Optional[pa.FixedSizeListArray] = None, ): """Create a KMeans model. @@ -90,6 +91,7 @@ def __init__( Supported distance metrics: "l2", "cosine", "dot" max_iters: int The maximum number of iterations to run the KMeans algorithm. Default: 50. + centroids (pyarrow.FixedSizeListArray, optional.) – Provide existing centroids. """ metric_type = metric_type.lower() if metric_type not in ["l2", "dot", "cosine"]: @@ -98,7 +100,9 @@ def __init__( ) self.k = k self._metric_type = metric_type - self._kmeans = _KMeans(k, metric_type, max_iters=max_iters) + self._kmeans = _KMeans( + k, metric_type, max_iters=max_iters, centroids_arr=centroids + ) def __repr__(self) -> str: return f"lance.KMeans(k={self.k}, metric_type={self._metric_type})" diff --git a/python/python/tests/test_kmeans.py b/python/python/tests/test_kmeans.py index 1f7f7715f3..e992019f24 100644 --- a/python/python/tests/test_kmeans.py +++ b/python/python/tests/test_kmeans.py @@ -26,3 +26,20 @@ def test_kmeans_dot(): kmeans = lance.util.KMeans(32, metric_type="dot") data = np.random.randn(1000, 128).astype(np.float32) kmeans.fit(data) + + +def test_precomputed_kmeans(): + data = np.random.randn(1000, 128).astype(np.float32) + kmeans = lance.util.KMeans(8, metric_type="l2") + kmeans.fit(data) + original_clusters = kmeans.predict(data) + + values = np.stack(kmeans.centroids.to_numpy(zero_copy_only=False)).flatten() + centroids = pa.FixedSizeListArray.from_arrays(values, list_size=128) + + # Initialize a new KMeans with precomputed centroids. + new_kmeans = lance.util.KMeans(8, metric_type="l2", centroids=centroids) + new_clusters = new_kmeans.predict(data) + + # Verify the predictions are the same for both KMeans instances. + assert original_clusters == new_clusters diff --git a/python/src/utils.rs b/python/src/utils.rs index 101785de24..d8b98402e8 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -58,13 +58,40 @@ pub struct KMeans { #[pymethods] impl KMeans { #[new] - #[pyo3(signature = (k, metric_type="l2", max_iters=50))] - fn new(k: usize, metric_type: &str, max_iters: u32) -> PyResult { + #[pyo3(signature = (k, metric_type="l2", max_iters=50, centroids_arr=None))] + fn new( + k: usize, + metric_type: &str, + max_iters: u32, + centroids_arr: Option<&PyAny>, + ) -> PyResult { + let trained_kmeans = if let Some(arr) = centroids_arr { + let data = ArrayData::from_pyarrow(arr)?; + if !matches!(data.data_type(), DataType::FixedSizeList(_, _)) { + return Err(PyValueError::new_err("Must be a FixedSizeList")); + } + let fixed_size_arr = FixedSizeListArray::from(data); + let params = KMeansParams { + distance_type: metric_type.try_into().unwrap(), + max_iters, + ..Default::default() + }; + let kmeans = + LanceKMeans::new_with_params(&fixed_size_arr, k, ¶ms).map_err(|e| { + PyRuntimeError::new_err(format!( + "Error initialing KMeans from existing centroids: {}", + e + )) + })?; + Some(kmeans) + } else { + None + }; Ok(Self { k, metric_type: metric_type.try_into().unwrap(), max_iters, - trained_kmeans: None, + trained_kmeans, }) }