Skip to content

Commit

Permalink
Add multilabel support to show_best and get_best
Browse files Browse the repository at this point in the history
Now ``Evaluation.get_best_hyperparameters()`` and
``Evaluation.show_best()`` support the new multi-label classification
metrics
  • Loading branch information
sergioburdisso committed May 15, 2020
1 parent 6c7d8d9 commit ef2419b
Showing 1 changed file with 37 additions and 14 deletions.
51 changes: 37 additions & 14 deletions pyss3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def multilabel_confusion_matrix(*args):
STR_ACCURACY, STR_PRECISION = "accuracy", "precision"
STR_RECALL, STR_F1 = "recall", "f1-score"
STR_HAMMING_LOSS, STR_EXACT_MATCH = "hamming-loss", "exact-match"
GLOBAL_METRICS = [STR_ACCURACY, STR_HAMMING_LOSS, STR_EXACT_MATCH]
METRICS = [STR_PRECISION, STR_RECALL, STR_F1]
EXCP_METRICS = [STR_ACCURACY, STR_HAMMING_LOSS, "confusion_matrix", "categories"]
EXCP_METRICS = GLOBAL_METRICS + ["confusion_matrix", "categories"]
AVGS = ["micro avg", "macro avg", "weighted avg", "samples avg"]

STR_TEST, STR_FOLD = 'test', 'fold'
Expand Down Expand Up @@ -211,7 +212,7 @@ def __cache_save_result__(
if cache["accuracy"]["best"]["value"] == {}:
cache["accuracy"]["best"]["value"] = -1
if hamming_loss is not None:
cache["hamming_loss"]["best"]["value"] = -1
cache["hamming-loss"]["best"]["value"] = -1
for metric, avg in product(METRICS, AVGS):
if avg in report: # scikit-learn > 0.20 does not include 'micro avg' in report
cache[metric][avg]["best"]["value"] = -1
Expand All @@ -223,7 +224,7 @@ def __cache_save_result__(
if cache["accuracy"]["fold_values"][s][l][p][a] == {}:
cache["accuracy"]["fold_values"][s][l][p][a] = [0] * k_fold
if hamming_loss is not None:
cache["hamming_loss"]["fold_values"][s][l][p][a] = [0] * k_fold
cache["hamming-loss"]["fold_values"][s][l][p][a] = [0] * k_fold
cache["confusion_matrix"][s][l][p][a] = [None] * k_fold
for metric, avg in product(METRICS, AVGS):
if avg in report:
Expand All @@ -235,7 +236,7 @@ def __cache_save_result__(
# saving fold results
cache["accuracy"]["fold_values"][s][l][p][a][i_fold] = rf(accuracy)
if hamming_loss is not None:
cache["hamming_loss"]["fold_values"][s][l][p][a][i_fold] = rf(1 - hamming_loss)
cache["hamming-loss"]["fold_values"][s][l][p][a][i_fold] = rf(1 - hamming_loss)
for metric, avg in product(METRICS, AVGS):
if avg in report:
cache[metric][avg]["fold_values"][s][l][p][a][i_fold] = rf(report[avg][metric])
Expand All @@ -258,10 +259,10 @@ def __cache_save_result__(
best_acc["p"], best_acc["a"] = p, a

if hamming_loss is not None:
hamloss_avg = rf(mean(cache["hamming_loss"]["fold_values"][s][l][p][a]))
cache["hamming_loss"]["value"][s][l][p][a] = hamloss_avg
hamloss_avg = rf(mean(cache["hamming-loss"]["fold_values"][s][l][p][a]))
cache["hamming-loss"]["value"][s][l][p][a] = hamloss_avg

best_haml = cache["hamming_loss"]["best"]
best_haml = cache["hamming-loss"]["best"]
if hamloss_avg > best_haml["value"]:
best_haml["value"] = hamloss_avg
best_haml["s"], best_haml["l"] = s, l
Expand Down Expand Up @@ -621,7 +622,7 @@ def __evaluation_result__(
hammingloss = None

if metric == STR_HAMMING_LOSS and not multilabel:
raise ValueError("the '%s' metric is only allowed when in multi-label classification."
raise ValueError("the '%s' metric is only allowed in multi-label classification."
% STR_HAMMING_LOSS)

if not multilabel:
Expand Down Expand Up @@ -939,15 +940,18 @@ def get_best_hyperparameters(
ones matching the last performed evaluation.
Available metrics are: 'accuracy', 'f1-score', 'precision', and
'recall'.
'recall'. In addition, In multi-label classification also 'hamming-loss'
and 'exact-match'
Except for accuracy, a ``metric_target`` option must also be supplied
along with the ``metric`` indicating the target we aim at measuring,
that is, whether we want to measure some averaging performance or the
performance on a particular category.
:param metric: the evaluation metric, options are: 'accuracy', 'f1-score',
'precision', or 'recall' (default: 'accuracy').
'precision', or 'recall'. In addition, In multi-label
classification also 'hamming-loss' and 'exact-match'
(default: 'accuracy').
:type metric: str
:param metric_target: the target we aim at measuring with the given
metric. Options are: 'macro avg', 'micro avg',
Expand All @@ -973,9 +977,11 @@ def get_best_hyperparameters(
if not Evaluation.__clf__:
raise ValueError(ERROR_CNA)

if metric != STR_ACCURACY and metric not in METRICS:
if metric not in METRICS + GLOBAL_METRICS:
raise KeyError(ERROR_NAM % str(metric))

metric = metric if metric != STR_EXACT_MATCH else STR_ACCURACY

l_tag, l_method, l_def_cat = Evaluation.__get_last_evaluation__()
tag, method, def_cat = tag or l_tag, method or l_method, def_cat or l_def_cat
cache = Evaluation.__cache__[tag][method][def_cat]
Expand All @@ -985,7 +991,7 @@ def get_best_hyperparameters(

c_metric = cache[metric]

if metric == STR_ACCURACY:
if metric in GLOBAL_METRICS:
best = c_metric["best"]
else:
if metric_target in AVGS:
Expand Down Expand Up @@ -1014,7 +1020,9 @@ def show_best(tag=None, method=None, def_cat=None, metric=None, avg=None):
'most-probable', 'unknown' or a category label (optional).
:type def_cat: str
:param metric: an evaluation metric, options are: 'accuracy', 'f1-score',
'precision', and 'recall' (optional).
'precision', and 'recall'. In addition, In multi-label
classification also 'hamming-loss' and 'exact-match'
(optional).
:type metric: str
:param avg: an averaging method, options are: 'macro avg', 'micro avg',
and 'weighted avg' (optional).
Expand Down Expand Up @@ -1047,17 +1055,32 @@ def show_best(tag=None, method=None, def_cat=None, metric=None, avg=None):
ps.blue(t), ps.blue(dc)
), end='')

multilabel = STR_HAMMING_LOSS in cache[t][md][dc]

evl = cache[t][md][dc]["accuracy"]["value"]
n_evl = len([
a for s in evl for l in evl[s]
for p in evl[s][l] for a in evl[s][l][p]
])
print("(%d evaluations)" % n_evl)

if multilabel:
best = cache[t][md][dc]["hamming-loss"]["best"]
print(
" Best %s: %s %s" % (
ps.green("hamming loss"),
ps.warning(round_fix(1 - best["value"])),
ps.blue("(s %s l %s p %s a %s)") % (
best["s"], best["l"], best["p"], best["a"]
)
)
)

best = cache[t][md][dc]["accuracy"]["best"]
print(
" Best %s: %s %s" % (
ps.green("accuracy"), ps.warning(best["value"]),
ps.green("accuracy" if not multilabel else "exact match ratio"),
ps.warning(best["value"]),
ps.blue("(s %s l %s p %s a %s)") % (
best["s"], best["l"], best["p"], best["a"]
)
Expand Down

0 comments on commit ef2419b

Please sign in to comment.