Skip to content

Commit

Permalink
Add multilebel support to the Live Test tool (#9)
Browse files Browse the repository at this point in the history
The Live Test Tool now supports multi-label classification when
``Live_Test.run(x_test, y_test)`` is called with a ``y_test`` with
multiple labels. In addition, now the Live Test can be called without a
``y_test``, for instance, to try out some documents with no labels. That
is, now these two are valid:

```
docs = ["the 1st document", "the 2nd document", "the 3rd document"]

Live_Test.run(clf, docs)  # no labels (y_test)
```

and

```
x_test = ["the 1st document", "the 2nd document", "the 3rd document"]
y_test = [["labelA", "labelB"], ["labelA"], ["labelC", "labelB"]]

Live_Test.run(clf, x_test, y_test)  # multi-label y_test
```

Are supported. The multi-label supports was implemented following the
steps given in #9, currently implementation covers Step 1 and 2.
Document Percentage given in Step 2 were finally computed using the
label-based accuracy (a.k.a the "hamming score").

Resolves: #9
  • Loading branch information
sergioburdisso committed May 19, 2020
1 parent 4ec009a commit 15657ee
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 73 deletions.
1 change: 1 addition & 0 deletions examples/extract_insight.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@
"metadata": {},
"source": [
"<br>\n",
"\n",
"And that's all! is it? want to go a little bit deeper? the following section will show some more advanced features the ``extract_insight`` method has, just in case some of them can be useful to you.\n",
"\n",
"---"
Expand Down
2 changes: 1 addition & 1 deletion examples/movie_review.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@
"metadata": {},
"outputs": [],
"source": [
"clf.set_hyperparameters(s=.44, l=.48, p=0.5)"
"clf.set_hyperparameters(s=.44, l=.48, p=.5)"
]
},
{
Expand Down
17 changes: 3 additions & 14 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
import numpy as np

from io import open
from sys import version_info
from time import time
from tqdm import tqdm
from math import pow, tanh
from .util import Print, VERBOSITY, Preproc as Pp
from .util import is_a_collection, Print, VERBOSITY, Preproc as Pp

# python 2 and 3 compatibility
from functools import reduce
Expand Down Expand Up @@ -59,10 +58,6 @@
NOISE_FR = 1
MIN_MAD_SD = .03

PY2 = version_info[0] == 2
if not PY2:
basestring = None # to avoid the Flake8 "F821 undefined name" error


class SS3:
"""
Expand Down Expand Up @@ -1195,7 +1190,7 @@ def get_a(self):
"""
return self.__a__

def get_categories(self):
def get_categories(self, all=False):
"""
Get the list of category names.
Expand All @@ -1205,7 +1200,7 @@ def get_categories(self):
return [
self.get_category_name(ci)
for ci in range(len(self.__categories__))
if self.get_category_name(ci) != STR_OTHERS_CATEGORY
if all or self.get_category_name(ci) != STR_OTHERS_CATEGORY
]

def get_most_probable_category(self):
Expand Down Expand Up @@ -2630,12 +2625,6 @@ def list_hash(str_list):
return m.hexdigest()


def is_a_collection(o):
"""Return True when the object ``o`` is a collection."""
return hasattr(o, "__getitem__") and ((PY2 and not isinstance(o, basestring)) or
(not PY2 and not isinstance(o, (str, bytes))))


def vsum(v0, v1):
"""Vectorial version of sum."""
return [v0[i] + v1[i] for i in xrange(len(v0))]
Expand Down
20 changes: 10 additions & 10 deletions pyss3/resources/evaluation_plot/model_evaluation.html
Original file line number Diff line number Diff line change
Expand Up @@ -334,18 +334,19 @@ <h3 ng-cloak>
var DOM_INFO_LINE = document.getElementById("info-line");
var DOM_INFO_CIRC = document.getElementById("info-circle");
var DOM_INFO_PANEL = document.getElementById("point-info");
var DATA = null;
var DEF_METRIC = "accuracy";
var GLOBAL_METRIC = ["accuracy", "hamming-loss", "exact-match"];
var INVERSE_METRIC = ["hamming-loss"];
var MULTILABEL = false;
var CM_MAX_SIZE = 50;
var CM_MAX_FOLD_SIZE = 90;
var DATA = null;
var CM_N_CATS = null;
var CM = {// confusion matrix
matrix: null,
p: null, // path
m: null, // method
dc: null // default category
dc: null, // default category
};
var PI = {cmd: {}, code: {}};
var app = angular.module("ss3", ["ngAnimate"]);
Expand Down Expand Up @@ -638,7 +639,7 @@ <h3 ng-cloak>
cmd_eval += hparams;

PI.cmd.cmd_set = 'set ' + hparams;
PI.code.cmd_set = code_set + "Live_test.run(clf, x_test, y_test)";
PI.code.cmd_set = code_set + "Live_Test.run(clf, x_test, y_test)";
PI.cmd.cmd_eval = cmd_eval;
PI.code.cmd_eval = code_set + code_eval;
PI.cmd.cmd_remove = `evaluations remove ${C.p} ${C.m}${def_cat} ` + hparams;
Expand Down Expand Up @@ -961,7 +962,6 @@ <h3 ng-cloak>

function update_confusion_matrix(){
var C = $scope.c, keys = Object.keys;
var cats_n = null;
var cm_size = CM_MAX_FOLD_SIZE;

CM = $results[C.p][C.m][C.dc]["confusion_matrix"];
Expand All @@ -984,8 +984,8 @@ <h3 ng-cloak>

if (!MULTILABEL){
CM[s][l][p][a] = CM[s][l][p][a].map(cm_fold => {
if (cats_n === null){
cats_n = cm_fold.length
if (CM_N_CATS === null){
CM_N_CATS = cm_fold.length
}
return cm_fold.map(r => {
let r_sum = sum(r);
Expand All @@ -997,8 +997,8 @@ <h3 ng-cloak>
}else{
CM[s][l][p][a] = CM[s][l][p][a].map(cm_fold => {
return cm_fold.map(cm => {
if (cats_n === null){
cats_n = cm.length
if (CM_N_CATS === null){
CM_N_CATS = cm.length
}
return cm.map(r => {
let r_sum = sum(r);
Expand Down Expand Up @@ -1030,8 +1030,8 @@ <h3 ng-cloak>
CM.dc = $scope.c.dc;

cm_size /= Math.min(CM["folds"], 3);
// $scope.conf_matrix_cell_size = Math.min(cm_size, (cats_n > 2? CM_MAX_SIZE : 30)) / cats_n;
$scope.conf_matrix_cell_size = Math.min(cm_size, CM_MAX_SIZE) / cats_n;
// $scope.conf_matrix_cell_size = Math.min(cm_size, (CM_N_CATS > 2? CM_MAX_SIZE : 30)) / CM_N_CATS;
$scope.conf_matrix_cell_size = Math.min(cm_size, CM_MAX_SIZE) / CM_N_CATS;
}

$scope.update(true, true);
Expand Down
34 changes: 26 additions & 8 deletions pyss3/resources/live_test/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@
a.sidenav-trigger.top-nav i {
font-size: 32px;
}
.label-ok{
color: #26a69a;
}
.label-nok{
color: #e53935;
}
@media only screen and (max-width: 992px){ header, main, footer {
padding-left: 0;
}}
Expand All @@ -192,17 +198,18 @@
<a href="#" data-target="slide-out" class="top-nav sidenav-trigger waves-effect waves-light circle hide-on-large-only"><i class="material-icons">menu</i></a>
</div>
<ul id="slide-out" class="sidenav sidenav-fixed" ng-show="keys(info.docs).length">
<li><a class="subheader">Test Documents by Category</a></li>
<li><a class="subheader">Test Documents<span ng-if="!info.docs.hasOwnProperty('')">by Category</span></a></li>
<ul class="collapsible">
<li ng-repeat="c in keys(info.docs)" ng-cloak ng-class="{active:keys(info.docs).length == 1}">
<div class="collapsible-header"><i class="material-icons right" ng-style="{'color': get_cat_rgb(info.categories.indexOf(c))}">
style</i>{{c}} ({{info.docs[c].file.length}}) <small class="recall blue-text text-darken-2 right" title="recall (% of hits)" ng-if="info.categories.indexOf(c) != -1">{{get_accuracy(c)|number:1}}%</small>
<div class="collapsible-header" ng-if="!info.docs.hasOwnProperty('')"><i class="material-icons right" ng-style="{'color': get_cat_rgb(info.categories.indexOf(c))}">
style</i>{{c}} ({{info.docs[c].file.length}}) <small class="recall blue-text text-darken-2 right" title="recall (% of hits)" ng-if="info.categories.indexOf(c) != -1">{{get_recall(c)|number:1}}%</small>
</div>
<div class="collapsible-body">
<ul>
<li ng-repeat="f in info.docs[c].file track by $index" ng-class="{'active': $index == i_doc && c == c_doc}">
<a class="waves-effect" href="#" ng-click="get_doc(c, $index)">
{{f}}
{{f}}
<small class="recall text-darken-2 right" title="hamming score (label-based accuracy)" ng-class="get_accuracy_color(info.docs[c]['labels_recall'][$index], $index)" ng-if="multilabel">{{info.docs[c]['labels_recall'][$index] * 100 | number:1}}%</small>
<i class="material-icons right misclass" ng-if="info.categories.indexOf(c) != -1 && get_clf_result(c, $index) != c" title="misclassified as '{{get_clf_result(c, $index)}}'">priority_high</i>
</a>
</li>
Expand Down Expand Up @@ -258,8 +265,19 @@ <h4 class="light" ng-cloak><b>MODEL:</b> {{info.model_name|uppercase}}</h4>
Edit Text<i class="material-icons right">mode_edit</i>
</button>
<div class="pink-text text-darken-2" style="margin-bottom: 0;">
<h5 class="blue-text" ng-if="f_doc && document">Document: <span style="color:black">{{f_doc}}</span> <span style="color:gray">({{c_doc}})</span></h5>
<h5 style="margin-top: 0;">
<h5 class="blue-text" ng-if="f_doc && document">
Document: <span style="color:black">{{f_doc}}</span> <span style="color:gray" ng-if="c_doc">({{c_doc}})</span>
<span ng-if="multilabel" ng-repeat="label in info.docs['']['true_labels'][i_doc]">
<div class="chip pointer" ng-class="{'label-ok': is_cat_active(ss3.ci.indexOf(label)), 'label-nok': !is_cat_active(ss3.ci.indexOf(label))}" ng-click="select_cat(ss3.ci.indexOf(label))">{{label}}</div>
</span>
</h5>
<h5 style="margin-top: 0;" ng-if="multilabel">
<span class="blue-text" id="hashtag">Classification Result: </span>
<span ng-if="multilabel" ng-repeat="cv_i in ss3.cvns" ng-show="is_cat_active(cv_i)">
<div class="chip pointer" ng-class="{'label-ok': is_in_golden_true(cv_i), 'label-nok': !is_in_golden_true(cv_i)}" ng-click="select_cat(cv_i[0])">{{ss3.ci[cv_i[0]]}}</div>
</span>
</h5>
<h5 style="margin-top: 0;" ng-if="!multilabel">
<!-- <span style="color: black" id="hashtag">Main Category: </span> -->
<span class="blue-text" id="hashtag">Classification Result: </span>
<span ng-hide="ss3 || loading">N/A</span>
Expand All @@ -272,13 +290,13 @@ <h5 style="margin-top: 0;">
<span class="pointer pulse" ng-style="{'opacity': ss3.cvns[0][1]}" ng-click="select_cat(ss3.cvns[0][0])">{{ss3.ci[ss3.cvns[0][0]]}}</span>
</span>
</h5>
<h5 ng-if="ss3 && ncats > 2 && is_cat_active(ss3.cvns[1]) && ss3.cvns[0][1]">
<h5 ng-if="!multilabel && ss3 && ncats > 2 && is_cat_active(ss3.cvns[1]) && ss3.cvns[0][1] && ss3.cvns[0][2] >= .75">
<span style="color: black">
<span ng-show="ss3.cvns[1][1] < .25">Hmmm...Just Guessing...</span>
<span ng-show="ss3.cvns[1][1] < .5">Maybe </span>
Also<span ng-show="ss3.cvns[1][1] >= .4">:</span>
</span>
<span style="display:inline-block; padding-right:" ng-repeat="cv_i in ss3.cvns" ng-hide="$first" ng-if="is_cat_active(cv_i) && $index < 3">
<span style="display:inline-block; padding-right:" ng-repeat="cv_i in ss3.cvns" ng-hide="$first" ng-if="is_cat_active(cv_i)">
<span ng-show="$index > 1" class="black-text">,</span>
<span class="pointer" ng-style="{'opacity': cv_i[1]}" ng-click="select_cat(cv_i[0])">{{ss3.ci[cv_i[0]]}}</span>
</span>
Expand Down
25 changes: 22 additions & 3 deletions pyss3/resources/live_test/js/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ app.controller("mainCtrl", function($scope) {
$scope.scat = -1;
$scope.ncats = 0;
$scope.levels = [false, true, true];
$scope.multilabel = false;
$scope.info = null;
$scope.document = '';
$scope.document_original = null;
Expand Down Expand Up @@ -71,7 +72,8 @@ app.controller("mainCtrl", function($scope) {
}

$scope.get_clf_result = function(c, i){
return $scope.info.categories[$scope.info.docs[c]["clf_result"][i]];
return "clf_result" in $scope.info.docs[c]?
$scope.info.categories[$scope.info.docs[c]["clf_result"][i]] : c;
}

$scope.get_doc = function(c, i){
Expand All @@ -92,13 +94,23 @@ app.controller("mainCtrl", function($scope) {
);
}

$scope.get_accuracy = function(c){
$scope.get_recall = function(c){
var hits = $scope.info.docs[c]["file"].map(function(_, i_doc){
return $scope.get_clf_result(c, i_doc) == c;
});
return sum(hits) / $scope.info.docs[c]["file"].length * 100;
}

$scope.get_accuracy_color = function(v, i){
if (i == $scope.i_doc)
return 'white-text';
if (v == 0)
return 'red-text';
if (v == 1)
return 'green-text';
return 'blue-text';
}

$scope.get_cat_rgb = function(icat){
return __rgb_colors__[icat];
}
Expand Down Expand Up @@ -127,6 +139,7 @@ app.controller("mainCtrl", function($scope) {

$scope.clear = function(){
__reset__();
$scope.i_doc = -1;
$scope.document= "";
__update_textarea__();
__goto__(0);
Expand Down Expand Up @@ -186,7 +199,12 @@ app.controller("mainCtrl", function($scope) {
}

$scope.is_cat_active = function (cat_info) {
return active_cats.indexOf(cat_info[0]) != -1;
var icat = Number.isInteger(cat_info)? cat_info : cat_info[0];
return active_cats.indexOf(icat) != -1;
}

$scope.is_in_golden_true = function(cv_i){
return $scope.info.docs['']['true_labels'][$scope.i_doc].indexOf($scope.ss3.ci[cv_i[0]]) != -1;
}

$scope.on_chart_change = function(){
Expand Down Expand Up @@ -464,6 +482,7 @@ app.controller("mainCtrl", function($scope) {
$server.submit("get_info", '', function(data){
$scope.info = data;
var cats = data.categories.slice(0, -1);
$scope.multilabel = ('' in data.docs) && ("true_labels" in data.docs['']);
$scope.ncats = cats.length;
__rgba_colors__ = cats.map(function(_, i){
return "rgba(" + __get_rgb__(i, cats.length) + ',';
Expand Down
Loading

0 comments on commit 15657ee

Please sign in to comment.