Skip to content

Commit

Permalink
Add get_ngrams_length to SS3 class
Browse files Browse the repository at this point in the history
This method can be used to return the length of longest learned n-gram.
  • Loading branch information
sergioburdisso committed Mar 1, 2020
1 parent 6d5d942 commit b4f8827
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
9 changes: 9 additions & 0 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,15 @@ def get_most_probable_category(self):
"""
return self.get_category_name(self.__get_most_probable_category__())

def get_ngrams_length(self):
"""
Return the length of longest learned n-gram.
:returns: the length of longest learned n-gram.
:rtype: int
"""
return len(self.__max_fr__[0]) if len(self.__max_fr__) > 0 else 0

def get_category_index(self, name):
"""
Given its name, return the category index.
Expand Down
5 changes: 5 additions & 0 deletions tests/test_pyss3.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def test_pyss3_ss3(mockers):
assert clf.get_category_index("a_category") == IDX_UNKNOWN_CATEGORY
assert clf.get_category_name(0) == STR_UNKNOWN_CATEGORY
assert clf.get_category_name(-1) == STR_UNKNOWN_CATEGORY
assert clf.get_ngrams_length() == 0

with pytest.raises(pyss3.EmptyModelError):
clf.predict(x_test)
Expand All @@ -263,6 +264,8 @@ def test_pyss3_ss3(mockers):
# cv_m=STR_NORM_GV_XAI, sn_m=STR_XAI
clf.fit(x_train, y_train)

assert clf.get_ngrams_length() == 1

perform_tests_with(clf, [.00114, .00295, 0, 0, 0, .00016, .01894, 8.47741])
perform_tests_on(clf.cv, 0.4307)
perform_tests_on(clf.gv, 0.2148)
Expand Down Expand Up @@ -309,6 +312,8 @@ def test_pyss3_ss3(mockers):

clf.fit(x_train, y_train, n_grams=3)

assert clf.get_ngrams_length() == 3

# update_values
clf.set_l(.3)
clf.update_values()
Expand Down

0 comments on commit b4f8827

Please sign in to comment.