diff --git a/src/sparsify/auto/scripts/main.py b/src/sparsify/auto/scripts/main.py index a06ce301..131ae4ff 100644 --- a/src/sparsify/auto/scripts/main.py +++ b/src/sparsify/auto/scripts/main.py @@ -32,7 +32,7 @@ request_student_teacher_configs, save_history, ) -from sparsify.schemas import APIArgs, Metrics, SparsificationTrainingConfig +from sparsify.schemas import APIArgs, Metrics, RunMode, SparsificationTrainingConfig from tensorboard.program import TensorBoard @@ -122,7 +122,7 @@ def main(api_args: Optional[APIArgs] = None): student_config = SparsificationTrainingConfig(**config_dict) else: teacher_config = SparsificationTrainingConfig(**config_dict) - if not api_args.teacher_only: + if RunMode(api_args.run_mode) != RunMode.teacher_only: student_config = api_request_config(api_args) save_directory = api_args.resume diff --git a/src/sparsify/auto/utils/nm_api.py b/src/sparsify/auto/utils/nm_api.py index 671568ab..e3470a73 100644 --- a/src/sparsify/auto/utils/nm_api.py +++ b/src/sparsify/auto/utils/nm_api.py @@ -21,7 +21,7 @@ import requests from sparsify.login import import_sparsifyml_authenticated -from sparsify.schemas import APIArgs, Metrics, SparsificationTrainingConfig +from sparsify.schemas import APIArgs, Metrics, RunMode, SparsificationTrainingConfig from sparsify.utils import get_base_url, strtobool @@ -86,14 +86,14 @@ def request_student_teacher_configs( student_config, teacher_config = None, None - if api_args.teacher_only: + if RunMode(api_args.run_mode) == RunMode.teacher_only: teacher_config = SparsificationTrainingConfig(**api_request_config(api_args)) else: student_config = SparsificationTrainingConfig(**api_request_config(api_args)) if student_config.distill_teacher == "auto": teacher_input_args = api_args.copy(deep=True) - teacher_input_args.teacher_only = True + teacher_input_args.run_mode = RunMode.teacher_only teacher_config = SparsificationTrainingConfig( **api_request_config(teacher_input_args) ) diff --git a/src/sparsify/cli/run.py b/src/sparsify/cli/run.py index daa5ee0a..1210e2b4 100644 --- a/src/sparsify/cli/run.py +++ b/src/sparsify/cli/run.py @@ -67,7 +67,7 @@ def sparse_transfer(**kwargs): from sparsify import auto # recipe arg should be a sparse transfer recipe - auto.main(_parse_run_args_to_auto(**kwargs)) + auto.main(_parse_run_args_to_auto(sparse_transfer=True, **kwargs)) @main.command() @@ -83,10 +83,10 @@ def training_aware(**kwargs): from sparsify import auto # recipe arg should be a training aware recipe - auto.main(_parse_run_args_to_auto(**kwargs)) + auto.main(_parse_run_args_to_auto(sparse_transfer=False, **kwargs)) -def _parse_run_args_to_auto(**kwargs): +def _parse_run_args_to_auto(sparse_transfer: bool, **kwargs): from sparsify.schemas import APIArgs if kwargs["eval_metric"] == "kl": @@ -109,7 +109,7 @@ def _parse_run_args_to_auto(**kwargs): teacher_kwargs={}, tuning_parameters=None, teacher_tuning_parameters=None, - teacher_only=False, + run_mode="sparse_transfer" if sparse_transfer else "training_aware", ) diff --git a/src/sparsify/schemas/auto_api.py b/src/sparsify/schemas/auto_api.py index ffe6ee8c..cd330e51 100644 --- a/src/sparsify/schemas/auto_api.py +++ b/src/sparsify/schemas/auto_api.py @@ -20,6 +20,7 @@ import argparse import json import os +from enum import Enum from functools import total_ordering from typing import Any, Dict, List, Optional, Union @@ -32,6 +33,7 @@ __all__ = [ "APIArgs", + "RunMode", "SparsificationTrainingConfig", "Metrics", "DEFAULT_OUTPUT_DIRECTORY", @@ -40,6 +42,16 @@ DEFAULT_OUTPUT_DIRECTORY = "./output" +class RunMode(Enum): + """ + Class defining options for auto runner modes + """ + + training_aware = "training_aware" # full sparsification run from dense baseline + sparse_transfer = "sparse_transfer" + teacher_only = "teacher_only" # only train dense upstream teacher + + class APIArgs(BaseModel): """ Class containing the front-end arguments for Sparsify.Auto @@ -163,10 +175,13 @@ class APIArgs(BaseModel): "settings. See example tuning config output for expected format", default=None, ) - teacher_only: bool = Field( - title="teacher_only", - description=("set to True to only auto tune the teacher"), - default=False, + run_mode: RunMode = Field( + title="run_mode", + description=( + "training run mode objective - 'sparse_transfer', 'training_aware', or " + "'teacher_only'. Default is 'sparse_transfer'" + ), + default=RunMode.sparse_transfer, ) @validator("task")