Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable One-Shot Launch from Finetuning Script #1907

Merged
merged 9 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
"apply_recipe_structure_to_model",
"reload_model_state",
"reload_model_from_checkpoint",
"save_model_and_recipe",
"fallback_to_cpu",
"parse_dtype",
]

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -173,3 +176,60 @@ def reload_model_from_checkpoint(model: Module, checkpoint: Optional[str] = None
# reload the state dict for the model from the checkpoint
if reload_model_state(model, checkpoint, orig_state_dict):
_LOGGER.info(f"Reloaded model state from checkpoint {checkpoint}")


def save_model_and_recipe(
model: Module,
save_path: str,
tokenizer: Optional[Any] = None,
):
"""
Save a model, tokenizer and the currently loaded recipe to file

:param model: pytorch model to save
:param save_path: path to save output to
:param tokenizer: model tokenizer to save
"""
model.save_pretrained(save_path)
if tokenizer is not None:
tokenizer.save_pretrained(save_path)

_LOGGER.info("Saving output to {}".format(os.path.abspath(save_path)))

recipe_path = os.path.join(save_path, RECIPE_FILE_NAME)
session = session_manager.active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)


def fallback_to_cpu(device: str) -> str:
"""
Takes in a device string and forces it to cpu if cuda is not available

:param device: device id to check
:return: device modified for CUDA status
"""
if "cuda" in device and not torch.cuda.is_available():
_LOGGER.warning(
f"Requested {device} but CUDA is not available, falling back to CPU"
)
return "cpu"

return device


def parse_dtype(dtype_arg: str) -> torch.dtype:
"""
:param dtype_arg: dtype string to parse
:return: torch.dtype parsed from input string
"""
dtype = "auto" # get precision from model by default
if dtype_arg == "half" or dtype_arg == "float16":
dtype = torch.float16
elif dtype_arg == "bfloat16":
dtype = torch.bfloat16
elif dtype_arg == "full" or dtype_arg == "float32":
dtype = torch.float32

return dtype
4 changes: 4 additions & 0 deletions src/sparseml/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class DataTrainingArguments:
default=None,
metadata={"help": "Optional percentages of each split to download"},
)
num_calibration_samples: Optional[int] = field(
default=512,
metadata={"help": "Number of samples to use for one-shot calibration"},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached preprocessed datasets or not."},
Expand Down
14 changes: 12 additions & 2 deletions src/sparseml/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def get_raw_dataset(data_args, cache_dir: Optional[str] = None, **kwargs) -> Dat


def make_dataset_splits(
tokenized_datasets: Dict[str, Any], do_train: bool, do_eval: bool, do_predict: bool
tokenized_datasets: Dict[str, Any],
do_train: bool = False,
do_eval: bool = False,
do_predict: bool = False,
do_oneshot: bool = False,
) -> Dict[str, Dataset]:
"""
Restructures the datasets dictionary based on what tasks will be run
Expand All @@ -48,14 +52,15 @@ def make_dataset_splits(
:param do_train: Whether to store the train dataset
:param do_eval: Whether to store the validation dataset
:param do_predict: Whether to store the test dataset
:param do_oneshot: Whether to store the calibration dataset
:return: Datasets to be used by the requested tasks
"""

# handles case where all splits are contained in a single dataset
if "all" in tokenized_datasets and len(tokenized_datasets) == 1:
tokenized_datasets = tokenized_datasets.get("all")

train_split = eval_split = predict_split = None
train_split = eval_split = predict_split = calib_split = None
if do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
Expand All @@ -68,10 +73,15 @@ def make_dataset_splits(
if "test" not in tokenized_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_split = tokenized_datasets["test"]
if do_oneshot:
if "calibration" not in tokenized_datasets:
raise ValueError("--do_oneshot requires a calibration dataset")
calib_split = tokenized_datasets["calibration"]

split_datasets = {
"train": train_split,
"validation": eval_split,
"test": predict_split,
"calibration": calib_split,
}
return split_datasets
4 changes: 4 additions & 0 deletions src/sparseml/transformers/finetune/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ class ModelArguments:
"(necessary to use this script with private models)"
},
)
precision: str = field(
default="auto",
metadata={"help": "Precision to cast model weights to, default to auto"},
)
203 changes: 203 additions & 0 deletions src/sparseml/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import List

import torch
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset, RandomSampler
from transformers import AutoTokenizer

import sparseml.core.session as session_manager
from sparseml.core.framework import Framework
from sparseml.pytorch.model_load.helpers import fallback_to_cpu, save_model_and_recipe
from sparseml.transformers.finetune import Trainer, TrainingArguments
from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.data_helpers import make_dataset_splits
from sparseml.transformers.finetune.model_args import ModelArguments


_LOGGER: logging.Logger = logging.getLogger(__name__)


