Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding KMeans PyTorch Implementation to cfa model #998

Merged

Conversation

aadhamm
Copy link
Contributor

@aadhamm aadhamm commented Apr 2, 2023

Description

  • Provide an implementation to the KMeans clustering algorithm through the PyTorch framework

  • Fixes # (956)

Changes

  • Bug fix (non-breaking change which fixes an issue)
  • Refactor (non-breaking change which refactors the code base)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist

  • My code follows the pre-commit style and check guidelines of this project.
  • I have performed a self-review of my code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing tests pass locally with my changes
  • I have added a summary of my changes to the CHANGELOG (not for minor changes, docs and tests).

@aadhamm aadhamm mentioned this pull request Apr 2, 2023
1 task
Copy link
Contributor

@samet-akcay samet-akcay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating this. I wanted to compare the scikit-learn and torch implementations, but had some trouble. Can you share the steps that you mentioned here in this link
#956 (comment)

@@ -78,6 +78,66 @@ def get_feature_extractor(backbone: str, return_nodes: list[str]) -> GraphModule
return feature_extractor


#Kmeans clustering algorithm implementation in PyTorch framework
class KMeans_torch:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class KMeans_torch:
class KMeans:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KMeans could potentially be used by other algorithms in the future. Therefore, it would be good to create a file, named src/anomalib/models/components/cluster/kmeans.py and move this implementation there.

@@ -78,6 +78,66 @@ def get_feature_extractor(backbone: str, return_nodes: list[str]) -> GraphModule
return feature_extractor


#Kmeans clustering algorithm implementation in PyTorch framework
class KMeans_torch:
def __init__(self, n_clusters, max_iter=10):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We try to use type hints as much as possible, which will be useful for our new CLI that automatically handle the type of the variables.

Suggested change
def __init__(self, n_clusters, max_iter=10):
def __init__(self, n_clusters: int, max_iter:int = 10):


#thise line returns labels and centoids of the results,
#alternative to Sklearn's cluster_centers_ & labels_ attributes
return self.cluster_assignments, self.centroids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to return self, similar to the scikit-learn implementation?

@aadhamm
Copy link
Contributor Author

aadhamm commented Apr 4, 2023

Dear Samet,

thank you for putting such effort into review and feedback,

of course, all of your requested changes are possible and I will start working on them and commit changes ASAP, I will also provide you with the full code/notebook I used for comparison between the two implementations.

in the meantime, I potentially looking to address some other TO-DO issues either were the other two issues in cfa model or issues in other models as well, if that's possible and eligible of course.

Thank you.

@aadhamm
Copy link
Contributor Author

aadhamm commented Apr 10, 2023

Dear @samet-akcay,

please forgive me for the late reply, due to several engagements in the last few days,

here is my full approach to comparing the Kmeans Sklean and PyTorch implementations through silhouette score, which is a metric used to calculate the goodness of a clustering technique. Its value ranges from -1 to 1.

It contains the modifications you asked for, this is a preview, I will commit the changes as required and I may add a running notebook with the following comparison code, if it's not necessary, I will delete it.

Thank you

# Import necessary libraries
import torch
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
import numpy as np
from sklearn.metrics import silhouette_score
##sklearn KMeans is imported lately so it's class not be confused with pytorch class

class KMeans:
    def __init__(self, n_clusters: int, max_iter:int = 10):
        
        """
        Initializes the KMeans object.

        Parameters:
            n_clusters: The number of clusters to create.
            max_iter: The maximum number of iterations to run the algorithm for.
        """
        self.n_clusters = n_clusters
        self.max_iter = max_iter

    def fit(self, X):
        """
        Runs the k-means algorithm on input data X.

        Parameters:
            X: A tensor of shape (N, D) containing the input data.
            N is the number of data points 
            D is the dimensionality of the data points.
        """
        N, D = X.shape

        # Initialize centroids randomly from the data points
        centroid_indices = torch.randint(0, N, (self.n_clusters,))
        self.cluster_centers_ = X[centroid_indices]

        # Run the k-means algorithm for max_iter iterations
        for i in range(self.max_iter):
            # Compute the distance between each data point and each centroid
            distances = torch.cdist(X, self.cluster_centers_)

            # Assign each data point to the closest centroid
            self.labels_ = torch.argmin(distances, dim=1)

            # Update the centroids to be the mean of the data points assigned to them
            for j in range(self.n_clusters):
                mask = self.labels_ == j
                if mask.any():
                    self.cluster_centers_[j] = X[mask].mean(dim=0)
                    
        #thise line returns labels and centoids of the results,          
        return self.labels_, self.cluster_centers_

    def predict(self, X):
        """
        Assigns each data point in X to its closest centroid.

        Parameters:
            X: A tensor of shape (N, D) containing the input data.

        Returns:
            A tensor of shape (N,) containing the index of the closest centroid for each data point.
        """
        distances = torch.cdist(X, self.cluster_centers_)
        return torch.argmin(distances, dim=1)



# Generate sample data
X, _ = make_blobs(n_samples=1000, centers=5, n_features=10, random_state=42)

X_torch = torch.tensor(X, dtype=torch.float32)


# Set parameters
n_clusters = 3
max_iter = 3000

# Run KMeans using Torch implementation
kmeans_torch = KMeans(n_clusters=n_clusters, max_iter=max_iter)
kmeans_torch.fit(X_torch)


