Skip to content

Commit

Permalink
move the model modification into the quantization modifier logic
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Apr 16, 2024
1 parent afbd0f4 commit 2797e07
Show file tree
Hide file tree
Showing 21 changed files with 202 additions and 470 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

from torch import nn

from sparseml.transformers.sparsification.modification import modify_model
from sparseml.transformers.sparsification.modification.modification_objects import (
QATLinear,
)


def test_modifying_mobilebert(mobilebert_model):

mobilebert_ = deepcopy(mobilebert_model)
mobilebert = modify_model(mobilebert_model)

assert isinstance(mobilebert_.embeddings.embedding_transformation, nn.Linear)
assert isinstance(mobilebert.embeddings.embedding_transformation, QATLinear)
# flake8: noqa
from .modify_model import modify_model
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""
Set of helper objects that are used to modify
the HuggingFace transformer models
the quantized models
"""

import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,31 @@
# limitations under the License.

import logging
import os

import torch

from sparseml.transformers.sparsification.modification.registry import (
ModificationRegistry,
)
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry


_LOGGER = logging.getLogger(__name__)


def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Module:
def modify_model(model: torch.nn.Module) -> torch.nn.Module:
"""
Modify the original transformers model so that it is
compatible with the SparseML library.
Modify the original model so that it is
compatible with the quantization format required by the
SparseML library.
The model will be modified, if there exist a modification
function for the model in the registry of modifications.
Otherwise, the original model will be returned.
:param model: The original HuggingFace transformers model
:return: The potentially modified model
:param model: The original model to be modified
:return: The potentially modified model to support
SparseML quantization
"""
model_name = model.__class__.__name__
NM_DISABLE_TRANSFORMERS_MODIFICATION = os.environ.get(
"NM_DISABLE_TRANSFORMERS_MODIFICATION", "False"
).lower() in ["true", "1"]

try:
modification_func = ModificationRegistry.get_value_from_registry(model_name)
except KeyError:
Expand All @@ -50,21 +48,7 @@ def modify_model(model: torch.nn.Module, disable: int = False) -> torch.nn.Modul
)
return model

if NM_DISABLE_TRANSFORMERS_MODIFICATION:
_LOGGER.debug(
"Application of the modification function to model "
"disabled through the environment variable."
)
return model

if disable:
_LOGGER.debug(
"Application of the modification function for to model "
"disabled through the `disable` argument."
)
return model

_LOGGER.info(
f"Modifying the model {model_name} to be compatible with SparseML library"
f"Modifying the model {model_name} to be compatible with SparseML quantization"
)
return modification_func(model)
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
)
from sparsezoo.utils.registry import RegistryMixin


class ModificationRegistry(RegistryMixin):
"""
A registry for modification functions that can be applied to models
so that they can be used in the context of sparseml.transformers
so that they can be compatible with the quantization format required by the
SparseML library.
"""

@classmethod
def get_value_from_registry(cls, name: str):
"""
Extends the base class method to check the transformers version after
successfully retrieving the value from the registry. The motivation is
to ensure that the transformers version falls within the supported range
before we proceed with model modification.
"""
retrieved_value = super().get_value_from_registry(name)
check_transformers_version()
return retrieved_value
4 changes: 4 additions & 0 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.base import QuantizationModifier
from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization.utils.helpers import (
configure_module_bn_wrappers,
freeze_bn_stats,
Expand Down Expand Up @@ -73,6 +74,9 @@ def __init__(self, **kwargs):

def on_initialize_structure(self, state: State, **kwargs):
module = state.model.model
# before the structure is modified to support quantization,
# we need to potentially modify the model architecture
module = modify_model(module)
self._enable_module_qat(module)
state.model.model.apply(torch.quantization.disable_observer)

Expand Down
1 change: 1 addition & 0 deletions src/sparseml/transformers/sparsification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# flake8: noqa

from .modification import *
from .question_answering import *
from .sparse_config import *
from .sparse_model import *
Expand Down
20 changes: 13 additions & 7 deletions src/sparseml/transformers/sparsification/modification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
from .modify_model import modify_model
from .modifying_bert import *
from .modifying_distilbert import *
from .modifying_llama import *
from .modifying_mistral import *
from .modifying_mobilebert import *
from .modifying_opt import *
# isort:skip_file

from .base import check_transformers_version

check_transformers_version()

from .modifying_bert import modify
from .modifying_llama import modify
from .modifying_mistral import modify
from .modifying_distilbert import modify
from .modifying_mobilebert import modify
from .modifying_opt import modify
Original file line number Diff line number Diff line change
Expand Up @@ -18,64 +18,49 @@
"""


