Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROC curves for multi-class (What-If Tool) #2755

Merged
merged 8 commits into from
Oct 11, 2019

Conversation

grovina
Copy link
Contributor

@grovina grovina commented Oct 10, 2019

  • Motivation for features / changes

Show ROC curves for multi-class models.

  • Technical description of changes

For each class, we can plot a ROC curve considering the class in question as the positive one and all the others as negatives (what is called a binarized version of the problem).
To achieve this, we iterate through all examples, populating an object with the following structure:
> model > slice > label > threshold > classification stats
Where the classification data contain the number and rates of true and false positives and negatives. This is done for every value of the threshold (every 1% between 0 to 100%), for every class label, for each feature slice (either each value for categorical features or each interval for numeric features), for every model.
The overall case (not sliced by any feature) is trivially treated as a single slice.

When displaying the ROC curves, we end up with the same case as the ordinary ROC curves from binary classification problems, once for each slice.

  • Screenshots of UI changes

The overall case:
Screenshot from 2019-10-10 14-34-36

The sliced feature case:
Screenshot from 2019-10-10 14-34-08

  • Detailed steps to verify changes work correctly (as executed by you)
    Open performance tab in the iris demo, and check that:
    • the correct number of ROC curves appear (same number as number of classes in the problem)
    • the values and shape of the ROC curve are coherent
    • the curve respects the data

Repeat for sliced features, and also for different numbers of buckets

  • Alternate designs / implementations considered

We could think of averaged ROC curves (like https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html). I personally think they are a bit harder to interpret.

Another possibility is merging all curves into a single plot. Although more compact, it wouldn't be as clear that each curve refers to a distinct version of the problem, and it could also get a bit too crowded with many classes or multiple models.

By considering one binarized problem for each class (the class vs. all
the rest), we build one ROC curve for each class.
This was done for the sliced case first so that the grouped case could
be treated as a simple extention of this one.
The data is stored for each model in `inferenceStats_`, in the key
`allThresholds`.
Extending previous sliced implementation to grouped case by simply
considering it as a single slice with an empty string key.
@grovina
Copy link
Contributor Author

grovina commented Oct 10, 2019

cc @jameswex and @tolga-b

Copy link
Contributor

@jameswex jameswex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the plots need some right padding (4 or 8px, and maybe then you can remove some left padding) otherwise the right-most one can hug the right edge of WIT with no spacing.

plotThresholds,
regenInferenceStats,
true
);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should really do ROC and PR curves for each class just like with binary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok! That should be trivial now ;)

@@ -4854,7 +5015,7 @@ <h2>Show similarity to selected datapoint</h2>
this.featureValueThresholds = [];
this.featureValueThresholds = this.sortFeatureValues(tempArray);

// ROC curves should only exist for the binary case
// ROC curves for the binary case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update comment since also does PR curves

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 also renamed a couple of functions that only made reference to PR.

return this.getRocChartId(index) + '-' + label;
},

getLabelVocab: function(index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call this getLabel or getLabelForIndex for clarity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Picked getLabel to be consistent with the other methods like getRocChartId, getPrChartId.

Just improving clarity.
Analogous to the binary case, using the same data and structures as the
ROC curves.
These methods are being used to determine whether to display ROC and PR
curves, but only make reference to PR curves. Renaming them to something
more generic that can take both (and possibly others) into account.
@jameswex
Copy link
Contributor

thanks @grovina , just the one comment about padding now

@jameswex
Copy link
Contributor

also make sure to run lint

Copy link
Contributor Author

@grovina grovina left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on the padding...

return this.getRocChartId(index) + '-' + label;
},

getLabelVocab: function(index) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Picked getLabel to be consistent with the other methods like getRocChartId, getPrChartId.

@@ -4854,7 +5015,7 @@ <h2>Show similarity to selected datapoint</h2>
this.featureValueThresholds = [];
this.featureValueThresholds = this.sortFeatureValues(tempArray);

// ROC curves should only exist for the binary case
// ROC curves for the binary case
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 also renamed a couple of functions that only made reference to PR.

plotThresholds,
regenInferenceStats,
true
);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok! That should be trivial now ;)

For multi-class models, let's put the curves for each class in a
separated line, so that it becomes easier to follow.

For both binary and multi-class models, we adjust the margins, position
of axis labels and centralize the plots.

I also renamed and cleaned up some CSS classes on the way.
@grovina
Copy link
Contributor Author

grovina commented Oct 11, 2019

Adjusted some CSS and fixed lint. Here screenshots for:

  • the multi-class case:
    Screenshot from 2019-10-11 17-38-12

  • the binary case:
    Screenshot from 2019-10-11 17-37-53

@jameswex jameswex merged commit 2de6882 into tensorflow:master Oct 11, 2019
grovina added a commit to grovina/tensorboard that referenced this pull request Oct 11, 2019
This was added by mistake in tensorflow#2755.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants