Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torchmetrics in baal.metrics #156

Closed
Dref360 opened this issue Oct 4, 2021 · 4 comments · Fixed by #230
Closed

Use torchmetrics in baal.metrics #156

Dref360 opened this issue Oct 4, 2021 · 4 comments · Fixed by #230
Assignees
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@Dref360
Copy link
Member

Dref360 commented Oct 4, 2021

We should swap most of our metrics for torchmetrics this would make our code more robust.

Only ECE is special I think and we should keep our own implementation.

@Dref360 Dref360 added the enhancement New feature or request label Oct 4, 2021
@Dref360 Dref360 added the good first issue Good for newcomers label Feb 3, 2022
@GeorgePearse
Copy link
Collaborator

I'd like to take this if no one's got started on it?

@Dref360
Copy link
Member Author

Dref360 commented Jun 6, 2022

Sounds good, what I was thinking was the following:

  • Add torchmetrics as a dependency (poetry add torchmetrics)
  • Make sure that baal.modelwrapper.ModelWrapper._update_metrics and _reset_metrics works with PL metrics and ours for backward compatibility.
  • Update experiment scripts to match the new PL.

@GeorgePearse GeorgePearse self-assigned this Jun 6, 2022
@Dref360
Copy link
Member Author

Dref360 commented Jun 10, 2022

I just wrote the following Baal metrics that leverage torchmetrics.

We have very similar APIs so I think its a good idea to scrap Baal metrics in favor of torchmetrics. Wyt?

 from typing import List
 
 import numpy as np
 import torch
 from torchmetrics import Accuracy
 from torchmetrics.functional import auc
 
 from baal.utils.metrics import Metrics
 
 
 class AccuracyAUC(Metrics):
     def __init__(self, thresholds: List[float], **kwargs):
         self.thresholds = thresholds
         self._accs = [Accuracy(threshold=th) for th in thresholds]
         super().__init__(**kwargs)
 
     def reset(self):
         for acc in self._accs:
             acc.reset()
 
     def update(self, output=None, target=None):
         for acc in self._accs:
             acc.update(output, target)
 
     @property
     def value(self):
         return auc(torch.FloatTensor(self.thresholds), torch.FloatTensor([acc.compute() for acc in self._accs]),
                    reorder=False).item()
 
 
 random_probs = torch.softmax(torch.randn(100, 10), dim=1)
 random_target = torch.randint(0, 10, (100,))
 
 met = AccuracyAUC(np.linspace(0, 1, 10))
 met.update(random_probs, random_target)
 print(met.value)

@Dref360
Copy link
Member Author

Dref360 commented Jul 10, 2022

Hello, I would have time to work on it this week if you are not already started.

@Dref360 Dref360 assigned Dref360 and unassigned GeorgePearse Aug 27, 2022
Dref360 added a commit that referenced this issue Aug 27, 2022
@Dref360 Dref360 mentioned this issue Aug 27, 2022
3 tasks
Dref360 added a commit that referenced this issue Oct 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
Status: 👀 In review
Development

Successfully merging a pull request may close this issue.

2 participants