Skip to content

A Python implementation of Gaussian Mixture Model (GMM)

Notifications You must be signed in to change notification settings

reutregev/gmm-em

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GMM-EM

A Python implementation of Gaussian Mixture Model.
GMM is a type of clustering algorithm, whereby each cluster is determined by a Gaussian distribution.
The underlying assumption is that each data point could have been generated by the mixture of the distributions, with a corresponding probability to belong to each of the clusters.

The Expectation Maximization (EM) Algorithm is used to find maximum likelihood estimates of parameters (for GMM the parameters are weights, means and covariance).
The algorithm consists of two steps:

  • E-step: Estimate the probability of data points to belong to each distribution (denoted as z)
  • M-step: Update the value of the model parameters based on the estimation of z the E-step
    These two steps are repeated until convergence.

The algorithm is implemented using NumPy and Scipy: em_algorithm.py.

An example of the algorithm applied on a 2D generated dataset could be found in GMM_example.ipynb.

Usage

python main.py (--find_num_clusters_range | --n_clusters) [--data_path] [OPTIONS]

Arguments:
    --find_num_clusters_range   Two numbers that specify the range of search for the number of clusters. Must be > 1
    --n_clusters                Use the given number of clusters (> 1) as a parameter for EM algorithm
    --data_path                 Path to a binary file in NumPy .npy format, containing a array of [n_samples, n_features]

Options:
    --max_iter                  Max iterations for EM algorithm, default is 100

Example

main.py reads data from --data_path.
Then, there are two options to run GMM on the data:

  1. Use the given --n_clusters as a parameter for GMM:
    python main.py --n_clusters 4 --data_path /path/to/file
    
  2. Search for the number of clusters in --find_num_clusters_range range, that fits best to the generated data:
    python main.py --find_num_clusters_range 2 10 --data_path /path/to/file
    

About

A Python implementation of Gaussian Mixture Model (GMM)

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published