Skip to content

Commit

Permalink
Add def_cat argument to Live_Test.run (#9)
Browse files Browse the repository at this point in the history
Now the Live Test Tool for multi-label classification can be lunched
with a default category label to be assigned when no tags could be
predicted by SS3. To do that, the ``def_cat`` argument was added,
therefore, ``Live_Test.run`` can now be called as follows:

```
Live_Test.run(clf, x_test, y_test, def_cat="most-probable")
Live_Test.run(clf, x_test, y_test, def_cat="aCategoryLabel")
Live_Test.run(clf, x_test, y_test)  # default, predict no label ([])
```
  • Loading branch information
sergioburdisso committed May 20, 2020
1 parent 5f5c055 commit b617bb7
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 20 deletions.
10 changes: 10 additions & 0 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,16 @@ def __get_category_vocab__(self, icat):
)
return sorted(vocab_icat, key=lambda k: -k[-1])

def __get_def_cat__(self, def_cat):
"""Given the `def_cat` argument, get the default category value."""
if def_cat is not None and (def_cat not in [STR_MOST_PROBABLE, STR_UNKNOWN] and
self.get_category_index(def_cat) == IDX_UNKNOWN_CATEGORY):
raise ValueError(
"the default category must be 'most-probable', 'unknown', or a category name."
)
def_cat = None if def_cat == STR_UNKNOWN else def_cat
return self.get_most_probable_category() if def_cat == STR_MOST_PROBABLE else def_cat

