Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding KMeans PyTorch Implementation to cfa model #998

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/anomalib/models/components/cluster/kmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""KMeans clustering algorithm implementation using PyTorch."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import torch


class KMeans:
def __init__(self, n_clusters: int, max_iter: int = 10):
"""
Initializes the KMeans object.

Args:
n_clusters (int): The number of clusters to create.
max_iter (int, optional)): The maximum number of iterations to run the algorithm. Defaults to 10.
"""
self.n_clusters = n_clusters
self.max_iter = max_iter

def fit(self, inputs):
"""
Fits the K-means algorithm to the input data.

Args:
inputs (torch.Tensor): Input data of shape (batch_size, n_features).

Returns:
tuple: A tuple containing the labels of the input data with respect to the identified clusters
and the cluster centers themselves. The labels have a shape of (batch_size,) and the
cluster centers have a shape of (n_clusters, n_features).

Raises:
ValueError: If the number of clusters is less than or equal to 0.
"""
batch_size, _ = inputs.shape

# Initialize centroids randomly from the data points
centroid_indices = torch.randint(0, batch_size, (self.n_clusters,))
self.cluster_centers_ = inputs[centroid_indices]

# Run the k-means algorithm for max_iter iterations
for _ in range(self.max_iter):
# Compute the distance between each data point and each centroid
distances = torch.cdist(inputs, self.cluster_centers_)

# Assign each data point to the closest centroid
self.labels_ = torch.argmin(distances, dim=1)

# Update the centroids to be the mean of the data points assigned to them
for j in range(self.n_clusters):
mask = self.labels_ == j
if mask.any():
self.cluster_centers_[j] = inputs[mask].mean(dim=0)
# this line returns labels and centoids of the results
return self.labels_, self.cluster_centers_

def predict(self, inputs):
"""
Predicts the labels of input data based on the fitted model.

Args:
inputs (torch.Tensor): Input data of shape (batch_size, n_features).

Returns:
torch.Tensor: The predicted labels of the input data with respect to the identified clusters.

Raises:
AttributeError: If the KMeans object has not been fitted to input data.
"""
distances = torch.cdist(inputs, self.cluster_centers_)
return torch.argmin(distances, dim=1)
Loading