Skip to content

Commit

Permalink
🗑 Remove compute_on_step argument (#342)
Browse files Browse the repository at this point in the history
  • Loading branch information
djdameln committed Jun 1, 2022
1 parent 4d02e44 commit 379e311
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions anomalib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ def metric_collection_from_names(metric_names: List[str], prefix: Optional[str])
AnomalibMetricCollection: Collection of metrics.
"""
metrics_module = importlib.import_module("anomalib.utils.metrics")
metrics = AnomalibMetricCollection([], prefix=prefix, compute_groups=False)
metrics = AnomalibMetricCollection([], prefix=prefix)
for metric_name in metric_names:
if hasattr(metrics_module, metric_name):
metric_cls = getattr(metrics_module, metric_name)
metrics.add_metrics(metric_cls(compute_on_step=False))
metrics.add_metrics(metric_cls())
elif hasattr(torchmetrics, metric_name):
try:
metric_cls = getattr(torchmetrics, metric_name)
metrics.add_metrics(metric_cls(compute_on_step=False))
metrics.add_metrics(metric_cls())
except TypeError:
warnings.warn(f"Incorrect constructor arguments for {metric_name} metric from TorchMetrics package.")
else:
Expand Down
2 changes: 1 addition & 1 deletion anomalib/utils/metrics/adaptive_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class AdaptiveThreshold(Metric):
def __init__(self, default_value: float = 0.5, **kwargs):
super().__init__(**kwargs)

self.precision_recall_curve = PrecisionRecallCurve(num_classes=1, compute_on_step=False)
self.precision_recall_curve = PrecisionRecallCurve(num_classes=1)
self.add_state("value", default=torch.tensor(default_value), persistent=True) # pylint: disable=not-callable
self.value = torch.tensor(default_value) # pylint: disable=not-callable

Expand Down
2 changes: 1 addition & 1 deletion anomalib/utils/metrics/optimal_f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class OptimalF1(Metric):
def __init__(self, num_classes: int, **kwargs):
super().__init__(**kwargs)

self.precision_recall_curve = PrecisionRecallCurve(num_classes=num_classes, compute_on_step=False)
self.precision_recall_curve = PrecisionRecallCurve(num_classes=num_classes)

self.threshold: torch.Tensor

Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ omegaconf>=2.1.1
opencv-python>=4.5.3.56
pandas>=1.1.0
pytorch-lightning>=1.6.0
torchmetrics==0.8.0
torchmetrics==0.9.0
torchvision>=0.9.1
torchtext>=0.9.1
wandb==0.12.17
Expand Down

0 comments on commit 379e311

Please sign in to comment.