Skip to content

Commit

Permalink
add available attribute needed by sparsify, and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat committed Sep 22, 2023
1 parent f2cf607 commit b2bd3b2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/sparsezoo/objects/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
from pathlib import Path
from typing import Dict, List, Optional, Union

from sparsezoo.objects import File
Expand Down Expand Up @@ -49,6 +50,15 @@ def __init__(
if custom_default is not None:
self._default_recipe_name = "recipe_" + custom_default

@property
def available(self) -> Optional[List[str]]:
"""
:return: List of all recipe names, or None if none are available
"""
if len(self._recipes) == 0:
return None
return [Path(recipe.name).stem for recipe in self._recipes]

@property
def recipes(self) -> List:
"""
Expand All @@ -67,8 +77,8 @@ def default(self) -> File:

# fallback to first recipe in list
_LOGGER.warning(
"No default recipe {self._default_recipe_name} found, falling back to"
"first listed recipe"
f"No default recipe {self._default_recipe_name} found, falling back to "
f"first listed recipe {self._recipes[0].name}"
)
return self._recipes[0]

Expand Down
9 changes: 9 additions & 0 deletions tests/sparsezoo/objects/test_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def test_recipe_getters():
found_by_name = model.recipes.get_recipe_by_name("does_not_exist.md")
assert found_by_name is None

available_recipes = model.recipes.available
assert len(available_recipes) == 4
assert "recipe_transfer_token_classification" in available_recipes
assert "recipe" in available_recipes


def test_custom_default():
custom_default_name = "transfer_text_classification"
Expand All @@ -50,3 +55,7 @@ def test_custom_default():

default_recipe = model.recipes.default
assert default_recipe.name == expected_default_name

available_recipes = model.recipes.available
assert len(available_recipes) == 1
assert available_recipes[0] == "recipe_transfer_text_classification"

0 comments on commit b2bd3b2

Please sign in to comment.