Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MVP for Alternating Flow #1912

Merged
merged 44 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
d5abe8e
initial recipe re-loading
Satrat Nov 16, 2023
ec0e180
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 5, 2023
2d7b5b7
loading for input recipe
Satrat Dec 7, 2023
2cc9e16
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 7, 2023
356bd81
persist structure across recipe loads
Satrat Dec 7, 2023
1b67b6f
clean up fn names
Satrat Dec 7, 2023
f06ed8a
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 12, 2023
ab5a464
clean up duplicated code
Satrat Dec 12, 2023
11f4efe
delete extra file
Satrat Dec 12, 2023
7e960a3
unit tests
Satrat Dec 12, 2023
ebb5407
fix failing test
Satrat Dec 12, 2023
6a394d7
quantization edge cases
Satrat Dec 12, 2023
d7974bf
quant tests
Satrat Dec 13, 2023
4b9014d
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 13, 2023
701ab2c
fixes for stage name clashes
Satrat Dec 13, 2023
5812488
Merge branch 'sparse_auto_recipe' of github.com:neuralmagic/sparseml …
Satrat Dec 13, 2023
21473aa
clean up documentation
Satrat Dec 13, 2023
485501b
setup StageRunner class
Satrat Dec 13, 2023
2d536a3
running one_shot from text_gen script
Satrat Dec 13, 2023
a4406ae
cleanup helper fns
Satrat Dec 14, 2023
4576a80
precision support
Satrat Dec 14, 2023
27467e3
formatting
Satrat Dec 14, 2023
10a0fed
Merge branch 'main' into alternate_flows
Satrat Dec 14, 2023
7c754e0
WIP for alternating
Satrat Dec 15, 2023
0eb06bf
fixing device issue
Satrat Dec 15, 2023
f45326d
Merge branch 'sparse_auto_recipe' into alternating_flow_pt2
Satrat Dec 15, 2023
e46dd96
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 15, 2023
d308987
MVP for alternating flows
Satrat Dec 15, 2023
fe9af83
add apply flag during finalization as well
Satrat Dec 15, 2023
5f6e854
clarity comments
Satrat Dec 15, 2023
4588eb2
Merge branch 'sparse_auto_recipe' of github.com:neuralmagic/sparseml …
Satrat Dec 15, 2023
391350d
clean up docstrings
Satrat Dec 15, 2023
f7fb65a
fix unit test
Satrat Dec 15, 2023
e429929
Merge branch 'main' into alternate_flows
Satrat Dec 15, 2023
7336453
Merge branch 'sparse_auto_recipe' into alternating_flow_pt2
Satrat Dec 15, 2023
2968171
Merge branch 'alternate_flows' into alternating_flow_pt2
Satrat Dec 15, 2023
ee1ee2d
add finetuning README
Satrat Dec 20, 2023
9004da6
Merge branch 'main' of github.com:neuralmagic/sparseml
Satrat Dec 21, 2023
180a24d
Merge branch 'main' into alternating_flow_pt2
Satrat Dec 21, 2023
a8760eb
cleaning up stage logic
Satrat Dec 21, 2023
8eba7dd
Merge branch 'main' into alternating_flow_pt2
Satrat Jan 2, 2024
9ef0d4c
quality
Satrat Jan 2, 2024
c4562c0
Merge branch 'main' into alternating_flow_pt2
Satrat Jan 8, 2024
797413a
Merge branch 'main' into alternating_flow_pt2
Satrat Jan 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/sparseml/core/recipe/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Any, Dict, List, Optional

from pydantic import Field, root_validator
Expand All @@ -23,14 +24,20 @@
from sparseml.core.recipe.modifier import RecipeModifier


__all__ = ["RecipeStage"]
__all__ = ["RecipeStage", "StageRunType"]


class StageRunType(Enum):
TRAIN = "train"
ONESHOT = "oneshot"


