Skip to content

Commit

Permalink
Enable One-Shot Launch from Finetuning Script (#1907)
Browse files Browse the repository at this point in the history
* setup StageRunner class

* running one_shot from text_gen script

* cleanup helper fns

* precision support

* formatting
  • Loading branch information
Satrat committed Dec 21, 2023
1 parent 3bba1b6 commit f088321
Show file tree
Hide file tree
Showing 8 changed files with 330 additions and 149 deletions.
60 changes: 60 additions & 0 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
"apply_recipe_structure_to_model",
"reload_model_state",
"reload_model_from_checkpoint",
"save_model_and_recipe",
"fallback_to_cpu",
"parse_dtype",
]

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -173,3 +176,60 @@ def reload_model_from_checkpoint(model: Module, checkpoint: Optional[str] = None
# reload the state dict for the model from the checkpoint
if reload_model_state(model, checkpoint, orig_state_dict):
_LOGGER.info(f"Reloaded model state from checkpoint {checkpoint}")


def save_model_and_recipe(
model: Module,
save_path: str,
tokenizer: Optional[Any] = None,
):
"""
Save a model, tokenizer and the currently loaded recipe to file
:param model: pytorch model to save
:param save_path: path to save output to
:param tokenizer: model tokenizer to save
"""
model.save_pretrained(save_path)
if tokenizer is not None:
tokenizer.save_pretrained(save_path)

_LOGGER.info("Saving output to {}".format(os.path.abspath(save_path)))

recipe_path = os.path.join(save_path, RECIPE_FILE_NAME)
session = session_manager.active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)


def fallback_to_cpu(device: str) -> str:
"""
Takes in a device string and forces it to cpu if cuda is not available
:param device: device id to check
:return: device modified for CUDA status
"""
if "cuda" in device and not torch.cuda.is_available():
_LOGGER.warning(
f"Requested {device} but CUDA is not available, falling back to CPU"
)
return "cpu"

return device


def parse_dtype(dtype_arg: str) -> torch.dtype:
"""
:param dtype_arg: dtype string to parse
:return: torch.dtype parsed from input string
"""
dtype = "auto" # get precision from model by default
if dtype_arg == "half" or dtype_arg == "float16":
dtype = torch.float16
elif dtype_arg == "bfloat16":
dtype = torch.bfloat16
elif dtype_arg == "full" or dtype_arg == "float32":
dtype = torch.float32

return dtype
4 changes: 4 additions & 0 deletions src/sparseml/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class DataTrainingArguments:
default=None,
metadata={"help": "Optional percentages of each split to download"},
)
num_calibration_samples: Optional[int] = field(
default=512,
metadata={"help": "Number of samples to use for one-shot calibration"},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached preprocessed datasets or not."},
Expand Down
14 changes: 12 additions & 2 deletions src/sparseml/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def get_raw_dataset(data_args, cache_dir: Optional[str] = None, **kwargs) -> Dat


def make_dataset_splits(
tokenized_datasets: Dict[str, Any], do_train: bool, do_eval: bool, do_predict: bool
tokenized_datasets: Dict[str, Any],
do_train: bool = False,
do_eval: bool = False,
do_predict: bool = False,
do_oneshot: bool = False,
) -> Dict[str, Dataset]:
"""
Restructures the datasets dictionary based on what tasks will be run
Expand All @@ -48,14 +52,15 @@ def make_dataset_splits(
:param do_train: Whether to store the train dataset
:param do_eval: Whether to store the validation dataset
:param do_predict: Whether to store the test dataset
:param do_oneshot: Whether to store the calibration dataset
:return: Datasets to be used by the requested tasks
"""

# handles case where all splits are contained in a single dataset
if "all" in tokenized_datasets and len(tokenized_datasets) == 1:
tokenized_datasets = tokenized_datasets.get("all")

train_split = eval_split = predict_split = None
train_split = eval_split = predict_split = calib_split = None
if do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
Expand All @@ -68,10 +73,15 @@ def make_dataset_splits(
if "test" not in tokenized_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_split = tokenized_datasets["test"]
if do_oneshot:
if "calibration" not in tokenized_datasets:
raise ValueError("--do_oneshot requires a calibration dataset")
calib_split = tokenized_datasets["calibration"]

split_datasets = {
"train": train_split,
"validation": eval_split,
"test": predict_split,
"calibration": calib_split,
}
return split_datasets
4 changes: 4 additions & 0 deletions src/sparseml/transformers/finetune/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ class ModelArguments:
"(necessary to use this script with private models)"
},
)
precision: str = field(
default="auto",
metadata={"help": "Precision to cast model weights to, default to auto"},
)
203 changes: 203 additions & 0 deletions src/sparseml/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import List

import torch
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset, RandomSampler
from transformers import AutoTokenizer

import sparseml.core.session as session_manager
from sparseml.core.framework import Framework
from sparseml.pytorch.model_load.helpers import fallback_to_cpu, save_model_and_recipe
from sparseml.transformers.finetune import Trainer, TrainingArguments
from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.data_helpers import make_dataset_splits
from sparseml.transformers.finetune.model_args import ModelArguments


