Skip to content

Commit

Permalink
feat: allow passing in precomputed centroids to lance.util.KMeans (#2586
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jiachengdb committed Jul 11, 2024
1 parent c89223b commit 7a2f828
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 4 deletions.
6 changes: 5 additions & 1 deletion python/python/lance/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
k: int,
metric_type: Literal["l2", "dot", "cosine"] = "l2",
max_iters: int = 50,
centroids: Optional[pa.FixedSizeListArray] = None,
):
"""Create a KMeans model.
Expand All @@ -90,6 +91,7 @@ def __init__(
Supported distance metrics: "l2", "cosine", "dot"
max_iters: int
The maximum number of iterations to run the KMeans algorithm. Default: 50.
centroids (pyarrow.FixedSizeListArray, optional.) – Provide existing centroids.
"""
metric_type = metric_type.lower()
if metric_type not in ["l2", "dot", "cosine"]:
Expand All @@ -98,7 +100,9 @@ def __init__(
)
self.k = k
self._metric_type = metric_type
self._kmeans = _KMeans(k, metric_type, max_iters=max_iters)
self._kmeans = _KMeans(
k, metric_type, max_iters=max_iters, centroids_arr=centroids
)

def __repr__(self) -> str:
return f"lance.KMeans(k={self.k}, metric_type={self._metric_type})"
Expand Down
17 changes: 17 additions & 0 deletions python/python/tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,20 @@ def test_kmeans_dot():
kmeans = lance.util.KMeans(32, metric_type="dot")
data = np.random.randn(1000, 128).astype(np.float32)
kmeans.fit(data)


def test_precomputed_kmeans():
data = np.random.randn(1000, 128).astype(np.float32)
kmeans = lance.util.KMeans(8, metric_type="l2")
kmeans.fit(data)
original_clusters = kmeans.predict(data)

values = np.stack(kmeans.centroids.to_numpy(zero_copy_only=False)).flatten()
centroids = pa.FixedSizeListArray.from_arrays(values, list_size=128)

# Initialize a new KMeans with precomputed centroids.
new_kmeans = lance.util.KMeans(8, metric_type="l2", centroids=centroids)
new_clusters = new_kmeans.predict(data)

# Verify the predictions are the same for both KMeans instances.
assert original_clusters == new_clusters
33 changes: 30 additions & 3 deletions python/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,40 @@ pub struct KMeans {
#[pymethods]
impl KMeans {
#[new]
#[pyo3(signature = (k, metric_type="l2", max_iters=50))]
fn new(k: usize, metric_type: &str, max_iters: u32) -> PyResult<Self> {
#[pyo3(signature = (k, metric_type="l2", max_iters=50, centroids_arr=None))]
fn new(
k: usize,
metric_type: &str,
max_iters: u32,
centroids_arr: Option<&PyAny>,
) -> PyResult<Self> {
let trained_kmeans = if let Some(arr) = centroids_arr {
let data = ArrayData::from_pyarrow(arr)?;
if !matches!(data.data_type(), DataType::FixedSizeList(_, _)) {
return Err(PyValueError::new_err("Must be a FixedSizeList"));
}
let fixed_size_arr = FixedSizeListArray::from(data);
let params = KMeansParams {
distance_type: metric_type.try_into().unwrap(),
max_iters,
..Default::default()
};
let kmeans =
LanceKMeans::new_with_params(&fixed_size_arr, k, &params).map_err(|e| {
PyRuntimeError::new_err(format!(
"Error initialing KMeans from existing centroids: {}",
e
))
})?;
Some(kmeans)
} else {
None
};
Ok(Self {
k,
metric_type: metric_type.try_into().unwrap(),
max_iters,
trained_kmeans: None,
trained_kmeans,
})
}

Expand Down

0 comments on commit 7a2f828

Please sign in to comment.