Skip to content

mim-solutions/embedding-sum

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

embedding-sum

In this repository we open-source a method for training interpretable models, originally developed for ransomware attack detection (see Sagenso/MIRAD)

Key features:

  • model and prediction interpretability
  • concise implementation based on numpy and pytorch
  • efficiency: training can be executed on a CPU, evaluation is fast
  • extensibility: model structure allows easy enforcement of task-specific requirements by additional regularization

How it works

The entire implementation can be found in lib/embedding_sum.py and consists of three classes:

  • Digitizer
  • EmbeddingSumModule
  • EmbeddingSumClassifier

The method can be summarized as follows:

  1. For each feature, the Digitizer defines quantile-based bins and encodes feature values as bin ordinals,
  2. The EmbeddingSumModule defines a trainable parameter for each bin of each feature and implements evaluation as sum of the trainable parameters corresponding to bin ordinals in the input vector,
  3. The EmbeddingSumClassifier trains the EmbeddingSumModule using gradient descent and a compound loss function combining binary cross-entropy with regularization terms that encourage desirable properties of the model.

In other words, we train an additive model composed of step functions. The model can be interpreted by plotting the step functions. Usage example can be found in the tutorial notebook.

Running the notebooks

First, clone the repository. Then, either install it directly and run jupyter

poetry install  # or: pip install -e .
jupyter notebook

or build and run a docker image

 docker build -t embedding-sum .
 docker run --rm -it -p 8888:8888 embedding-sum

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 96.0%
  • Python 3.8%
  • Dockerfile 0.2%