Skip to content

Commit

Permalink
Add the cv, gv, lv, sg and sn functions
Browse files Browse the repository at this point in the history
Public methods for the SS3's ``cv``, ``gv``, ``lv``, ``sg`` and ``sn``
functions have been added to the SS3 class.

These functions were originally defined in Section 3.2.2 of the
original paper: https://arxiv.org/pdf/1905.08772.pdf
  • Loading branch information
sergioburdisso committed Feb 16, 2020
1 parent bc6873c commit ef35b25
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 0 deletions.
130 changes: 130 additions & 0 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,20 @@ def __cv_norm_gv_xai__(self, ngram, icat, cache=True):
except TypeError:
return 0

def __apply_fn__(self, fn, ngram, cat):
"""Private method used by gv, lv, sn, sg functions."""
icat = self.get_category_index(cat)
if icat == IDX_UNKNOWN_CATEGORY:
raise InvalidCategoryError

if ngram.strip() == '':
return 0

ngram = [self.get_word_index(w)
for w in re.split(self.__word_delimiter__, ngram)
if w]
return fn(ngram, icat) if IDX_UNKNOWN_WORD not in ngram else 0

def __classify_ngram__(self, ngram):
"""Classify the given n-gram."""
cv = [
Expand Down Expand Up @@ -2124,6 +2138,122 @@ def predict(
Print.info("finished --time: %.1fs" % (time() - stime), offset=1)
return y_pred

def cv(self, ngram, cat):
"""
Return the "confidence value" of a given word n-gram for the given category.
This value is obtained applying a final transformation on the global value
of the given word n-gram using the gv function [*].
These transformation are given when creating a new SS3 instance (see the
SS3 class constructor's ``cv_m`` argument for more information).
[*] the gv function is defined in Section 3.2.2 of the original paper:
https://arxiv.org/pdf/1905.08772.pdf
Example
>>> clf.cv("chicken", "food")
>>> clf.cv("roast chicken", "food")
>>> clf.cv("chicken", "sports")
:param ngram: the word or word n-gram
:type ngram: str
:param cat: the category label
:type cat: str
:returns: the confidence value
:rtype: float
:raises: InvalidCategoryError
"""
return self.__apply_fn__(self.__cv__, ngram, cat)

def gv(self, ngram, cat):
"""
Return the "global value" of a given word n-gram for the given category.
(gv function is defined in Section 3.2.2 of the original paper:
https://arxiv.org/pdf/1905.08772.pdf)
Example
>>> clf.gv("chicken", "food")
>>> clf.gv("roast chicken", "food")
>>> clf.gv("chicken", "sports")
:param ngram: the word or word n-gram
:type ngram: str
:param cat: the category label
:type cat: str
:returns: the global value
:rtype: float
:raises: InvalidCategoryError
"""
return self.__apply_fn__(self.__gv__, ngram, cat)

def lv(self, ngram, cat):
"""
Return the "local value" of a given word n-gram for the given category.
(lv function is defined in Section 3.2.2 of the original paper:
https://arxiv.org/pdf/1905.08772.pdf)
Example
>>> clf.lv("chicken", "food")
>>> clf.lv("roast chicken", "food")
>>> clf.lv("chicken", "sports")
:param ngram: the word or word n-gram
:type ngram: str
:param cat: the category label
:type cat: str
:returns: the local value
:rtype: float
:raises: InvalidCategoryError
"""
return self.__apply_fn__(self.__lv__, ngram, cat)

def sg(self, ngram, cat):
"""
Return the "significance factor" of a given word n-gram for the given category.
(sg function is defined in Section 3.2.2 of the original paper:
https://arxiv.org/pdf/1905.08772.pdf)
Example
>>> clf.sg("chicken", "food")
>>> clf.sg("roast chicken", "food")
>>> clf.sg("chicken", "sports")
:param ngram: the word or word n-gram
:type ngram: str
:param cat: the category label
:type cat: str
:returns: the significance factor
:rtype: float
:raises: InvalidCategoryError
"""
return self.__apply_fn__(self.__sg__, ngram, cat)

def sn(self, ngram, cat):
"""
Return the "sanction factor" of a given word n-gram for the given category.
(sn function is defined in Section 3.2.2 of the original paper:
https://arxiv.org/pdf/1905.08772.pdf)
Example
>>> clf.sn("chicken", "food")
>>> clf.sn("roast chicken", "food")
>>> clf.sn("chicken", "sports")
:param ngram: the word or word n-gram
:type ngram: str
:param cat: the category label
:type cat: str
:returns: the sanction factor
:rtype: float
:raises: InvalidCategoryError
"""
return self.__apply_fn__(self.__sn__, ngram, cat)


class EmptyModelError(Exception):
"""Exception to be thrown when the model is empty."""
Expand Down
28 changes: 28 additions & 0 deletions tests/test_pyss3.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,20 @@ def perform_tests_with(clf, cv_test, stopwords=True):
clf.set_block_delimiters(parag=PARA_DELTR, sent=SENT_DELTR, word=WORD_DELTR)


def perform_tests_on(fn, value, ngram="chicken", cat="food"):
"""Perform tests on gv, lv, sn, or sg."""
assert round(fn(ngram, cat), 4) == value
assert round(fn("xxx", cat), 4) == 0
assert round(fn("the xxx chicken", cat), 4) == 0
assert round(fn("", cat), 4) == 0
with pytest.raises(pyss3.InvalidCategoryError):
fn("chicken", "xxx")
with pytest.raises(pyss3.InvalidCategoryError):
fn("chicken", "")
with pytest.raises(pyss3.InvalidCategoryError):
fn("", "")


def test_pyss3_functions():
"""Test pyss3 functions."""
assert pyss3.sigmoid(1, 0) == 0
Expand Down Expand Up @@ -242,6 +256,13 @@ def test_pyss3_ss3(mockers):
clf.fit(x_train, y_train)

perform_tests_with(clf, [.00114, .00295, 0, 0, 0, .00016, .01894, 8.47741])
perform_tests_on(clf.cv, 0.4307)
perform_tests_on(clf.gv, 0.2148)
perform_tests_on(clf.lv, 0.2148)
perform_tests_on(clf.sg, 1)
perform_tests_on(clf.sn, 1)
perform_tests_on(clf.cv, 0, "video games", "science&technology")
perform_tests_on(clf.gv, 0, "video games", "science&technology")

# cv_m=STR_NORM_GV, sn_m=STR_XAI
clf = SS3(
Expand All @@ -251,6 +272,7 @@ def test_pyss3_ss3(mockers):
clf.fit(x_train, y_train)

perform_tests_with(clf, [0.00114, 0.00295, 0, 0, 0, 0.00016, 0.01894, 8.47741])
perform_tests_on(clf.cv, 0.4307)

# cv_m=STR_GV, sn_m=STR_XAI
clf = SS3(
Expand All @@ -260,6 +282,7 @@ def test_pyss3_ss3(mockers):
clf.fit(x_train, y_train)

perform_tests_with(clf, [0.00062, 0.00109, 0, 0, 0, 0.00014, 0.01894, 6.31228])
assert clf.cv("chicken", "food") == clf.gv("chicken", "food")

# cv_m=STR_NORM_GV_XAI, sn_m=STR_VANILLA
clf = SS3(
Expand Down Expand Up @@ -288,6 +311,11 @@ def test_pyss3_ss3(mockers):
clf.update_values()

perform_tests_with(clf, [.00074, .00124, 0, 0, 0, .00028, .00202, 9.19105])
perform_tests_on(clf.cv, 1.5664, "video games", "science&technology")
perform_tests_on(clf.gv, 0.6697, "video games", "science&technology")
perform_tests_on(clf.lv, 0.6697, "video games", "science&technology")
perform_tests_on(clf.sg, 1, "video games", "science&technology")
perform_tests_on(clf.sn, 1, "video games", "science&technology")

# n-gram recognition tests
pred = clf.classify("android mobile and video games", json=True)
Expand Down

0 comments on commit ef35b25

Please sign in to comment.