Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Export Refactor][Transformers] Enable loading SparseModels #1921

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sparseml.export.validators import validate_correctness as validate_correctness_
from sparseml.export.validators import validate_structure as validate_structure_
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.utils.helpers import default_device, use_single_gpu
from sparseml.pytorch.utils.helpers import default_device
from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
resolve_integration,
Expand All @@ -46,8 +46,9 @@ def export(
single_graph_file: bool = True,
num_export_samples: int = 0,
batch_size: int = 1,
recipe: Optional[Union[Path, str]] = None,
deployment_directory_name: str = "deployment",
device: str = "auto",
device: str = "cpu",
graph_optimizations: Union[str, List[str], None] = "all",
validate_correctness: bool = False,
validate_structure: bool = True,
Expand Down Expand Up @@ -82,6 +83,9 @@ def export(
the model to. Defaults to 'deepsparse'.
:param opset: The ONNX opset to use for exporting the model.
Defaults to the latest supported opset.
:param recipe: The path to the recipe to use for exporting the model.
Defaults to None. If a recipe is found in the source_path, it will
be automatically used for export.
:param single_graph_file: Whether to save the model as a single
file. Defaults to True.
:param num_export_samples: The number of samples to create for
Expand Down Expand Up @@ -109,14 +113,24 @@ def export(
:param task: Optional task to use for exporting the model.
Defaults to None.
"""
# TODO: Remove with the followin once sparsezoo: #404 lands
"""
from sparsezoo.utils.registry import standardize_lookup_name
task = standardize_lookup_name(task)
"""
if task is not None:
task = task.replace("_", "-").replace(" ", "-")

# TODO: Remove once sparsezoo: #404 lands
if integration is not None:
integration = integration.replace("_", "-").replace(" ", "-")

# create the target path if it doesn't exist
if not Path(target_path).exists():
Path(target_path).mkdir(parents=True, exist_ok=True)

# choose the appropriate device
device = default_device() if device == "auto" else device
device = use_single_gpu(device) if "cuda" in device else device

# assert the valid deployment target
if deployment_target not in AVAILABLE_DEPLOYMENT_TARGETS:
Expand All @@ -140,8 +154,14 @@ def export(
# that were created along with the model and are needed
# for the export
model, loaded_model_kwargs = helper_functions.create_model(
source_path, device=device, task=task, batch_size=batch_size, **kwargs
source_path,
device=device,
task=task,
batch_size=batch_size,
recipe=recipe,
**kwargs,
)
model.eval()

if loaded_model_kwargs:
_LOGGER.info(
Expand Down
4 changes: 0 additions & 4 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ def resolve_integration(
will attempt to infer it from the source_path.
:return: The name of the integration to use for exporting the model.
"""

if integration is not None:
integration = integration.replace("_", "-")

from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
)
Expand Down
43 changes: 41 additions & 2 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sparseml.pytorch.sparsification.quantization.helpers import (
initialize_channel_wise_scale_zp,
)
from sparseml.transformers.utils import SparseAutoModel
from sparseml.pytorch.utils import ModuleSparsificationInfo


__all__ = [
Expand All @@ -39,6 +39,45 @@
RECIPE_FILE_NAME = "recipe.yaml"


def log_model_load(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
model: Module, model_name_or_path: str, model_type: str, delayed_load: bool
):
"""
Log the state of a loaded model including sparsity and
prunable params information.

:param model: the loaded model
:param model_name_or_path: the original name of or path to the model that loaded
:param model_type: specify the type of model loaded for logging;
ex one of [model, student, teacher]
:param delayed_load: True if this model load was delayed until after
recipe instantiation due to QAT or other architectural state changes
"""
if delayed_load:
_LOGGER.info(
f"Delayed load of model {model_name_or_path} detected. "
f"Will print out model information once SparseML recipes have loaded"
)
return

sparsification_info = ModuleSparsificationInfo(model)

_LOGGER.info(
f"Loaded {model_type} from {model_name_or_path} "
f"with {sparsification_info.params_total} total params. "
f"Of those there are {sparsification_info.params_prunable_total} prunable "
f"params which have {sparsification_info.params_prunable_sparse_percent} "
"avg sparsity."
)
model_type = (
"sparse" if sparsification_info.params_prunable_sparse_percent > 5 else "dense"
)
_LOGGER.info(
f"{model_type} model detected, "
f"all sparsification info: {sparsification_info}"
)


def apply_recipe_structure_to_model(model: Module, recipe_path: str, model_path: str):
"""
Takes a loaded Pytorch model and applies any structural changes such as quantization
Expand Down Expand Up @@ -149,7 +188,7 @@ def reload_model_state(
_LOGGER.info(
f"Reloaded {total_loaded} model params for SparseML Recipe from {load_path}"
)
SparseAutoModel.log_model_load(
log_model_load(
model,
load_path,
model_type="student",
Expand Down
9 changes: 0 additions & 9 deletions src/sparseml/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,6 @@ def default_device() -> str:
return "cuda:{}".format(",".join(device_ids))


def use_single_gpu(device: str) -> str:
"""
return: the first gpu in the device string if multiple are available
"""
if "cuda" not in device:
raise ValueError("use_single_gpu should only be called on cuda devices")
return device.split(",")[0]


def device_of(inputs: Any):
if isinstance(inputs, Tensor):
return inputs.device
Expand Down
17 changes: 13 additions & 4 deletions src/sparseml/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,20 @@ def main(
"use_auth_token": True if model_args.use_auth_token else None,
}
# this calls from_pretrained under the hood so should be FSDP safe
model, teacher = SparseAutoModel.text_generation_from_pretrained_distil(
model = SparseAutoModel.text_classification_from_pretrained(
model_name_or_path=model_path,
teacher_name_or_path=training_args.distill_teacher,
model_kwargs=model_kwargs,
teacher_kwargs=teacher_kwargs,
model_type="student" if training_args.distill_teacher else "model",
**model_kwargs,
)
teacher = (
SparseAutoModel.text_classification_from_pretrained(
model_name_or_path=training_args.distill_teacher,
model_type="teacher",
**teacher_kwargs,
)
if training_args.distill_teacher
and training_args.distill_teacher not in ["self", "disable"]
else training_args.distill_teacher
)

# initialize structure of input model from recipe if needed
Expand Down
15 changes: 9 additions & 6 deletions src/sparseml/transformers/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NLG_TOKENIZER_FILES,
OPTIONAL_DEPLOYMENT_FILES,
TaskNames,
resolve_sequence_length,
)
from sparseml.transformers.utils.load_task_dataset import load_task_dataset
from sparseml.transformers.utils.optimizations import apply_kv_cache_injection
Expand All @@ -38,10 +39,9 @@
from src.sparseml.transformers.utils.initializers import (
_parse_data_args,
initialize_config,
initialize_model,
initialize_sparse_model,
initialize_tokenizer,
initialize_trainer,
resolve_sequence_length,
)


Expand All @@ -52,6 +52,7 @@ def create_model(
source_path: Union[Path, str],
device: Optional[str] = None,
task: Optional[str] = None,
recipe: Optional[str] = None,
**kwargs,
) -> Tuple[torch.nn.Module, Dict[str, Any]]:
"""
Expand All @@ -61,6 +62,8 @@ def create_model(
:param source_path: The path to the model
:param device: The device to use for the model and dataloader instantiation
:param task: The task to use for the model and dataloader instantiation
:param recipe: The recipe to use for the model and dataloader instantiation.
If None, attempt to use the default recipe

:return: A tuple of the
- torch model
Expand All @@ -74,7 +77,6 @@ def create_model(

if task is None:
raise ValueError("To create a transformer model, a task must be specified")
task = task.replace("_", "-")

if not trust_remote_code:
_LOGGER.warning(
Expand All @@ -85,11 +87,13 @@ def create_model(
config = initialize_config(source_path, trust_remote_code, **config_args)
sequence_length = sequence_length or resolve_sequence_length(config)
tokenizer = initialize_tokenizer(source_path, sequence_length, task)
model = initialize_model(
model = initialize_sparse_model(
model_path=source_path,
task=task,
config=config,
trust_remote_code=trust_remote_code,
recipe=recipe,
sequence_length=sequence_length,
device=device,
)

Expand All @@ -108,9 +112,8 @@ def create_model(
else:
validation_dataset = None

model.train()
trainer = initialize_trainer(model, source_path, validation_dataset)
model.eval()
# TODO: Parse out dataloader from the trainer

return model, dict(
trainer=trainer,
Expand Down
10 changes: 5 additions & 5 deletions src/sparseml/transformers/sparsification/obcq/obcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
llama_forward,
opt_forward,
)
from sparseml.transformers.utils.model import SparseCausalLM
from sparseml.transformers.utils.sparse_model import SparseAutoModel


__all__ = ["one_shot"]
Expand Down Expand Up @@ -91,18 +91,18 @@ def one_shot(
model_loader_fn = None
forward_fn = None
if "opt" in model_type:
model_loader_fn = SparseCausalLM.opt_model_from_pretrained
model_loader_fn = SparseAutoModel.text_classification_from_pretrained
forward_fn = opt_forward
elif "llama" in model_type or "mistral" in model_type:
model_loader_fn = SparseCausalLM.auto_model_from_pretrained
model_loader_fn = SparseAutoModel.text_classification_from_pretrained
forward_fn = llama_forward
else:
_LOGGER.warning(
f"A supported model type({SUPPORTED_MODELS}) could not be "
f"parsed from model_path={model_path}. Defaulting to "
"AutoModelForCausalLM loading. "
"SparseAutoModel loading. "
Comment on lines 100 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're using SparseAutoModel for everything this warning message doesn't make much sense. The only thing we lose from not using a supported model is perplexity evaluation, and thats more of a debugging feature anyways. Can we change the message to this:

f"A supported model type({SUPPORTED_MODELS}) could not be "
f"parsed from model_path={model_path}. Perplexity evaluation will only work for supported models"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can get rid of the if statement altogether, and just condition on the model for setting forward_fn

)
model_loader_fn = SparseCausalLM.auto_model_from_pretrained
model_loader_fn = SparseAutoModel.text_classification_from_pretrained
forward_fn = llama_forward
torch_dtype = _parse_dtype(precision)
model = model_loader_fn(
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from transformers.trainer_pt_utils import reissue_pt_warnings
from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint

from sparseml.pytorch.model_load.helpers import log_model_load
from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer
from sparseml.pytorch.sparsification.quantization.helpers import (
initialize_channel_wise_scale_zp,
Expand All @@ -47,7 +48,6 @@
TensorBoardLogger,
WANDBLogger,
)
from sparseml.transformers.utils import SparseAutoModel
from sparseml.transformers.utils.helpers import RECIPE_NAME


Expand Down Expand Up @@ -638,7 +638,7 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
_LOGGER.info(
f"Reloaded {total_loaded} model params for SparseML Recipe from {load_path}"
)
SparseAutoModel.log_model_load(
log_model_load(
self.model,
self.model_state_path,
model_type="student" if self.teacher else "model",
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# flake8: noqa
from .helpers import *
from .metrics import *
from .model import *
from .sparse_model import *
Loading