diff --git a/README.md b/README.md index 0877f28f..f19bc4cd 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Sequentia offers the use of multivariate observation sequences with varying dura - [x] Hidden Markov Models (via [Pomegranate](https://github.com/jmschrei/pomegranate) [[1]](#references)) - [x] Multivariate Gaussian Emissions - - [ ] Gaussian Mixture Model Emissions (_soon!_) + - [x] Gaussian Mixture Model Emissions (full and diagonal covariances) - [x] Left-Right and Ergodic Topologies - [x] Approximate Dynamic Time Warping k-Nearest Neighbors (implemented with [FastDTW](https://github.com/slaypni/fastdtw) [[2]](#references)) - [ ] Long Short-Term Memory Networks (_soon!_) @@ -67,8 +67,6 @@ Sequentia offers the use of multivariate observation sequences with varying dura - [x] Multi-processing for DTW k-NN predictions -> **Disclaimer**: The package currently remains largely untested and is still in its early stages – _use with caution_! - ## Installation ```console diff --git a/docs/_includes/examples/classifiers/gmmhmm.py b/docs/_includes/examples/classifiers/gmmhmm.py new file mode 100644 index 00000000..966306ec --- /dev/null +++ b/docs/_includes/examples/classifiers/gmmhmm.py @@ -0,0 +1,11 @@ +import numpy as np +from sequentia.classifiers import GMMHMM + +# Create some sample data +X = [np.random.random((10 * i, 3)) for i in range(1, 4)] + +# Create and fit a left-right HMM with random transitions and initial state distribution +hmm = GMMHMM(label='class1', n_states=5, n_components=3, covariance='diagonal', topology='left-right') +hmm.set_random_initial() +hmm.set_random_transitions() +hmm.fit(X) \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index fc8c831d..610de59b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -40,7 +40,7 @@ 'm2r' ] -autodoc_member_order = 'bysource' +# autodoc_member_order = 'bysource' autosummary_generate = True numpydoc_show_class_members = False diff --git a/docs/sections/classifiers/hmm.rst b/docs/sections/classifiers/hmm.rst index b302559e..15411265 100644 --- a/docs/sections/classifiers/hmm.rst +++ b/docs/sections/classifiers/hmm.rst @@ -91,10 +91,47 @@ API reference .. autoclass:: sequentia.classifiers.hmm.HMM :members: +Hidden Markov Model with Gaussian Mixture Emissions (``GMMHMM``) +================================================================ + +The assumption that a single multivariate Gaussian emission distribution +is accurate and representative enough to model the probability of observation +vectors of any state of a HMM is often a very strong and naive one. + +Instead, a more powerful approach is to represent the emission distribution as +a mixture of multiple multivariate Gaussian densities. An emission distribution +for state :math:`m`, formed by a mixture of :math:`G` multivariate Gaussian densities is defined as: + +.. math:: + b_m(\mathbf{o}^{(t)}) = \sum_{g=1}^G c_g^{(m)} \mathcal{N}\big(\mathbf{o}^{(t)}\ ;\ \boldsymbol\mu_g^{(m)}, \Sigma_g^{(m)}\big) + +where :math:`\mathbf{o}^{(t)}` is an observation vector at time :math:`t`, +:math:`c_g^{(m)}` is a *mixing coefficient* such that :math:`\sum_{g=1}^G c_g^{(m)} = 1` +and :math:`\boldsymbol\mu_g^{(m)}` and :math:`\Sigma_g^{(m)}` are the mean vector +and covariance matrix of the :math:`g^\text{th}` mixture component of the :math:`m^\text{th}` +state, respectively. + +Even in the case that multiple Gaussian densities are not needed, the mixing coefficients +can be adjusted so that irrelevant Gaussians are omitted and only a single Gaussian remains. + +Example +------- + +.. literalinclude:: ../../_includes/examples/classifiers/gmmhmm.py + :language: python + :linenos: + +API reference +------------- + +.. autoclass:: sequentia.classifiers.hmm.GMMHMM + :inherited-members: + :members: + Hidden Markov Model Classifier (``HMMClassifier``) ================================================== -Multiple HMMs can be combined to form a multi-class classifier. +Multiple HMMs (and/or GMMHMMs) can be combined to form a multi-class classifier. To classify a new observation sequence :math:`O'`, this works by: 1. | Creating and training the HMMs :math:`\lambda_1, \lambda_2, \ldots, \lambda_N`. diff --git a/lib/sequentia/classifiers/__init__.py b/lib/sequentia/classifiers/__init__.py index b8c42468..a77871d2 100644 --- a/lib/sequentia/classifiers/__init__.py +++ b/lib/sequentia/classifiers/__init__.py @@ -1,5 +1,5 @@ from .hmm import ( - HMM, HMMClassifier, + HMM, GMMHMM, HMMClassifier, _Topology, _LeftRightTopology, _ErgodicTopology, _StrictLeftRightTopology ) from .dtwknn import DTWKNN \ No newline at end of file diff --git a/lib/sequentia/classifiers/hmm/__init__.py b/lib/sequentia/classifiers/hmm/__init__.py index 44131c0f..c8f88f1c 100644 --- a/lib/sequentia/classifiers/hmm/__init__.py +++ b/lib/sequentia/classifiers/hmm/__init__.py @@ -1,3 +1,4 @@ from .hmm import HMM +from .gmmhmm import GMMHMM from .hmm_classifier import HMMClassifier from .topologies import _Topology, _LeftRightTopology, _ErgodicTopology, _StrictLeftRightTopology \ No newline at end of file diff --git a/lib/sequentia/classifiers/hmm/gmmhmm.py b/lib/sequentia/classifiers/hmm/gmmhmm.py new file mode 100644 index 00000000..15402522 --- /dev/null +++ b/lib/sequentia/classifiers/hmm/gmmhmm.py @@ -0,0 +1,197 @@ +import numpy as np, pomegranate as pg, json +from .hmm import HMM + +class GMMHMM(HMM): + """A hidden Markov model representing an isolated temporal sequence class, + with mixtures of multivariate Gaussian components representing state emission distributions. + + Parameters + ---------- + label: str + A label for the model, corresponding to the class being represented. + + n_states: int + The number of states for the model. + + n_components: int + The number of mixture components used in the emission distribution for each state. + + covariance: {'diagonal', 'full'} + The covariance matrix type. + + topology: {'ergodic', 'left-right', 'strict-left-right'} + The topology for the model. + + random_state: numpy.random.RandomState, int, optional + A random state object or seed for reproducible randomness. + + Attributes + ---------- + label: str + The label for the model. + + n_states: int + The number of states for the model. + + n_seqs: int + The number of observation sequences use to train the model. + + initial: numpy.ndarray + The initial state distribution of the model. + + transitions: numpy.ndarray + The transition matrix of the model. + """ + + def __init__(self, label, n_states, n_components, covariance='diagonal', topology='left-right', random_state=None): + super().__init__(label, n_states, topology, random_state) + self._n_components = self._val.restricted_integer( + n_components, lambda x: x > 1, desc='number of mixture components', expected='greater than one') + self._covariance = self._val.one_of(covariance, ['diagonal', 'full'], desc='covariance matrix type') + + def fit(self, X, n_jobs=1): + """Fits the HMM to observation sequences assumed to be labeled as the class that the model represents. + + Parameters + ---------- + X: List[numpy.ndarray] + Collection of multivariate observation sequences, each of shape :math:`(T \\times D)` where + :math:`T` may vary per observation sequence. + + n_jobs: int + | The number of jobs to run in parallel. + | Setting this to -1 will use all available CPU cores. + """ + X = self._val.observation_sequences(X) + self._val.restricted_integer(n_jobs, lambda x: x == -1 or x > 0, 'number of jobs', '-1 or greater than zero') + + try: + (self._initial, self._transitions) + except AttributeError as e: + raise AttributeError('Must specify initial state distribution and transitions before the HMM can be fitted') from e + + self._n_seqs = len(X) + self._n_features = X[0].shape[1] + + # Create a mixture distribution of multivariate Gaussian emission components using combined samples for initial parameter estimation + concat = np.concatenate(X) + if self._covariance == 'diagonal': + # Use diagonal covariance matrices + dist = pg.GeneralMixtureModel( + [pg.MultivariateGaussianDistribution(concat.mean(axis=0), concat.std(axis=0) * np.eye(self._n_features)) for _ in range(self._n_components)], + self._random_state.dirichlet(np.ones(self._n_components) + ) + ) + else: + # Use full covariance matrices + dist = pg.GeneralMixtureModel.from_samples(pg.MultivariateGaussianDistribution, self._n_components, concat) + + # Create the HMM object + self._model = pg.HiddenMarkovModel.from_matrix( + name=self._label, + transition_probabilities=self._transitions, + distributions=[dist.copy() for _ in range(self._n_states)], + starts=self._initial + ) + + # Perform the Baum-Welch algorithm to fit the model to the observations + self._model.fit(X, n_jobs=n_jobs) + + # Update the initial state distribution and transitions to reflect the updated parameters + inner_tx = self._model.dense_transition_matrix()[:, :self._n_states] + self._initial = inner_tx[self._n_states] + self._transitions = inner_tx[:self._n_states] + + @property + def n_components(self): + return self._n_components + + @property + def covariance(self): + return self._covariance + + def as_dict(self): + """Serializes the :class:`GMMHMM` object into a `dict`, ready to be stored in JSON format. + + Returns + ------- + serialized: dict + JSON-ready serialization of the :class:`GMMHMM` object. + """ + + try: + self._model + except AttributeError as e: + raise AttributeError('The model needs to be fitted before it can be exported to a dict') from e + + model = self._model.to_json() + + if 'NaN' in model: + raise ValueError('Encountered NaN value(s) in HMM parameters') + else: + return { + 'type': 'GMMHMM', + 'label': self._label, + 'n_states': self._n_states, + 'n_components': self._n_components, + 'covariance': self._covariance, + 'topology': self._topologies[self._topology.__class__], + 'model': { + 'initial': self._initial.tolist(), + 'transitions': self._transitions.tolist(), + 'n_seqs': self._n_seqs, + 'n_features': self._n_features, + 'hmm': json.loads(model) + } + } + + @classmethod + def load(cls, data, random_state=None): + """Deserializes either a `dict` or JSON serialized :class:`GMMHMM` object. + + Parameters + ---------- + data: str or dict + - File path of the serialized JSON data generated by the :meth:`save` method. + - `dict` representation of the :class:`GMMHMM`, generated by the :meth:`as_dict` method. + + random_state: numpy.random.RandomState, int, optional + A random state object or seed for reproducible randomness. + + Returns + ------- + deserialized: :class:`GMMHMM` + The deserialized HMM object. + + See Also + -------- + save: Serializes a :class:`GMMHMM` into a JSON file. + as_dict: Generates a `dict` representation of the :class:`GMMHMM`. + """ + + # Load the serialized GMMHMM data + if isinstance(data, dict): + pass + elif isinstance(data, str): + with open(data, 'r') as f: + data = json.load(f) + else: + pass + + # Check that JSON is in the "correct" format + if data['type'] == 'HMM': + raise ValueError('You must use the HMM class to deserialize a stored HMM model') + elif data['type'] == 'GMMHMM': + pass + else: + raise ValueError("Attempted to deserialize an invalid model - expected 'type' field to be 'GMMHMM'") + + # Deserialize the data into a GMMHMM object + gmmhmm = cls(data['label'], data['n_states'], data['n_components'], data['covariance'], data['topology'], random_state=random_state) + gmmhmm._initial = np.array(data['model']['initial']) + gmmhmm._transitions = np.array(data['model']['transitions']) + gmmhmm._n_seqs = data['model']['n_seqs'] + gmmhmm._n_features = data['model']['n_features'] + gmmhmm._model = pg.HiddenMarkovModel.from_json(json.dumps(data['model']['hmm'])) + + return gmmhmm \ No newline at end of file diff --git a/lib/sequentia/classifiers/hmm/hmm.py b/lib/sequentia/classifiers/hmm/hmm.py index 81825280..1a4071f1 100644 --- a/lib/sequentia/classifiers/hmm/hmm.py +++ b/lib/sequentia/classifiers/hmm/hmm.py @@ -194,6 +194,7 @@ def as_dict(self): raise ValueError('Encountered NaN value(s) in HMM parameters') else: return { + 'type': 'HMM', 'label': self._label, 'n_states': self._n_states, 'topology': self._topologies[self._topology.__class__], @@ -256,6 +257,14 @@ def load(cls, data, random_state=None): else: pass + # Check that JSON is in the "correct" format + if data['type'] == 'HMM': + pass + elif data['type'] == 'GMMHMM': + raise ValueError('You must use the GMMHMM class to deserialize a stored GMMHMM model') + else: + raise ValueError("Attempted to deserialize an invalid model - expected 'type' field to be 'HMM'") + # Deserialize the data into a HMM object hmm = cls(data['label'], data['n_states'], data['topology'], random_state=random_state) hmm._initial = np.array(data['model']['initial']) diff --git a/lib/sequentia/classifiers/hmm/hmm_classifier.py b/lib/sequentia/classifiers/hmm/hmm_classifier.py index 249b0ff5..99c69c06 100644 --- a/lib/sequentia/classifiers/hmm/hmm_classifier.py +++ b/lib/sequentia/classifiers/hmm/hmm_classifier.py @@ -1,20 +1,22 @@ import numpy as np, json from .hmm import HMM +from .gmmhmm import GMMHMM from sklearn.metrics import confusion_matrix from ...internals import _Validator class HMMClassifier: - """A classifier that combines individual :class:`~HMM` objects, which model isolated sequences from different classes.""" + """A classifier that combines individual :class:`~HMM` and/or :class:`~GMMHMM` objects, + which model isolated sequences from different classes.""" def __init__(self): self._val = _Validator() def fit(self, models): - """Fits the classifier with a collection of :class:`~HMM` objects. + """Fits the classifier with a collection of :class:`~HMM` and/or :class:`~GMMHMM` objects. Parameters ---------- - models: List[HMM] or Dict[Any, HMM] + models: List[HMM, GMMHMM] or Dict[Any, HMM/GMMHMM] A collection of :class:`~HMM` objects to use for classification. """ if isinstance(models, list): @@ -127,8 +129,8 @@ def as_dict(self): """Serializes the :class:`HMMClassifier` object into a `dict`, ready to be stored in JSON format. .. note:: - Serializing a :class:`HMMClassifier` implicitly serializes the internal :class:`HMM` objects - by calling :meth:`HMM.as_dict` and storing all of the model data in a single `dict`. + Serializing a :class:`HMMClassifier` implicitly serializes the internal :class:`HMM` or :class:`GMMHMM` objects + by calling :meth:`HMM.as_dict` or :meth:`GMMHMM.as_dict` and storing all of the model data in a single `dict`. Returns ------- @@ -138,6 +140,7 @@ def as_dict(self): See Also -------- HMM.as_dict: The serialization function used for individual :class:`HMM` objects. + GMMHMM.as_dict: The serialization function used for individual :class:`GMMHMM` objects. """ try: @@ -192,6 +195,18 @@ def load(cls, path, random_state=None): data = json.load(f) clf = cls() - clf._models = [HMM.load(model, random_state=random_state) for model in data['models']] + clf._models = [] + + for model in data['models']: + # Retrieve the type of HMM + if model['type'] == 'HMM': + hmm = HMM + elif model['type'] == 'GMMHMM': + hmm = GMMHMM + else: + raise ValueError("Expected 'type' field to be either 'HMM' or 'GMMHMM'") + + # Deserialize the HMM and add it to the classifier + clf._models.append(hmm.load(model, random_state=random_state)) return clf \ No newline at end of file diff --git a/lib/test/lib/classifiers/hmm/test_gmmhmm.py b/lib/test/lib/classifiers/hmm/test_gmmhmm.py new file mode 100644 index 00000000..bb6bfd25 --- /dev/null +++ b/lib/test/lib/classifiers/hmm/test_gmmhmm.py @@ -0,0 +1,201 @@ +import pytest, warnings, os, json, numpy as np +from copy import deepcopy +with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + import pomegranate as pg +from sequentia.classifiers import HMM, GMMHMM, _LeftRightTopology, _ErgodicTopology, _StrictLeftRightTopology +from ....support import assert_equal, assert_not_equal + +# Set seed for reproducible randomness +seed = 0 +np.random.seed(seed) +rng = np.random.RandomState(seed) + +# Create some sample data +X = [rng.random((10 * i, 3)) for i in range(1, 4)] +x = rng.random((15, 3)) + +# Unparameterized HMMs +hmm_diag = GMMHMM(label='c1', n_states=5, n_components=5, covariance='diagonal', random_state=rng) +hmm_full = GMMHMM(label='c1', n_states=5, n_components=5, covariance='full', random_state=rng) + +# ============================== # +# GMMHMM.n_components (property) # +# ============================== # + +def test_n_components(): + assert deepcopy(hmm_diag).n_components == 5 + +# ============================ # +# GMMHMM.covariance (property) # +# ============================ # + +def test_covariance(): + assert deepcopy(hmm_diag).covariance == 'diagonal' + +# ================ # +# GMMHMM.as_dict() # +# ================ # + +def test_as_dict_unfitted(): + """Export an unfitted GMMHMM to dict""" + hmm = deepcopy(hmm_diag) + with pytest.raises(AttributeError) as e: + hmm.as_dict() + assert str(e.value) == 'The model needs to be fitted before it can be exported to a dict' + +def test_as_dict_fitted(): + """Export a fitted GMMHMM to dict""" + hmm = deepcopy(hmm_diag) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + before = hmm.initial, hmm.transitions + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + d = hmm.as_dict() + + assert d['type'] == 'GMMHMM' + assert d['label'] == 'c1' + assert d['n_states'] == 5 + assert d['n_components'] == 5 + assert d['covariance'] == 'diagonal' + assert d['topology'] == 'left-right' + assert_not_equal(d['model']['initial'], before[0]) + assert_not_equal(d['model']['transitions'], before[1]) + assert d['model']['n_seqs'] == 3 + assert d['model']['n_features'] == 3 + assert isinstance(d['model']['hmm'], dict) + +# ========== # +# HMM.save() # +# ========== # + +def test_save_directory(): + """Save a GMMHMM into a directory""" + hmm = deepcopy(hmm_diag) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + with pytest.raises(IsADirectoryError) as e: + hmm.save('.') + assert str(e.value) == "[Errno 21] Is a directory: '.'" + +def test_save_no_extension(): + """Save a GMMHMM into a file without an extension""" + try: + hmm = deepcopy(hmm_diag) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + hmm.save('test') + assert os.path.isfile('test') + finally: + os.remove('test') + +def test_save_with_extension(): + """Save a GMMHMM into a file with a .json extension""" + try: + hmm = deepcopy(hmm_diag) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + hmm.save('test.json') + assert os.path.isfile('test.json') + finally: + os.remove('test.json') + +# ============= # +# GMMHMM.load() # +# ============= # + +def test_load_invalid_dict(): + """Load a GMMHMM from an invalid dict""" + with pytest.raises(KeyError) as e: + GMMHMM.load({}) + +def test_load_dict(): + """Load a GMMHMM from a valid dict""" + hmm = deepcopy(hmm_diag) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + before = hmm.initial, hmm.transitions + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + hmm = GMMHMM.load(hmm.as_dict()) + + assert isinstance(hmm, GMMHMM) + assert hmm._label == 'c1' + assert hmm._n_states == 5 + assert hmm._n_components == 5 + assert hmm._covariance == 'diagonal' + assert isinstance(hmm._topology, _LeftRightTopology) + assert_not_equal(hmm._initial, before[0]) + assert_not_equal(hmm._transitions, before[1]) + assert hmm._n_seqs == 3 + assert hmm._n_features == 3 + assert isinstance(hmm._model, pg.HiddenMarkovModel) + +def test_load_invalid_path(): + """Load a GMMHMM from a directory""" + with pytest.raises(IsADirectoryError) as e: + GMMHMM.load('.') + +def test_load_inexistent_path(): + """Load a GMMHMM from an inexistent path""" + with pytest.raises(FileNotFoundError) as e: + GMMHMM.load('test') + +def test_load_invalid_format(): + """Load a GMMHMM from an illegally formatted file""" + try: + with open('test', 'w') as f: + f.write('illegal') + with pytest.raises(json.decoder.JSONDecodeError) as e: + GMMHMM.load('test') + finally: + os.remove('test') + +def test_load_invalid_json(): + """Load a GMMHMM from an invalid JSON file""" + try: + with open('test', 'w') as f: + f.write("{}") + with pytest.raises(KeyError) as e: + GMMHMM.load('test') + finally: + os.remove('test') + +def test_load_path(): + """Load a GMMHMM from a valid JSON file""" + try: + hmm = deepcopy(hmm_diag) + hmm.set_uniform_initial() + hmm.set_uniform_transitions() + before = hmm.initial, hmm.transitions + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + hmm.fit(X) + hmm.save('test') + hmm = GMMHMM.load('test') + + assert isinstance(hmm, GMMHMM) + assert hmm._label == 'c1' + assert hmm._n_states == 5 + assert hmm._n_components == 5 + assert isinstance(hmm._topology, _LeftRightTopology) + assert hmm._covariance == 'diagonal' + assert_not_equal(hmm._initial, before[0]) + assert_not_equal(hmm._transitions, before[1]) + assert hmm._n_seqs == 3 + assert hmm._n_features == 3 + assert isinstance(hmm._model, pg.HiddenMarkovModel) + finally: + os.remove('test') \ No newline at end of file diff --git a/lib/test/lib/classifiers/hmm/test_hmm.py b/lib/test/lib/classifiers/hmm/test_hmm.py index 64a63cb8..f76d908e 100644 --- a/lib/test/lib/classifiers/hmm/test_hmm.py +++ b/lib/test/lib/classifiers/hmm/test_hmm.py @@ -575,6 +575,7 @@ def test_as_dict_fitted(): hmm.fit(X) d = hmm.as_dict() + assert d['type'] == 'HMM' assert d['label'] == 'c1' assert d['n_states'] == 5 assert d['topology'] == 'ergodic' @@ -647,9 +648,10 @@ def test_load_invalid_dict(): def test_load_dict(): """Load a HMM from a valid dict""" - hmm = deepcopy(hmm_e) + hmm = deepcopy(hmm_lr) hmm.set_uniform_initial() hmm.set_uniform_transitions() + before = hmm.initial, hmm.transitions with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) hmm.fit(X) @@ -658,17 +660,9 @@ def test_load_dict(): assert isinstance(hmm, HMM) assert hmm._label == 'c1' assert hmm._n_states == 5 - assert isinstance(hmm._topology, _ErgodicTopology) - assert_equal(hmm._initial, np.array([ - 0.2, 0.2, 0.2, 0.2, 0.2 - ])) - assert_equal(hmm._transitions, np.array([ - [0.2, 0.2, 0.2, 0.2, 0.2], - [0.2, 0.2, 0.2, 0.2, 0.2], - [0.2, 0.2, 0.2, 0.2, 0.2], - [0.2, 0.2, 0.2, 0.2, 0.2], - [0.2, 0.2, 0.2, 0.2, 0.2] - ])) + assert isinstance(hmm._topology, _LeftRightTopology) + assert_not_equal(hmm._initial, before[0]) + assert_not_equal(hmm._transitions, before[1]) assert hmm._n_seqs == 3 assert hmm._n_features == 3 assert isinstance(hmm._model, pg.HiddenMarkovModel) @@ -706,7 +700,7 @@ def test_load_invalid_json(): def test_load_path(): """Load a HMM from a valid JSON file""" try: - hmm = deepcopy(hmm_e) + hmm = deepcopy(hmm_slr) hmm.set_uniform_initial() hmm.set_uniform_transitions() with warnings.catch_warnings(): @@ -718,16 +712,16 @@ def test_load_path(): assert isinstance(hmm, HMM) assert hmm._label == 'c1' assert hmm._n_states == 5 - assert isinstance(hmm._topology, _ErgodicTopology) + assert isinstance(hmm._topology, _StrictLeftRightTopology) assert_equal(hmm._initial, np.array([ - 0.2, 0.2, 0.2, 0.2, 0.2 + 1.00000000e+00, 8.17583139e-17, 3.68732352e-49, 3.20095727e-33, 4.52958070e-34 ])) assert_equal(hmm._transitions, np.array([ - [0.2, 0.2, 0.2, 0.2, 0.2], - [0.2, 0.2, 0.2, 0.2, 0.2], - [0.2, 0.2, 0.2, 0.2, 0.2], - [0.2, 0.2, 0.2, 0.2, 0.2], - [0.2, 0.2, 0.2, 0.2, 0.2] + [0.00000000e+00, 1.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [0.00000000e+00, 6.66666667e-01, 3.33333333e-01, 0.00000000e+00, 0.00000000e+00], + [0.00000000e+00, 0.00000000e+00, 8.84889735e-10, 9.99999999e-01, 0.00000000e+00], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 6.03077553e-01, 3.96922447e-01], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00] ])) assert hmm._n_seqs == 3 assert hmm._n_features == 3