Skip to content

Commit

Permalink
[fix] match metric names ignoring separators (#310) (#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran committed May 8, 2023
1 parent c577c8c commit 21e11cd
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions src/sparsezoo/deployment_package/utils/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import logging
from types import MappingProxyType
from typing import Optional
from typing import List, Optional, Union

from sparsezoo import Model

Expand Down Expand Up @@ -75,21 +75,41 @@ def _accuracy(model: Model, metric_name=None) -> float:

if metric_name is not None:
for result in validation_results:
if metric_name in result.recorded_units.lower():
if _metric_name_matches(metric_name, result.recorded_units.lower()):
return result.recorded_value
_LOGGER.info(f"metric name {metric_name} not found for model {model}")

# fallback to if any accuracy metric found
accuracy_metrics = ["accuracy", "f1", "recall", "map", "top1 accuracy"]
for result in validation_results:
if result.recorded_units.lower() in accuracy_metrics:
if _metric_name_matches(result.recorded_units.lower(), accuracy_metrics):
return result.recorded_value

raise ValueError(
f"Could not find any accuracy metric {accuracy_metrics} for model {model}"
)


def _metric_name_matches(
metric_name: str, target_metrics: Union[str, List[str]]
) -> bool:
# returns true if metric name is included in the target metrics
if isinstance(target_metrics, str):
target_metrics = [target_metrics]
return any(
_standardized_str_eq(metric_name, target_metric)
for target_metric in target_metrics
)


def _standardized_str_eq(str_1: str, str_2: str) -> bool:
# strings are equal if lowercase, striped of spaces, -, and _ are equal
def _standardize(string):
return string.lower().replace(" ", "").replace("-", "").replace("_", "")

return _standardize(str_1) == _standardize(str_2)


EXTRACTORS = MappingProxyType(
{
"compression": _size,
Expand Down

0 comments on commit 21e11cd

Please sign in to comment.