Skip to content

Commit

Permalink
Add multi-label classification to predict
Browse files Browse the repository at this point in the history
Now ``SS3.predict``' takes an extra argument, ``multilabel``. When
``multilabel=True`` (False by default), ``predict`` performs a
multi-label classification on the given documents, and thus, for
each document returns the list of labels (instead of a single label).
  • Loading branch information
sergioburdisso committed Feb 9, 2020
1 parent bf313ba commit c5ac946
Showing 1 changed file with 11 additions and 28 deletions.
39 changes: 11 additions & 28 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,7 +1837,7 @@ def predict_proba(self, x_test, prep=True, leave_pbar=True):

def predict(
self, x_test, def_cat=STR_MOST_PROBABLE,
labels=True, prep=True, leave_pbar=True
labels=True, multilabel=False, prep=True, leave_pbar=True
):
"""
Classify a list of documents.
Expand All @@ -1851,6 +1851,10 @@ def predict(
:param labels: whether to return the list of category names or just
category indexes
:type labels: bool
:param multilabel: whether to perform multi-label classification or not.
if enabled, for each document returns a ``list`` of labels
instead of a single label (``str``).
:type multilabel: bool
:param prep: enables input preprocessing (default: True)
:type prep: bool
:param leave_pbar: controls whether to leave the progress bar or
Expand All @@ -1864,52 +1868,31 @@ def predict(
if not self.__categories__:
raise EmptyModelError

categories = self.get_categories()

if not def_cat or def_cat == STR_UNKNOWN:
def_cat = len(self.__categories__)
categories.append(STR_UNKNOWN_CATEGORY)
Print.info(
"default category was set to 'unknown' (its index will be %d)"
%
def_cat,
% self.get_category_index(STR_UNKNOWN_CATEGORY),
offset=1
)
else:
if def_cat == STR_MOST_PROBABLE:
def_cat = self.__get_most_probable_category__()
Print.info(
"default category was automatically set to '%s' "
"(the most probable one)"
%
self.get_category_name(def_cat),
"(the most probable one)" % self.get_most_probable_category(),
offset=1
)
else:
def_cat = self.get_category_index(def_cat)
Print.info(
"default category was set to '%s'"
%
self.get_category_name(def_cat),
offset=1
)
Print.info("default category was set to '%s'" % def_cat, offset=1)

stime = time()
Print.info("about to start classifying test documents", offset=1)
classify = self.classify
classify = self.classify_label if not multilabel else self.classify_multilabel

y_pred = [
r[0][0] if r[0][1] else def_cat
for r in [
classify(doc, prep=prep)
for doc in tqdm(
x_test, desc=" Classification", leave=leave_pbar
)
]
classify(doc, def_cat=def_cat, labels=labels, prep=prep)
for doc in tqdm(x_test, desc=" Classification", leave=leave_pbar)
]

if labels:
y_pred[:] = [categories[y] for y in y_pred]
Print.info("finished --time: %.1fs" % (time() - stime), offset=1)
return y_pred

Expand Down

0 comments on commit c5ac946

Please sign in to comment.