class RecipeStage(RecipeBase):
"""
Represents a stage in a recipe.

:param group: Name of the current stage
:param run_type: Whether this is a oneshot or training stage
:param args: Optional RecipeArgs to use for this stage
:param enabled: True to enable the stage, False otherwise
:param modifiers: list of RecipeModifiers that are a part of this stage
Expand All @@ -40,12 +47,28 @@ class RecipeStage(RecipeBase):
"""

group: Optional[str] = None
run_type: Optional[StageRunType] = None
args: Optional[RecipeArgs] = None
enabled: bool = True
modifiers: List[RecipeModifier] = Field(default_factory=list)
exclude_default: bool = False
args_evaluated: Optional[RecipeArgs] = None

def infer_run_type(self) -> Optional[StageRunType]:
"""
Infers the stage type from the type attribute or stage name, falls back to None

:return: string representing stage type, either train or oneshot, or None if
stage cannot be inferred
"""
if self.run_type == StageRunType.TRAIN or self.run_type == StageRunType.ONESHOT:
return self.run_type
if StageRunType.TRAIN.value in self.group:
return StageRunType.TRAIN
if StageRunType.ONESHOT.value in self.group:
return StageRunType.ONESHOT
return None

def calculate_start(self) -> int:
"""
:return: the start epoch for the stage, atleast one modifier
Expand Down
7 changes: 7 additions & 0 deletions src/sparseml/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ def reset(self):
"""
self._lifecycle.reset()

def reset_stage(self):
"""
Reset the session for starting a new stage, recipe and model stays intact
"""
self.lifecycle.initialized_ = False
self.lifecycle.finalized = False

def get_serialized_recipe(self) -> str:
"""
:return: serialized string of the current compiled recipe
Expand Down
17 changes: 17 additions & 0 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def apply_recipe_structure_to_model(model: Module, recipe_path: str, model_path:
model=model, recipe=recipe_path, framework=Framework.pytorch
)

# no need to reload if no recipe was applied
if recipe_path is None:
return

session = session_manager.active_session()
num_stages = len(session.lifecycle.recipe_container.compiled_recipe.stages)
msg = (
Expand Down Expand Up @@ -233,3 +237,16 @@ def parse_dtype(dtype_arg: str) -> torch.dtype:
dtype = torch.float32

return dtype


def get_session_model() -> Module:
"""
:return: pytorch module stored by the active SparseSession, or None if no session
is active
"""
session = session_manager.active_session()
if not session:
return None

active_model = session.state.model.model
return active_model
69 changes: 59 additions & 10 deletions src/sparseml/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import logging
from typing import List
import os
from typing import List, Optional

import torch
from torch.nn import Module
Expand All @@ -22,6 +23,7 @@

import sparseml.core.session as session_manager
from sparseml.core.framework import Framework
from sparseml.core.recipe import Recipe, StageRunType
from sparseml.pytorch.model_load.helpers import fallback_to_cpu, save_model_and_recipe
from sparseml.transformers.finetune import Trainer, TrainingArguments
from sparseml.transformers.finetune.data import TextGenerationDataset
Expand Down Expand Up @@ -65,6 +67,7 @@ def __init__(
self.model = model
self.trainer = None
self.tokenizer = None
self._output_dir = self._training_args.output_dir

def populate_datasets(self, tokenizer: "AutoTokenizer"):
"""
Expand Down Expand Up @@ -95,10 +98,10 @@ def populate_datasets(self, tokenizer: "AutoTokenizer"):

self.datasets = make_dataset_splits(
tokenized_datasets,
self._training_args.do_train,
self._training_args.do_eval,
self._training_args.do_predict,
self._training_args.do_oneshot,
do_train=self._training_args.do_train or self._training_args.run_stages,
do_eval=self._training_args.do_eval,
do_predict=self._training_args.do_predict,
do_oneshot=self._training_args.do_oneshot or self._training_args.run_stages,
)
self.tokenizer = tokenizer

