diff --git a/src/sparsezoo/objects/recipes.py b/src/sparsezoo/objects/recipes.py index 12f19e4a..d3f4b523 100644 --- a/src/sparsezoo/objects/recipes.py +++ b/src/sparsezoo/objects/recipes.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from pathlib import Path from typing import Dict, List, Optional, Union from sparsezoo.objects import File @@ -49,6 +50,15 @@ def __init__( if custom_default is not None: self._default_recipe_name = "recipe_" + custom_default + @property + def available(self) -> Optional[List[str]]: + """ + :return: List of all recipe names, or None if none are available + """ + if len(self._recipes) == 0: + return None + return [Path(recipe.name).stem for recipe in self._recipes] + @property def recipes(self) -> List: """ @@ -66,9 +76,12 @@ def default(self) -> File: return recipe # fallback to first recipe in list + if len(self._recipes) == 0: + raise ValueError("No recipes found, could not retrieve a default.") + _LOGGER.warning( - "No default recipe {self._default_recipe_name} found, falling back to" - "first listed recipe" + f"No default recipe {self._default_recipe_name} found, falling back to " + f"first listed recipe {self._recipes[0].name}" ) return self._recipes[0] diff --git a/tests/sparsezoo/objects/test_recipes.py b/tests/sparsezoo/objects/test_recipes.py index 8a2a005e..264468c8 100644 --- a/tests/sparsezoo/objects/test_recipes.py +++ b/tests/sparsezoo/objects/test_recipes.py @@ -14,6 +14,8 @@ import tempfile +import pytest + from sparsezoo.model import Model @@ -35,6 +37,11 @@ def test_recipe_getters(): found_by_name = model.recipes.get_recipe_by_name("does_not_exist.md") assert found_by_name is None + available_recipes = model.recipes.available + assert len(available_recipes) == 4 + assert "recipe_transfer_token_classification" in available_recipes + assert "recipe" in available_recipes + def test_custom_default(): custom_default_name = "transfer_text_classification" @@ -50,3 +57,16 @@ def test_custom_default(): default_recipe = model.recipes.default assert default_recipe.name == expected_default_name + + available_recipes = model.recipes.available + assert len(available_recipes) == 1 + assert available_recipes[0] == "recipe_transfer_text_classification" + + +def test_fail_default_on_empty(): + false_recipe_stub = "zoo:bert-base-wikipedia_bookcorpus-pruned90?recipe=nope" + temp_dir = tempfile.TemporaryDirectory(dir="/tmp") + model = Model(false_recipe_stub, temp_dir.name) + + with pytest.raises(ValueError): + _ = model.recipes.default