Skip to content

Commit

Permalink
[add:lib] Add (de)serialization support for all classifiers (#80)
Browse files Browse the repository at this point in the history
* Finish tests for HMM and HMMClassifier

* Finish serialization documentation

* Ensure no NaNs in tests

* Ensure Nans in HMM.test_as_dict_with_nan

* Ensure NaNs in HMM.test_as_dict_with_nan

* Ensure NaNs in HMM.test_as_dict_with_nan

* Remove HMM NaN serialization test
  • Loading branch information
eonu authored May 11, 2020
1 parent 3ba706b commit 79767ea
Show file tree
Hide file tree
Showing 7 changed files with 630 additions and 18 deletions.
74 changes: 73 additions & 1 deletion lib/sequentia/classifiers/dtwknn/dtwknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tqdm.auto
import random
import numpy as np
import h5py
from joblib import Parallel, delayed
from multiprocessing import cpu_count
from fastdtw import fastdtw
Expand Down Expand Up @@ -155,4 +156,75 @@ def evaluate(self, X, y, labels=None, verbose=True, n_jobs=1):
predictions = self.predict(X, verbose=verbose, n_jobs=n_jobs)
cm = confusion_matrix(y, predictions, labels=labels)

return np.sum(np.diag(cm)) / np.sum(cm), cm
return np.sum(np.diag(cm)) / np.sum(cm), cm

def save(self, path):
"""Stores the :class:`DTWKNN` object into a `HDF5 <https://support.hdfgroup.org/HDF5/doc/H5.intro.html>`_ file.
.. note:
As :math:`k`-NN is a non-parametric classification algorithms, saving the classifier simply saves
all of the training observation sequences and labels (along with the hyper-parameters).
Parameters
----------
path: str
File path (with or without `.h5` extension) to store the HDF5-serialized :class:`DTWKNN` object.
"""

try:
(self._X, self._y)
except AttributeError:
raise RuntimeError('The classifier needs to be fitted before it can be saved')

with h5py.File(path, 'w') as f:
# Store hyper-parameters (k, radius)
params = f.create_group('params')
params.create_dataset('k', data=self._k)
params.create_dataset('radius', data=self._radius)

# Store training data and labels (X, y)
data = f.create_group('data')
X = data.create_group('X')
for i, x in enumerate(self._X):
X.create_dataset(str(i), data=x)
data.create_dataset('y', data=np.string_(self._y))

@classmethod
def load(cls, path, encoding='utf-8', metric=euclidean):
"""Deserializes a HDF5-serialized :class:`DTWKNN` object.
Parameters
----------
path: str
File path of the serialized HDF5 data generated by the :meth:`save` method.
encoding: str
The encoding used to represent training labels when decoding the HDF5 file.
.. note::
Supported string encodings in Python can be found `here <https://docs.python.org/3/library/codecs.html#standard-encodings>`_.
metric: callable
Distance metric for FastDTW.
Returns
-------
deserialized: :class:`DTWKNN`
The deserialized DTWKNN classifier object.
See Also
--------
save: Serializes a :class:`DTWKNN` into a HDF5 file.
"""

with h5py.File(path, 'r') as f:
# Deserialize the model hyper-parameters
params = f['params']
clf = cls(k=int(params['k'][()]), radius=int(params['radius'][()]), metric=metric)

# Deserialize the training data and labels
X, y = f['data']['X'], f['data']['y']
clf._X = [np.array(X[k]) for k in X.keys()]
clf._y = [label.decode(encoding) for label in y]

return clf
95 changes: 94 additions & 1 deletion lib/sequentia/classifiers/hmm/hmm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import numpy as np
import pomegranate as pg
from .topologies.ergodic import _ErgodicTopology
Expand Down Expand Up @@ -174,4 +175,96 @@ def transitions(self):
@transitions.setter
def transitions(self, probabilities):
self._topology.validate_transitions(probabilities)
self._transitions = probabilities
self._transitions = probabilities

def as_dict(self):
"""Serializes the :class:`HMM` object into a `dict`, ready to be stored in JSON format.
Returns
-------
serialized: dict
JSON-ready serialization of the :class:`HMM` object.
"""

try:
self._model
except AttributeError as e:
raise AttributeError('The model needs to be fitted before it can be exported to a dict') from e

model = self._model.to_json()

if 'NaN' in model:
raise ValueError('Encountered NaN value(s) in HMM parameters')
else:
return {
'label': self._label,
'n_states': self._n_states,
'topology': 'ergodic' if isinstance(self._topology, _ErgodicTopology) else 'left-right',
'model': {
'initial': self._initial.tolist(),
'transitions': self._transitions.tolist(),
'n_seqs': self._n_seqs,
'n_features': self._n_features,
'hmm': json.loads(model)
}
}

def save(self, path):
"""Converts the :class:`HMM` object into a `dict` and stores it in a JSON file.
Parameters
----------
path: str
File path (with or without `.json` extension) to store the JSON-serialized :class:`HMM` object.
See Also
--------
as_dict: Generates the `dict` that is stored in the JSON file.
"""

data = self.as_dict()
with open(path, 'w') as f:
json.dump(data, f, indent=4)

@classmethod
def load(cls, data, random_state=None):
"""Deserializes either a `dict` or JSON serialized :class:`HMM` object.
Parameters
----------
data: str or dict
- File path of the serialized JSON data generated by the :meth:`save` method.
- `dict` representation of the :class:`HMM`, generated by the :meth:`as_dict` method.
random_state: numpy.random.RandomState, int, optional
A random state object or seed for reproducible randomness.
Returns
-------
deserialized: :class:`HMM`
The deserialized HMM object.
See Also
--------
save: Serializes a :class:`HMM` into a JSON file.
as_dict: Generates a `dict` representation of the :class:`HMM`.
"""

# Load the serialized HMM data
if isinstance(data, dict):
pass
elif isinstance(data, str):
with open(data, 'r') as f:
data = json.load(f)
else:
pass

# Deserialize the data into a HMM object
hmm = cls(data['label'], data['n_states'], data['topology'], random_state=random_state)
hmm._initial = np.array(data['model']['initial'])
hmm._transitions = np.array(data['model']['transitions'])
hmm._n_seqs = data['model']['n_seqs']
hmm._n_features = data['model']['n_features']
hmm._model = pg.HiddenMarkovModel.from_json(json.dumps(data['model']['hmm']))

return hmm
76 changes: 75 additions & 1 deletion lib/sequentia/classifiers/hmm/hmm_classifier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import numpy as np
from .hmm import HMM
from sklearn.metrics import confusion_matrix
Expand Down Expand Up @@ -121,4 +122,77 @@ def evaluate(self, X, y, prior=True, labels=None):
predictions = self.predict(X, prior, return_scores=False)
cm = confusion_matrix(y, predictions, labels=labels)

return np.sum(np.diag(cm)) / np.sum(cm), cm
return np.sum(np.diag(cm)) / np.sum(cm), cm

def as_dict(self):
"""Serializes the :class:`HMMClassifier` object into a `dict`, ready to be stored in JSON format.
.. note::
Serializing a :class:`HMMClassifier` implicitly serializes the internal :class:`HMM` objects
by calling :meth:`HMM.as_dict` and storing all of the model data in a single `dict`.
Returns
-------
serialized: dict
JSON-ready serialization of the :class:`HMMClassifier` object.
See Also
--------
HMM.as_dict: The serialization function used for individual :class:`HMM` objects.
"""

try:
self._models
except AttributeError as e:
raise AttributeError('The classifier needs to be fitted before it can be exported to a dict') from e

return {'models': [model.as_dict() for model in self._models]}

def save(self, path):
"""Converts the :class:`HMMClassifier` object into a `dict` and stores it in a JSON file.
Parameters
----------
path: str
File path (with or without `.json` extension) to store the JSON-serialized :class:`HMMClassifier` object.
See Also
--------
as_dict: Generates the `dict` that is stored in the JSON file.
"""

data = self.as_dict()
with open(path, 'w') as f:
json.dump(data, f, indent=4)

@classmethod
def load(cls, path, random_state=None):
"""Deserializes either a `dict` or JSON serialized :class:`HMMClassifier` object.
Parameters
----------
path: str
File path of the serialized JSON data generated by the :meth:`save` method.
random_state: numpy.random.RandomState, int, optional
A random state object or seed for reproducible randomness.
Returns
-------
deserialized: :class:`HMMClassifier`
The deserialized HMM classifier object.
See Also
--------
save: Serializes a :class:`HMMClassifier` into a JSON file.
as_dict: Generates a `dict` representation of the :class:`HMMClassifier`.
"""

# Load the serialized HMM classifier as JSON
with open(path, 'r') as f:
data = json.load(f)

clf = cls()
clf._models = [HMM.load(model, random_state=random_state) for model in data['models']]

return clf
75 changes: 73 additions & 2 deletions lib/test/lib/classifiers/dtwknn/test_dtwknn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pytest
import warnings
import os
import h5py
import numpy as np
from copy import deepcopy
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
from sequentia.classifiers import DTWKNN
from ....support import assert_equal, assert_not_equal
from ....support import assert_equal, assert_all_equal, assert_not_equal

# Set seed for reproducible randomness
seed = 0
Expand Down Expand Up @@ -213,4 +215,73 @@ def test_evaluate_with_no_labels_k3_r10_no_verbose(capsys):
acc, cm = clfs[2].evaluate(X, y, labels=None, verbose=False)
assert 'Classifying examples' not in capsys.readouterr().err
assert isinstance(acc, float)
assert isinstance(cm, np.ndarray)
assert isinstance(cm, np.ndarray)

# ============= #
# DTWKNN.save() #
# ============= #

def test_save_directory():
"""Save a DTWKNN classifier into a directory"""
with pytest.raises(OSError) as e:
clfs[2].save('.')

def test_save_no_extension():
"""Save a DTWKNN classifier into a file without an extension"""
try:
clfs[2].save('test')
assert os.path.isfile('test')
finally:
os.remove('test')

def test_save_with_extension():
"""Save a DTWKNN classifier into a file with a .h5 extension"""
try:
clfs[2].save('test.h5')
assert os.path.isfile('test.h5')
finally:
os.remove('test.h5')

# ============= #
# DTWKNN.load() #
# ============= #

def test_load_invalid_path():
"""Load a DTWKNN classifier from a directory"""
with pytest.raises(OSError) as e:
DTWKNN.load('.')

def test_load_inexistent_path():
"""Load a DTWKNN classifier from an inexistent path"""
with pytest.raises(OSError) as e:
DTWKNN.load('test')

def test_load_invalid_format():
"""Load a DTWKNN classifier from an illegally formatted file"""
try:
with open('test', 'w') as f:
f.write('illegal')
with pytest.raises(OSError) as e:
DTWKNN.load('test')
finally:
os.remove('test')

def test_load_path():
"""Load a DTWKNN classifier from a valid HDF5 file"""
try:
clfs[2].save('test')
clf = DTWKNN.load('test')

assert isinstance(clf, DTWKNN)
assert clf._k == 3
assert clf._radius == 10
assert isinstance(clf._X, list)
assert len(clf._X) == len(X)
assert isinstance(clf._X[0], np.ndarray)
assert_all_equal(clf._X, X)
assert isinstance(clf._y, list)
assert len(clf._y) == len(y)
assert isinstance(clf._y[0], str)
assert all(y1 == y2 for y1, y2 in zip(clf._y, y))
finally:
os.remove('test')
Loading

0 comments on commit 79767ea

Please sign in to comment.