diff --git a/lib/sequentia/classifiers/dtwknn/dtwknn.py b/lib/sequentia/classifiers/dtwknn/dtwknn.py index 50cc64e9..8cc52305 100644 --- a/lib/sequentia/classifiers/dtwknn/dtwknn.py +++ b/lib/sequentia/classifiers/dtwknn/dtwknn.py @@ -2,6 +2,7 @@ import tqdm.auto import random import numpy as np +import h5py from joblib import Parallel, delayed from multiprocessing import cpu_count from fastdtw import fastdtw @@ -155,4 +156,75 @@ def evaluate(self, X, y, labels=None, verbose=True, n_jobs=1): predictions = self.predict(X, verbose=verbose, n_jobs=n_jobs) cm = confusion_matrix(y, predictions, labels=labels) - return np.sum(np.diag(cm)) / np.sum(cm), cm \ No newline at end of file + return np.sum(np.diag(cm)) / np.sum(cm), cm + + def save(self, path): + """Stores the :class:`DTWKNN` object into a `HDF5 `_ file. + + .. note: + As :math:`k`-NN is a non-parametric classification algorithms, saving the classifier simply saves + all of the training observation sequences and labels (along with the hyper-parameters). + + Parameters + ---------- + path: str + File path (with or without `.h5` extension) to store the HDF5-serialized :class:`DTWKNN` object. + """ + + try: + (self._X, self._y) + except AttributeError: + raise RuntimeError('The classifier needs to be fitted before it can be saved') + + with h5py.File(path, 'w') as f: + # Store hyper-parameters (k, radius) + params = f.create_group('params') + params.create_dataset('k', data=self._k) + params.create_dataset('radius', data=self._radius) + + # Store training data and labels (X, y) + data = f.create_group('data') + X = data.create_group('X') + for i, x in enumerate(self._X): + X.create_dataset(str(i), data=x) + data.create_dataset('y', data=np.string_(self._y)) + + @classmethod + def load(cls, path, encoding='utf-8', metric=euclidean): + """Deserializes a HDF5-serialized :class:`DTWKNN` object. + + Parameters + ---------- + path: str + File path of the serialized HDF5 data generated by the :meth:`save` method. + + encoding: str + The encoding used to represent training labels when decoding the HDF5 file. + + .. note:: + Supported string encodings in Python can be found `here `_. + + metric: callable + Distance metric for FastDTW. + + Returns + ------- + deserialized: :class:`DTWKNN` + The deserialized DTWKNN classifier object. + + See Also + -------- + save: Serializes a :class:`DTWKNN` into a HDF5 file. + """ + + with h5py.File(path, 'r') as f: + # Deserialize the model hyper-parameters + params = f['params'] + clf = cls(k=int(params['k'][()]), radius=int(params['radius'][()]), metric=metric) + + # Deserialize the training data and labels + X, y = f['data']['X'], f['data']['y'] + clf._X = [np.array(X[k]) for k in X.keys()] + clf._y = [label.decode(encoding) for label in y] + + return clf \ No newline at end of file diff --git a/lib/sequentia/classifiers/hmm/hmm.py b/lib/sequentia/classifiers/hmm/hmm.py index a5cc35e4..f2e6af50 100644 --- a/lib/sequentia/classifiers/hmm/hmm.py +++ b/lib/sequentia/classifiers/hmm/hmm.py @@ -1,3 +1,4 @@ +import json import numpy as np import pomegranate as pg from .topologies.ergodic import _ErgodicTopology @@ -174,4 +175,96 @@ def transitions(self): @transitions.setter def transitions(self, probabilities): self._topology.validate_transitions(probabilities) - self._transitions = probabilities \ No newline at end of file + self._transitions = probabilities + + def as_dict(self): + """Serializes the :class:`HMM` object into a `dict`, ready to be stored in JSON format. + + Returns + ------- + serialized: dict + JSON-ready serialization of the :class:`HMM` object. + """ + + try: + self._model + except AttributeError as e: + raise AttributeError('The model needs to be fitted before it can be exported to a dict') from e + + model = self._model.to_json() + + if 'NaN' in model: + raise ValueError('Encountered NaN value(s) in HMM parameters') + else: + return { + 'label': self._label, + 'n_states': self._n_states, + 'topology': 'ergodic' if isinstance(self._topology, _ErgodicTopology) else 'left-right', + 'model': { + 'initial': self._initial.tolist(), + 'transitions': self._transitions.tolist(), + 'n_seqs': self._n_seqs, + 'n_features': self._n_features, + 'hmm': json.loads(model) + } + } + + def save(self, path): + """Converts the :class:`HMM` object into a `dict` and stores it in a JSON file. + + Parameters + ---------- + path: str + File path (with or without `.json` extension) to store the JSON-serialized :class:`HMM` object. + + See Also + -------- + as_dict: Generates the `dict` that is stored in the JSON file. + """ + + data = self.as_dict() + with open(path, 'w') as f: + json.dump(data, f, indent=4) + + @classmethod + def load(cls, data, random_state=None): + """Deserializes either a `dict` or JSON serialized :class:`HMM` object. + + Parameters + ---------- + data: str or dict + - File path of the serialized JSON data generated by the :meth:`save` method. + - `dict` representation of the :class:`HMM`, generated by the :meth:`as_dict` method. + + random_state: numpy.random.RandomState, int, optional + A random state object or seed for reproducible randomness. + + Returns + ------- + deserialized: :class:`HMM` + The deserialized HMM object. + + See Also + -------- + save: Serializes a :class:`HMM` into a JSON file. + as_dict: Generates a `dict` representation of the :class:`HMM`. + """ + + # Load the serialized HMM data + if isinstance(data, dict): + pass + elif isinstance(data, str): + with open(data, 'r') as f: + data = json.load(f) + else: + pass + + # Deserialize the data into a HMM object + hmm = cls(data['label'], data['n_states'], data['topology'], random_state=random_state) + hmm._initial = np.array(data['model']['initial']) + hmm._transitions = np.array(data['model']['transitions']) + hmm._n_seqs = data['model']['n_seqs'] + hmm._n_features = data['model']['n_features'] + hmm._model = pg.HiddenMarkovModel.from_json(json.dumps(data['model']['hmm'])) + + return hmm \ No newline at end of file diff --git a/lib/sequentia/classifiers/hmm/hmm_classifier.py b/lib/sequentia/classifiers/hmm/hmm_classifier.py index 15a9fa88..601e6ebd 100644 --- a/lib/sequentia/classifiers/hmm/hmm_classifier.py +++ b/lib/sequentia/classifiers/hmm/hmm_classifier.py @@ -1,3 +1,4 @@ +import json import numpy as np from .hmm import HMM from sklearn.metrics import confusion_matrix @@ -121,4 +122,77 @@ def evaluate(self, X, y, prior=True, labels=None): predictions = self.predict(X, prior, return_scores=False) cm = confusion_matrix(y, predictions, labels=labels) - return np.sum(np.diag(cm)) / np.sum(cm), cm \ No newline at end of file + return np.sum(np.diag(cm)) / np.sum(cm), cm + + def as_dict(self): + """Serializes the :class:`HMMClassifier` object into a `dict`, ready to be stored in JSON format. + + .. note:: + Serializing a :class:`HMMClassifier` implicitly serializes the internal :class:`HMM` objects + by calling :meth:`HMM.as_dict` and storing all of the model data in a single `dict`. + + Returns + ------- + serialized: dict + JSON-ready serialization of the :class:`HMMClassifier` object. + + See Also + -------- + HMM.as_dict: The serialization function used for individual :class:`HMM` objects. + """ + + try: + self._models + except AttributeError as e: + raise AttributeError('The classifier needs to be fitted before it can be exported to a dict') from e + + return {'models': [model.as_dict() for model in self._models]} + + def save(self, path): + """Converts the :class:`HMMClassifier` object into a `dict` and stores it in a JSON file. + + Parameters + ---------- + path: str + File path (with or without `.json` extension) to store the JSON-serialized :class:`HMMClassifier` object. + + See Also + -------- + as_dict: Generates the `dict` that is stored in the JSON file. + """ + + data = self.as_dict() + with open(path, 'w') as f: + json.dump(data, f, indent=4) + + @classmethod + def load(cls, path, random_state=None): + """Deserializes either a `dict` or JSON serialized :class:`HMMClassifier` object. + + Parameters + ---------- + path: str + File path of the serialized JSON data generated by the :meth:`save` method. + + random_state: numpy.random.RandomState, int, optional + A random state object or seed for reproducible randomness. + + Returns + ------- + deserialized: :class:`HMMClassifier` + The deserialized HMM classifier object. + + See Also + -------- + save: Serializes a :class:`HMMClassifier` into a JSON file. + as_dict: Generates a `dict` representation of the :class:`HMMClassifier`. + """ + + # Load the serialized HMM classifier as JSON + with open(path, 'r') as f: + data = json.load(f) + + clf = cls() + clf._models = [HMM.load(model, random_state=random_state) for model in data['models']] + + return clf \ No newline at end of file diff --git a/lib/test/lib/classifiers/dtwknn/test_dtwknn.py b/lib/test/lib/classifiers/dtwknn/test_dtwknn.py index dab83329..f833e5f8 100644 --- a/lib/test/lib/classifiers/dtwknn/test_dtwknn.py +++ b/lib/test/lib/classifiers/dtwknn/test_dtwknn.py @@ -1,11 +1,13 @@ import pytest import warnings +import os +import h5py import numpy as np from copy import deepcopy with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) from sequentia.classifiers import DTWKNN -from ....support import assert_equal, assert_not_equal +from ....support import assert_equal, assert_all_equal, assert_not_equal # Set seed for reproducible randomness seed = 0 @@ -213,4 +215,73 @@ def test_evaluate_with_no_labels_k3_r10_no_verbose(capsys): acc, cm = clfs[2].evaluate(X, y, labels=None, verbose=False) assert 'Classifying examples' not in capsys.readouterr().err assert isinstance(acc, float) - assert isinstance(cm, np.ndarray) \ No newline at end of file + assert isinstance(cm, np.ndarray) + +# ============= # +# DTWKNN.save() # +# ============= # + +def test_save_directory(): + """Save a DTWKNN classifier into a directory""" + with pytest.raises(OSError) as e: + clfs[2].save('.') + +def test_save_no_extension(): + """Save a DTWKNN classifier into a file without an extension""" + try: + clfs[2].save('test') + assert os.path.isfile('test') + finally: + os.remove('test') + +def test_save_with_extension(): + """Save a DTWKNN classifier into a file with a .h5 extension""" + try: + clfs[2].save('test.h5') + assert os.path.isfile('test.h5') + finally: + os.remove('test.h5') + +# ============= # +# DTWKNN.load() # +# ============= # + +def test_load_invalid_path(): + """Load a DTWKNN classifier from a directory""" + with pytest.raises(OSError) as e: + DTWKNN.load('.') + +def test_load_inexistent_path(): + """Load a DTWKNN classifier from an inexistent path""" + with pytest.raises(OSError) as e: + DTWKNN.load('test') + +def test_load_invalid_format(): + """Load a DTWKNN classifier from an illegally formatted file""" + try: + with open('test', 'w') as f: + f.write('illegal') + with pytest.raises(OSError) as e: + DTWKNN.load('test') + finally: + os.remove('test') + +def test_load_path(): + """Load a DTWKNN classifier from a valid HDF5 file""" + try: + clfs[2].save('test') + clf = DTWKNN.load('test') + + assert isinstance(clf, DTWKNN) + assert clf._k == 3 + assert clf._radius == 10 + assert isinstance(clf._X, list) + assert len(clf._X) == len(X) + assert isinstance(clf._X[0], np.ndarray) + assert_all_equal(clf._X, X) + assert isinstance(clf._y, list) + assert len(clf._y) == len(y) + assert isinstance(clf._y[0], str) + assert all(y1 == y2 for y1, y2 in zip(clf._y, y)) + finally: + os.remove('test') \ No newline at end of file diff --git a/lib/test/lib/classifiers/hmm/test_hmm.py b/lib/test/lib/classifiers/hmm/test_hmm.py index cbabd387..ab3da229 100644 --- a/lib/test/lib/classifiers/hmm/test_hmm.py +++ b/lib/test/lib/classifiers/hmm/test_hmm.py @@ -1,5 +1,7 @@ import pytest import warnings +import os +import json import numpy as np from copy import deepcopy with warnings.catch_warnings(): @@ -21,9 +23,9 @@ hmm_lr = HMM(label='c1', n_states=5, topology='left-right', random_state=rng) hmm_e = HMM(label='c1', n_states=5, topology='ergodic', random_state=rng) -# ========================= # -# HMM.set_uniform_initial() # -# ========================= # +# ================================================== # +# HMM.set_uniform_initial() + HMM.initial (property) # +# ================================================== # def test_left_right_uniform_initial(): """Uniform initial state distribution for a left-right HMM""" @@ -41,9 +43,9 @@ def test_ergodic_uniform_initial(): 0.2, 0.2, 0.2, 0.2, 0.2 ])) -# ======================== # -# HMM.set_random_initial() # -# ======================== # +# ================================================= # +# HMM.set_random_initial() + HMM.initial (property) # +# ================================================= # def test_left_right_random_initial(): """Random initial state distribution for a left-right HMM""" @@ -61,9 +63,9 @@ def test_ergodic_random_initial(): 0.35029635, 0.13344569, 0.02784745, 0.33782453, 0.15058597 ])) -# ====================================================== # -# HMM.set_uniform_transitions() + HMM.initial (property) # -# ====================================================== # +# ========================================================== # +# HMM.set_uniform_transitions() + HMM.transitions (property) # +# ========================================================== # def test_left_right_uniform_transitions(): """Uniform transition matrix for a left-right HMM""" @@ -387,4 +389,185 @@ def test_ergodic_transitions_ergodic(): topology = _ErgodicTopology(n_states=5, random_state=rng) transitions = topology.random_transitions() hmm.transitions = transitions - assert_equal(hmm.transitions, transitions) \ No newline at end of file + assert_equal(hmm.transitions, transitions) + +# ============= # +# HMM.as_dict() # +# ============= # + +def test_as_dict_unfitted(): + """Export an unfitted HMM to dict""" + hmm = deepcopy(hmm_e) + with pytest.raises(AttributeError) as e: + hmm.as_dict() + assert str(e.value) == 'The model needs to be fitted before it can be exported to a dict' + +def test_as_dict_fitted(): + """Export a fitted HMM to dict""" + hmm = deepcopy(hmm_e) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + d = hmm.as_dict() + + assert d['label'] == 'c1' + assert d['n_states'] == 5 + assert d['topology'] == 'ergodic' + assert_equal(d['model']['initial'], np.array([ + 0.2, 0.2, 0.2, 0.2, 0.2 + ])) + assert_equal(d['model']['transitions'], np.array([ + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2] + ])) + assert d['model']['n_seqs'] == 3 + assert d['model']['n_features'] == 3 + assert isinstance(d['model']['hmm'], dict) + +# ========== # +# HMM.save() # +# ========== # + +def test_save_directory(): + """Save a HMM into a directory""" + hmm = deepcopy(hmm_e) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + with pytest.raises(IsADirectoryError) as e: + hmm.save('.') + assert str(e.value) == "[Errno 21] Is a directory: '.'" + +def test_save_no_extension(): + """Save a HMM into a file without an extension""" + try: + hmm = deepcopy(hmm_e) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + hmm.save('test') + assert os.path.isfile('test') + finally: + os.remove('test') + +def test_save_with_extension(): + """Save a HMM into a file with a .json extension""" + try: + hmm = deepcopy(hmm_e) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + hmm.save('test.json') + assert os.path.isfile('test.json') + finally: + os.remove('test.json') + +# ========== # +# HMM.load() # +# ========== # + +def test_load_invalid_dict(): + """Load a HMM from an invalid dict""" + with pytest.raises(KeyError) as e: + HMM.load({}) + +def test_load_dict(): + """Load a HMM from a valid dict""" + hmm = deepcopy(hmm_e) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + hmm = HMM.load(hmm.as_dict()) + + assert isinstance(hmm, HMM) + assert hmm._label == 'c1' + assert hmm._n_states == 5 + assert isinstance(hmm._topology, _ErgodicTopology) + assert_equal(hmm._initial, np.array([ + 0.2, 0.2, 0.2, 0.2, 0.2 + ])) + assert_equal(hmm._transitions, np.array([ + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2] + ])) + assert hmm._n_seqs == 3 + assert hmm._n_features == 3 + assert isinstance(hmm._model, pg.HiddenMarkovModel) + +def test_load_invalid_path(): + """Load a HMM from a directory""" + with pytest.raises(IsADirectoryError) as e: + HMM.load('.') + +def test_load_inexistent_path(): + """Load a HMM from an inexistent path""" + with pytest.raises(FileNotFoundError) as e: + HMM.load('test') + +def test_load_invalid_format(): + """Load a HMM from an illegally formatted file""" + try: + with open('test', 'w') as f: + f.write('illegal') + with pytest.raises(json.decoder.JSONDecodeError) as e: + HMM.load('test') + finally: + os.remove('test') + +def test_load_invalid_json(): + """Load a HMM from an invalid JSON file""" + try: + with open('test', 'w') as f: + f.write("{}") + with pytest.raises(KeyError) as e: + HMM.load('test') + finally: + os.remove('test') + +def test_load_path(): + """Load a HMM from a valid JSON file""" + try: + hmm = deepcopy(hmm_e) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + hmm.save('test') + hmm = HMM.load('test') + + assert isinstance(hmm, HMM) + assert hmm._label == 'c1' + assert hmm._n_states == 5 + assert isinstance(hmm._topology, _ErgodicTopology) + assert_equal(hmm._initial, np.array([ + 0.2, 0.2, 0.2, 0.2, 0.2 + ])) + assert_equal(hmm._transitions, np.array([ + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2, 0.2] + ])) + assert hmm._n_seqs == 3 + assert hmm._n_features == 3 + assert isinstance(hmm._model, pg.HiddenMarkovModel) + finally: + os.remove('test') \ No newline at end of file diff --git a/lib/test/lib/classifiers/hmm/test_hmm_classifier.py b/lib/test/lib/classifiers/hmm/test_hmm_classifier.py index ac5623eb..00284cc4 100644 --- a/lib/test/lib/classifiers/hmm/test_hmm_classifier.py +++ b/lib/test/lib/classifiers/hmm/test_hmm_classifier.py @@ -1,8 +1,13 @@ import pytest import warnings +import os +import json import numpy as np from copy import deepcopy -from sequentia.classifiers import HMM, HMMClassifier +with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + import pomegranate as pg +from sequentia.classifiers import HMM, HMMClassifier, _ErgodicTopology from ....support import assert_equal, assert_not_equal # Set seed for reproducible randomness @@ -34,6 +39,14 @@ hmm_clf = HMMClassifier() hmm_clf.fit(hmm_list) +# Fit a classifier (with no NaN values) +hmm_clf_no_nan = HMMClassifier() +hmm = HMM(label='c1', n_states=5, topology='ergodic', random_state=rng) +hmm.set_uniform_initial() +hmm.set_uniform_transitions() +hmm.fit([rng.random((10 * i, 3)) for i in range(1, 4)]) +hmm_clf_no_nan.fit([hmm]) + # =================== # # HMMClassifier.fit() # # =================== # @@ -191,4 +204,109 @@ def test_evaluate_no_prior_no_labels(): """Evaluate with no prior and no confusion matrix labels""" acc, cm = hmm_clf.evaluate(X, Y, prior=False, labels=None) assert isinstance(acc, float) - assert isinstance(cm, np.ndarray) \ No newline at end of file + assert isinstance(cm, np.ndarray) + +# ======================= # +# HMMClassifier.as_dict() # +# ======================= # + +def test_as_dict_unfitted(): + """Export an unfitted HMM classifier to dict""" + with pytest.raises(AttributeError) as e: + HMMClassifier().as_dict() + assert str(e.value) == 'The classifier needs to be fitted before it can be exported to a dict' + +def test_as_dict_fitted(): + """Export a fitted HMM classifier to dict""" + d = hmm_clf_no_nan.as_dict() + + assert isinstance(d['models'], list) + assert len(d['models']) == 1 + assert d['models'][0]['label'] == 'c1' + assert d['models'][0]['n_states'] == 5 + assert d['models'][0]['topology'] == 'ergodic' + assert np.array(d['models'][0]['model']['initial']).shape == (5,) + assert np.array(d['models'][0]['model']['transitions']).shape == (5, 5) + assert d['models'][0]['model']['n_seqs'] == 3 + assert d['models'][0]['model']['n_features'] == 3 + assert isinstance(d['models'][0]['model']['hmm'], dict) + +# ==================== # +# HMMClassifier.save() # +# ==================== # + +def test_save_directory(): + """Save a HMM classifier into a directory""" + with pytest.raises(IsADirectoryError) as e: + hmm_clf_no_nan.save('.') + assert str(e.value) == "[Errno 21] Is a directory: '.'" + +def test_save_no_extension(): + """Save a HMM classifier into a file without an extension""" + try: + hmm_clf_no_nan.save('test') + assert os.path.isfile('test') + finally: + os.remove('test') + +def test_save_with_extension(): + """Save a HMM classifier into a file with a .json extension""" + try: + hmm_clf_no_nan.save('test.json') + assert os.path.isfile('test.json') + finally: + os.remove('test.json') + +# ==================== # +# HMMClassifier.load() # +# ==================== # + +def test_load_invalid_path(): + """Load a HMM classifier from a directory""" + with pytest.raises(IsADirectoryError) as e: + HMMClassifier.load('.') + +def test_load_inexistent_path(): + """Load a HMM classifier from an inexistent path""" + with pytest.raises(FileNotFoundError) as e: + HMMClassifier.load('test') + +def test_load_invalid_format(): + """Load a HMM classifier from an illegally formatted file""" + try: + with open('test', 'w') as f: + f.write('illegal') + with pytest.raises(json.decoder.JSONDecodeError) as e: + HMMClassifier.load('test') + finally: + os.remove('test') + +def test_load_invalid_json(): + """Load a HMM classifier from an invalid JSON file""" + try: + with open('test', 'w') as f: + f.write("{}") + with pytest.raises(KeyError) as e: + HMMClassifier.load('test') + finally: + os.remove('test') + +def test_load_path(): + """Load a HMM classifier from a valid JSON file""" + try: + hmm_clf_no_nan.save('test') + h = HMMClassifier.load('test') + + assert isinstance(h, HMMClassifier) + assert isinstance(h._models, list) + assert len(h._models) == 1 + assert h._models[0]._label == 'c1' + assert h._models[0]._n_states == 5 + assert isinstance(h._models[0]._topology, _ErgodicTopology) + assert h._models[0]._initial.shape == (5,) + assert h._models[0]._transitions.shape == (5, 5) + assert h._models[0]._n_seqs == 3 + assert h._models[0]._n_features == 3 + assert isinstance(h._models[0]._model, pg.HiddenMarkovModel) + finally: + os.remove('test') \ No newline at end of file diff --git a/setup.py b/setup.py index a09111ab..75cc7b36 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ 'scipy>=1.3,<2', 'scikit-learn>=0.22,<1', 'tqdm>=4.36,<5', - 'joblib>=0.14,<1' + 'joblib>=0.14,<1', + 'h5py>=2.10,<2.11' ] ) \ No newline at end of file