from sklearn.cluster import KMeans
# Run KMeans using scikit-learn implementation
kmeans_sklearn = KMeans(n_clusters=n_clusters, max_iter=max_iter)
kmeans_sklearn.fit(X)



sklearn_labels = kmeans_sklearn.labels_
pytorch_labels = kmeans_torch.labels_

sklearn_centers = kmeans_sklearn.cluster_centers_
pytorch_centers = kmeans_torch.cluster_centers_



# Check that the results are the same
print("Scikit-learn centroids:\n", sklearn_centers)
print("Torch centroids:\n", pytorch_centers)

print("Scikit-learn labels:\n", sklearn_labels)
print("Torch labels:\n", pytorch_labels) 



import numpy as np
from sklearn.metrics import silhouette_score


# Calculate the silhouette score
silhouette_sklearn = silhouette_score(X, sklearn_labels)

# Calculate the silhouette score
silhouette_torch = silhouette_score(X_torch, pytorch_labels)


# Print the comparison results
print(f"Silhouette score for Sklearn: {silhouette_sklearn}")
print(f"Silhouette score for Pytorch: {silhouette_torch}")

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@aadhamm
Copy link
Contributor Author

aadhamm commented Apr 24, 2023

@samet-akcay, I wanted to ask if this to-do task, TODO: Replace this with the new torchfx feature extractor has already been done? because I am willing to continue with the rest of them in the same CFA implementation.

As far as I know, create_feature_extractor() is the new fx feature extractor.

@aadhamm
Copy link
Contributor Author

aadhamm commented Apr 24, 2023

@samet-akcay, if it's already been done, I may contribute to this task, # TODO: Automatically infer the number of dims.

@samet-akcay
Copy link
Contributor

@samet-akcay, I wanted to ask if this to-do task, TODO: Replace this with the new torchfx feature extractor has already been done? because I am willing to continue with the rest of them in the same CFA implementation.

As far as I know, create_feature_extractor() is the new fx feature extractor.

@ashwinvaidya17, can you help here to show how @aadhamm can utilize the new feature extractor for the cfa model

@ashwinvaidya17
Copy link
Collaborator

@samet-akcay, I wanted to ask if this to-do task, TODO: Replace this with the new torchfx feature extractor has already been done? because I am willing to continue with the rest of them in the same CFA implementation.

As far as I know, create_feature_extractor() is the new fx feature extractor.

You can have a look here to see how it is done. You can adapt your code to call this class.

self.feature_extractor = TorchFXFeatureExtractor(
backbone="efficientnet_b5", weights=EfficientNet_B5_Weights.DEFAULT, return_nodes=["features.6.8"]

If you still have any questions then feel free to ask

Copy link
Collaborator

@ashwinvaidya17 ashwinvaidya17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the efforts I have a few comments.

  • Can you remove the jupyter notebook from the components folder. If you want to confirm the outputs then you can move this code into a file and turn it into a unit test.
  • Can you also address the issues raised by codacy.
  • The spinx parser is configured to follow the Google's docstring format. Can you update the docstrings to conform to this format? Here is an example https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html

Comment on lines 21 to 24
Parameters:
X: A tensor of shape (N, D) containing the input data.
N is the number of data points
D is the dimensionality of the data points.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure these are the right parameters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to more than one source, N is the number of samples and D is the number of features, D is s used to initialize the centroids randomly with the same dimensionality as the input data, but not used actually afterward in the code which created issues related to codacy tests.

I resolved this later using your suggestion.

K-means clustering - PyTorch API
Fast Pytorch Kmeans

Comment on lines 10 to 12
Parameters:
n_clusters: The number of clusters to create.
max_iter: The maximum number of iterations to run the algorithm for.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We follow the google docstring format. This should be Args. Also, can you also add the types to the variables in the docs as well.

N is the number of data points
D is the dimensionality of the data points.
"""
N, D = X.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe batch_size, _ = inputs.shape?

mask = self.labels_ == j
if mask.any():
self.cluster_centers_[j] = X[mask].mean(dim=0)
#thise line returns labels and centoids of the results,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small typo "this line"

"""
Assigns each data point in X to its closest centroid.

Parameters:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be Args as well

@aadhamm
Copy link
Contributor Author

aadhamm commented Apr 24, 2023

@ashwinvaidya17 Thank you for your detailed review, I will work on all the required changes and update you as soon as possible and yes, you are right, there are variables that could be reduced as it will work best for codacy tests.

@aadhamm
Copy link
Contributor Author

aadhamm commented Apr 29, 2023

@samet-akcay @ashwinvaidya17 Hello everyone, I have made the required changes and addressed issues by codacy.

If you could take a look and give me your feedback I will appreciate it a lot.

Thank you.

@aadhamm
Copy link
Contributor Author

aadhamm commented Aug 30, 2023

@samet-akcay @ashwinvaidya17 Hello, everyone i have done the required changes and all the checks have passed, if you could take a look and see if it is ready to merge or not yet.

Thank you.

Copy link
Collaborator

@ashwinvaidya17 ashwinvaidya17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not getting back with this. Thanks for all the efforts.

@ashwinvaidya17 ashwinvaidya17 enabled auto-merge (squash) September 11, 2023 12:02
@ashwinvaidya17 ashwinvaidya17 merged commit ed4d1a1 into openvinotoolkit:main Sep 11, 2023
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants