diff --git a/lib/sequentia/classifiers/hmm/hmm.py b/lib/sequentia/classifiers/hmm/hmm.py index 2bdb6e09..81825280 100644 --- a/lib/sequentia/classifiers/hmm/hmm.py +++ b/lib/sequentia/classifiers/hmm/hmm.py @@ -46,11 +46,9 @@ def __init__(self, label, n_states, topology='left-right', random_state=None): n_states, lambda x: x > 0, desc='number of states', expected='greater than zero') self._val.one_of(topology, ['ergodic', 'left-right', 'strict-left-right'], desc='topology') self._random_state = self._val.random_state(random_state) - self._topology = { - 'ergodic': _ErgodicTopology, - 'left-right': _LeftRightTopology, - 'strict-left-right': _StrictLeftRightTopology - }[topology](self._n_states, self._random_state) + self._topologies = {'ergodic': _ErgodicTopology, 'left-right': _LeftRightTopology, 'strict-left-right': _StrictLeftRightTopology} + self._topologies.update(dict([reversed(i) for i in self._topologies.items()])) + self._topology = self._topologies[topology](self._n_states, self._random_state) def set_uniform_initial(self): """Sets a uniform initial state distribution.""" @@ -198,7 +196,7 @@ def as_dict(self): return { 'label': self._label, 'n_states': self._n_states, - 'topology': 'ergodic' if isinstance(self._topology, _ErgodicTopology) else 'left-right', + 'topology': self._topologies[self._topology.__class__], 'model': { 'initial': self._initial.tolist(), 'transitions': self._transitions.tolist(),