def __get_next_iwords__(self, sent, icat):
"""Return the list of possible following words' indexes."""
if not self.get_category_name(icat):
Expand Down
7 changes: 5 additions & 2 deletions pyss3/resources/live_test/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ <h5 style="margin-top: 0;" ng-if="multilabel">
<span ng-if="multilabel" ng-repeat="cv_i in ss3.cvns" ng-show="is_cat_active(cv_i)">
<div class="chip pointer" ng-class="{'label-ok': is_in_golden_true(cv_i), 'label-nok': !is_in_golden_true(cv_i)}" ng-click="select_cat(cv_i[0])">{{ss3.ci[cv_i[0]]}}</div>
</span>
<span ng-if="get_n_active_cats() == 0">N/A</span>
<span ng-if="get_n_active_cats() == 0">
<span ng-if="!info.def_cat">N/A</span>
<div class="chip pointer" ng-if="info.def_cat" ng-class="{'label-ok': is_in_golden_true(info.def_cat), 'label-nok': !is_in_golden_true(info.def_cat)}" ng-click="select_cat(ss3.ci.indexOf(info.def_cat))">{{info.def_cat}}</div>
</span>
</h5>
<h5 style="margin-top: 0;" ng-if="!multilabel">
<!-- <span style="color: black" id="hashtag">Main Category: </span> -->
Expand Down Expand Up @@ -375,7 +378,7 @@ <h5 class="blue-text text-darken-2">Topic: </h5>
</div>
<div ng-repeat="cv_i in ss3.cvns" class="menu_item waves-effect waves-red" ng-click="select_cat_menu(cv_i[0])" ng-class="{active:scat==cv_i[0]}" style="display: block">
<a href="" class="menu_item" ng-class="{'disabled':!is_cat_active(cv_i)}" ng-style="{'border-color':get_cat_rgb(cv_i[0])}" style="border-width: 7px;">
{{ss3.ci[cv_i[0]]|uppercase}} <label title="confidence value">({{cv_i[2]|number:1}}cv)</label>
{{ss3.ci[cv_i[0]]|uppercase}} <label title="confidence value">({{cv_i[2]|number:2}}cv)</label>
</a>
</div>
</div>
Expand Down
17 changes: 13 additions & 4 deletions pyss3/resources/live_test/js/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ app.controller("mainCtrl", function($scope) {
}

$scope.get_n_active_cats = function(){
if ($scope.ss3 == null) return;
var c = 0;
for (let i=0; i < $scope.ss3.cvns.length; i++){
if ($scope.is_cat_active($scope.ss3.cvns[i]))
Expand All @@ -210,12 +211,20 @@ app.controller("mainCtrl", function($scope) {
}

$scope.is_cat_active = function (cat_info) {
var icat = Number.isInteger(cat_info)? cat_info : cat_info[0];
return active_cats.indexOf(icat) != -1;
if (!$scope.ss3 || !cat_info)
return false;

if (active_cats.length == 0){
return $scope.info.def_cat && $scope.ss3.ci[cat_info] == $scope.info.def_cat;
}else{
var icat = Number.isInteger(cat_info)? cat_info : cat_info[0];
return active_cats.indexOf(icat) != -1;
}
}

$scope.is_in_golden_true = function(cv_i){
return $scope.info.docs['']['true_labels'][$scope.i_doc].indexOf($scope.ss3.ci[cv_i[0]]) != -1;
$scope.is_in_golden_true = function(l){
l = (typeof l === 'string' || l instanceof String)? l : $scope.ss3.ci[l[0]];
return $scope.info.docs['']['true_labels'][$scope.i_doc].indexOf(l) != -1;
}

$scope.on_chart_change = function(){
Expand Down
43 changes: 31 additions & 12 deletions pyss3/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class Server:
__folder_label__ = None
__preprocess__ = None
__default_prep__ = None
__default_cat__ = None

@staticmethod
def __send_as_json__(sock, data):
Expand Down Expand Up @@ -251,7 +252,8 @@ def __do_get_info__(sock):
"model_name": clf.get_name(),
"hps": clf.get_hyperparameters(),
"categories": clf.get_categories(all=True) + ["[unknown]"],
"docs": Server.__docs__
"docs": Server.__docs__,
"def_cat": Server.__default_cat__
})
Print.info("sending classifier info...")

Expand Down Expand Up @@ -368,19 +370,28 @@ def set_model(clf):
Server.__clear_testset__()

@staticmethod
def set_testset(x_test, y_test=None):
def set_testset(x_test, y_test=None, def_cat=None):
"""
Assign the test set to visualize.
:param x_test: the list of documents to classify and visualize
:type x_test: list (of str)
:param y_label: the list of category labels
:type y_label: list (of str)
:param def_cat: default category to be assigned when SS3 is not
able to classify a document. Options are
"most-probable", "unknown" or a given category name.
(default: "most-probable", or "unknown" for
multi-label classification)
:type def_cat: str
:raises: ValueError
"""
clf = Server.__clf__

Server.__clear_testset__()
Server.__x_test__ = x_test
Server.__default_cat__ = clf.__get_def_cat__(def_cat)

clf = Server.__clf__
classify = clf.classify
docs = Server.__docs__
unkwon_cat_i = len(Server.__clf__.get_categories())
Expand Down Expand Up @@ -415,13 +426,13 @@ def set_testset(x_test, y_test=None):
)

if multilabel:
y_pred = [[ci for ci, _ in r[:kmean_multilabel_size(r)]]
for r in docs[y_test[0]]["clf_result"]]
if Server.__default_cat__ is not None:
y_pred = [labels if labels else [clf.get_category_index(Server.__default_cat__)]
for labels in y_pred]
t = membership_matrix(clf, y_test_labels).todense()
p = membership_matrix(
clf,
[[ci for ci, _ in r[:kmean_multilabel_size(r)]]
for r in docs[y_test[0]]["clf_result"]],
labels=False
).todense()
p = membership_matrix(clf, y_pred, labels=False).todense()
accuracy = (t & p).sum(axis=1) / (t | p).sum(axis=1)
accuracy[np.isnan(accuracy)] = 1
docs[y_test[0]]["true_labels"] = y_test_labels
Expand Down Expand Up @@ -502,7 +513,7 @@ def start_listening(port=0):
@staticmethod
def serve(
clf=None, x_test=None, y_test=None, port=0, browser=True,
quiet=True, prep=True, prep_func=None
quiet=True, prep=True, prep_func=None, def_cat=None
):
"""
Wait for classification requests and serve them.
Expand All @@ -529,8 +540,16 @@ def serve(
If not given, the default preprocessing function will
be used
:type prep_func: function
:param def_cat: default category to be assigned when SS3 is not
able to classify a document. Options are
"most-probable", "unknown" or a given category name.
(default: "most-probable", or "unknown" for
multi-label classification)
:type def_cat: str
:raises: ValueError
"""
Server.__clf__ = clf or Server.__clf__
clf = clf or Server.__clf__
Server.__clf__ = clf
Server.__preprocess__ = prep_func
Server.__default_prep__ = prep

Expand All @@ -543,7 +562,7 @@ def serve(

if x_test is not None:
if y_test is None or len(y_test) == len(x_test):
Server.set_testset(x_test, y_test)
Server.set_testset(x_test, y_test, def_cat)
else:
Print.error("y_test must have the same length as x_test")
return
Expand Down
10 changes: 8 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def mockers(mocker):
"parse_args").return_value = MockCmdLineArgs


@pytest.fixture(params=[0, 1, 2, 3, 4, 5])
@pytest.fixture(params=[0, 1, 2, 3, 4, 5, 6, 7])
def test_case(request, mocker):
"""Argument values generator for test_live_test(test_case)."""
mocker.patch("webbrowser.open")
Expand Down Expand Up @@ -140,6 +140,12 @@ def test_live_test(test_case):
serve_args["y_test"] += [["labelC"]] * (len(x_train) // 2)
elif test_case == 5:
serve_args["y_test"] = ["label"]
elif test_case == 6:
serve_args["y_test"] = y_train
serve_args["def_cat"] = 'most-probable'
elif test_case == 7:
serve_args["y_test"] = y_train
serve_args["def_cat"] = 'xxxxx' # raise ValueError

if PYTHON3:
threading.Thread(target=LT.serve, kwargs=serve_args, daemon=True).start()
Expand Down Expand Up @@ -197,7 +203,7 @@ def test_live_test(test_case):
assert "<html>" in r


def test_main(mockers):
def test_main(mockers, mocker):
"""Test the main() function."""
if not PYTHON3:
return
Expand Down
3 changes: 3 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,12 @@ def test_evaluation(mocker):
assert kfold_validation(clf_ml, x_data_ml, y_data_ml, plot=PY3) > 0
assert kfold_validation(clf, x_data, y_data, plot=PY3) > 0
s, l, p, a = clf.get_hyperparameters()
s, l, p, a = clf.get_hyperparameters()
s0, l0, p0, a0 = Evaluation.grid_search(clf, x_data, y_data)
s1, l1, p1, a1 = Evaluation.get_best_hyperparameters()
s2, l2, p2, a2 = Evaluation.get_best_hyperparameters("recall")
Evaluation.__last_eval_tag__ = None
s1, l1, p1, a1 = Evaluation.get_best_hyperparameters()
assert s0 == s and l0 == l and p0 == p and a0 == a
assert s0 == s1 and l0 == l1 and p0 == p1 and a0 == a1
assert s0 == s2 and l0 == l2 and p0 == p2 and a0 == a2
Expand Down

0 comments on commit b617bb7

Please sign in to comment.