Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Export Refactor] Feature Branch #1858

Merged
merged 21 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5058120
initial commit
dbogunowicz Nov 28, 2023
bb5887e
respond to PR comments
dbogunowicz Dec 4, 2023
2930dea
[Export Refactor][Image Classification] `create_model` function (#1878)
dbogunowicz Dec 11, 2023
5ac4d35
[Export Refactor][Image Classification] `create_dummy_input` function…
dbogunowicz Dec 11, 2023
59f3f5a
[Export Refactor][Image Classification] `export_model` function (#1883)
dbogunowicz Dec 11, 2023
9096b0d
[Export Refactor][Image Classification] `apply_optimizations` functio…
dbogunowicz Dec 11, 2023
16c9bf3
[Export Refactor][Image Classification] `export_sample_inputs_outputs…
dbogunowicz Dec 11, 2023
4d66402
remove duplicated function
dbogunowicz Dec 11, 2023
ed04d3f
[Export Refactor][Image Classification] `create_deployment_folder` fu…
dbogunowicz Dec 12, 2023
571aed8
[Export Refactor][Image Classification] `validate_correctness` functi…
dbogunowicz Dec 12, 2023
627ddd6
[Export Refactor] End to end testing (#1898)
dbogunowicz Dec 14, 2023
3da5f23
[Export Refactor] Prepare the module to be more general (before inclu…
dbogunowicz Dec 19, 2023
c65ab6e
[Export Refactor][Transformers] Enable loading SparseModels (#1921)
dbogunowicz Dec 21, 2023
e4770c8
Fix the tests
dbogunowicz Dec 29, 2023
7b28881
fix tests with help from sara
dbogunowicz Jan 2, 2024
6179cb2
[Export][Transformers] Enable loading `text-generation` datasets (#1938)
dbogunowicz Jan 5, 2024
7f166a1
tests fixed
dbogunowicz Jan 6, 2024
c3c90a4
fix test
dbogunowicz Jan 6, 2024
57a4dd0
[Export refactor] final manual testing fixes (#1948)
bfineran Jan 10, 2024
ee78625
Export Refactor CLI (#1949)
bfineran Jan 10, 2024
8c647b8
Merge branch 'main' into feature/damian/feature_branch_export
bfineran Jan 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sparseml.export.validators import validate_correctness as validate_correctness_
from sparseml.export.validators import validate_structure as validate_structure_
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.utils.helpers import default_device, use_single_gpu
from sparseml.pytorch.utils.helpers import default_device
from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
resolve_integration,
Expand All @@ -46,8 +46,9 @@ def export(
single_graph_file: bool = True,
num_export_samples: int = 0,
batch_size: int = 1,
recipe: Optional[Union[Path, str]] = None,
deployment_directory_name: str = "deployment",
device: str = "auto",
device: str = "cpu",
Satrat marked this conversation as resolved.
Show resolved Hide resolved
graph_optimizations: Union[str, List[str], None] = "all",
validate_correctness: bool = False,
validate_structure: bool = True,
Expand Down Expand Up @@ -82,6 +83,9 @@ def export(
the model to. Defaults to 'deepsparse'.
:param opset: The ONNX opset to use for exporting the model.
Defaults to the latest supported opset.
:param recipe: The path to the recipe to use for exporting the model.
Defaults to None. If a recipe is found in the source_path, it will
be automatically used for export.
:param single_graph_file: Whether to save the model as a single
file. Defaults to True.
:param num_export_samples: The number of samples to create for
Expand Down Expand Up @@ -109,14 +113,24 @@ def export(
:param task: Optional task to use for exporting the model.
Defaults to None.
"""
# TODO: Remove with the followin once sparsezoo: #404 lands
"""
from sparsezoo.utils.registry import standardize_lookup_name
task = standardize_lookup_name(task)
"""
if task is not None:
task = task.replace("_", "-").replace(" ", "-")

# TODO: Remove once sparsezoo: #404 lands
if integration is not None:
integration = integration.replace("_", "-").replace(" ", "-")

# create the target path if it doesn't exist
if not Path(target_path).exists():
Path(target_path).mkdir(parents=True, exist_ok=True)

# choose the appropriate device
device = default_device() if device == "auto" else device
device = use_single_gpu(device) if "cuda" in device else device

# assert the valid deployment target
if deployment_target not in AVAILABLE_DEPLOYMENT_TARGETS:
Expand All @@ -140,8 +154,14 @@ def export(
# that were created along with the model and are needed
# for the export
model, loaded_model_kwargs = helper_functions.create_model(
source_path, device=device, task=task, batch_size=batch_size, **kwargs
source_path,
device=device,
task=task,
batch_size=batch_size,
recipe=recipe,
**kwargs,
)
model.eval()

if loaded_model_kwargs:
_LOGGER.info(
Expand Down
4 changes: 0 additions & 4 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ def resolve_integration(
will attempt to infer it from the source_path.
:return: The name of the integration to use for exporting the model.
"""

if integration is not None:
integration = integration.replace("_", "-")

from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
)
Expand Down
43 changes: 41 additions & 2 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sparseml.pytorch.sparsification.quantization.helpers import (
initialize_channel_wise_scale_zp,
)
from sparseml.transformers.utils import SparseAutoModel
from sparseml.pytorch.utils import ModuleSparsificationInfo


__all__ = [
Expand All @@ -42,6 +42,45 @@
RECIPE_FILE_NAME = "recipe.yaml"


def log_model_load(
model: Module, model_name_or_path: str, model_type: str, delayed_load: bool
):
"""
Log the state of a loaded model including sparsity and
prunable params information.

:param model: the loaded model
:param model_name_or_path: the original name of or path to the model that loaded
:param model_type: specify the type of model loaded for logging;
ex one of [model, student, teacher]
:param delayed_load: True if this model load was delayed until after
recipe instantiation due to QAT or other architectural state changes
"""
if delayed_load:
_LOGGER.info(
f"Delayed load of model {model_name_or_path} detected. "
f"Will print out model information once SparseML recipes have loaded"
)
return

sparsification_info = ModuleSparsificationInfo(model)

_LOGGER.info(
f"Loaded {model_type} from {model_name_or_path} "
f"with {sparsification_info.params_total} total params. "
f"Of those there are {sparsification_info.params_prunable_total} prunable "
f"params which have {sparsification_info.params_prunable_sparse_percent} "
"avg sparsity."
)
model_type = (
"sparse" if sparsification_info.params_prunable_sparse_percent > 5 else "dense"
)
_LOGGER.info(
f"{model_type} model detected, "
f"all sparsification info: {sparsification_info}"
)


def apply_recipe_structure_to_model(model: Module, recipe_path: str, model_path: str):
"""
Takes a loaded Pytorch model and applies any structural changes such as quantization
Expand Down Expand Up @@ -152,7 +191,7 @@ def reload_model_state(
_LOGGER.info(
f"Reloaded {total_loaded} model params for SparseML Recipe from {load_path}"
)
SparseAutoModel.log_model_load(
log_model_load(
model,
load_path,
model_type="student",
Expand Down
9 changes: 0 additions & 9 deletions src/sparseml/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,6 @@ def default_device() -> str:
return "cuda:{}".format(",".join(device_ids))


def use_single_gpu(device: str) -> str:
"""
return: the first gpu in the device string if multiple are available
"""
if "cuda" not in device:
raise ValueError("use_single_gpu should only be called on cuda devices")
return device.split(",")[0]


def device_of(inputs: Any):
if isinstance(inputs, Tensor):
return inputs.device
Expand Down
17 changes: 13 additions & 4 deletions src/sparseml/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,20 @@ def main(
"use_auth_token": True if model_args.use_auth_token else None,
}
# this calls from_pretrained under the hood so should be FSDP safe
model, teacher = SparseAutoModel.text_generation_from_pretrained_distil(
model = SparseAutoModel.text_classification_from_pretrained(
model_name_or_path=model_path,
teacher_name_or_path=training_args.distill_teacher,
model_kwargs=model_kwargs,
teacher_kwargs=teacher_kwargs,
model_type="student" if training_args.distill_teacher else "model",
**model_kwargs,
)
teacher = (
SparseAutoModel.text_classification_from_pretrained(
model_name_or_path=training_args.distill_teacher,
model_type="teacher",
**teacher_kwargs,
)
if training_args.distill_teacher
and training_args.distill_teacher not in ["self", "disable"]
else training_args.distill_teacher
)

# initialize structure of input model from recipe if needed
Expand Down
15 changes: 9 additions & 6 deletions src/sparseml/transformers/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NLG_TOKENIZER_FILES,
OPTIONAL_DEPLOYMENT_FILES,
TaskNames,
resolve_sequence_length,
)
from sparseml.transformers.utils.load_task_dataset import load_task_dataset
from sparseml.transformers.utils.optimizations import apply_kv_cache_injection
Expand All @@ -38,10 +39,9 @@
from src.sparseml.transformers.utils.initializers import (
_parse_data_args,
initialize_config,
initialize_model,
initialize_sparse_model,
initialize_tokenizer,
initialize_trainer,
resolve_sequence_length,
)


Expand All @@ -52,6 +52,7 @@ def create_model(
source_path: Union[Path, str],
device: Optional[str] = None,
task: Optional[str] = None,
recipe: Optional[str] = None,
**kwargs,
) -> Tuple[torch.nn.Module, Dict[str, Any]]:
"""
Expand All @@ -61,6 +62,8 @@ def create_model(
:param source_path: The path to the model
:param device: The device to use for the model and dataloader instantiation
:param task: The task to use for the model and dataloader instantiation
:param recipe: The recipe to use for the model and dataloader instantiation.
If None, attempt to use the default recipe

:return: A tuple of the
- torch model
Expand All @@ -74,7 +77,6 @@ def create_model(

if task is None:
raise ValueError("To create a transformer model, a task must be specified")
task = task.replace("_", "-")

if not trust_remote_code:
_LOGGER.warning(
Expand All @@ -85,11 +87,13 @@ def create_model(
config = initialize_config(source_path, trust_remote_code, **config_args)
sequence_length = sequence_length or resolve_sequence_length(config)
tokenizer = initialize_tokenizer(source_path, sequence_length, task)
model = initialize_model(
model = initialize_sparse_model(
model_path=source_path,
task=task,
config=config,
trust_remote_code=trust_remote_code,
recipe=recipe,
sequence_length=sequence_length,
device=device,
)

Expand All @@ -108,9 +112,8 @@ def create_model(
else:
validation_dataset = None

model.train()
trainer = initialize_trainer(model, source_path, validation_dataset)
model.eval()
# TODO: Parse out dataloader from the trainer

return model, dict(
trainer=trainer,
Expand Down
10 changes: 5 additions & 5 deletions src/sparseml/transformers/sparsification/obcq/obcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
llama_forward,
opt_forward,
)
from sparseml.transformers.utils.model import SparseCausalLM
from sparseml.transformers.utils.sparse_model import SparseAutoModel


__all__ = ["one_shot"]
Expand Down Expand Up @@ -93,18 +93,18 @@ def one_shot(
model_loader_fn = None
forward_fn = None
if "opt" in model_type:
model_loader_fn = SparseCausalLM.opt_model_from_pretrained
model_loader_fn = SparseAutoModel.text_classification_from_pretrained
forward_fn = opt_forward
elif "llama" in model_type or "mistral" in model_type:
model_loader_fn = SparseCausalLM.auto_model_from_pretrained
model_loader_fn = SparseAutoModel.text_classification_from_pretrained
forward_fn = llama_forward
else:
_LOGGER.warning(
f"A supported model type({SUPPORTED_MODELS}) could not be "
f"parsed from model_path={model_path}. Defaulting to "
"AutoModelForCausalLM loading. "
"SparseAutoModel loading. "
)
model_loader_fn = SparseCausalLM.auto_model_from_pretrained
model_loader_fn = SparseAutoModel.text_classification_from_pretrained
forward_fn = llama_forward
torch_dtype = parse_dtype(precision)
model = model_loader_fn(
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from transformers.trainer_pt_utils import reissue_pt_warnings
from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint

from sparseml.pytorch.model_load.helpers import log_model_load
from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer
from sparseml.pytorch.sparsification.quantization.helpers import (
initialize_channel_wise_scale_zp,
Expand All @@ -47,7 +48,6 @@
TensorBoardLogger,
WANDBLogger,
)
from sparseml.transformers.utils import SparseAutoModel
from sparseml.transformers.utils.helpers import RECIPE_NAME


Expand Down Expand Up @@ -638,7 +638,7 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
_LOGGER.info(
f"Reloaded {total_loaded} model params for SparseML Recipe from {load_path}"
)
SparseAutoModel.log_model_load(
log_model_load(
self.model,
self.model_state_path,
model_type="student" if self.teacher else "model",
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# flake8: noqa
from .helpers import *
from .metrics import *
from .model import *
from .sparse_model import *
Loading