Skip to content

Commit

Permalink
Restore "default" Property for Recipes (#366)
Browse files Browse the repository at this point in the history
* class for recipes

* fix failing test

* check for Recipes instead of list

* missed list instance

* allow for parsing of custom default recipes

* update comment

* PR comments

* PR comments
  • Loading branch information
Satrat authored Sep 15, 2023
1 parent 47b67c7 commit f2cf607
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 14 deletions.
19 changes: 12 additions & 7 deletions src/sparsezoo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
File,
NumpyDirectory,
OnnxGz,
Recipes,
SelectDirectory,
is_directory,
)
Expand Down Expand Up @@ -150,9 +151,10 @@ def __init__(self, source: str, download_path: Optional[str] = None):

self.logs: Directory = self._directory_from_files(files, display_name="logs")

self.recipes = self._file_from_files(files, display_name="^recipe", regex=True)
if isinstance(self.recipes, File):
self.recipes = [self.recipes]
recipe_file_list = self._file_from_files(
files, display_name="^recipe", regex=True
)
self.recipes = Recipes(recipe_file_list, stub_params=self.stub_params)

self._onnx_gz: OnnxGz = self._directory_from_files(
files, directory_class=OnnxGz, display_name="model.onnx.tar.gz"
Expand Down Expand Up @@ -652,7 +654,9 @@ def _directory_from_files(
return directories_found[0]

def _download(
self, file: Union[File, List[File], Dict[Any, File]], directory_path: str
self,
file: Union[File, Recipes, Dict[Any, File]],
directory_path: str,
) -> bool:

if isinstance(file, File):
Expand All @@ -670,9 +674,10 @@ def _download(
)
return False

elif isinstance(file, list):
validations = (self._download(_file, directory_path) for _file in file)
return all(validations)
elif isinstance(file, Recipes):
validations = (
self._download(_file, directory_path) for _file in file.recipes
)

else:
validations = (
Expand Down
6 changes: 3 additions & 3 deletions src/sparsezoo/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
ThroughputResults,
ValidationResult,
)
from sparsezoo.objects import Directory, File, NumpyDirectory, OnnxGz
from sparsezoo.objects import Directory, File, NumpyDirectory, OnnxGz, Recipes
from sparsezoo.utils import BASE_API_URL, convert_to_bool, save_numpy


Expand Down Expand Up @@ -562,8 +562,8 @@ def _copy_file_contents(
_copy_and_overwrite(file.path, copy_path, shutil.copytree)
else:
# for the structured directories/files
if isinstance(file, list):
for _file in file:
if isinstance(file, Recipes):
for _file in file.recipes:
copy_path = os.path.join(output_dir, os.path.basename(_file.path))
_copy_and_overwrite(_file.path, copy_path, shutil.copyfile)
elif isinstance(file, OnnxGz):
Expand Down
1 change: 1 addition & 0 deletions src/sparsezoo/objects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@

# flake8: noqa
from .file import *
from .recipes import *
87 changes: 87 additions & 0 deletions src/sparsezoo/objects/recipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, List, Optional, Union

from sparsezoo.objects import File


__all__ = ["Recipes"]

_LOGGER = logging.getLogger(__name__)


class Recipes:
"""
Object to store a list of recipes for a downloaded model and pull the default
:param recipes: list of recipes to store
:param stub_params: dictionary that may contain custom default recipes names
"""

_RECIPE_DEFAULT_NAME = "recipe.md"

def __init__(
self,
recipes: Optional[Union[File, List[File]]],
stub_params: Dict[str, str] = {},
):
if recipes is None:
recipes = []
if isinstance(recipes, File):
recipes = [recipes]
self._recipes = recipes

self._default_recipe_name = self._RECIPE_DEFAULT_NAME
custom_default = stub_params.get("recipe_type") or stub_params.get("recipe")
if custom_default is not None:
self._default_recipe_name = "recipe_" + custom_default

@property
def recipes(self) -> List:
"""
:return: The full list of recipes
"""
return self._recipes

@property
def default(self) -> File:
"""
:return: The default recipe in the recipe list
"""
for recipe in self._recipes:
if recipe.name.startswith(self._default_recipe_name):
return recipe

# fallback to first recipe in list
_LOGGER.warning(
"No default recipe {self._default_recipe_name} found, falling back to"
"first listed recipe"
)
return self._recipes[0]

def get_recipe_by_name(self, recipe_name: str) -> Union[File, None]:
"""
Returns the File for the recipe matching the name recipe_name if it exists
:param recipe_name: recipe filename to search for
:return: File with the name recipe_name, or None if it doesn't exist
"""

for recipe in self._recipes:
if recipe.name == recipe_name:
return recipe

return None # no matching recipe found
14 changes: 10 additions & 4 deletions src/sparsezoo/validation/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class objects are valid
import os
from typing import Callable, Dict, Optional, Set, Tuple, Union

from sparsezoo.objects import Directory, File
from sparsezoo.objects import Directory, File, Recipes
from sparsezoo.validation import (
validate_cv_classification,
validate_cv_detection,
Expand Down Expand Up @@ -110,8 +110,10 @@ def validate(self, minimal_validation: bool) -> bool:
_file.validate() for _file in file.values()
)
# checker for list-type file
elif isinstance(file, list):
validations[file.__repr__()] = all(_file.validate() for _file in file)
elif isinstance(file, Recipes):
validations[file.__repr__()] = all(
_file.validate() for _file in file.recipes
)
else:
# checker for File/Directory class objects
if file.name == "training":
Expand Down Expand Up @@ -150,7 +152,11 @@ def validate_structure(self) -> bool:
"""
if self.minimal_validation:
for file in self.model.files:
if isinstance(file, list) or isinstance(file, dict) or (file is None):
if (
isinstance(file, Recipes)
or isinstance(file, dict)
or (file is None)
):
continue
else:
self.required_files.discard(file.name)
Expand Down
52 changes: 52 additions & 0 deletions tests/sparsezoo/objects/test_recipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile

from sparsezoo.model import Model


def test_recipe_getters():
stub_with_multiple_recipes = "zoo:bert-base-wikipedia_bookcorpus-pruned90"
temp_dir = tempfile.TemporaryDirectory(dir="/tmp")
model = Model(stub_with_multiple_recipes, temp_dir.name)

default_recipe = model.recipes.default
assert default_recipe.name == "recipe.md"

all_recipes = model.recipes.recipes
assert len(all_recipes) == 4

recipe_name = "recipe_transfer_text_classification.md"
found_by_name = model.recipes.get_recipe_by_name(recipe_name)
assert found_by_name.name == recipe_name

found_by_name = model.recipes.get_recipe_by_name("does_not_exist.md")
assert found_by_name is None


def test_custom_default():
custom_default_name = "transfer_text_classification"
stub_with_multiple_recipes = (
"zoo:bert-base-wikipedia_bookcorpus-pruned90?recipe={}".format(
custom_default_name
)
)
temp_dir = tempfile.TemporaryDirectory(dir="/tmp")
model = Model(stub_with_multiple_recipes, temp_dir.name)

expected_default_name = "recipe_" + custom_default_name + ".md"

default_recipe = model.recipes.default
assert default_recipe.name == expected_default_name

0 comments on commit f2cf607

Please sign in to comment.