Skip to content

Commit

Permalink
Change the Live Test multi-label classification
Browse files Browse the repository at this point in the history
I've changed way in which the "Live Test" tool selects the categories.
The previous method was ad hoc and didn't produce good results in
general. For the one version, I decided to implement the K-Means
algorithm to divide the categories into two clusters, one for
"select as classified" and another for "ignore". The key idea is to
partitioned the categories according to the obtained confidence value.
  • Loading branch information
sergioburdisso committed Feb 8, 2020
1 parent f78dd9f commit 046f9f4
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions pyss3/resources/visual_classifier/js/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,7 @@ app.controller("mainCtrl", function($scope) {
}

$scope.is_cat_active = function (cat_info) {
return ($scope.ss3.cvns[0][2] < 1 && cat_info[2] >= .75) ||
($scope.ss3.cvns[0][2] > 1 && cat_info[1] >= .35 && cat_info[2] > 1);
return active_cats.indexOf(cat_info[0]) != -1;
}

$scope.need_space = function(words, i, sents, j){
Expand Down Expand Up @@ -234,16 +233,11 @@ app.controller("mainCtrl", function($scope) {
$scope.ss3 = data;
$scope.scat = -1;
active_cats = [];
for (var i in data.cvns){
var cvn = data.cvns[i];
if (cvn[1] >= .35 && cvn[2] > 1){
active_cats.push(cvn[0]);
}
}
if (active_cats.length < 2){
if (active_cats.length == 0)
active_cats.push(data.cvns[0][0]);
if (data.cvns.length == 2){
active_cats.push(data.cvns[0][0]);
active_cats.push(data.cvns[1][0]);
}else{
active_cats = __k_means_active_cat__(data.cvns);
}
__create_chart_values__();

Expand All @@ -269,6 +263,32 @@ app.controller("mainCtrl", function($scope) {
);
}

function __k_means_active_cat__(cats){
var cent = {neg: -1, pos: -1}; // centroids (one for each "group")
var clust = {neg: [], pos: []}; // clusters (one for each "group")
var new_cent_neg = cats[cats.length - 1][2];
var new_cent_pos = cats[0][2];
var active_cats = null;
while (cent.pos != new_cent_pos || cent.neg != new_cent_neg){
cent.neg = new_cent_neg;
cent.pos = new_cent_pos;
clust.neg = []; clust.pos = [];
active_cats = [];
for (var i=0, cat_cv; i < cats.length; i++){
cat_cv = cats[i][2];
if (Math.abs(cent.neg - cat_cv) < Math.abs(cent.pos - cat_cv)){
clust.neg.push(cat_cv);
}else{
clust.pos.push(cat_cv);
active_cats.push(cats[i][0]);
}
}
new_cent_neg = clust.neg.reduce((a,b) => a + b, 0) / clust.neg.length;
new_cent_pos = clust.pos.reduce((a,b) => a + b, 0) / clust.pos.length;
}
return active_cats;
}

function __create_chart_values__(){
var $ss3 = $scope.ss3;
var $chart = $scope.chart;
Expand Down

0 comments on commit 046f9f4

Please sign in to comment.