Skip to content

Commit

Permalink
[One Shot] Specify recipe_args on CLI (#1902)
Browse files Browse the repository at this point in the history
* add cli arg for recipe_args

* style

* fix for output recipe

* adding recipe args to finetuning

* distillation fix
  • Loading branch information
Satrat committed Jan 9, 2024
1 parent 266599a commit c35206d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/sparseml/core/recipe/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,4 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
:return: the dictionary representation of the modifier
"""
return {self.type: self.args, "group": f"{self.group}_modifiers"}
return {self.type: self.args_evaluated, "group": f"{self.group}_modifiers"}
12 changes: 12 additions & 0 deletions src/sparseml/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def parse_args(**kwargs):
else:
model_args, data_args, training_args = parser.parse_dict(kwargs)

if training_args.recipe_args is not None:
arg_dict = {}
for recipe_arg in training_args.recipe_args:
key, value = recipe_arg.split("=")
arg_dict[key] = value
training_args.recipe_args = arg_dict

return model_args, data_args, training_args


Expand Down Expand Up @@ -166,6 +173,10 @@ def main(
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
teacher_config = AutoConfig.from_pretrained(
training_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
)

model_kwargs = {
"config": config,
Expand All @@ -175,6 +186,7 @@ def main(
"torch_dtype": parse_dtype(model_args.precision),
}
teacher_kwargs = {
"config": teacher_config,
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
}
Expand Down
11 changes: 8 additions & 3 deletions src/sparseml/transformers/finetune/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional

from transformers import TrainingArguments as HFTrainingArgs

Expand Down Expand Up @@ -51,9 +51,14 @@ class TrainingArguments(HFTrainingArgs):
),
},
)
recipe_args: Optional[str] = field(
recipe_args: Optional[List[str]] = field(
default=None,
metadata={"help": "Recipe arguments to be overwritten"},
metadata={
"help": (
"List of recipe arguments to evaluate, of the format key1=value1 "
"key2=value2"
)
},
)
do_oneshot: Optional[bool] = field(
default=False,
Expand Down
21 changes: 20 additions & 1 deletion src/sparseml/transformers/sparsification/obcq/obcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import os
from pathlib import Path
from typing import Optional
from typing import Dict, Optional

from torch.nn import Module
from transformers import AutoConfig
Expand Down Expand Up @@ -56,6 +56,7 @@ def one_shot(
deploy_dir: Optional[str] = ".",
recipe_file: Optional[str] = None,
precision: str = "auto",
recipe_args: Optional[Dict] = None,
eval_data: Optional[str] = None,
do_save: Optional[bool] = False,
) -> Module:
Expand All @@ -70,6 +71,7 @@ def one_shot(
:param deploy_dir: The output directory to save the model to
:param recipe_file: recipe containing SparseGPT configuration
:param precision: precision to load model as, either auto, half or full
:param recipe_args: additional arguments to use for recipe evaluation
:param eval_data: dataset to use for perplexity evalaution, or none to skip
:param do_save: whether to save the output model to disk
Expand Down Expand Up @@ -144,6 +146,7 @@ def one_shot(
start=-1,
device=device,
copy_data=False,
recipe_args=recipe_args,
)

if do_save:
Expand All @@ -166,6 +169,15 @@ def one_shot(
return model


class KeyValue(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, dict())

for value in values:
key, value = value.split("=")
getattr(namespace, self.dest)[key] = value


if __name__ == "__main__":
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -195,6 +207,12 @@ def one_shot(
default="auto",
help="Precision to cast model weights to, default to auto",
)
parser.add_argument(
"--recipe_args",
nargs="*",
action=KeyValue,
help="Recipe arguments to evaluate, of the format key1=value1 key2=value2",
)
parser.add_argument(
"--eval", type=str, default=None, help="Optional dataset for perplexity eval"
)
Expand All @@ -213,6 +231,7 @@ def one_shot(
device=args.device,
recipe_file=args.recipe,
precision=args.precision,
recipe_args=args.recipe_args,
eval_data=args.eval,
do_save=args.save,
)

0 comments on commit c35206d

Please sign in to comment.