Skip to content

Commit

Permalink
Fix IndexError in classify_(multi)label
Browse files Browse the repository at this point in the history
When ``classify_multilabel`` or ``classify_label`` were called with an
empty document (''), an IndexError exaption was thrown.
  • Loading branch information
sergioburdisso committed Feb 19, 2020
1 parent 3e2e3d5 commit fa91952
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1959,7 +1959,7 @@ def classify_label(self, doc, def_cat=STR_MOST_PROBABLE, labels=True, prep=True)
"""
r = self.classify(doc, sort=True, prep=prep)

if not r[0][1]:
if not r or not r[0][1]:
if not def_cat or def_cat == STR_UNKNOWN:
cat = STR_UNKNOWN_CATEGORY
elif def_cat == STR_MOST_PROBABLE:
Expand Down Expand Up @@ -1999,7 +1999,7 @@ def classify_multilabel(self, doc, def_cat=STR_MOST_PROBABLE, labels=True, prep=
"""
r = self.classify(doc, sort=True, prep=prep)

if not r[0][1]:
if not r or not r[0][1]:
if not def_cat or def_cat == STR_UNKNOWN:
cat = STR_UNKNOWN_CATEGORY
elif def_cat == STR_MOST_PROBABLE:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_pyss3.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def perform_tests_with(clf, cv_test, stopwords=True):
assert clf.classify_label(x_test[0]) == y_test[0]
assert clf.classify_label(x_test[0], labels=False) == clf.get_category_index(y_test[0])

assert clf.classify_label('') == most_prob_cat
assert clf.classify_label('', def_cat=STR_UNKNOWN) == STR_UNKNOWN_CATEGORY
assert clf.classify_label('', def_cat=def_cat) == def_cat

assert clf.classify_label(doc_unknown) == most_prob_cat
assert clf.classify_label(doc_unknown, def_cat=STR_UNKNOWN) == STR_UNKNOWN_CATEGORY
assert clf.classify_label(doc_unknown, def_cat=def_cat) == def_cat
Expand All @@ -156,6 +160,10 @@ def perform_tests_with(clf, cv_test, stopwords=True):
assert len(multilabel_labels) == len(r)
assert r[0] in multilabel_idxs and r[1] in multilabel_idxs

assert clf.classify_multilabel('') == [most_prob_cat]
assert clf.classify_multilabel('', def_cat=STR_UNKNOWN) == [pyss3.STR_UNKNOWN_CATEGORY]
assert clf.classify_multilabel('', def_cat=def_cat) == [def_cat]

assert clf.classify_multilabel(doc_unknown) == [most_prob_cat]
assert clf.classify_multilabel(doc_unknown, def_cat=STR_UNKNOWN) == [pyss3.STR_UNKNOWN_CATEGORY]
assert clf.classify_multilabel(doc_unknown, def_cat=def_cat) == [def_cat]
Expand Down

0 comments on commit fa91952

Please sign in to comment.