Skip to content

Commit

Permalink
training aware and sparse transfer run mode support (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran committed Apr 28, 2023
1 parent 47afad9 commit d5d9b48
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/sparsify/auto/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
request_student_teacher_configs,
save_history,
)
from sparsify.schemas import APIArgs, Metrics, SparsificationTrainingConfig
from sparsify.schemas import APIArgs, Metrics, RunMode, SparsificationTrainingConfig
from tensorboard.program import TensorBoard


Expand Down Expand Up @@ -122,7 +122,7 @@ def main(api_args: Optional[APIArgs] = None):
student_config = SparsificationTrainingConfig(**config_dict)
else:
teacher_config = SparsificationTrainingConfig(**config_dict)
if not api_args.teacher_only:
if RunMode(api_args.run_mode) != RunMode.teacher_only:
student_config = api_request_config(api_args)

save_directory = api_args.resume
Expand Down
6 changes: 3 additions & 3 deletions src/sparsify/auto/utils/nm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import requests

from sparsify.login import import_sparsifyml_authenticated
from sparsify.schemas import APIArgs, Metrics, SparsificationTrainingConfig
from sparsify.schemas import APIArgs, Metrics, RunMode, SparsificationTrainingConfig
from sparsify.utils import get_base_url, strtobool


Expand Down Expand Up @@ -86,14 +86,14 @@ def request_student_teacher_configs(

student_config, teacher_config = None, None

if api_args.teacher_only:
if RunMode(api_args.run_mode) == RunMode.teacher_only:
teacher_config = SparsificationTrainingConfig(**api_request_config(api_args))

else:
student_config = SparsificationTrainingConfig(**api_request_config(api_args))
if student_config.distill_teacher == "auto":
teacher_input_args = api_args.copy(deep=True)
teacher_input_args.teacher_only = True
teacher_input_args.run_mode = RunMode.teacher_only
teacher_config = SparsificationTrainingConfig(
**api_request_config(teacher_input_args)
)
Expand Down
8 changes: 4 additions & 4 deletions src/sparsify/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def sparse_transfer(**kwargs):
from sparsify import auto

# recipe arg should be a sparse transfer recipe
auto.main(_parse_run_args_to_auto(**kwargs))
auto.main(_parse_run_args_to_auto(sparse_transfer=True, **kwargs))


@main.command()
Expand All @@ -83,10 +83,10 @@ def training_aware(**kwargs):
from sparsify import auto

# recipe arg should be a training aware recipe
auto.main(_parse_run_args_to_auto(**kwargs))
auto.main(_parse_run_args_to_auto(sparse_transfer=False, **kwargs))


def _parse_run_args_to_auto(**kwargs):
def _parse_run_args_to_auto(sparse_transfer: bool, **kwargs):
from sparsify.schemas import APIArgs

if kwargs["eval_metric"] == "kl":
Expand All @@ -109,7 +109,7 @@ def _parse_run_args_to_auto(**kwargs):
teacher_kwargs={},
tuning_parameters=None,
teacher_tuning_parameters=None,
teacher_only=False,
run_mode="sparse_transfer" if sparse_transfer else "training_aware",
)


Expand Down
23 changes: 19 additions & 4 deletions src/sparsify/schemas/auto_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import argparse
import json
import os
from enum import Enum
from functools import total_ordering
from typing import Any, Dict, List, Optional, Union

Expand All @@ -32,6 +33,7 @@

__all__ = [
"APIArgs",
"RunMode",
"SparsificationTrainingConfig",
"Metrics",
"DEFAULT_OUTPUT_DIRECTORY",
Expand All @@ -40,6 +42,16 @@
DEFAULT_OUTPUT_DIRECTORY = "./output"


class RunMode(Enum):
"""
Class defining options for auto runner modes
"""

training_aware = "training_aware" # full sparsification run from dense baseline
sparse_transfer = "sparse_transfer"
teacher_only = "teacher_only" # only train dense upstream teacher


class APIArgs(BaseModel):
"""
Class containing the front-end arguments for Sparsify.Auto
Expand Down Expand Up @@ -163,10 +175,13 @@ class APIArgs(BaseModel):
"settings. See example tuning config output for expected format",
default=None,
)
teacher_only: bool = Field(
title="teacher_only",
description=("set to True to only auto tune the teacher"),
default=False,
run_mode: RunMode = Field(
title="run_mode",
description=(
"training run mode objective - 'sparse_transfer', 'training_aware', or "
"'teacher_only'. Default is 'sparse_transfer'"
),
default=RunMode.sparse_transfer,
)

@validator("task")
Expand Down

0 comments on commit d5d9b48

Please sign in to comment.