Expand Down Expand Up @@ -144,9 +147,11 @@ def format_calibration_data(self) -> List[torch.Tensor]:
: min(self._data_args.num_calibration_samples, len(parsed_calib_data))
]

def one_shot(self):
def one_shot(self, stage: Optional[str] = None):
"""
Run oneshot calibration on the active model

:param stage: which stage of the recipe to run, or None to run whole recipe
"""
_LOGGER.info("*** One Shot ***")

Expand All @@ -155,6 +160,7 @@ def one_shot(self):
session_manager.apply(
framework=Framework.pytorch,
recipe=self._training_args.recipe,
recipe_stage=stage,
model=self.model,
calib_data=calib_data,
start=-1,
Expand All @@ -164,25 +170,29 @@ def one_shot(self):

save_model_and_recipe(
model=self.model,
save_path=self._training_args.output_dir,
save_path=self._output_dir,
tokenizer=self.tokenizer,
)

def train(self, checkpoint: str):
def train(self, checkpoint: str, stage: Optional[str] = None):
"""
Run trainer's training loop on train_dataset, saving the resulting model to
output_dir

:param checkpoint: Optional checkpoint to resume from
:param stage: which stage of the recipe to run, or None to run whole recipe
"""
train_result = self.trainer.train(resume_from_checkpoint=checkpoint)
_LOGGER.info("*** Train ***")
train_result = self.trainer.train(
resume_from_checkpoint=checkpoint, stage=stage
)
metrics = train_result.metrics
metrics["train_samples"] = len(self.get_dataset_split("train"))
self.trainer.log_metrics("train", metrics)
self.trainer.save_metrics("train", metrics)

# this includes saving the state, optimizer and scheduler
self.trainer.save_model()
self.trainer.save_model(output_dir=self._output_dir)

def evaluate(self):
"""
Expand All @@ -206,3 +216,42 @@ def predict(self):
metrics["predict_samples"] = len(self.dataset["test"])
self.trainer.log_metrics("predict", metrics)
self.trainer.save_metrics("predict", metrics)

def run_sequential_stages(self):
"""
Run the recipe stage by stage, allowing for alternating between one-shot and
finetuning flows. Optionally save the model output at the end of each stage
"""

recipe_obj = Recipe.create_instance(self._training_args.recipe)

for stage in recipe_obj.stages:
# validate stage
stage_name = stage.group
run_type = stage.infer_run_type()
if not run_type:
raise ValueError(
f"a valid stage type ({[e.value for e in StageRunType]}) "
"must be provided in run_stages mode. Either add a run_type "
"attribute to each stage in the recipe or include it as part of "
"the stage name."
)

# setup checkpoint dir, TODO: this should be optional
self._output_dir = os.path.join(
self._training_args.output_dir, "stage_" + stage_name
)
if not os.path.exists(self._output_dir):
os.makedirs(self._output_dir)

# run stage
if run_type is StageRunType.ONESHOT:
self.one_shot(stage=stage_name)
elif run_type is StageRunType.TRAIN:
self.train(checkpoint=None, stage=stage_name)

# setup for next stage
session = session_manager.active_session()
session.reset_stage()

self.trainer.log_model_sparsification()
22 changes: 14 additions & 8 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class SessionManagerMixIn:
Mix-In class to extend the Hugging Face Trainer class to support SparseML recipes
for one-shot and finetuning flows.

:param model: PyTorch model to run training on
:param model_state_path: path to Pytorch model checkpoint or saved model
:param recipe: path to recipe file to apply during training
:param recipe_args: additional kwargs to use for evaluating recipe
Expand All @@ -64,9 +63,8 @@ class SessionManagerMixIn:

def __init__(
self,
model: Module,
model_state_path: str,
recipe: Optional[str],
recipe: Optional[str] = None,
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
metadata_args: Optional[List[str]] = None,
data_args: Optional["DataTrainingArguments"] = None, # noqa: F821
Expand Down Expand Up @@ -96,7 +94,7 @@ def __init__(
session_manager.create_session()

# call Trainer initialization
super().__init__(model=model, **kwargs)
super().__init__(**kwargs)

# setup callbacks and loss
self.optim_callbacks = TrainingLoopCallbacks(self)
Expand All @@ -105,7 +103,7 @@ def __init__(
self.callback_handler.add_callback(self.callback_disable_fp16)
self.criterion = torch.nn.CrossEntropyLoss()

model_signature = inspect.signature(model.forward)
model_signature = inspect.signature(self.model.forward)
self._model_signature_columns = list(model_signature.parameters.keys())

if self.teacher is not None and teacher not in ("disable", "self"):
Expand All @@ -114,14 +112,20 @@ def __init__(
else:
self._teacher_signature_columns = None

def initialize_session(self, epoch: float, checkpoint: Optional[str]):
def initialize_session(
self,
epoch: float,
checkpoint: Optional[str] = None,
stage: Optional[str] = None,
):
"""
Initialize the SparseSession from the specified epoch, evaluates the recipe
and initialized the modifiers for the training session

:param epoch: Epoch to initialize session from, usually 0 unless loading
from a checkpoint
:param checkpoint: Optional checkpoint to initialize from to continue training
:param stage: Optional stage of recipe to run, or None to run all stages
"""
session = session_manager.active_session()
if session.lifecycle.initialized_ or session.lifecycle.finalized:
Expand All @@ -135,6 +139,7 @@ def initialize_session(self, epoch: float, checkpoint: Optional[str]):
model=self.model,
teacher_model=self.teacher, # TODO: what about for self/disable?
recipe=self.recipe,
recipe_stage=stage,
recipe_args=self.recipe_args,
framework=Framework.pytorch,
train_data=train_data,
Expand Down Expand Up @@ -305,19 +310,20 @@ def prediction_step(
)
return model_outputs

def train(self, *args, **kwargs):
def train(self, *args, stage: Optional[str] = None, **kwargs):
"""
Run a sparsification training cycle. Runs initialization for the sparse session
before calling super().train() and finalization of the session after.

Logs sparsification details for the trained model.

:param args: positional args to pass to super().train()
:param stage: Optional stage of recipe to run, or None to run all stages
:param kwargs: keyword args to pass to super().train()
:return: the output from super.train()
"""
checkpoint, epoch = self._calculate_checkpoint_info(kwargs)
self.initialize_session(epoch=epoch, checkpoint=checkpoint)
self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage)
self.callback_disable_fp16.check_disable(epoch, force=True)
self.accelerator.wait_for_everyone()
output = super().train(*args, **kwargs)
Expand Down
15 changes: 12 additions & 3 deletions src/sparseml/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from sparseml.pytorch.model_load.helpers import (
apply_recipe_structure_to_model,
get_session_model,
parse_dtype,
)
from sparseml.transformers.finetune import Trainer, TrainingArguments
Expand Down Expand Up @@ -192,6 +193,7 @@ def main(
else:
if not os.path.exists(recipe_path):
_LOGGER.warning(f"No recipes were applied for {model_path}.")
apply_recipe_structure_to_model(model, None, model_path)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
else:
_LOGGER.warning(f"Applying recipe {recipe_path} to {model_path}")
apply_recipe_structure_to_model(model, recipe_path, model_path)
Expand Down Expand Up @@ -233,21 +235,28 @@ def main(

# Initialize our Trainer
trainer = Trainer(
model=model,
model_init=get_session_model,
teacher=teacher,
model_state_path=model_path,
recipe=training_args.recipe,
metadata_args=metadata_args,
recipe_args=training_args.recipe_args,
args=training_args,
data_args=data_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
stage_runner.set_trainer(trainer)

# alternating Training/One-shot
if training_args.run_stages:
stage_runner.run_sequential_stages()

# exit immediately
return

# Training
if training_args.do_train:
checkpoint = None
Expand Down
Loading
Loading