class StageRunner:
"""
Launcher class for train, eval and one_shot flows. Manages data splits for each
flow and configurations. In the future this class will also handle alternating
between the different flows

LifeCycle
- populate_datasets()
- set_trainer()
- train() / evaluate() / predict()

:param model_args: Arguments pertaining to model/config/tokenizer
:param data_args: Arguments pertaining to what data to use for different flows
:param training_args: Arguments pertaining to training loop configuration
:model: unwrapped model to run flows on
"""

def __init__(
self,
data_args: "DataTrainingArguments",
model_args: "ModelArguments",
training_args: "TrainingArguments",
model: Module,
):
self._data_args = data_args
self._model_args = model_args
self._training_args = training_args

self.datasets = {}
self.model = model
self.trainer = None
self.tokenizer = None

def populate_datasets(self, tokenizer: "AutoTokenizer"):
"""
Loads datasets for each flow based on data_args, stores a Dataset for each
enabled flow in self.datasets

:param tokenizer: tokenizer to use for dataset tokenization
"""
splits = self._data_args.splits
tokenized_datasets = {}
if self._data_args.splits is None:
splits = {"all": None}
for split_name, split_str in splits.items():
dataset_manager = TextGenerationDataset.load_from_registry(
self._data_args.dataset_name,
data_args=self._data_args,
split=split_str,
tokenizer=tokenizer,
)
raw_dataset = dataset_manager.get_raw_dataset(self._model_args.cache_dir)
tokenized_dataset = dataset_manager.tokenize_and_process(raw_dataset)
tokenized_datasets[split_name] = tokenized_dataset

self.datasets = make_dataset_splits(
tokenized_datasets,
self._training_args.do_train,
self._training_args.do_eval,
self._training_args.do_predict,
self._training_args.do_oneshot,
)
self.tokenizer = tokenizer

def set_trainer(self, trainer: Trainer):
"""
:param trainer: update trainer
"""
self.trainer = trainer

def set_model(self, model: Module):
"""
:param model: update pytorch model
"""
self.model = model

def get_dataset_split(self, split_name: str) -> Dataset:
"""
Retrieve a dataset split by name

:param split_name: name of dataset split to return
:return: dataset split labeled by split_name
"""
return self.datasets.get(split_name)

def format_calibration_data(self) -> List[torch.Tensor]:
"""
Creates a dataloader out of the calibration dataset split, trimming it to
the desired number of calibration samples

:return: list of trimmed calibration data tensors
"""
oneshot_dataset = self.get_dataset_split("calibration")

dataloader_params = {
"batch_size": 1,
"sampler": RandomSampler(oneshot_dataset),
"collate_fn": self.trainer.data_collator,
}

calib_dataloader = DataLoader(oneshot_dataset, **dataloader_params)
parsed_calib_data = [inp["input_ids"] for inp in calib_dataloader]
return parsed_calib_data[
: min(self._data_args.num_calibration_samples, len(parsed_calib_data))
]

def one_shot(self):
"""
Run oneshot calibration on the active model
"""
_LOGGER.info("*** One Shot ***")

calib_data = self.format_calibration_data()
oneshot_device = fallback_to_cpu(self._training_args.oneshot_device)
session_manager.apply(
framework=Framework.pytorch,
recipe=self._training_args.recipe,
model=self.model,
calib_data=calib_data,
start=-1,
device=oneshot_device,
copy_data=False,
)

save_model_and_recipe(
model=self.model,
save_path=self._training_args.output_dir,
tokenizer=self.tokenizer,
)

def train(self, checkpoint: str):
"""
Run trainer's training loop on train_dataset, saving the resulting model to
output_dir

:param checkpoint: Optional checkpoint to resume from
"""
train_result = self.trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(self.get_dataset_split("train"))
self.trainer.log_metrics("train", metrics)
self.trainer.save_metrics("train", metrics)

# this includes saving the state, optimizer and scheduler
self.trainer.save_model()

def evaluate(self):
"""
Run trainer's evaluation loop on eval_dataset, logging the desired metrics
"""
_LOGGER.info("*** Evaluate ***")
metrics = self.trainer.evaluate(self.get_dataset_split("validation"))

metrics["eval_samples"] = len(self.get_dataset_split("validation"))
self.trainer.log_metrics("eval", metrics)
self.trainer.save_metrics("eval", metrics)

def predict(self):
"""
Run trainer's prediction loop on predict_dataset, logging the desired metrics
"""
_LOGGER.info("*** Predict ***")
results = self.trainer.predict(self.dataset["test"])
metrics = results.metrics

metrics["predict_samples"] = len(self.dataset["test"])
self.trainer.log_metrics("predict", metrics)
self.trainer.save_metrics("predict", metrics)
Loading
Loading