_LOGGER: logging.Logger = logging.getLogger(__name__)


class StageRunner:
"""
Launcher class for train, eval and one_shot flows. Manages data splits for each
flow and configurations. In the future this class will also handle alternating
between the different flows
LifeCycle
- populate_datasets()
- set_trainer()
- train() / evaluate() / predict()
:param model_args: Arguments pertaining to model/config/tokenizer
:param data_args: Arguments pertaining to what data to use for different flows
:param training_args: Arguments pertaining to training loop configuration
:model: unwrapped model to run flows on
"""

def __init__(
self,
data_args: "DataTrainingArguments",
model_args: "ModelArguments",
training_args: "TrainingArguments",
model: Module,
):
self._data_args = data_args
self._model_args = model_args
self._training_args = training_args

self.datasets = {}
self.model = model
self.trainer = None
self.tokenizer = None

def populate_datasets(self, tokenizer: "AutoTokenizer"):
"""
Loads datasets for each flow based on data_args, stores a Dataset for each
enabled flow in self.datasets
:param tokenizer: tokenizer to use for dataset tokenization
"""
splits = self._data_args.splits
tokenized_datasets = {}
if self._data_args.splits is None:
splits = {"all": None}
for split_name, split_str in splits.items():
dataset_manager = TextGenerationDataset.load_from_registry(
self._data_args.dataset_name,
data_args=self._data_args,
split=split_str,
tokenizer=tokenizer,
)
raw_dataset = dataset_manager.get_raw_dataset(self._model_args.cache_dir)
tokenized_dataset = dataset_manager.tokenize_and_process(raw_dataset)
tokenized_datasets[split_name] = tokenized_dataset

self.datasets = make_dataset_splits(
tokenized_datasets,
self._training_args.do_train,
self._training_args.do_eval,
self._training_args.do_predict,
self._training_args.do_oneshot,
)
self.tokenizer = tokenizer

def set_trainer(self, trainer: Trainer):
"""
:param trainer: update trainer
"""
self.trainer = trainer

def set_model(self, model: Module):
"""
:param model: update pytorch model
"""
self.model = model

def get_dataset_split(self, split_name: str) -> Dataset:
"""
Retrieve a dataset split by name
:param split_name: name of dataset split to return
:return: dataset split labeled by split_name
"""
return self.datasets.get(split_name)

def format_calibration_data(self) -> List[torch.Tensor]:
"""
Creates a dataloader out of the calibration dataset split, trimming it to
the desired number of calibration samples
:return: list of trimmed calibration data tensors
"""
oneshot_dataset = self.get_dataset_split("calibration")

dataloader_params = {
"batch_size": 1,
"sampler": RandomSampler(oneshot_dataset),
"collate_fn": self.trainer.data_collator,
}

calib_dataloader = DataLoader(oneshot_dataset, **dataloader_params)
parsed_calib_data = [inp["input_ids"] for inp in calib_dataloader]
return parsed_calib_data[
: min(self._data_args.num_calibration_samples, len(parsed_calib_data))
]

def one_shot(self):
"""
Run oneshot calibration on the active model
"""
_LOGGER.info("*** One Shot ***")

calib_data = self.format_calibration_data()
oneshot_device = fallback_to_cpu(self._training_args.oneshot_device)
session_manager.apply(
framework=Framework.pytorch,
recipe=self._training_args.recipe,
model=self.model,
calib_data=calib_data,
start=-1,
device=oneshot_device,
copy_data=False,
)

save_model_and_recipe(
model=self.model,
save_path=self._training_args.output_dir,
tokenizer=self.tokenizer,
)

def train(self, checkpoint: str):
"""
Run trainer's training loop on train_dataset, saving the resulting model to
output_dir
:param checkpoint: Optional checkpoint to resume from
"""
train_result = self.trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(self.get_dataset_split("train"))
self.trainer.log_metrics("train", metrics)
self.trainer.save_metrics("train", metrics)

# this includes saving the state, optimizer and scheduler
self.trainer.save_model()

def evaluate(self):
"""
Run trainer's evaluation loop on eval_dataset, logging the desired metrics
"""
_LOGGER.info("*** Evaluate ***")
metrics = self.trainer.evaluate(self.get_dataset_split("validation"))

metrics["eval_samples"] = len(self.get_dataset_split("validation"))
self.trainer.log_metrics("eval", metrics)
self.trainer.save_metrics("eval", metrics)

def predict(self):
"""
Run trainer's prediction loop on predict_dataset, logging the desired metrics
"""
_LOGGER.info("*** Predict ***")
results = self.trainer.predict(self.dataset["test"])
metrics = results.metrics

metrics["predict_samples"] = len(self.dataset["test"])
self.trainer.log_metrics("predict", metrics)
self.trainer.save_metrics("predict", metrics)
Loading

0 comments on commit f088321

Please sign in to comment.