Skip to content

Commit

Permalink
Add explicit multi-label classification support
Browse files Browse the repository at this point in the history
Two methods have been added to the ``SS3`` class. One to perform
multi-label classification (``classify_multilabel``) using k-means
clustering on the confidence vector to select the proper (cluster
with) category labels. Additionally, to be consistent, a second
method (``classify_label``) was added to provide the (standard)
single label classification counterpart.

Thus, calling ``ss3.classify_multilabel(aDocument)`` would return
the list of labels (.g. ["technology", "business"]) to be associated
with this document (``aDocument``) according to SS3. Whereas
``ss3.classify_label(aDocument)`` would return only a single label
(e.g. "technology").
  • Loading branch information
sergioburdisso committed Feb 9, 2020
1 parent 0984f2e commit 0759bca
Showing 1 changed file with 112 additions and 9 deletions.
121 changes: 112 additions & 9 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,11 +1621,11 @@ def classify(self, doc, prep=True, sort=True, json=False):
:type prep: bool
:param sort: sort the classification result (from best to worst)
:type sort: bool
:param json: return the result in JSON format
:param json: return a debugging version of the result in JSON format.
:type json: bool
:returns: the document confidence vector if ``sort`` is False.
If ``sort`` is True, a list of pairs
(category index, confidence value) ordered by cv.
(category index, confidence value) ordered by confidence value.
:rtype: list
"""
if not self.__categories__ or not doc:
Expand Down Expand Up @@ -1687,6 +1687,82 @@ def classify(self, doc, prep=True, sort=True, json=False):
"ci": [self.get_category_name(ic) for ic in xrange(nbr_cats)]
}

def classify_label(self, doc, def_cat=STR_MOST_PROBABLE, labels=True, prep=True):
"""
Classify a given document returning the category label.
:param doc: the content of the document
:type doc: str
:param def_cat: default category to be assigned when SS3 is not
able to classify a document. Options are
"most-probable", "unknown" or a given category name.
(default: "most-probable")
:type def_cat: str
:param labels: whether to return the category label or just the
category index (default: True)
:type labels: bool
:param prep: enables input preprocessing (default: True)
:type prep: bool
:returns: the category label or the category index.
:rtype: str or int
"""
r = self.classify(doc, sort=True, prep=prep)

categories = self.get_categories()

if not def_cat or def_cat == STR_UNKNOWN:
def_cat = len(self.__categories__)
categories.append(STR_UNKNOWN_CATEGORY)
elif def_cat == STR_MOST_PROBABLE:
def_cat = self.__get_most_probable_category__()
else:
def_cat = self.get_category_index(def_cat)

cat_i = r[0][0] if r[0][1] else def_cat

return categories[cat_i] if labels else cat_i

def classify_multilabel(self, doc, def_cat=STR_MOST_PROBABLE, labels=True, prep=True):
"""
Classify a given document returning multiple category labels.
This method could be used to perform multi-label classification. Internally, it
uses k-mean clustering on the confidence vector to select the proper group of
labels.
:param doc: the content of the document
:type doc: str
:param def_cat: default category to be assigned when SS3 is not
able to classify a document. Options are
"most-probable", "unknown" or a given category name.
(default: "most-probable")
:type def_cat: str
:param labels: whether to return the category labels or just the
category indexes (default: True)
:type labels: bool
:param prep: enables input preprocessing (default: True)
:type prep: bool
:returns: the list of category labels (or indexes).
:rtype: list (of str or int)
"""
r = self.classify(doc, sort=True, prep=prep)

categories = self.get_categories()

if not def_cat or def_cat == STR_UNKNOWN:
def_cat = len(self.__categories__)
categories.append(STR_UNKNOWN_CATEGORY)
elif def_cat == STR_MOST_PROBABLE:
def_cat = self.__get_most_probable_category__()
else:
def_cat = self.get_category_index(def_cat)

if not r[0][1]:
return [categories[def_cat]] if labels else [def_cat]

result = [cat_i for cat_i, _ in r[:kmean_multilabel_size(r)]]
return [categories[cat_i] for cat_i in result] if labels else result

def fit(self, x_train, y_train, n_grams=1, prep=True, leave_pbar=True):
"""
Train the model given a list of documents and category labels.
Expand Down Expand Up @@ -1780,7 +1856,7 @@ def predict(
:returns: if ``labels`` is True, the list of category names,
otherwise, the list of category indexes.
:rtype: list (of int or str)
:raises: EmptyModelError
:raises: EmptyModelError, InvalidCategoryError
"""
if not self.__categories__:
raise EmptyModelError
Expand All @@ -1807,12 +1883,7 @@ def predict(
offset=1
)
else:
try:
def_cat = self.get_category_index(def_cat.lower())
except AttributeError:
def_cat = self.get_category_index(def_cat)
# if not def_cat:
# Print.error("Not a valid category")
def_cat = self.get_category_index(def_cat)
Print.info(
"default category was set to '%s'"
%
Expand Down Expand Up @@ -1862,6 +1933,38 @@ def __init__(self, msg=''):
)


def kmean_multilabel_size(res):
"""
Use k-means to tell where to split the ``SS3.classify'''s output.
Given a ``SS3.classify``'s output (``res``), tell where to partition it
into 2 clusters so that one of the cluster holds the category labels that
the classifier should output when performing multi-label classification.
To achieve this, implement k-means (i.e. 2-means) clustering over the
category confidence values in ``res``.
:param res: the classification output of ``SS3.classify``
:type res: list (of sorted pairs (category, confidence value))
:returns: a positive integer indicating where to split ``res``
:rtype: int
"""
cent = {"neg": -1, "pos": -1} # centroids (2 clusters: "pos" and "neg")
clust = {"neg": [], "pos": []} # clusters (2 clusters: "pos" and "neg")
new_cent_neg = res[-1][1]
new_cent_pos = res[0][1]
while (cent["pos"] != new_cent_pos) or (cent["neg"] != new_cent_neg):
cent["neg"], cent["pos"] = new_cent_neg, new_cent_pos
clust["neg"], clust["pos"] = [], []
for _, cat_cv in res:
if abs(cent["neg"] - cat_cv) < abs(cent["pos"] - cat_cv):
clust["neg"].append(cat_cv)
else:
clust["pos"].append(cat_cv)
new_cent_neg = sum(clust["neg"]) / len(clust["neg"])
new_cent_pos = sum(clust["pos"]) / len(clust["pos"])
return len(clust["pos"])


def sigmoid(v, l):
"""A sigmoid function."""
try:
Expand Down

0 comments on commit 0759bca

Please sign in to comment.