Skip to content

Commit

Permalink
MVP for Alternating Flow (#1912)
Browse files Browse the repository at this point in the history
* initial recipe re-loading

* loading for input recipe

* persist structure across recipe loads

* clean up fn names

* clean up duplicated code

* delete extra file

* unit tests

* fix failing test

* quantization edge cases

* quant tests

* fixes for stage name clashes

* clean up documentation

* setup StageRunner class

* running one_shot from text_gen script

* cleanup helper fns

* precision support

* formatting

* WIP for alternating

* fixing device issue

* MVP for alternating flows

* add apply flag during finalization as well

* clarity comments

* clean up docstrings

* fix unit test

* add finetuning README

* cleaning up stage logic

* quality
  • Loading branch information
Satrat committed Jan 9, 2024
1 parent 8683a06 commit f592037
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 25 deletions.
25 changes: 24 additions & 1 deletion src/sparseml/core/recipe/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Any, Dict, List, Optional

from pydantic import Field, root_validator
Expand All @@ -23,14 +24,20 @@
from sparseml.core.recipe.modifier import RecipeModifier


__all__ = ["RecipeStage"]
__all__ = ["RecipeStage", "StageRunType"]


class StageRunType(Enum):
TRAIN = "train"
ONESHOT = "oneshot"


class RecipeStage(RecipeBase):
"""
Represents a stage in a recipe.
:param group: Name of the current stage
:param run_type: Whether this is a oneshot or training stage
:param args: Optional RecipeArgs to use for this stage
:param enabled: True to enable the stage, False otherwise
:param modifiers: list of RecipeModifiers that are a part of this stage
Expand All @@ -40,12 +47,28 @@ class RecipeStage(RecipeBase):
"""

group: Optional[str] = None
run_type: Optional[StageRunType] = None
args: Optional[RecipeArgs] = None
enabled: bool = True
modifiers: List[RecipeModifier] = Field(default_factory=list)
exclude_default: bool = False
args_evaluated: Optional[RecipeArgs] = None

def infer_run_type(self) -> Optional[StageRunType]:
"""
Infers the stage type from the type attribute or stage name, falls back to None
:return: string representing stage type, either train or oneshot, or None if
stage cannot be inferred
"""
if self.run_type == StageRunType.TRAIN or self.run_type == StageRunType.ONESHOT:
return self.run_type
if StageRunType.TRAIN.value in self.group:
return StageRunType.TRAIN
if StageRunType.ONESHOT.value in self.group:
return StageRunType.ONESHOT
return None

def calculate_start(self) -> int:
"""
:return: the start epoch for the stage, atleast one modifier
Expand Down
7 changes: 7 additions & 0 deletions src/sparseml/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ def reset(self):
"""
self._lifecycle.reset()

def reset_stage(self):
"""
Reset the session for starting a new stage, recipe and model stays intact
"""
self.lifecycle.initialized_ = False
self.lifecycle.finalized = False

def get_serialized_recipe(self) -> str:
"""
:return: serialized string of the current compiled recipe
Expand Down
17 changes: 17 additions & 0 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def apply_recipe_structure_to_model(model: Module, recipe_path: str, model_path:
model=model, recipe=recipe_path, framework=Framework.pytorch
)

# no need to reload if no recipe was applied
if recipe_path is None:
return

session = session_manager.active_session()
num_stages = len(session.lifecycle.recipe_container.compiled_recipe.stages)
msg = (
Expand Down Expand Up @@ -233,3 +237,16 @@ def parse_dtype(dtype_arg: str) -> torch.dtype:
dtype = torch.float32

return dtype


def get_session_model() -> Module:
"""
:return: pytorch module stored by the active SparseSession, or None if no session
is active
"""
session = session_manager.active_session()
if not session:
return None

active_model = session.state.model.model
return active_model
69 changes: 59 additions & 10 deletions src/sparseml/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import logging
from typing import List
import os
from typing import List, Optional

import torch
from torch.nn import Module
Expand All @@ -22,6 +23,7 @@

import sparseml.core.session as session_manager
from sparseml.core.framework import Framework
from sparseml.core.recipe import Recipe, StageRunType
from sparseml.pytorch.model_load.helpers import fallback_to_cpu, save_model_and_recipe
from sparseml.transformers.finetune import Trainer, TrainingArguments
from sparseml.transformers.finetune.data import TextGenerationDataset
Expand Down Expand Up @@ -65,6 +67,7 @@ def __init__(
self.model = model
self.trainer = None
self.tokenizer = None
self._output_dir = self._training_args.output_dir

def populate_datasets(self, tokenizer: "AutoTokenizer"):
"""
Expand Down Expand Up @@ -95,10 +98,10 @@ def populate_datasets(self, tokenizer: "AutoTokenizer"):

self.datasets = make_dataset_splits(
tokenized_datasets,
self._training_args.do_train,
self._training_args.do_eval,
self._training_args.do_predict,
self._training_args.do_oneshot,
do_train=self._training_args.do_train or self._training_args.run_stages,
do_eval=self._training_args.do_eval,
do_predict=self._training_args.do_predict,
do_oneshot=self._training_args.do_oneshot or self._training_args.run_stages,
)
self.tokenizer = tokenizer

Expand Down Expand Up @@ -144,9 +147,11 @@ def format_calibration_data(self) -> List[torch.Tensor]:
: min(self._data_args.num_calibration_samples, len(parsed_calib_data))
]

