Skip to content

Commit

Permalink
fix tests with help from sara
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Jan 2, 2024
1 parent e4770c8 commit 7b28881
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 25 deletions.
1 change: 0 additions & 1 deletion src/sparseml/transformers/sparsification/obcq/obcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def one_shot(
sequence_length=sequence_length,
torch_dtype=torch_dtype,
config=config,
recipe=recipe_file,
device=device,
)

Expand Down
48 changes: 25 additions & 23 deletions src/sparseml/transformers/utils/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,32 +219,34 @@ def resolve_recipe_application(
:param model_path: the path to the model to load
:return: the resolved recipe
"""
recipe_is_file = True
if recipe:

if recipe is None:
# if recipe is None -> still look for recipe.yaml in the model_path
recipe = os.path.join(model_path, RECIPE_NAME)
if os.path.isfile(recipe):
# recipe is a path to a recipe file
pass
elif os.path.isfile(os.path.join(model_path, recipe)):
# recipe is a name of a recipe file
recipe = os.path.join(model_path, recipe)
else:
# recipe is a string containing the recipe
recipe_is_file = False
_LOGGER.debug(
"Applying the recipe string directly to the model, without "
"checking for a potential existing recipe in the model_path."
)
else:
_LOGGER.info(
"No recipe requested and no default recipe "
f"found in {model_path}. Skipping recipe application."
)
return None
return recipe

elif os.path.isfile(recipe):
# recipe is a path to a recipe file
return _resolve_recipe_file(recipe, model_path)

if recipe_is_file:
# if recipe is a file, resolve it to a path
elif os.path.isfile(os.path.join(model_path, recipe)):
# recipe is a name of a recipe file
recipe = os.path.join(model_path, recipe)
return _resolve_recipe_file(recipe, model_path)
return recipe
elif isinstance(recipe, str):
# recipe is a string containing the recipe
_LOGGER.debug(
"Applying the recipe string directly to the model, without "
"checking for a potential existing recipe in the model_path."
)
return recipe

_LOGGER.info(
"No recipe requested and no default recipe "
f"found in {model_path}. Skipping recipe application."
)
return None


def _resolve_recipe_file(
Expand Down
2 changes: 1 addition & 1 deletion tests/sparseml/transformers/obcq/test_obcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_lm_head_target():
model = initialize_sparse_model(
model_path=tiny_model_path,
device=device,
task="text-classification",
task="text-generation",
config=config,
)

Expand Down

0 comments on commit 7b28881

Please sign in to comment.