-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use torchmetrics
in baal.metrics
#156
Labels
Comments
I'd like to take this if no one's got started on it? |
Sounds good, what I was thinking was the following:
|
I just wrote the following Baal metrics that leverage torchmetrics. We have very similar APIs so I think its a good idea to scrap Baal metrics in favor of torchmetrics. Wyt? from typing import List
import numpy as np
import torch
from torchmetrics import Accuracy
from torchmetrics.functional import auc
from baal.utils.metrics import Metrics
class AccuracyAUC(Metrics):
def __init__(self, thresholds: List[float], **kwargs):
self.thresholds = thresholds
self._accs = [Accuracy(threshold=th) for th in thresholds]
super().__init__(**kwargs)
def reset(self):
for acc in self._accs:
acc.reset()
def update(self, output=None, target=None):
for acc in self._accs:
acc.update(output, target)
@property
def value(self):
return auc(torch.FloatTensor(self.thresholds), torch.FloatTensor([acc.compute() for acc in self._accs]),
reorder=False).item()
random_probs = torch.softmax(torch.randn(100, 10), dim=1)
random_target = torch.randint(0, 10, (100,))
met = AccuracyAUC(np.linspace(0, 1, 10))
met.update(random_probs, random_target)
print(met.value) |
Hello, I would have time to work on it this week if you are not already started. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
We should swap most of our metrics for
torchmetrics
this would make our code more robust.Only ECE is special I think and we should keep our own implementation.
The text was updated successfully, but these errors were encountered: