-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ROC curves for multi-class (What-If Tool) #2755
Conversation
By considering one binarized problem for each class (the class vs. all the rest), we build one ROC curve for each class. This was done for the sliced case first so that the grouped case could be treated as a simple extention of this one. The data is stored for each model in `inferenceStats_`, in the key `allThresholds`.
Extending previous sliced implementation to grouped case by simply considering it as a single slice with an empty string key.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think the plots need some right padding (4 or 8px, and maybe then you can remove some left padding) otherwise the right-most one can hug the right edge of WIT with no spacing.
plotThresholds, | ||
regenInferenceStats, | ||
true | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should really do ROC and PR curves for each class just like with binary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok! That should be trivial now ;)
@@ -4854,7 +5015,7 @@ <h2>Show similarity to selected datapoint</h2> | |||
this.featureValueThresholds = []; | |||
this.featureValueThresholds = this.sortFeatureValues(tempArray); | |||
|
|||
// ROC curves should only exist for the binary case | |||
// ROC curves for the binary case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update comment since also does PR curves
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 also renamed a couple of functions that only made reference to PR.
return this.getRocChartId(index) + '-' + label; | ||
}, | ||
|
||
getLabelVocab: function(index) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call this getLabel or getLabelForIndex for clarity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Picked getLabel
to be consistent with the other methods like getRocChartId
, getPrChartId
.
Just improving clarity.
Analogous to the binary case, using the same data and structures as the ROC curves.
These methods are being used to determine whether to display ROC and PR curves, but only make reference to PR curves. Renaming them to something more generic that can take both (and possibly others) into account.
thanks @grovina , just the one comment about padding now |
also make sure to run lint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Working on the padding...
return this.getRocChartId(index) + '-' + label; | ||
}, | ||
|
||
getLabelVocab: function(index) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Picked getLabel
to be consistent with the other methods like getRocChartId
, getPrChartId
.
@@ -4854,7 +5015,7 @@ <h2>Show similarity to selected datapoint</h2> | |||
this.featureValueThresholds = []; | |||
this.featureValueThresholds = this.sortFeatureValues(tempArray); | |||
|
|||
// ROC curves should only exist for the binary case | |||
// ROC curves for the binary case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 also renamed a couple of functions that only made reference to PR.
plotThresholds, | ||
regenInferenceStats, | ||
true | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok! That should be trivial now ;)
For multi-class models, let's put the curves for each class in a separated line, so that it becomes easier to follow. For both binary and multi-class models, we adjust the margins, position of axis labels and centralize the plots. I also renamed and cleaned up some CSS classes on the way.
This was added by mistake in tensorflow#2755.
Show ROC curves for multi-class models.
For each class, we can plot a ROC curve considering the class in question as the positive one and all the others as negatives (what is called a binarized version of the problem).
To achieve this, we iterate through all examples, populating an object with the following structure:
> model > slice > label > threshold > classification stats
Where the classification data contain the number and rates of true and false positives and negatives. This is done for every value of the threshold (every 1% between 0 to 100%), for every class label, for each feature slice (either each value for categorical features or each interval for numeric features), for every model.
The overall case (not sliced by any feature) is trivially treated as a single slice.
When displaying the ROC curves, we end up with the same case as the ordinary ROC curves from binary classification problems, once for each slice.
The overall case:
The sliced feature case:
Open performance tab in the iris demo, and check that:
Repeat for sliced features, and also for different numbers of buckets
We could think of averaged ROC curves (like https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html). I personally think they are a bit harder to interpret.
Another possibility is merging all curves into a single plot. Although more compact, it wouldn't be as clear that each curve refers to a distinct version of the problem, and it could also get a bit too crowded with many classes or multiple models.