diff --git a/pyss3/__init__.py b/pyss3/__init__.py index 6abca2b..1462874 100644 --- a/pyss3/__init__.py +++ b/pyss3/__init__.py @@ -1959,7 +1959,7 @@ def classify_label(self, doc, def_cat=STR_MOST_PROBABLE, labels=True, prep=True) """ r = self.classify(doc, sort=True, prep=prep) - if not r[0][1]: + if not r or not r[0][1]: if not def_cat or def_cat == STR_UNKNOWN: cat = STR_UNKNOWN_CATEGORY elif def_cat == STR_MOST_PROBABLE: @@ -1999,7 +1999,7 @@ def classify_multilabel(self, doc, def_cat=STR_MOST_PROBABLE, labels=True, prep= """ r = self.classify(doc, sort=True, prep=prep) - if not r[0][1]: + if not r or not r[0][1]: if not def_cat or def_cat == STR_UNKNOWN: cat = STR_UNKNOWN_CATEGORY elif def_cat == STR_MOST_PROBABLE: diff --git a/tests/test_pyss3.py b/tests/test_pyss3.py index e5be0cd..e0eb5ca 100644 --- a/tests/test_pyss3.py +++ b/tests/test_pyss3.py @@ -139,6 +139,10 @@ def perform_tests_with(clf, cv_test, stopwords=True): assert clf.classify_label(x_test[0]) == y_test[0] assert clf.classify_label(x_test[0], labels=False) == clf.get_category_index(y_test[0]) + assert clf.classify_label('') == most_prob_cat + assert clf.classify_label('', def_cat=STR_UNKNOWN) == STR_UNKNOWN_CATEGORY + assert clf.classify_label('', def_cat=def_cat) == def_cat + assert clf.classify_label(doc_unknown) == most_prob_cat assert clf.classify_label(doc_unknown, def_cat=STR_UNKNOWN) == STR_UNKNOWN_CATEGORY assert clf.classify_label(doc_unknown, def_cat=def_cat) == def_cat @@ -156,6 +160,10 @@ def perform_tests_with(clf, cv_test, stopwords=True): assert len(multilabel_labels) == len(r) assert r[0] in multilabel_idxs and r[1] in multilabel_idxs + assert clf.classify_multilabel('') == [most_prob_cat] + assert clf.classify_multilabel('', def_cat=STR_UNKNOWN) == [pyss3.STR_UNKNOWN_CATEGORY] + assert clf.classify_multilabel('', def_cat=def_cat) == [def_cat] + assert clf.classify_multilabel(doc_unknown) == [most_prob_cat] assert clf.classify_multilabel(doc_unknown, def_cat=STR_UNKNOWN) == [pyss3.STR_UNKNOWN_CATEGORY] assert clf.classify_multilabel(doc_unknown, def_cat=def_cat) == [def_cat]