import logging
import math
from typing import Optional, Tuple

import torch
from torch import nn
from transformers.models.bert.modeling_bert import BertAttention, BertSelfAttention
from transformers.models.bert.modeling_bert import BertSelfAttention

from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.modification_objects import (
QATMatMul,
)
from sparseml.transformers.sparsification.modification.registry import (
ModificationRegistry,
)


_LOGGER = logging.getLogger(__name__)


@ModificationRegistry.register(name="BertModel", alias=["BertForQuestionAnswering"])
def modify(model: nn.Module) -> nn.Module:
"""
Modify the Bert model to be compatible with SparseML
quantization
1. Replaces the MultiHeadSelfAttention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules
Note: This function will not alter any of the alternatives
to the MultiHeadSelfAttention module such as BertAttention
Replaces the attention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules
:param model: the original Bert model
:return: the modified Bert model
"""
for name, submodule in model.named_modules():
if type(submodule) is BertSelfAttention:
if isinstance(submodule, BertSelfAttention):
swap_modules(
model, name, BertSelfAttentionWithQuantizableMatmuls(submodule)
)
elif type(submodule) is BertAttention:
_LOGGER.debug(
f"The model contains {submodule.__class__.__name__} "
"module, which will not be modified"
)
return model


class BertSelfAttentionWithQuantizableMatmuls(BertSelfAttention):
"""
Wrapper around the original BertSelfAttention module to replace the
Wrapper around the original attention module to replace the
matmul operations with quantizable matmul operations
:param bert_self_attention: the original BertSelfAttention module
:param bert_self_attention: the original attention module to be
wrapped and modified
"""

def __init__(self, bert_self_attention: BertSelfAttention):
self.__class__ = type(
bert_self_attention.__class__.__name__,
self.__class__.__name__,
(self.__class__, bert_self_attention.__class__),
{},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,9 @@
MultiHeadSelfAttention,
)

from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.modification_objects import (
QATMatMul,
)
from sparseml.transformers.sparsification.modification.registry import (
ModificationRegistry,
)


_LOGGER = logging.getLogger(__name__)
Expand All @@ -44,40 +40,34 @@
def modify(model: nn.Module) -> nn.Module:
"""
Modify the DistilBert model to be compatible with SparseML
quantization
1. Replaces the MultiHeadSelfAttention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules
Note: This function will not alter any of the alternatives
to the MultiHeadSelfAttention module such as DistilBertFlashAttention2
Replaces the attention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules
:param model: the original DistilBert model
:return: the modified DistilBert model
"""
for name, submodule in model.named_modules():
if type(submodule) is MultiHeadSelfAttention:
if isinstance(submodule, (MultiHeadSelfAttention, DistilBertFlashAttention2)):
swap_modules(
model, name, MultiHeadSelfAttentionWithQuantizableMatmuls(submodule)
)
if type(submodule) is DistilBertFlashAttention2:
_LOGGER.debug(
f"The model contains {submodule.__class__.__name__} "
"module, which will not be modified"
)
return model


class MultiHeadSelfAttentionWithQuantizableMatmuls(MultiHeadSelfAttention):
"""
Wrapper around the original MultiHeadSelfAttention module to replace the
matmul operations with quantizable matmul operations
Wrapper around the original attention module to introduce
MultiHeadSelfAttention with quantizable matmul operations
:param mhs_attention: the original MultiHeadSelfAttention module
:param mhs_attention: the original attention module to be
wrapped and modified
"""

def __init__(self, mhs_attention: MultiHeadSelfAttention):
self.__class__ = type(
mhs_attention.__class__.__name__,
self.__class__.__name__,
(self.__class__, mhs_attention.__class__),
{},
)
Expand Down
Loading

0 comments on commit 2797e07

Please sign in to comment.