Skip to content

Commit

Permalink
Add metrics configuration callback to benchmarking (#346)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinvaidya17 committed Jun 3, 2022
1 parent a996f2f commit f2cf458
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
23 changes: 21 additions & 2 deletions anomalib/utils/sweep/helpers/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
# and limitations under the License.


from typing import List
from typing import List, Union

from omegaconf import DictConfig, ListConfig
from pytorch_lightning import Callback

from anomalib.utils.callbacks import MetricsConfigurationCallback
from anomalib.utils.callbacks.timer import TimerCallback


def get_sweep_callbacks() -> List[Callback]:
def get_sweep_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
"""Gets callbacks relevant to sweep.
Args:
Expand All @@ -32,5 +34,22 @@ def get_sweep_callbacks() -> List[Callback]:
List[Callback]: List of callbacks
"""
callbacks: List[Callback] = [TimerCallback()]
# Add metric configuration to the model via MetricsConfigurationCallback
image_metric_names = config.metrics.image if "image" in config.metrics.keys() else None
pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else None
image_threshold = (
config.metrics.threshold.image_default if "image_default" in config.metrics.threshold.keys() else None
)
pixel_threshold = (
config.metrics.threshold.pixel_default if "pixel_default" in config.metrics.threshold.keys() else None
)
metrics_callback = MetricsConfigurationCallback(
config.metrics.threshold.adaptive,
image_threshold,
pixel_threshold,
image_metric_names,
pixel_metric_names,
)
callbacks.append(metrics_callback)

return callbacks
2 changes: 1 addition & 1 deletion tools/benchmarking/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_single_model_metrics(model_config: Union[DictConfig, ListConfig], openvi
datamodule = get_datamodule(model_config)
model = get_model(model_config)

callbacks = get_sweep_callbacks()
callbacks = get_sweep_callbacks(model_config)

trainer = Trainer(**model_config.trainer, logger=None, callbacks=callbacks)

Expand Down

0 comments on commit f2cf458

Please sign in to comment.