Skip to content

Commit

Permalink
Add "available" attribute to recipes (#368)
Browse files Browse the repository at this point in the history
* add available attribute needed by sparsify, and tests

* raise error on missing default
  • Loading branch information
Satrat committed Sep 22, 2023
1 parent f2cf607 commit c6ff218
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/sparsezoo/objects/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
from pathlib import Path
from typing import Dict, List, Optional, Union

from sparsezoo.objects import File
Expand Down Expand Up @@ -49,6 +50,15 @@ def __init__(
if custom_default is not None:
self._default_recipe_name = "recipe_" + custom_default

@property
def available(self) -> Optional[List[str]]:
"""
:return: List of all recipe names, or None if none are available
"""
if len(self._recipes) == 0:
return None
return [Path(recipe.name).stem for recipe in self._recipes]

@property
def recipes(self) -> List:
"""
Expand All @@ -66,9 +76,12 @@ def default(self) -> File:
return recipe

# fallback to first recipe in list
if len(self._recipes) == 0:
raise ValueError("No recipes found, could not retrieve a default.")

_LOGGER.warning(
"No default recipe {self._default_recipe_name} found, falling back to"
"first listed recipe"
f"No default recipe {self._default_recipe_name} found, falling back to "
f"first listed recipe {self._recipes[0].name}"
)
return self._recipes[0]

Expand Down
20 changes: 20 additions & 0 deletions tests/sparsezoo/objects/test_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import tempfile

import pytest

from sparsezoo.model import Model


Expand All @@ -35,6 +37,11 @@ def test_recipe_getters():
found_by_name = model.recipes.get_recipe_by_name("does_not_exist.md")
assert found_by_name is None

available_recipes = model.recipes.available
assert len(available_recipes) == 4
assert "recipe_transfer_token_classification" in available_recipes
assert "recipe" in available_recipes


def test_custom_default():
custom_default_name = "transfer_text_classification"
Expand All @@ -50,3 +57,16 @@ def test_custom_default():

default_recipe = model.recipes.default
assert default_recipe.name == expected_default_name

available_recipes = model.recipes.available
assert len(available_recipes) == 1
assert available_recipes[0] == "recipe_transfer_text_classification"


def test_fail_default_on_empty():
false_recipe_stub = "zoo:bert-base-wikipedia_bookcorpus-pruned90?recipe=nope"
temp_dir = tempfile.TemporaryDirectory(dir="/tmp")
model = Model(false_recipe_stub, temp_dir.name)

with pytest.raises(ValueError):
_ = model.recipes.default

0 comments on commit c6ff218

Please sign in to comment.