From b2bd3b237f64b303d9e980331f36abfaaf4503b7 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 22 Sep 2023 15:25:27 -0400 Subject: [PATCH 1/2] add available attribute needed by sparsify, and tests --- src/sparsezoo/objects/recipes.py | 14 ++++++++++++-- tests/sparsezoo/objects/test_recipes.py | 9 +++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/sparsezoo/objects/recipes.py b/src/sparsezoo/objects/recipes.py index 12f19e4a..6f81376e 100644 --- a/src/sparsezoo/objects/recipes.py +++ b/src/sparsezoo/objects/recipes.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from pathlib import Path from typing import Dict, List, Optional, Union from sparsezoo.objects import File @@ -49,6 +50,15 @@ def __init__( if custom_default is not None: self._default_recipe_name = "recipe_" + custom_default + @property + def available(self) -> Optional[List[str]]: + """ + :return: List of all recipe names, or None if none are available + """ + if len(self._recipes) == 0: + return None + return [Path(recipe.name).stem for recipe in self._recipes] + @property def recipes(self) -> List: """ @@ -67,8 +77,8 @@ def default(self) -> File: # fallback to first recipe in list _LOGGER.warning( - "No default recipe {self._default_recipe_name} found, falling back to" - "first listed recipe" + f"No default recipe {self._default_recipe_name} found, falling back to " + f"first listed recipe {self._recipes[0].name}" ) return self._recipes[0] diff --git a/tests/sparsezoo/objects/test_recipes.py b/tests/sparsezoo/objects/test_recipes.py index 8a2a005e..3e48abe5 100644 --- a/tests/sparsezoo/objects/test_recipes.py +++ b/tests/sparsezoo/objects/test_recipes.py @@ -35,6 +35,11 @@ def test_recipe_getters(): found_by_name = model.recipes.get_recipe_by_name("does_not_exist.md") assert found_by_name is None + available_recipes = model.recipes.available + assert len(available_recipes) == 4 + assert "recipe_transfer_token_classification" in available_recipes + assert "recipe" in available_recipes + def test_custom_default(): custom_default_name = "transfer_text_classification" @@ -50,3 +55,7 @@ def test_custom_default(): default_recipe = model.recipes.default assert default_recipe.name == expected_default_name + + available_recipes = model.recipes.available + assert len(available_recipes) == 1 + assert available_recipes[0] == "recipe_transfer_text_classification" From b5654138fc13dcf76089ab082e7a8a530924ec81 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 22 Sep 2023 15:41:58 -0400 Subject: [PATCH 2/2] raise error on missing default --- src/sparsezoo/objects/recipes.py | 3 +++ tests/sparsezoo/objects/test_recipes.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/src/sparsezoo/objects/recipes.py b/src/sparsezoo/objects/recipes.py index 6f81376e..d3f4b523 100644 --- a/src/sparsezoo/objects/recipes.py +++ b/src/sparsezoo/objects/recipes.py @@ -76,6 +76,9 @@ def default(self) -> File: return recipe # fallback to first recipe in list + if len(self._recipes) == 0: + raise ValueError("No recipes found, could not retrieve a default.") + _LOGGER.warning( f"No default recipe {self._default_recipe_name} found, falling back to " f"first listed recipe {self._recipes[0].name}" diff --git a/tests/sparsezoo/objects/test_recipes.py b/tests/sparsezoo/objects/test_recipes.py index 3e48abe5..264468c8 100644 --- a/tests/sparsezoo/objects/test_recipes.py +++ b/tests/sparsezoo/objects/test_recipes.py @@ -14,6 +14,8 @@ import tempfile +import pytest + from sparsezoo.model import Model @@ -59,3 +61,12 @@ def test_custom_default(): available_recipes = model.recipes.available assert len(available_recipes) == 1 assert available_recipes[0] == "recipe_transfer_text_classification" + + +def test_fail_default_on_empty(): + false_recipe_stub = "zoo:bert-base-wikipedia_bookcorpus-pruned90?recipe=nope" + temp_dir = tempfile.TemporaryDirectory(dir="/tmp") + model = Model(false_recipe_stub, temp_dir.name) + + with pytest.raises(ValueError): + _ = model.recipes.default