def one_shot(self):
def one_shot(self, stage: Optional[str] = None):
"""
Run oneshot calibration on the active model
:param stage: which stage of the recipe to run, or None to run whole recipe
"""
_LOGGER.info("*** One Shot ***")

Expand All @@ -155,6 +160,7 @@ def one_shot(self):
session_manager.apply(
framework=Framework.pytorch,
recipe=self._training_args.recipe,
recipe_stage=stage,
model=self.model,
calib_data=calib_data,
start=-1,
Expand All @@ -164,25 +170,29 @@ def one_shot(self):

save_model_and_recipe(
model=self.model,
save_path=self._training_args.output_dir,
save_path=self._output_dir,
tokenizer=self.tokenizer,
)

def train(self, checkpoint: str):
def train(self, checkpoint: str, stage: Optional[str] = None):
"""
Run trainer's training loop on train_dataset, saving the resulting model to
output_dir
:param checkpoint: Optional checkpoint to resume from
:param stage: which stage of the recipe to run, or None to run whole recipe
"""
train_result = self.trainer.train(resume_from_checkpoint=checkpoint)
_LOGGER.info("*** Train ***")
train_result = self.trainer.train(
resume_from_checkpoint=checkpoint, stage=stage
)
metrics = train_result.metrics
metrics["train_samples"] = len(self.get_dataset_split("train"))
self.trainer.log_metrics("train", metrics)
self.trainer.save_metrics("train", metrics)

# this includes saving the state, optimizer and scheduler
self.trainer.save_model()
self.trainer.save_model(output_dir=self._output_dir)

def evaluate(self):
"""
Expand All @@ -206,3 +216,42 @@ def predict(self):
metrics["predict_samples"] = len(self.dataset["test"])
self.trainer.log_metrics("predict", metrics)
self.trainer.save_metrics("predict", metrics)

def run_sequential_stages(self):
"""
Run the recipe stage by stage, allowing for alternating between one-shot and
finetuning flows. Optionally save the model output at the end of each stage
"""

recipe_obj = Recipe.create_instance(self._training_args.recipe)

for stage in recipe_obj.stages:
# validate stage
stage_name = stage.group
run_type = stage.infer_run_type()
if not run_type:
raise ValueError(
f"a valid stage type ({[e.value for e in StageRunType]}) "
"must be provided in run_stages mode. Either add a run_type "
"attribute to each stage in the recipe or include it as part of "
"the stage name."
)

# setup checkpoint dir, TODO: this should be optional
self._output_dir = os.path.join(
self._training_args.output_dir, "stage_" + stage_name
)
if not os.path.exists(self._output_dir):
os.makedirs(self._output_dir)

# run stage
if run_type is StageRunType.ONESHOT:
self.one_shot(stage=stage_name)
elif run_type is StageRunType.TRAIN:
self.train(checkpoint=None, stage=stage_name)

# setup for next stage
session = session_manager.active_session()
session.reset_stage()

self.trainer.log_model_sparsification()
22 changes: 14 additions & 8 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class SessionManagerMixIn:
Mix-In class to extend the Hugging Face Trainer class to support SparseML recipes
for one-shot and finetuning flows.
:param model: PyTorch model to run training on
:param model_state_path: path to Pytorch model checkpoint or saved model
:param recipe: path to recipe file to apply during training
:param recipe_args: additional kwargs to use for evaluating recipe
Expand All @@ -64,9 +63,8 @@ class SessionManagerMixIn:

def __init__(
self,
model: Module,
model_state_path: str,
recipe: Optional[str],
recipe: Optional[str] = None,
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
metadata_args: Optional[List[str]] = None,
data_args: Optional["DataTrainingArguments"] = None, # noqa: F821
Expand Down Expand Up @@ -96,7 +94,7 @@ def __init__(
session_manager.create_session()

# call Trainer initialization
super().__init__(model=model, **kwargs)
super().__init__(**kwargs)

# setup callbacks and loss
self.optim_callbacks = TrainingLoopCallbacks(self)
Expand All @@ -105,7 +103,7 @@ def __init__(
self.callback_handler.add_callback(self.callback_disable_fp16)
self.criterion = torch.nn.CrossEntropyLoss()

model_signature = inspect.signature(model.forward)
model_signature = inspect.signature(self.model.forward)
self._model_signature_columns = list(model_signature.parameters.keys())

if self.teacher is not None and teacher not in ("disable", "self"):
Expand All @@ -114,14 +112,20 @@ def __init__(
else:
self._teacher_signature_columns = None

def initialize_session(self, epoch: float, checkpoint: Optional[str]):
def initialize_session(
self,
epoch: float,
checkpoint: Optional[str] = None,
stage: Optional[str] = None,
):
"""
Initialize the SparseSession from the specified epoch, evaluates the recipe
and initialized the modifiers for the training session
:param epoch: Epoch to initialize session from, usually 0 unless loading
from a checkpoint
:param checkpoint: Optional checkpoint to initialize from to continue training
:param stage: Optional stage of recipe to run, or None to run all stages
"""
session = session_manager.active_session()
if session.lifecycle.initialized_ or session.lifecycle.finalized:
Expand All @@ -135,6 +139,7 @@ def initialize_session(self, epoch: float, checkpoint: Optional[str]):
model=self.model,
teacher_model=self.teacher, # TODO: what about for self/disable?
recipe=self.recipe,
recipe_stage=stage,
recipe_args=self.recipe_args,
framework=Framework.pytorch,
train_data=train_data,
Expand Down Expand Up @@ -305,19 +310,20 @@ def prediction_step(
)
return model_outputs

def train(self, *args, **kwargs):
def train(self, *args, stage: Optional[str] = None, **kwargs):
"""
Run a sparsification training cycle. Runs initialization for the sparse session
before calling super().train() and finalization of the session after.
Logs sparsification details for the trained model.
:param args: positional args to pass to super().train()
:param stage: Optional stage of recipe to run, or None to run all stages
:param kwargs: keyword args to pass to super().train()
:return: the output from super.train()
"""
checkpoint, epoch = self._calculate_checkpoint_info(kwargs)
self.initialize_session(epoch=epoch, checkpoint=checkpoint)
self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage)
self.callback_disable_fp16.check_disable(epoch, force=True)
self.accelerator.wait_for_everyone()
output = super().train(*args, **kwargs)
Expand Down
15 changes: 12 additions & 3 deletions src/sparseml/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from sparseml.pytorch.model_load.helpers import (
apply_recipe_structure_to_model,
get_session_model,
parse_dtype,
)
from sparseml.transformers.finetune import Trainer, TrainingArguments
Expand Down Expand Up @@ -192,6 +193,7 @@ def main(
else:
if not os.path.exists(recipe_path):
_LOGGER.warning(f"No recipes were applied for {model_path}.")
apply_recipe_structure_to_model(model, None, model_path)
else:
_LOGGER.warning(f"Applying recipe {recipe_path} to {model_path}")
apply_recipe_structure_to_model(model, recipe_path, model_path)
Expand Down Expand Up @@ -233,21 +235,28 @@ def main(

# Initialize our Trainer
trainer = Trainer(
model=model,
model_init=get_session_model,
teacher=teacher,
model_state_path=model_path,
recipe=training_args.recipe,
metadata_args=metadata_args,
recipe_args=training_args.recipe_args,
args=training_args,
data_args=data_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
stage_runner.set_trainer(trainer)

# alternating Training/One-shot
if training_args.run_stages:
stage_runner.run_sequential_stages()

# exit immediately
return

# Training
if training_args.do_train:
checkpoint = None
Expand Down
Loading

0 comments on commit f592037

Please sign in to comment.