From 63a369f87fa2680163e17ef09cef2bc4a030c5a8 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 28 May 2024 20:07:52 +0000 Subject: [PATCH 1/5] convert old modifier to legacy --- src/sparseml/modifiers/quantization/base.py | 6 +++--- src/sparseml/modifiers/quantization/gptq/base.py | 2 +- src/sparseml/modifiers/quantization/pytorch.py | 4 ++-- src/sparseml/modifiers/quantization_vllm/base.py | 4 ++-- .../modifiers/quantization_vllm/pytorch.py | 6 +++--- .../sparsification/compressed_tensors_utils.py | 8 ++++---- tests/sparseml/modifiers/quantization/test_base.py | 10 +++++----- .../modifiers/pruning/sparsegpt/test_pytorch.py | 14 +++++++------- .../pytorch/modifiers/quantization/test_pytorch.py | 10 +++++----- .../compression/recipes/new_quant_channel.yaml | 2 +- .../compression/recipes/new_quant_full.yaml | 2 +- .../compression/recipes/new_quant_simple.yaml | 2 +- .../compression/recipes/new_quant_weight.yaml | 2 +- .../compression/recipes/old_quant_channel.yaml | 2 +- .../compression/recipes/old_quant_full.yaml | 2 +- .../compression/recipes/old_quant_weight.yaml | 2 +- .../transformers/finetune/test_quantization.yaml | 2 +- .../repeat_quants/tiny_llama_repeat_quant.yaml | 4 ++-- .../separate_quants/tiny_llama_separate_quant.yaml | 4 ++-- .../recipes/additional_sparsity_with_quant.yaml | 2 +- .../sparseml/transformers/obcq/recipes/quant.yaml | 2 +- .../obcq/recipes/quant_and_sparse.yaml | 2 +- .../obcq/test_obcq_fake_quant_wrapper.py | 2 +- .../modification/test_modifying_llama.py | 2 +- .../modification/test_modifying_mistral.py | 2 +- .../modification/test_modifying_opt.py | 2 +- .../transformers/test_recipe_compatibility.py | 2 +- .../transformers/utils/test_initializers.py | 2 +- 28 files changed, 53 insertions(+), 53 deletions(-) diff --git a/src/sparseml/modifiers/quantization/base.py b/src/sparseml/modifiers/quantization/base.py index e66f5b9ea72..9b9f1569f09 100644 --- a/src/sparseml/modifiers/quantization/base.py +++ b/src/sparseml/modifiers/quantization/base.py @@ -17,17 +17,17 @@ from sparseml.core import Event, Modifier -__all__ = ["QuantizationModifier"] +__all__ = ["LegacyQuantizationModifier"] -class QuantizationModifier(Modifier): +class LegacyQuantizationModifier(Modifier): """ Enables quantization aware training (QAT) for a given module or its submodules After the start epoch, the specified module(s) forward pass will emulate quantized execution and the modifier will be enabled until training is completed. | Sample yaml: - | QuantizationModifier: + | LegacyQuantizationModifier: | start: 0.0 | scheme: | input_activations: diff --git a/src/sparseml/modifiers/quantization/gptq/base.py b/src/sparseml/modifiers/quantization/gptq/base.py index cb0023d1919..004fce2ee7a 100644 --- a/src/sparseml/modifiers/quantization/gptq/base.py +++ b/src/sparseml/modifiers/quantization/gptq/base.py @@ -194,7 +194,7 @@ def _build_quant_modifier(self, framework): ) quant_args["config_groups"] = {"config_group_0": default_quant_scheme} _LOGGER.info(f"Building quantization modifier with args: {quant_args}") - vllm_quant_config = {"vLLMQuantizationModifier": quant_args} + vllm_quant_config = {"QuantizationModifier": quant_args} self._build_quant_modifier_from_dict(vllm_quant_config, framework) def compressible_layers(self) -> Dict: diff --git a/src/sparseml/modifiers/quantization/pytorch.py b/src/sparseml/modifiers/quantization/pytorch.py index 927d8db79d3..0bedd489e9d 100644 --- a/src/sparseml/modifiers/quantization/pytorch.py +++ b/src/sparseml/modifiers/quantization/pytorch.py @@ -19,7 +19,7 @@ from torch.nn import Module from sparseml.core import Event, EventType, State -from sparseml.modifiers.quantization.base import QuantizationModifier +from sparseml.modifiers.quantization.base import LegacyQuantizationModifier from sparseml.modifiers.quantization.modification import modify_model from sparseml.modifiers.quantization.utils.helpers import ( configure_module_bn_wrappers, @@ -42,7 +42,7 @@ _LOGGER = logging.getLogger(__name__) -class QuantizationModifierPyTorch(QuantizationModifier): +class LegacyQuantizationModifierPyTorch(LegacyQuantizationModifier): """ Pytorch-specific implementation of quantization modifier diff --git a/src/sparseml/modifiers/quantization_vllm/base.py b/src/sparseml/modifiers/quantization_vllm/base.py index c8b2522ecee..e6af6485aa3 100644 --- a/src/sparseml/modifiers/quantization_vllm/base.py +++ b/src/sparseml/modifiers/quantization_vllm/base.py @@ -24,10 +24,10 @@ from sparseml.core import Event, Modifier -__all__ = ["vLLMQuantizationModifier"] +__all__ = ["QuantizationModifier"] -class vLLMQuantizationModifier(Modifier): +class QuantizationModifier(Modifier): """ Enables post training quantization (PTQ) and quantization aware training (QAT) for a given module or its submodules. After calibration (PTQ) or the start epoch (QAT), diff --git a/src/sparseml/modifiers/quantization_vllm/pytorch.py b/src/sparseml/modifiers/quantization_vllm/pytorch.py index a6e7f179525..a6b5e1bc288 100644 --- a/src/sparseml/modifiers/quantization_vllm/pytorch.py +++ b/src/sparseml/modifiers/quantization_vllm/pytorch.py @@ -23,16 +23,16 @@ set_module_for_calibration, ) from sparseml.core import Event, EventType, State -from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier +from sparseml.modifiers.quantization_vllm.base import QuantizationModifier from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward _LOGGER = logging.getLogger(__name__) -class vLLMQuantizationModifierPyTorch(vLLMQuantizationModifier): +class QuantizationModifierPyTorch(QuantizationModifier): """ - PyTorch specific implementation of vLLMQuantizationModifier + PyTorch specific implementation of QuantizationModifier Enables post training quantization (PTQ) and quantization aware training (QAT) for a given module or its submodules. After calibration (PTQ) or the start epoch (QAT), diff --git a/src/sparseml/transformers/sparsification/compressed_tensors_utils.py b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py index c62a1eb9bf9..0a7e914393c 100644 --- a/src/sparseml/transformers/sparsification/compressed_tensors_utils.py +++ b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py @@ -83,10 +83,10 @@ def save_pretrained_wrapper( # check if we are in the old quantization framework if qat_active(model) and not is_model_quantized(model): _LOGGER.info( - "Compression for models quantized with QuantizationModifer is not " - "supported. Save will be run without compression and no sparsity " - "statistics will be calculated. To save a quantized model in a " - "compressed state please use vLLMQuantizationModifier instead." + "Compression for models quantized with LegacyQuantizationModifer " + "is not supported. Save will be run without compression and no " + "sparsity statistics will be calculated. To save a quantized model " + "in a compressed state please use QuantizationModifier instead." ) original_save_pretrained.__get__(model, model_class)( diff --git a/tests/sparseml/modifiers/quantization/test_base.py b/tests/sparseml/modifiers/quantization/test_base.py index 064d8dcb671..d0bd316c534 100644 --- a/tests/sparseml/modifiers/quantization/test_base.py +++ b/tests/sparseml/modifiers/quantization/test_base.py @@ -19,7 +19,7 @@ from sparseml.core.event import Event from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework -from sparseml.modifiers.quantization import QuantizationModifier +from sparseml.modifiers.quantization import LegacyQuantizationModifier from tests.sparseml.modifiers.conf import setup_modifier_factory @@ -31,14 +31,14 @@ def setUp(self): def test_quantization_registered(self): quant_obj = ModifierFactory.create( - type_="QuantizationModifier", + type_="LegacyQuantizationModifier", framework=Framework.general, allow_experimental=False, allow_registered=True, **self.kwargs, ) - self.assertIsInstance(quant_obj, QuantizationModifier) + self.assertIsInstance(quant_obj, LegacyQuantizationModifier) @pytest.mark.unit @@ -52,7 +52,7 @@ def setUp(self): def test_end_epochs(self): disable_quant_epoch, freeze_bn_epoch = None, None - obj_modifier = QuantizationModifier( + obj_modifier = LegacyQuantizationModifier( start=self.start, scheme=self.scheme, disable_quantization_observer_epoch=disable_quant_epoch, @@ -68,7 +68,7 @@ def test_end_epochs(self): assert not obj_modifier.check_should_freeze_bn_stats(event) disable_quant_epoch, freeze_bn_epoch = 3.5, 5.0 - obj_modifier = QuantizationModifier( + obj_modifier = LegacyQuantizationModifier( start=self.start, scheme=self.scheme, disable_quantization_observer_epoch=disable_quant_epoch, diff --git a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index e52b6e2ef23..7f962d5b017 100644 --- a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -21,8 +21,8 @@ from sparseml.core.model import ModifiableModel from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch from sparseml.modifiers.quantization.gptq.pytorch import GPTQModifierPyTorch -from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch -from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier +from sparseml.modifiers.quantization.pytorch import LegacyQuantizationModifierPyTorch +from sparseml.modifiers.quantization_vllm.base import QuantizationModifier from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory from tests.sparseml.pytorch.helpers import LinearNet from tests.testing_utils import requires_torch @@ -92,13 +92,13 @@ def test_create_default_quant_modifier(self): testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, vLLMQuantizationModifier) + assert isinstance(modifier.quantization_modifier_, QuantizationModifier) default_config_group_name = "config_group_0" should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[ default_config_group_name ] self.assertEqual(should_be_default_quant_scheme.input_activations.num_bits, 8) - # input activations are symmetric by default in vLLMQuantizationModifier + # input activations are symmetric by default in QuantizationModifier assert should_be_default_quant_scheme.input_activations.symmetric self.assertEqual(should_be_default_quant_scheme.weights.num_bits, 8) @@ -120,7 +120,7 @@ def test_set_quant_if_modifer_already_exists(self): ), ) - modifier = QuantizationModifierPyTorch(**kwargs) + modifier = LegacyQuantizationModifierPyTorch(**kwargs) testing_harness = LifecyleTestingHarness(model=model, start=-1) assert not testing_harness.get_state().model.qat_active() @@ -159,7 +159,7 @@ def setUp(self): } } } - self.quant_config = {"vLLMQuantizationModifier": self.quant_kwargs} + self.quant_config = {"QuantizationModifier": self.quant_kwargs} def test_set_quant_in_gptq(self): kwargs = dict(block_size=128, quantize=self.quant_config) @@ -170,7 +170,7 @@ def test_set_quant_in_gptq(self): testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - self.assertIsInstance(modifier.quantization_modifier_, vLLMQuantizationModifier) + self.assertIsInstance(modifier.quantization_modifier_, QuantizationModifier) dict_scheme = dict(modifier.quantization_modifier_.config_groups) self._check_config( diff --git a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py index 6b258b884cb..b8ece5d4180 100644 --- a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py @@ -21,7 +21,7 @@ from sparseml.core.event import Event, EventType from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework -from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch +from sparseml.modifiers.quantization.pytorch import LegacyQuantizationModifierPyTorch from sparseml.pytorch.sparsification.quantization.quantize import ( is_qat_helper_module, is_quantizable_module, @@ -45,14 +45,14 @@ def setUp(self): def test_quantization_registered(self): quant_obj = ModifierFactory.create( - type_="QuantizationModifier", + type_="LegacyQuantizationModifier", framework=Framework.pytorch, allow_experimental=False, allow_registered=True, **self.kwargs, ) - self.assertIsInstance(quant_obj, QuantizationModifierPyTorch) + self.assertIsInstance(quant_obj, LegacyQuantizationModifierPyTorch) @pytest.mark.unit @@ -71,7 +71,7 @@ def test_quantization_oneshot(self, model_class): state = State(framework=Framework.pytorch, start_event=Event()) state.update(model=model, start=-1) - modifier = QuantizationModifierPyTorch(**self.kwargs) + modifier = LegacyQuantizationModifierPyTorch(**self.kwargs) modifier.initialize(state) @@ -108,7 +108,7 @@ def setUp(self): def test_quantization_training(self, model_class): model = model_class() - modifier = QuantizationModifierPyTorch(**self.kwargs) + modifier = LegacyQuantizationModifierPyTorch(**self.kwargs) testing_harness = LifecyleTestingHarness(model=model) modifier.initialize(testing_harness.get_state()) diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_channel.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_channel.yaml index 48df197537c..2fa7af9d567 100644 --- a/tests/sparseml/transformers/compression/recipes/new_quant_channel.yaml +++ b/tests/sparseml/transformers/compression/recipes/new_quant_channel.yaml @@ -1,6 +1,6 @@ test_stage: quant_modifiers: - vLLMQuantizationModifier: + QuantizationModifier: ignore: ["lm_head", "model.layers.0.mlp.down_proj"] config_groups: group_0: diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml index 924dcd6e3f6..931f4e80ca5 100644 --- a/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml +++ b/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml @@ -1,6 +1,6 @@ test_stage: quant_modifiers: - vLLMQuantizationModifier: + QuantizationModifier: ignore: ["lm_head", "model.layers.0.mlp.down_proj"] config_groups: group_0: diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_simple.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_simple.yaml index 753605fc1dd..b0c7051425d 100644 --- a/tests/sparseml/transformers/compression/recipes/new_quant_simple.yaml +++ b/tests/sparseml/transformers/compression/recipes/new_quant_simple.yaml @@ -1,6 +1,6 @@ test_stage: quant_modifiers: - vLLMQuantizationModifier: + QuantizationModifier: ignore: ["lm_head"] config_groups: group_0: diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml index 19b9d196e6a..34e0a77e052 100644 --- a/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml +++ b/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml @@ -1,6 +1,6 @@ test_stage: quant_modifiers: - vLLMQuantizationModifier: + QuantizationModifier: ignore: ["lm_head", "model.layers.0.mlp.down_proj"] config_groups: group_0: diff --git a/tests/sparseml/transformers/compression/recipes/old_quant_channel.yaml b/tests/sparseml/transformers/compression/recipes/old_quant_channel.yaml index 350d07ce1c2..7d090943915 100644 --- a/tests/sparseml/transformers/compression/recipes/old_quant_channel.yaml +++ b/tests/sparseml/transformers/compression/recipes/old_quant_channel.yaml @@ -1,6 +1,6 @@ test_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - model.layers.0.mlp.down_proj - lm_head diff --git a/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml b/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml index 9d67e334fef..2540787d8dd 100644 --- a/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml +++ b/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml @@ -1,6 +1,6 @@ test_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - model.layers.0.mlp.down_proj - lm_head diff --git a/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml b/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml index 78e49595fe2..adab8340c2e 100644 --- a/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml +++ b/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml @@ -1,6 +1,6 @@ test_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - model.layers.0.mlp.down_proj - lm_head diff --git a/tests/sparseml/transformers/finetune/test_quantization.yaml b/tests/sparseml/transformers/finetune/test_quantization.yaml index 89381c31006..eb2d4afdc39 100644 --- a/tests/sparseml/transformers/finetune/test_quantization.yaml +++ b/tests/sparseml/transformers/finetune/test_quantization.yaml @@ -1,6 +1,6 @@ test_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm diff --git a/tests/sparseml/transformers/obcq/obcq_configs/repeat_quants/tiny_llama_repeat_quant.yaml b/tests/sparseml/transformers/obcq/obcq_configs/repeat_quants/tiny_llama_repeat_quant.yaml index 5bef2cae22d..a91b7b4d56a 100644 --- a/tests/sparseml/transformers/obcq/obcq_configs/repeat_quants/tiny_llama_repeat_quant.yaml +++ b/tests/sparseml/transformers/obcq/obcq_configs/repeat_quants/tiny_llama_repeat_quant.yaml @@ -5,7 +5,7 @@ dataset: open_platypus first_recipe: | first_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm @@ -17,7 +17,7 @@ first_recipe: | second_recipe: | second_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm diff --git a/tests/sparseml/transformers/obcq/obcq_configs/separate_quants/tiny_llama_separate_quant.yaml b/tests/sparseml/transformers/obcq/obcq_configs/separate_quants/tiny_llama_separate_quant.yaml index 1b7cab983f4..64a43cbd943 100644 --- a/tests/sparseml/transformers/obcq/obcq_configs/separate_quants/tiny_llama_separate_quant.yaml +++ b/tests/sparseml/transformers/obcq/obcq_configs/separate_quants/tiny_llama_separate_quant.yaml @@ -5,7 +5,7 @@ dataset: open_platypus first_recipe: | first_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm @@ -17,7 +17,7 @@ first_recipe: | second_recipe: | second_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm diff --git a/tests/sparseml/transformers/obcq/recipes/additional_sparsity_with_quant.yaml b/tests/sparseml/transformers/obcq/recipes/additional_sparsity_with_quant.yaml index 42538955b5e..72ca3c08fc7 100644 --- a/tests/sparseml/transformers/obcq/recipes/additional_sparsity_with_quant.yaml +++ b/tests/sparseml/transformers/obcq/recipes/additional_sparsity_with_quant.yaml @@ -6,7 +6,7 @@ test_stage: [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"] ] - QuantizationModifier: + LegacyQuantizationModifier: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm diff --git a/tests/sparseml/transformers/obcq/recipes/quant.yaml b/tests/sparseml/transformers/obcq/recipes/quant.yaml index 756373fcf89..f5436b3873f 100644 --- a/tests/sparseml/transformers/obcq/recipes/quant.yaml +++ b/tests/sparseml/transformers/obcq/recipes/quant.yaml @@ -6,7 +6,7 @@ test_stage: [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"] ] - QuantizationModifier: + LegacyQuantizationModifier: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm diff --git a/tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml b/tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml index b8c9f3451e0..198b32f0e3c 100644 --- a/tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml +++ b/tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml @@ -6,7 +6,7 @@ test_stage: [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"] ] - QuantizationModifier: + LegacyQuantizationModifier: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm diff --git a/tests/sparseml/transformers/obcq/test_obcq_fake_quant_wrapper.py b/tests/sparseml/transformers/obcq/test_obcq_fake_quant_wrapper.py index 6fafab075b7..ea677db787f 100644 --- a/tests/sparseml/transformers/obcq/test_obcq_fake_quant_wrapper.py +++ b/tests/sparseml/transformers/obcq/test_obcq_fake_quant_wrapper.py @@ -37,7 +37,7 @@ def setUp(self): self.recipe = """ first_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - Embedding scheme_overrides: diff --git a/tests/sparseml/transformers/sparsification/modification/test_modifying_llama.py b/tests/sparseml/transformers/sparsification/modification/test_modifying_llama.py index 9091d28b29e..d1532378c27 100644 --- a/tests/sparseml/transformers/sparsification/modification/test_modifying_llama.py +++ b/tests/sparseml/transformers/sparsification/modification/test_modifying_llama.py @@ -23,7 +23,7 @@ def llama_recipe(): return """test_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - MatMulRightInput_QK - MatMulLeftInput_QK diff --git a/tests/sparseml/transformers/sparsification/modification/test_modifying_mistral.py b/tests/sparseml/transformers/sparsification/modification/test_modifying_mistral.py index e71364a53e7..f47fafe0749 100644 --- a/tests/sparseml/transformers/sparsification/modification/test_modifying_mistral.py +++ b/tests/sparseml/transformers/sparsification/modification/test_modifying_mistral.py @@ -23,7 +23,7 @@ def mistral_recipe(): return """test_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - MatMulRightInput_QK - MatMulLeftInput_QK diff --git a/tests/sparseml/transformers/sparsification/modification/test_modifying_opt.py b/tests/sparseml/transformers/sparsification/modification/test_modifying_opt.py index 411371b0bbf..7af36872500 100644 --- a/tests/sparseml/transformers/sparsification/modification/test_modifying_opt.py +++ b/tests/sparseml/transformers/sparsification/modification/test_modifying_opt.py @@ -24,7 +24,7 @@ def opt_recipe(): return """test_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - BMMLeftInput_QK - BMMRightInput_QK diff --git a/tests/sparseml/transformers/test_recipe_compatibility.py b/tests/sparseml/transformers/test_recipe_compatibility.py index b0d303b1a4f..e0d7d2708ba 100644 --- a/tests/sparseml/transformers/test_recipe_compatibility.py +++ b/tests/sparseml/transformers/test_recipe_compatibility.py @@ -31,7 +31,7 @@ def model_path(tmp_path): def recipe(): return """test_stage: obcq_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm diff --git a/tests/sparseml/transformers/utils/test_initializers.py b/tests/sparseml/transformers/utils/test_initializers.py index 4a85e286d30..f00adb3dd09 100644 --- a/tests/sparseml/transformers/utils/test_initializers.py +++ b/tests/sparseml/transformers/utils/test_initializers.py @@ -34,7 +34,7 @@ def save_recipe_for_text_classification(source_path): recipe = """test_stage: quant_modifiers: - QuantizationModifier: + LegacyQuantizationModifier: post_oneshot_calibration: False scheme_overrides: Embedding: From d83fb2e084172e1ff3369b9d6c0410b96c38a7aa Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 28 May 2024 20:11:15 +0000 Subject: [PATCH 2/5] redo folder structure --- src/sparseml/modifiers/__init__.py | 2 +- src/sparseml/modifiers/quantization/base.py | 107 ++------ .../modifiers/quantization/pytorch.py | 174 +++---------- .../gptq => quantization_legacy}/__init__.py | 0 .../modifiers/quantization_legacy/base.py | 138 ++++++++++ .../gptq}/__init__.py | 0 .../gptq/base.py | 0 .../gptq/pytorch.py | 4 +- .../gptq/utils/__init__.py | 0 .../gptq/utils/gptq_wrapper.py | 0 .../modification/__init__.py | 0 .../modification/modification_objects.py | 0 .../modification/modify_model.py | 2 +- .../modification/registry.py | 0 .../modifiers/quantization_legacy/pytorch.py | 235 ++++++++++++++++++ .../utils/__init__.py | 0 .../utils/constants.py | 0 .../utils/fake_quant_wrapper.py | 0 .../utils/helpers.py | 2 +- .../utils/quantization_scheme.py | 2 +- .../utils/quantize.py | 8 +- .../modifiers/quantization_vllm/base.py | 83 ------- .../modifiers/quantization_vllm/pytorch.py | 141 ----------- .../modification/modifying_bert.py | 4 +- .../modification/modifying_distilbert.py | 4 +- .../modification/modifying_llama.py | 4 +- .../modification/modifying_mistral.py | 4 +- .../modification/modifying_mobilebert.py | 4 +- .../modification/modifying_opt.py | 4 +- .../sparsification/sparse_model.py | 2 +- .../modification/test_modify_model.py | 4 +- .../modifiers/quantization/test_base.py | 2 +- .../pruning/sparsegpt/test_pytorch.py | 6 +- .../modifiers/quantization/test_pytorch.py | 2 +- .../sparsification/modification/conftest.py | 2 +- 35 files changed, 470 insertions(+), 470 deletions(-) rename src/sparseml/modifiers/{quantization/gptq => quantization_legacy}/__init__.py (100%) create mode 100644 src/sparseml/modifiers/quantization_legacy/base.py rename src/sparseml/modifiers/{quantization_vllm => quantization_legacy/gptq}/__init__.py (100%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/gptq/base.py (100%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/gptq/pytorch.py (97%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/gptq/utils/__init__.py (100%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/gptq/utils/gptq_wrapper.py (100%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/modification/__init__.py (100%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/modification/modification_objects.py (100%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/modification/modify_model.py (96%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/modification/registry.py (100%) create mode 100644 src/sparseml/modifiers/quantization_legacy/pytorch.py rename src/sparseml/modifiers/{quantization => quantization_legacy}/utils/__init__.py (100%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/utils/constants.py (100%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/utils/fake_quant_wrapper.py (100%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/utils/helpers.py (99%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/utils/quantization_scheme.py (99%) rename src/sparseml/modifiers/{quantization => quantization_legacy}/utils/quantize.py (98%) delete mode 100644 src/sparseml/modifiers/quantization_vllm/base.py delete mode 100644 src/sparseml/modifiers/quantization_vllm/pytorch.py diff --git a/src/sparseml/modifiers/__init__.py b/src/sparseml/modifiers/__init__.py index d9f790343d9..4b1b5365146 100644 --- a/src/sparseml/modifiers/__init__.py +++ b/src/sparseml/modifiers/__init__.py @@ -18,5 +18,5 @@ from .logarithmic_equalization import * from .obcq import * from .pruning import * -from .quantization import * +from .quantization_legacy import * from .smoothquant import * diff --git a/src/sparseml/modifiers/quantization/base.py b/src/sparseml/modifiers/quantization/base.py index 9b9f1569f09..e6af6485aa3 100644 --- a/src/sparseml/modifiers/quantization/base.py +++ b/src/sparseml/modifiers/quantization/base.py @@ -12,106 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional +from pydantic import Field + +from compressed_tensors.quantization import ( + QuantizationConfig, + QuantizationScheme, + QuantizationStatus, +) from sparseml.core import Event, Modifier -__all__ = ["LegacyQuantizationModifier"] +__all__ = ["QuantizationModifier"] -class LegacyQuantizationModifier(Modifier): +class QuantizationModifier(Modifier): """ - Enables quantization aware training (QAT) for a given module or its submodules - After the start epoch, the specified module(s) forward pass will emulate - quantized execution and the modifier will be enabled until training is completed. - - | Sample yaml: - | LegacyQuantizationModifier: - | start: 0.0 - | scheme: - | input_activations: - | num_bits: 8 - | symmetric: False - | weights: - | num_bits: 8 - | symmetric: True - | scheme_overrides: - | feature_extractor: "default" - | classifier: - | input_activations: - | num_bits: 8 - | symmetric: False - | weights: null - | Conv2d: - | input_activations: - | num_bits: 8 - | symmetric: True - | ignore: ["ReLU", "input"] - | disable_quantization_observer_epoch: 2.0 - | freeze_bn_stats_epoch: 3.0 - | model_fuse_fn_name: 'fuse_module' - | strict: True - - :param ignore: optional list of module class names or submodule names - to not quantize. Default is None + Enables post training quantization (PTQ) and quantization aware training (QAT) for a + given module or its submodules. After calibration (PTQ) or the start epoch (QAT), + the specified module(s) forward pass will emulate quantized execution and the + modifier will be enabled until training is completed. + + :param config_groups: dictionary specifying quantization schemes to apply to target + modules. Modules not matching a scheme target will NOT be quantized. + :param ignore: optional list of module class names or submodule names to not + quantize even if they match a target in config_groups. Defaults to empty list. :param disable_quantization_observer_epoch: Epoch to disable updates to the module quantization observers. At this point, quantized weights and zero points will not be updated. Leave None to not disable observers during QAT. Default is None - :param freeze_bn_stats_epoch: Epoch to stop the tracking of batch norm stats. Leave - None to not stop tracking batch norm stats during QAT. Default is None - :param model_fuse_fn_name: Name of model function to fuse the model in place prior - to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as - 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. - Default is None - :param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed - to the model fusing function :param num_calibration_steps: Number of steps to run post training calibration for. When None, the entire calibration_dataloader is used - :param strict: if True, will raise an error if any module types or submodules in - scheme_overrides or ignore are not found in a given module. Default True """ - ignore: Optional[List[str]] = None + config_groups: Dict[str, QuantizationScheme] + ignore: List[str] = Field(default_factory=list) disable_quantization_observer_epoch: Optional[float] = None - freeze_bn_stats_epoch: Optional[float] = None - model_fuse_fn_name: Optional[str] = None - model_fuse_fn_kwargs: Optional[Dict[str, Any]] = None num_calibration_steps: Optional[int] = None - post_oneshot_calibration: Optional[bool] = False - strict: bool = True - - def __init__(self, **kwargs): - super().__init__(**kwargs) - if self.model_fuse_fn_kwargs is None: - self.model_fuse_fn_kwargs = {} - if self.ignore is None: - self.ignore = [] - def calculate_freeze_bn_stats_epoch(self) -> float: - """ - Get the epoch at which we want to stop updating batch normalization stats - - :return: freeze_bn_stats_epoch if set, else -1 - """ - return ( - self.freeze_bn_stats_epoch if self.freeze_bn_stats_epoch is not None else -1 + def create_init_config(self) -> QuantizationConfig: + return QuantizationConfig( + config_groups=self.config_groups, + quantization_status=QuantizationStatus.INITIALIZED, + ignore=self.ignore, ) - def check_should_freeze_bn_stats(self, event: Event) -> bool: - """ - Given the current index, determine if we should freeze batch normalization stats - - :param event: Event to get index from - :return: True if stats should be frozen, False otherwise - """ - freeze_epoch = self.calculate_freeze_bn_stats_epoch() - if freeze_epoch == -1: - return False - if event.current_index >= freeze_epoch: - return True - return False - def calculate_disable_observer_epoch(self) -> float: """ Get the epoch at which we want to disable to quantization observer diff --git a/src/sparseml/modifiers/quantization/pytorch.py b/src/sparseml/modifiers/quantization/pytorch.py index 0bedd489e9d..8761b16007a 100644 --- a/src/sparseml/modifiers/quantization/pytorch.py +++ b/src/sparseml/modifiers/quantization/pytorch.py @@ -13,77 +13,52 @@ # limitations under the License. import logging -from typing import Any, Dict, Optional +from typing import Any -import torch from torch.nn import Module -from sparseml.core import Event, EventType, State -from sparseml.modifiers.quantization.base import LegacyQuantizationModifier -from sparseml.modifiers.quantization.modification import modify_model -from sparseml.modifiers.quantization.utils.helpers import ( - configure_module_bn_wrappers, - freeze_bn_stats, - fuse_module_conv_bn_relus, -) -from sparseml.modifiers.quantization.utils.quantization_scheme import ( - QuantizationScheme, - QuantizationSchemeLoadable, -) -from sparseml.modifiers.quantization.utils.quantize import ( - convert_module_qat_from_schemes, - raise_if_torch_quantization_not_available, - set_quantization_schemes, +from compressed_tensors.quantization import ( + apply_quantization_config, + freeze_module_quantization, + set_module_for_calibration, ) +from sparseml.core import Event, EventType, State +from sparseml.modifiers.quantization.base import QuantizationModifier from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward -from sparseml.utils.fsdp.context import summon_full_params_context _LOGGER = logging.getLogger(__name__) -class LegacyQuantizationModifierPyTorch(LegacyQuantizationModifier): +class QuantizationModifierPyTorch(QuantizationModifier): """ - Pytorch-specific implementation of quantization modifier - - :param scheme: Default QuantizationScheme to use when enabling quantization - in a module. May also be a dictionary to be loaded into the QuantizationScheme - class. A string alias may also be used, supported aliases: - ['default', 'deepsparse', 'tensorrt']. - If None, the default scheme (`QuantizationScheme()`) will be used. - Default is None - :param scheme_overrides: optional mapping of module type names or submodule type - names to quantization schemes to override them with. If a scheme is mapped to - 'default', then it will use the scheme set in the modifier scheme property + PyTorch specific implementation of QuantizationModifier + + Enables post training quantization (PTQ) and quantization aware training (QAT) for a + given module or its submodules. After calibration (PTQ) or the start epoch (QAT), + the specified module(s) forward pass will emulate quantized execution and the + modifier will be enabled until training is completed. + + :param config_groups: dictionary specifying quantization schemes to apply to target + modules. Modules not matching a scheme target will NOT be quantized. + :param ignore: optional list of module class names or submodule names to not + quantize even if they match a target in config_groups. Defaults to empty list. + :param disable_quantization_observer_epoch: Epoch to disable updates to the module + quantization observers. At this point, quantized weights and zero points will + not be updated. Leave None to not disable observers during QAT. Default is None + :param num_calibration_steps: Number of steps to run post training calibration for. + When None, the entire calibration_dataloader is used """ - scheme: Optional[QuantizationSchemeLoadable] = None - scheme_overrides: Optional[Dict[str, QuantizationSchemeLoadable]] = None calibration_dataloader_: Any = None calibration_function_: Any = None - qat_enabled_: bool = False - quantization_observer_disabled_: bool = False - bn_stats_frozen_: bool = False - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.scheme = QuantizationScheme.load(self.scheme) - self.scheme_overrides = _load_quantization_schemes_dict( - self.scheme_overrides, self.scheme - ) def on_initialize_structure(self, state: State, **kwargs): module = state.model.model - # before the structure is modified to support quantization, - # we need to potentially modify the model architecture - module = modify_model(module) - self._enable_module_qat(module) - state.model.model.apply(torch.quantization.disable_observer) + self._apply_modifier_to_model(module) + module.apply(freeze_module_quantization) def on_initialize(self, state: State, **kwargs) -> bool: - raise_if_torch_quantization_not_available() - module = state.model.model - module = modify_model(module) if self.end and self.end != -1: raise ValueError( "end_epoch is disabled for QuantizationModifier and can only be set to" @@ -93,85 +68,39 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.calibration_dataloader_ = state.data.calib module = state.model.model + # intialize quantization in appropriate modules + self._apply_modifier_to_model(module) + if self.calculate_start() == -1: # one-shot - self._enable_module_qat(module) + module.apply(set_module_for_calibration) self._calibrate_if_possible(module) - self._disable_quantization_observer(module) + module.apply(freeze_module_quantization) return True def on_finalize(self, state: State, **kwargs) -> bool: - if self.post_oneshot_calibration: - state.model.model.apply(torch.quantization.enable_observer) - self._calibrate_if_possible(state.model.model) - self._disable_quantization_observer(state.model.model) return True def on_start(self, state: State, event: Event, **kwargs): - if not self.qat_enabled_: - self._enable_module_qat(state.model.model) + module = state.model.model + module.apply(set_module_for_calibration) def on_update(self, state: State, event: Event, **kwargs): if event.type_ == EventType.BATCH_START: - if self.check_should_freeze_bn_stats(event): - self._freeze_bn_stats(state.model.model) if self.check_should_disable_observer(event): - self._disable_quantization_observer(state.model.model) + module = state.model.model + module.apply(freeze_module_quantization) def on_end(self, state: State, event: Event, **kwargs): - self._disable_quantization_observer(state.model.model) + module = state.model.model + module.apply(freeze_module_quantization) def on_event(self, state: State, event: Event, **kwargs): pass - def _freeze_bn_stats(self, model: Module): - model.apply(freeze_bn_stats) - self.bn_stats_frozen_ = True - - def _disable_quantization_observer(self, model: Module): - model.apply(torch.quantization.disable_observer) - self.quantization_observer_disabled_ = True - - def _enable_module_qat(self, module: Module): - module.apply(torch.quantization.enable_observer) - - if not self.qat_enabled_: - with summon_full_params_context(module): - # fuse conv-bn-relu blocks prior to quantization emulation - self._fuse(module) - - # add quantization_schemes to target submodules - set_quantization_schemes( - module, - scheme=self.scheme, - scheme_overrides=self.scheme_overrides, - ignore=self.ignore, - strict=self.strict, - ) - - # fix for freezing batchnorm statistics when not fusing BN with convs. - # pytorch only supports freezing batchnorm statistics for fused modules. - # this fix wraps BN modules adding with a new module class that supports - # methods related to freezing/unfreezing BN statistics. - configure_module_bn_wrappers(module) - - # convert target qconfig layers to QAT modules with FakeQuantize - convert_module_qat_from_schemes(module) - - self.qat_enabled_ = True - - def _fuse(self, module: Module): - if self.model_fuse_fn_name in [None, "conv_bn_relus"]: - self.model_fuse_fn_kwargs["inplace"] = True - fuse_module_conv_bn_relus(module, **self.model_fuse_fn_kwargs) - elif self.model_fuse_fn_name != "no_fuse": - module_fuse_fn = getattr(module, self.model_fuse_fn_name, None) - if module_fuse_fn is None or not callable(module_fuse_fn): - raise ValueError( - "Invalid model_fuse_fn_name. " - "Module has no callable function {}".format(self.model_fuse_fn_name) - ) - module_fuse_fn(**self.model_fuse_fn_kwargs) + def _apply_modifier_to_model(self, model: Module): + modifier_as_config = self.create_init_config() + apply_quantization_config(model, modifier_as_config) def _calibrate_if_possible(self, module: Module): if self.num_calibration_steps == 0 and self.calibration_dataloader_: @@ -210,26 +139,3 @@ def _calibrate(self, module: Module): if module_training: module.train() - else: - self._disable_quantization_observer(module) - - -class _QuantizationSchemesDict(dict): - # wrapper class for dict to override the __str__ method for yaml serialization - - def __str__(self): - return str({submodule: scheme.dict() for submodule, scheme in self.items()}) - - -def _load_quantization_schemes_dict( - schemes_dict: Optional[Dict[str, QuantizationSchemeLoadable]], - default_scheme: QuantizationScheme, -) -> Dict[str, QuantizationScheme]: - if schemes_dict is None: - return {} - return _QuantizationSchemesDict( - { - submodule: QuantizationScheme.load(scheme, default=default_scheme) - for submodule, scheme in schemes_dict.items() - } - ) diff --git a/src/sparseml/modifiers/quantization/gptq/__init__.py b/src/sparseml/modifiers/quantization_legacy/__init__.py similarity index 100% rename from src/sparseml/modifiers/quantization/gptq/__init__.py rename to src/sparseml/modifiers/quantization_legacy/__init__.py diff --git a/src/sparseml/modifiers/quantization_legacy/base.py b/src/sparseml/modifiers/quantization_legacy/base.py new file mode 100644 index 00000000000..9b9f1569f09 --- /dev/null +++ b/src/sparseml/modifiers/quantization_legacy/base.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional + +from sparseml.core import Event, Modifier + + +__all__ = ["LegacyQuantizationModifier"] + + +class LegacyQuantizationModifier(Modifier): + """ + Enables quantization aware training (QAT) for a given module or its submodules + After the start epoch, the specified module(s) forward pass will emulate + quantized execution and the modifier will be enabled until training is completed. + + | Sample yaml: + | LegacyQuantizationModifier: + | start: 0.0 + | scheme: + | input_activations: + | num_bits: 8 + | symmetric: False + | weights: + | num_bits: 8 + | symmetric: True + | scheme_overrides: + | feature_extractor: "default" + | classifier: + | input_activations: + | num_bits: 8 + | symmetric: False + | weights: null + | Conv2d: + | input_activations: + | num_bits: 8 + | symmetric: True + | ignore: ["ReLU", "input"] + | disable_quantization_observer_epoch: 2.0 + | freeze_bn_stats_epoch: 3.0 + | model_fuse_fn_name: 'fuse_module' + | strict: True + + :param ignore: optional list of module class names or submodule names + to not quantize. Default is None + :param disable_quantization_observer_epoch: Epoch to disable updates to the module + quantization observers. At this point, quantized weights and zero points will + not be updated. Leave None to not disable observers during QAT. Default is None + :param freeze_bn_stats_epoch: Epoch to stop the tracking of batch norm stats. Leave + None to not stop tracking batch norm stats during QAT. Default is None + :param model_fuse_fn_name: Name of model function to fuse the model in place prior + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + Default is None + :param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed + to the model fusing function + :param num_calibration_steps: Number of steps to run post training calibration for. + When None, the entire calibration_dataloader is used + :param strict: if True, will raise an error if any module types or submodules in + scheme_overrides or ignore are not found in a given module. Default True + """ + + ignore: Optional[List[str]] = None + disable_quantization_observer_epoch: Optional[float] = None + freeze_bn_stats_epoch: Optional[float] = None + model_fuse_fn_name: Optional[str] = None + model_fuse_fn_kwargs: Optional[Dict[str, Any]] = None + num_calibration_steps: Optional[int] = None + post_oneshot_calibration: Optional[bool] = False + strict: bool = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.model_fuse_fn_kwargs is None: + self.model_fuse_fn_kwargs = {} + if self.ignore is None: + self.ignore = [] + + def calculate_freeze_bn_stats_epoch(self) -> float: + """ + Get the epoch at which we want to stop updating batch normalization stats + + :return: freeze_bn_stats_epoch if set, else -1 + """ + return ( + self.freeze_bn_stats_epoch if self.freeze_bn_stats_epoch is not None else -1 + ) + + def check_should_freeze_bn_stats(self, event: Event) -> bool: + """ + Given the current index, determine if we should freeze batch normalization stats + + :param event: Event to get index from + :return: True if stats should be frozen, False otherwise + """ + freeze_epoch = self.calculate_freeze_bn_stats_epoch() + if freeze_epoch == -1: + return False + if event.current_index >= freeze_epoch: + return True + return False + + def calculate_disable_observer_epoch(self) -> float: + """ + Get the epoch at which we want to disable to quantization observer + :return epoch to disable at, or -1 if it is not set + """ + return ( + self.disable_quantization_observer_epoch + if self.disable_quantization_observer_epoch is not None + else -1 + ) + + def check_should_disable_observer(self, event: Event) -> bool: + """ + Given the current index, determine if we should disable the observer + + :param event: Event to get index from + :return: True if observer should be disabled, False otherwise + """ + disable_epoch = self.calculate_disable_observer_epoch() + if disable_epoch == -1: + return False + if event.current_index >= disable_epoch: + return True + return False diff --git a/src/sparseml/modifiers/quantization_vllm/__init__.py b/src/sparseml/modifiers/quantization_legacy/gptq/__init__.py similarity index 100% rename from src/sparseml/modifiers/quantization_vllm/__init__.py rename to src/sparseml/modifiers/quantization_legacy/gptq/__init__.py diff --git a/src/sparseml/modifiers/quantization/gptq/base.py b/src/sparseml/modifiers/quantization_legacy/gptq/base.py similarity index 100% rename from src/sparseml/modifiers/quantization/gptq/base.py rename to src/sparseml/modifiers/quantization_legacy/gptq/base.py diff --git a/src/sparseml/modifiers/quantization/gptq/pytorch.py b/src/sparseml/modifiers/quantization_legacy/gptq/pytorch.py similarity index 97% rename from src/sparseml/modifiers/quantization/gptq/pytorch.py rename to src/sparseml/modifiers/quantization_legacy/gptq/pytorch.py index e9e3f715625..c76382db647 100644 --- a/src/sparseml/modifiers/quantization/gptq/pytorch.py +++ b/src/sparseml/modifiers/quantization_legacy/gptq/pytorch.py @@ -19,8 +19,8 @@ from sparseml.core.model import ModifiableModel from sparseml.core.state import State -from sparseml.modifiers.quantization.gptq.base import GPTQModifier -from sparseml.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper +from sparseml.modifiers.quantization_legacy.gptq.base import GPTQModifier +from sparseml.modifiers.quantization_legacy.gptq.utils.gptq_wrapper import GPTQWrapper from sparseml.modifiers.utils.layer_compressor import LayerCompressor from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward from sparseml.utils.fsdp.context import fix_fsdp_module_name diff --git a/src/sparseml/modifiers/quantization/gptq/utils/__init__.py b/src/sparseml/modifiers/quantization_legacy/gptq/utils/__init__.py similarity index 100% rename from src/sparseml/modifiers/quantization/gptq/utils/__init__.py rename to src/sparseml/modifiers/quantization_legacy/gptq/utils/__init__.py diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization_legacy/gptq/utils/gptq_wrapper.py similarity index 100% rename from src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py rename to src/sparseml/modifiers/quantization_legacy/gptq/utils/gptq_wrapper.py diff --git a/src/sparseml/modifiers/quantization/modification/__init__.py b/src/sparseml/modifiers/quantization_legacy/modification/__init__.py similarity index 100% rename from src/sparseml/modifiers/quantization/modification/__init__.py rename to src/sparseml/modifiers/quantization_legacy/modification/__init__.py diff --git a/src/sparseml/modifiers/quantization/modification/modification_objects.py b/src/sparseml/modifiers/quantization_legacy/modification/modification_objects.py similarity index 100% rename from src/sparseml/modifiers/quantization/modification/modification_objects.py rename to src/sparseml/modifiers/quantization_legacy/modification/modification_objects.py diff --git a/src/sparseml/modifiers/quantization/modification/modify_model.py b/src/sparseml/modifiers/quantization_legacy/modification/modify_model.py similarity index 96% rename from src/sparseml/modifiers/quantization/modification/modify_model.py rename to src/sparseml/modifiers/quantization_legacy/modification/modify_model.py index 1fee2d70c3c..b2dc72cfb83 100644 --- a/src/sparseml/modifiers/quantization/modification/modify_model.py +++ b/src/sparseml/modifiers/quantization_legacy/modification/modify_model.py @@ -15,7 +15,7 @@ import logging import os -from sparseml.modifiers.quantization.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry _LOGGER = logging.getLogger(__name__) diff --git a/src/sparseml/modifiers/quantization/modification/registry.py b/src/sparseml/modifiers/quantization_legacy/modification/registry.py similarity index 100% rename from src/sparseml/modifiers/quantization/modification/registry.py rename to src/sparseml/modifiers/quantization_legacy/modification/registry.py diff --git a/src/sparseml/modifiers/quantization_legacy/pytorch.py b/src/sparseml/modifiers/quantization_legacy/pytorch.py new file mode 100644 index 00000000000..34d8e9ac54e --- /dev/null +++ b/src/sparseml/modifiers/quantization_legacy/pytorch.py @@ -0,0 +1,235 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict, Optional + +import torch +from torch.nn import Module + +from sparseml.core import Event, EventType, State +from sparseml.modifiers.quantization_legacy.base import LegacyQuantizationModifier +from sparseml.modifiers.quantization_legacy.modification import modify_model +from sparseml.modifiers.quantization_legacy.utils.helpers import ( + configure_module_bn_wrappers, + freeze_bn_stats, + fuse_module_conv_bn_relus, +) +from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import ( + QuantizationScheme, + QuantizationSchemeLoadable, +) +from sparseml.modifiers.quantization_legacy.utils.quantize import ( + convert_module_qat_from_schemes, + raise_if_torch_quantization_not_available, + set_quantization_schemes, +) +from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward +from sparseml.utils.fsdp.context import summon_full_params_context + + +_LOGGER = logging.getLogger(__name__) + + +class LegacyQuantizationModifierPyTorch(LegacyQuantizationModifier): + """ + Pytorch-specific implementation of quantization modifier + + :param scheme: Default QuantizationScheme to use when enabling quantization + in a module. May also be a dictionary to be loaded into the QuantizationScheme + class. A string alias may also be used, supported aliases: + ['default', 'deepsparse', 'tensorrt']. + If None, the default scheme (`QuantizationScheme()`) will be used. + Default is None + :param scheme_overrides: optional mapping of module type names or submodule type + names to quantization schemes to override them with. If a scheme is mapped to + 'default', then it will use the scheme set in the modifier scheme property + """ + + scheme: Optional[QuantizationSchemeLoadable] = None + scheme_overrides: Optional[Dict[str, QuantizationSchemeLoadable]] = None + calibration_dataloader_: Any = None + calibration_function_: Any = None + qat_enabled_: bool = False + quantization_observer_disabled_: bool = False + bn_stats_frozen_: bool = False + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.scheme = QuantizationScheme.load(self.scheme) + self.scheme_overrides = _load_quantization_schemes_dict( + self.scheme_overrides, self.scheme + ) + + def on_initialize_structure(self, state: State, **kwargs): + module = state.model.model + # before the structure is modified to support quantization, + # we need to potentially modify the model architecture + module = modify_model(module) + self._enable_module_qat(module) + state.model.model.apply(torch.quantization.disable_observer) + + def on_initialize(self, state: State, **kwargs) -> bool: + raise_if_torch_quantization_not_available() + module = state.model.model + module = modify_model(module) + if self.end and self.end != -1: + raise ValueError( + "end_epoch is disabled for QuantizationModifier and can only be set to" + " -1 or None. Given {}".format(self.end) + ) + + self.calibration_dataloader_ = state.data.calib + module = state.model.model + + if self.calculate_start() == -1: # one-shot + self._enable_module_qat(module) + self._calibrate_if_possible(module) + self._disable_quantization_observer(module) + + return True + + def on_finalize(self, state: State, **kwargs) -> bool: + if self.post_oneshot_calibration: + state.model.model.apply(torch.quantization.enable_observer) + self._calibrate_if_possible(state.model.model) + self._disable_quantization_observer(state.model.model) + return True + + def on_start(self, state: State, event: Event, **kwargs): + if not self.qat_enabled_: + self._enable_module_qat(state.model.model) + + def on_update(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.BATCH_START: + if self.check_should_freeze_bn_stats(event): + self._freeze_bn_stats(state.model.model) + if self.check_should_disable_observer(event): + self._disable_quantization_observer(state.model.model) + + def on_end(self, state: State, event: Event, **kwargs): + self._disable_quantization_observer(state.model.model) + + def on_event(self, state: State, event: Event, **kwargs): + pass + + def _freeze_bn_stats(self, model: Module): + model.apply(freeze_bn_stats) + self.bn_stats_frozen_ = True + + def _disable_quantization_observer(self, model: Module): + model.apply(torch.quantization.disable_observer) + self.quantization_observer_disabled_ = True + + def _enable_module_qat(self, module: Module): + module.apply(torch.quantization.enable_observer) + + if not self.qat_enabled_: + with summon_full_params_context(module): + # fuse conv-bn-relu blocks prior to quantization emulation + self._fuse(module) + + # add quantization_schemes to target submodules + set_quantization_schemes( + module, + scheme=self.scheme, + scheme_overrides=self.scheme_overrides, + ignore=self.ignore, + strict=self.strict, + ) + + # fix for freezing batchnorm statistics when not fusing BN with convs. + # pytorch only supports freezing batchnorm statistics for fused modules. + # this fix wraps BN modules adding with a new module class that supports + # methods related to freezing/unfreezing BN statistics. + configure_module_bn_wrappers(module) + + # convert target qconfig layers to QAT modules with FakeQuantize + convert_module_qat_from_schemes(module) + + self.qat_enabled_ = True + + def _fuse(self, module: Module): + if self.model_fuse_fn_name in [None, "conv_bn_relus"]: + self.model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self.model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": + module_fuse_fn = getattr(module, self.model_fuse_fn_name, None) + if module_fuse_fn is None or not callable(module_fuse_fn): + raise ValueError( + "Invalid model_fuse_fn_name. " + "Module has no callable function {}".format(self.model_fuse_fn_name) + ) + module_fuse_fn(**self.model_fuse_fn_kwargs) + + def _calibrate_if_possible(self, module: Module): + if self.num_calibration_steps == 0 and self.calibration_dataloader_: + _LOGGER.warning( + f"num_calibration_steps is {self.num_calibration_steps}." + f"Calibration data loader will not be used." + ) + elif self.num_calibration_steps and not self.calibration_dataloader_: + raise ValueError( + f"num_calibration_steps is {self.num_calibration_steps}. " + "Calibration data loader is not set. Pass a " + "calibration_data_loader with initialize(...) method." + ) + + elif not self.calibration_dataloader_: + return + + self._calibrate(module) + + def _calibrate(self, module: Module): + class_name = self.__class__.__name__.replace("PyTorch", "") + _LOGGER.info( + f"Running {class_name} calibration with " + f"{len(self.calibration_dataloader_)} samples..." + ) + + module_training = module.training + module.eval() + + run_calibration_forward( + module, + self.calibration_dataloader_, + self.num_calibration_steps, + self.calibration_function_, + ) + + if module_training: + module.train() + else: + self._disable_quantization_observer(module) + + +class _QuantizationSchemesDict(dict): + # wrapper class for dict to override the __str__ method for yaml serialization + + def __str__(self): + return str({submodule: scheme.dict() for submodule, scheme in self.items()}) + + +def _load_quantization_schemes_dict( + schemes_dict: Optional[Dict[str, QuantizationSchemeLoadable]], + default_scheme: QuantizationScheme, +) -> Dict[str, QuantizationScheme]: + if schemes_dict is None: + return {} + return _QuantizationSchemesDict( + { + submodule: QuantizationScheme.load(scheme, default=default_scheme) + for submodule, scheme in schemes_dict.items() + } + ) diff --git a/src/sparseml/modifiers/quantization/utils/__init__.py b/src/sparseml/modifiers/quantization_legacy/utils/__init__.py similarity index 100% rename from src/sparseml/modifiers/quantization/utils/__init__.py rename to src/sparseml/modifiers/quantization_legacy/utils/__init__.py diff --git a/src/sparseml/modifiers/quantization/utils/constants.py b/src/sparseml/modifiers/quantization_legacy/utils/constants.py similarity index 100% rename from src/sparseml/modifiers/quantization/utils/constants.py rename to src/sparseml/modifiers/quantization_legacy/utils/constants.py diff --git a/src/sparseml/modifiers/quantization/utils/fake_quant_wrapper.py b/src/sparseml/modifiers/quantization_legacy/utils/fake_quant_wrapper.py similarity index 100% rename from src/sparseml/modifiers/quantization/utils/fake_quant_wrapper.py rename to src/sparseml/modifiers/quantization_legacy/utils/fake_quant_wrapper.py diff --git a/src/sparseml/modifiers/quantization/utils/helpers.py b/src/sparseml/modifiers/quantization_legacy/utils/helpers.py similarity index 99% rename from src/sparseml/modifiers/quantization/utils/helpers.py rename to src/sparseml/modifiers/quantization_legacy/utils/helpers.py index 318769e22ad..dd93d46dcfb 100644 --- a/src/sparseml/modifiers/quantization/utils/helpers.py +++ b/src/sparseml/modifiers/quantization_legacy/utils/helpers.py @@ -26,7 +26,7 @@ from torch import quantization as torch_quantization from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU -from sparseml.modifiers.quantization.utils.quantization_scheme import ( +from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import ( QuantizationArgs, QuantizationScheme, get_observer, diff --git a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py b/src/sparseml/modifiers/quantization_legacy/utils/quantization_scheme.py similarity index 99% rename from src/sparseml/modifiers/quantization/utils/quantization_scheme.py rename to src/sparseml/modifiers/quantization_legacy/utils/quantization_scheme.py index 29e64bf6477..e97771c437d 100644 --- a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py +++ b/src/sparseml/modifiers/quantization_legacy/utils/quantization_scheme.py @@ -30,7 +30,7 @@ except Exception: torch_quantization = None -from sparseml.modifiers.quantization.utils.fake_quant_wrapper import FakeQuantizeWrapper +from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import FakeQuantizeWrapper __all__ = [ diff --git a/src/sparseml/modifiers/quantization/utils/quantize.py b/src/sparseml/modifiers/quantization_legacy/utils/quantize.py similarity index 98% rename from src/sparseml/modifiers/quantization/utils/quantize.py rename to src/sparseml/modifiers/quantization_legacy/utils/quantize.py index 3b6d17cab65..89e9d2faaa9 100644 --- a/src/sparseml/modifiers/quantization/utils/quantize.py +++ b/src/sparseml/modifiers/quantization_legacy/utils/quantize.py @@ -22,17 +22,17 @@ from packaging import version from torch.nn import Identity, Module -from sparseml.modifiers.quantization.utils.constants import ( +from sparseml.modifiers.quantization_legacy.utils.constants import ( FUSED_MODULE_NAMES, NON_QUANTIZABLE_MODULE_NAMES, ) -from sparseml.modifiers.quantization.utils.fake_quant_wrapper import FakeQuantizeWrapper -from sparseml.modifiers.quantization.utils.helpers import ( +from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import FakeQuantizeWrapper +from sparseml.modifiers.quantization_legacy.utils.helpers import ( QATWrapper, configure_module_default_qconfigs, prepare_embeddings_qat, ) -from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme +from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import QuantizationScheme from sparseml.pytorch.utils import get_layer from sparseml.utils.fsdp.context import fix_fsdp_module_name diff --git a/src/sparseml/modifiers/quantization_vllm/base.py b/src/sparseml/modifiers/quantization_vllm/base.py deleted file mode 100644 index e6af6485aa3..00000000000 --- a/src/sparseml/modifiers/quantization_vllm/base.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List, Optional - -from pydantic import Field - -from compressed_tensors.quantization import ( - QuantizationConfig, - QuantizationScheme, - QuantizationStatus, -) -from sparseml.core import Event, Modifier - - -__all__ = ["QuantizationModifier"] - - -class QuantizationModifier(Modifier): - """ - Enables post training quantization (PTQ) and quantization aware training (QAT) for a - given module or its submodules. After calibration (PTQ) or the start epoch (QAT), - the specified module(s) forward pass will emulate quantized execution and the - modifier will be enabled until training is completed. - - :param config_groups: dictionary specifying quantization schemes to apply to target - modules. Modules not matching a scheme target will NOT be quantized. - :param ignore: optional list of module class names or submodule names to not - quantize even if they match a target in config_groups. Defaults to empty list. - :param disable_quantization_observer_epoch: Epoch to disable updates to the module - quantization observers. At this point, quantized weights and zero points will - not be updated. Leave None to not disable observers during QAT. Default is None - :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used - """ - - config_groups: Dict[str, QuantizationScheme] - ignore: List[str] = Field(default_factory=list) - disable_quantization_observer_epoch: Optional[float] = None - num_calibration_steps: Optional[int] = None - - def create_init_config(self) -> QuantizationConfig: - return QuantizationConfig( - config_groups=self.config_groups, - quantization_status=QuantizationStatus.INITIALIZED, - ignore=self.ignore, - ) - - def calculate_disable_observer_epoch(self) -> float: - """ - Get the epoch at which we want to disable to quantization observer - :return epoch to disable at, or -1 if it is not set - """ - return ( - self.disable_quantization_observer_epoch - if self.disable_quantization_observer_epoch is not None - else -1 - ) - - def check_should_disable_observer(self, event: Event) -> bool: - """ - Given the current index, determine if we should disable the observer - - :param event: Event to get index from - :return: True if observer should be disabled, False otherwise - """ - disable_epoch = self.calculate_disable_observer_epoch() - if disable_epoch == -1: - return False - if event.current_index >= disable_epoch: - return True - return False diff --git a/src/sparseml/modifiers/quantization_vllm/pytorch.py b/src/sparseml/modifiers/quantization_vllm/pytorch.py deleted file mode 100644 index a6b5e1bc288..00000000000 --- a/src/sparseml/modifiers/quantization_vllm/pytorch.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Any - -from torch.nn import Module - -from compressed_tensors.quantization import ( - apply_quantization_config, - freeze_module_quantization, - set_module_for_calibration, -) -from sparseml.core import Event, EventType, State -from sparseml.modifiers.quantization_vllm.base import QuantizationModifier -from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward - - -_LOGGER = logging.getLogger(__name__) - - -class QuantizationModifierPyTorch(QuantizationModifier): - """ - PyTorch specific implementation of QuantizationModifier - - Enables post training quantization (PTQ) and quantization aware training (QAT) for a - given module or its submodules. After calibration (PTQ) or the start epoch (QAT), - the specified module(s) forward pass will emulate quantized execution and the - modifier will be enabled until training is completed. - - :param config_groups: dictionary specifying quantization schemes to apply to target - modules. Modules not matching a scheme target will NOT be quantized. - :param ignore: optional list of module class names or submodule names to not - quantize even if they match a target in config_groups. Defaults to empty list. - :param disable_quantization_observer_epoch: Epoch to disable updates to the module - quantization observers. At this point, quantized weights and zero points will - not be updated. Leave None to not disable observers during QAT. Default is None - :param num_calibration_steps: Number of steps to run post training calibration for. - When None, the entire calibration_dataloader is used - """ - - calibration_dataloader_: Any = None - calibration_function_: Any = None - - def on_initialize_structure(self, state: State, **kwargs): - module = state.model.model - self._apply_modifier_to_model(module) - module.apply(freeze_module_quantization) - - def on_initialize(self, state: State, **kwargs) -> bool: - if self.end and self.end != -1: - raise ValueError( - "end_epoch is disabled for QuantizationModifier and can only be set to" - " -1 or None. Given {}".format(self.end) - ) - - self.calibration_dataloader_ = state.data.calib - module = state.model.model - - # intialize quantization in appropriate modules - self._apply_modifier_to_model(module) - - if self.calculate_start() == -1: # one-shot - module.apply(set_module_for_calibration) - self._calibrate_if_possible(module) - module.apply(freeze_module_quantization) - - return True - - def on_finalize(self, state: State, **kwargs) -> bool: - return True - - def on_start(self, state: State, event: Event, **kwargs): - module = state.model.model - module.apply(set_module_for_calibration) - - def on_update(self, state: State, event: Event, **kwargs): - if event.type_ == EventType.BATCH_START: - if self.check_should_disable_observer(event): - module = state.model.model - module.apply(freeze_module_quantization) - - def on_end(self, state: State, event: Event, **kwargs): - module = state.model.model - module.apply(freeze_module_quantization) - - def on_event(self, state: State, event: Event, **kwargs): - pass - - def _apply_modifier_to_model(self, model: Module): - modifier_as_config = self.create_init_config() - apply_quantization_config(model, modifier_as_config) - - def _calibrate_if_possible(self, module: Module): - if self.num_calibration_steps == 0 and self.calibration_dataloader_: - _LOGGER.warning( - f"num_calibration_steps is {self.num_calibration_steps}." - f"Calibration data loader will not be used." - ) - elif self.num_calibration_steps and not self.calibration_dataloader_: - raise ValueError( - f"num_calibration_steps is {self.num_calibration_steps}. " - "Calibration data loader is not set. Pass a " - "calibration_data_loader with initialize(...) method." - ) - - elif not self.calibration_dataloader_: - return - - self._calibrate(module) - - def _calibrate(self, module: Module): - class_name = self.__class__.__name__.replace("PyTorch", "") - _LOGGER.info( - f"Running {class_name} calibration with " - f"{len(self.calibration_dataloader_)} samples..." - ) - - module_training = module.training - module.eval() - - run_calibration_forward( - module, - self.calibration_dataloader_, - self.num_calibration_steps, - self.calibration_function_, - ) - - if module_training: - module.train() diff --git a/src/sparseml/transformers/sparsification/modification/modifying_bert.py b/src/sparseml/transformers/sparsification/modification/modifying_bert.py index b1c273999ba..2632600ed8c 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_bert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_bert.py @@ -25,8 +25,8 @@ from torch import nn from transformers.models.bert.modeling_bert import BertSelfAttention -from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul -from sparseml.modifiers.quantization.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.modification_objects import QATMatMul +from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py b/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py index 2cc9915b900..9aa0389590c 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py @@ -27,8 +27,8 @@ MultiHeadSelfAttention, ) -from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul -from sparseml.modifiers.quantization.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.modification_objects import QATMatMul +from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_llama.py b/src/sparseml/transformers/sparsification/modification/modifying_llama.py index d51827fc8f3..5e480376a74 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_llama.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_llama.py @@ -32,11 +32,11 @@ repeat_kv, ) -from sparseml.modifiers.quantization.modification.modification_objects import ( +from sparseml.modifiers.quantization_legacy.modification.modification_objects import ( QuantizableIdentity, QuantizableMatMul, ) -from sparseml.modifiers.quantization.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_mistral.py b/src/sparseml/transformers/sparsification/modification/modifying_mistral.py index 1a03d635027..a50f31eb588 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_mistral.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_mistral.py @@ -32,11 +32,11 @@ repeat_kv, ) -from sparseml.modifiers.quantization.modification.modification_objects import ( +from sparseml.modifiers.quantization_legacy.modification.modification_objects import ( QuantizableIdentity, QuantizableMatMul, ) -from sparseml.modifiers.quantization.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py b/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py index 469ca36a736..62fd6bef7f8 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py @@ -20,8 +20,8 @@ from torch import nn from transformers.models.mobilebert.modeling_mobilebert import MobileBertEmbeddings -from sparseml.modifiers.quantization.modification.modification_objects import QATLinear -from sparseml.modifiers.quantization.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.modification_objects import QATLinear +from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_opt.py b/src/sparseml/transformers/sparsification/modification/modifying_opt.py index 5f696ee36c7..fb448316cab 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_opt.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_opt.py @@ -23,11 +23,11 @@ from torch import nn from transformers.models.opt.modeling_opt import OPTAttention, OptFlashAttention2 -from sparseml.modifiers.quantization.modification.modification_objects import ( +from sparseml.modifiers.quantization_legacy.modification.modification_objects import ( QuantizableBatchMatmul, QuantizableIdentity, ) -from sparseml.modifiers.quantization.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/sparse_model.py b/src/sparseml/transformers/sparsification/sparse_model.py index 76e75862fff..3132411d332 100644 --- a/src/sparseml/transformers/sparsification/sparse_model.py +++ b/src/sparseml/transformers/sparsification/sparse_model.py @@ -31,7 +31,7 @@ from transformers.file_utils import WEIGHTS_NAME from compressed_tensors.compressors import ModelCompressor -from sparseml.modifiers.quantization.modification import modify_model +from sparseml.modifiers.quantization_legacy.modification import modify_model from sparseml.pytorch.model_load.helpers import ( apply_recipe_structure_to_model, log_model_load, diff --git a/tests/sparseml/modifiers/quantization/modification/test_modify_model.py b/tests/sparseml/modifiers/quantization/modification/test_modify_model.py index 2bde19a5757..16c13af7207 100644 --- a/tests/sparseml/modifiers/quantization/modification/test_modify_model.py +++ b/tests/sparseml/modifiers/quantization/modification/test_modify_model.py @@ -17,8 +17,8 @@ import pytest -from sparseml.modifiers.quantization.modification import modify_model -from sparseml.modifiers.quantization.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification import modify_model +from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry from sparsezoo.utils.registry import _ALIAS_REGISTRY, _REGISTRY, standardize_lookup_name diff --git a/tests/sparseml/modifiers/quantization/test_base.py b/tests/sparseml/modifiers/quantization/test_base.py index d0bd316c534..491f03c1866 100644 --- a/tests/sparseml/modifiers/quantization/test_base.py +++ b/tests/sparseml/modifiers/quantization/test_base.py @@ -19,7 +19,7 @@ from sparseml.core.event import Event from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework -from sparseml.modifiers.quantization import LegacyQuantizationModifier +from sparseml.modifiers.quantization_legacy import LegacyQuantizationModifier from tests.sparseml.modifiers.conf import setup_modifier_factory diff --git a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 7f962d5b017..01c5fb0cbf9 100644 --- a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -20,9 +20,9 @@ from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch -from sparseml.modifiers.quantization.gptq.pytorch import GPTQModifierPyTorch -from sparseml.modifiers.quantization.pytorch import LegacyQuantizationModifierPyTorch -from sparseml.modifiers.quantization_vllm.base import QuantizationModifier +from sparseml.modifiers.quantization_legacy.gptq.pytorch import GPTQModifierPyTorch +from sparseml.modifiers.quantization_legacy.pytorch import LegacyQuantizationModifierPyTorch +from sparseml.modifiers.quantization.base import QuantizationModifier from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory from tests.sparseml.pytorch.helpers import LinearNet from tests.testing_utils import requires_torch diff --git a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py index b8ece5d4180..b1327f4ce3d 100644 --- a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py @@ -21,7 +21,7 @@ from sparseml.core.event import Event, EventType from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework -from sparseml.modifiers.quantization.pytorch import LegacyQuantizationModifierPyTorch +from sparseml.modifiers.quantization_legacy.pytorch import LegacyQuantizationModifierPyTorch from sparseml.pytorch.sparsification.quantization.quantize import ( is_qat_helper_module, is_quantizable_module, diff --git a/tests/sparseml/transformers/sparsification/modification/conftest.py b/tests/sparseml/transformers/sparsification/modification/conftest.py index d6a9fd1c0ad..9c19cc702c2 100644 --- a/tests/sparseml/transformers/sparsification/modification/conftest.py +++ b/tests/sparseml/transformers/sparsification/modification/conftest.py @@ -18,7 +18,7 @@ from transformers import AutoConfig, AutoModel from accelerate import init_empty_weights -from sparseml.modifiers.quantization.modification import modify_model +from sparseml.modifiers.quantization_legacy.modification import modify_model from sparseml.pytorch.model_load.helpers import apply_recipe_structure_to_model from sparseml.transformers import SparseAutoConfig, SparseAutoModelForCausalLM From 25d8b1d636e57bd2e5d4a662decd0270a3ebb5c4 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 28 May 2024 20:24:45 +0000 Subject: [PATCH 3/5] fixing imports --- src/sparseml/modifiers/__init__.py | 2 +- src/sparseml/modifiers/quantization/__init__.py | 3 ++- .../gptq/__init__.py | 0 .../gptq/base.py | 0 .../gptq/pytorch.py | 4 ++-- .../gptq/utils/__init__.py | 0 .../gptq/utils/gptq_wrapper.py | 0 .../quantization/quantization/__init__.py | 17 +++++++++++++++++ .../quantization/{ => quantization}/base.py | 0 .../quantization/{ => quantization}/pytorch.py | 2 +- .../modification/modify_model.py | 4 +++- .../utils/quantization_scheme.py | 4 +++- .../quantization_legacy/utils/quantize.py | 8 ++++++-- .../modification/modifying_bert.py | 8 ++++++-- .../modification/modifying_distilbert.py | 8 ++++++-- .../modification/modifying_llama.py | 4 +++- .../modification/modifying_mistral.py | 4 +++- .../modification/modifying_mobilebert.py | 8 ++++++-- .../modification/modifying_opt.py | 4 +++- .../modification/test_modify_model.py | 4 +++- .../modifiers/pruning/sparsegpt/test_pytorch.py | 8 +++++--- .../modifiers/quantization/test_pytorch.py | 4 +++- 22 files changed, 73 insertions(+), 23 deletions(-) rename src/sparseml/modifiers/{quantization_legacy => quantization}/gptq/__init__.py (100%) rename src/sparseml/modifiers/{quantization_legacy => quantization}/gptq/base.py (100%) rename src/sparseml/modifiers/{quantization_legacy => quantization}/gptq/pytorch.py (97%) rename src/sparseml/modifiers/{quantization_legacy => quantization}/gptq/utils/__init__.py (100%) rename src/sparseml/modifiers/{quantization_legacy => quantization}/gptq/utils/gptq_wrapper.py (100%) create mode 100644 src/sparseml/modifiers/quantization/quantization/__init__.py rename src/sparseml/modifiers/quantization/{ => quantization}/base.py (100%) rename src/sparseml/modifiers/quantization/{ => quantization}/pytorch.py (98%) diff --git a/src/sparseml/modifiers/__init__.py b/src/sparseml/modifiers/__init__.py index 4b1b5365146..d9f790343d9 100644 --- a/src/sparseml/modifiers/__init__.py +++ b/src/sparseml/modifiers/__init__.py @@ -18,5 +18,5 @@ from .logarithmic_equalization import * from .obcq import * from .pruning import * -from .quantization_legacy import * +from .quantization import * from .smoothquant import * diff --git a/src/sparseml/modifiers/quantization/__init__.py b/src/sparseml/modifiers/quantization/__init__.py index 9cdf715c135..2e1cdc0d24c 100644 --- a/src/sparseml/modifiers/quantization/__init__.py +++ b/src/sparseml/modifiers/quantization/__init__.py @@ -14,4 +14,5 @@ # flake8: noqa -from .base import * +from .gptq import * +from .quantization import * diff --git a/src/sparseml/modifiers/quantization_legacy/gptq/__init__.py b/src/sparseml/modifiers/quantization/gptq/__init__.py similarity index 100% rename from src/sparseml/modifiers/quantization_legacy/gptq/__init__.py rename to src/sparseml/modifiers/quantization/gptq/__init__.py diff --git a/src/sparseml/modifiers/quantization_legacy/gptq/base.py b/src/sparseml/modifiers/quantization/gptq/base.py similarity index 100% rename from src/sparseml/modifiers/quantization_legacy/gptq/base.py rename to src/sparseml/modifiers/quantization/gptq/base.py diff --git a/src/sparseml/modifiers/quantization_legacy/gptq/pytorch.py b/src/sparseml/modifiers/quantization/gptq/pytorch.py similarity index 97% rename from src/sparseml/modifiers/quantization_legacy/gptq/pytorch.py rename to src/sparseml/modifiers/quantization/gptq/pytorch.py index c76382db647..e9e3f715625 100644 --- a/src/sparseml/modifiers/quantization_legacy/gptq/pytorch.py +++ b/src/sparseml/modifiers/quantization/gptq/pytorch.py @@ -19,8 +19,8 @@ from sparseml.core.model import ModifiableModel from sparseml.core.state import State -from sparseml.modifiers.quantization_legacy.gptq.base import GPTQModifier -from sparseml.modifiers.quantization_legacy.gptq.utils.gptq_wrapper import GPTQWrapper +from sparseml.modifiers.quantization.gptq.base import GPTQModifier +from sparseml.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper from sparseml.modifiers.utils.layer_compressor import LayerCompressor from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward from sparseml.utils.fsdp.context import fix_fsdp_module_name diff --git a/src/sparseml/modifiers/quantization_legacy/gptq/utils/__init__.py b/src/sparseml/modifiers/quantization/gptq/utils/__init__.py similarity index 100% rename from src/sparseml/modifiers/quantization_legacy/gptq/utils/__init__.py rename to src/sparseml/modifiers/quantization/gptq/utils/__init__.py diff --git a/src/sparseml/modifiers/quantization_legacy/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py similarity index 100% rename from src/sparseml/modifiers/quantization_legacy/gptq/utils/gptq_wrapper.py rename to src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py diff --git a/src/sparseml/modifiers/quantization/quantization/__init__.py b/src/sparseml/modifiers/quantization/quantization/__init__.py new file mode 100644 index 00000000000..9cdf715c135 --- /dev/null +++ b/src/sparseml/modifiers/quantization/quantization/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .base import * diff --git a/src/sparseml/modifiers/quantization/base.py b/src/sparseml/modifiers/quantization/quantization/base.py similarity index 100% rename from src/sparseml/modifiers/quantization/base.py rename to src/sparseml/modifiers/quantization/quantization/base.py diff --git a/src/sparseml/modifiers/quantization/pytorch.py b/src/sparseml/modifiers/quantization/quantization/pytorch.py similarity index 98% rename from src/sparseml/modifiers/quantization/pytorch.py rename to src/sparseml/modifiers/quantization/quantization/pytorch.py index 8761b16007a..246fd3ce52a 100644 --- a/src/sparseml/modifiers/quantization/pytorch.py +++ b/src/sparseml/modifiers/quantization/quantization/pytorch.py @@ -23,7 +23,7 @@ set_module_for_calibration, ) from sparseml.core import Event, EventType, State -from sparseml.modifiers.quantization.base import QuantizationModifier +from sparseml.modifiers.quantization.quantization.base import QuantizationModifier from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward diff --git a/src/sparseml/modifiers/quantization_legacy/modification/modify_model.py b/src/sparseml/modifiers/quantization_legacy/modification/modify_model.py index b2dc72cfb83..97a1f1022da 100644 --- a/src/sparseml/modifiers/quantization_legacy/modification/modify_model.py +++ b/src/sparseml/modifiers/quantization_legacy/modification/modify_model.py @@ -15,7 +15,9 @@ import logging import os -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) _LOGGER = logging.getLogger(__name__) diff --git a/src/sparseml/modifiers/quantization_legacy/utils/quantization_scheme.py b/src/sparseml/modifiers/quantization_legacy/utils/quantization_scheme.py index e97771c437d..f235cbfdf8c 100644 --- a/src/sparseml/modifiers/quantization_legacy/utils/quantization_scheme.py +++ b/src/sparseml/modifiers/quantization_legacy/utils/quantization_scheme.py @@ -30,7 +30,9 @@ except Exception: torch_quantization = None -from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import FakeQuantizeWrapper +from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import ( + FakeQuantizeWrapper, +) __all__ = [ diff --git a/src/sparseml/modifiers/quantization_legacy/utils/quantize.py b/src/sparseml/modifiers/quantization_legacy/utils/quantize.py index 89e9d2faaa9..038ae5cab92 100644 --- a/src/sparseml/modifiers/quantization_legacy/utils/quantize.py +++ b/src/sparseml/modifiers/quantization_legacy/utils/quantize.py @@ -26,13 +26,17 @@ FUSED_MODULE_NAMES, NON_QUANTIZABLE_MODULE_NAMES, ) -from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import FakeQuantizeWrapper +from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import ( + FakeQuantizeWrapper, +) from sparseml.modifiers.quantization_legacy.utils.helpers import ( QATWrapper, configure_module_default_qconfigs, prepare_embeddings_qat, ) -from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import QuantizationScheme +from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import ( + QuantizationScheme, +) from sparseml.pytorch.utils import get_layer from sparseml.utils.fsdp.context import fix_fsdp_module_name diff --git a/src/sparseml/transformers/sparsification/modification/modifying_bert.py b/src/sparseml/transformers/sparsification/modification/modifying_bert.py index 2632600ed8c..fccb65ea885 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_bert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_bert.py @@ -25,8 +25,12 @@ from torch import nn from transformers.models.bert.modeling_bert import BertSelfAttention -from sparseml.modifiers.quantization_legacy.modification.modification_objects import QATMatMul -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.modification_objects import ( + QATMatMul, +) +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py b/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py index 9aa0389590c..d2bf92dd637 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py @@ -27,8 +27,12 @@ MultiHeadSelfAttention, ) -from sparseml.modifiers.quantization_legacy.modification.modification_objects import QATMatMul -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.modification_objects import ( + QATMatMul, +) +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_llama.py b/src/sparseml/transformers/sparsification/modification/modifying_llama.py index 5e480376a74..d7aea9ac1c6 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_llama.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_llama.py @@ -36,7 +36,9 @@ QuantizableIdentity, QuantizableMatMul, ) -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_mistral.py b/src/sparseml/transformers/sparsification/modification/modifying_mistral.py index a50f31eb588..a27a75d5992 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_mistral.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_mistral.py @@ -36,7 +36,9 @@ QuantizableIdentity, QuantizableMatMul, ) -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py b/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py index 62fd6bef7f8..2ab9d819fb5 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py @@ -20,8 +20,12 @@ from torch import nn from transformers.models.mobilebert.modeling_mobilebert import MobileBertEmbeddings -from sparseml.modifiers.quantization_legacy.modification.modification_objects import QATLinear -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.modification_objects import ( + QATLinear, +) +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_opt.py b/src/sparseml/transformers/sparsification/modification/modifying_opt.py index fb448316cab..eb42dd6d686 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_opt.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_opt.py @@ -27,7 +27,9 @@ QuantizableBatchMatmul, QuantizableIdentity, ) -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/tests/sparseml/modifiers/quantization/modification/test_modify_model.py b/tests/sparseml/modifiers/quantization/modification/test_modify_model.py index 16c13af7207..4ad1cb6580b 100644 --- a/tests/sparseml/modifiers/quantization/modification/test_modify_model.py +++ b/tests/sparseml/modifiers/quantization/modification/test_modify_model.py @@ -18,7 +18,9 @@ import pytest from sparseml.modifiers.quantization_legacy.modification import modify_model -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparsezoo.utils.registry import _ALIAS_REGISTRY, _REGISTRY, standardize_lookup_name diff --git a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 01c5fb0cbf9..0fcb66eee9c 100644 --- a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -20,9 +20,11 @@ from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch -from sparseml.modifiers.quantization_legacy.gptq.pytorch import GPTQModifierPyTorch -from sparseml.modifiers.quantization_legacy.pytorch import LegacyQuantizationModifierPyTorch -from sparseml.modifiers.quantization.base import QuantizationModifier +from sparseml.modifiers.quantization.gptq.pytorch import GPTQModifierPyTorch +from sparseml.modifiers.quantization.quantization.base import QuantizationModifier +from sparseml.modifiers.quantization_legacy.pytorch import ( + LegacyQuantizationModifierPyTorch, +) from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory from tests.sparseml.pytorch.helpers import LinearNet from tests.testing_utils import requires_torch diff --git a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py index b1327f4ce3d..2e9750c60c7 100644 --- a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py @@ -21,7 +21,9 @@ from sparseml.core.event import Event, EventType from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework -from sparseml.modifiers.quantization_legacy.pytorch import LegacyQuantizationModifierPyTorch +from sparseml.modifiers.quantization_legacy.pytorch import ( + LegacyQuantizationModifierPyTorch, +) from sparseml.pytorch.sparsification.quantization.quantize import ( is_qat_helper_module, is_quantizable_module, From 893edd9b15eb2886b199fc8b0f17f8c67b78ee1b Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 28 May 2024 20:47:42 +0000 Subject: [PATCH 4/5] update import --- src/sparseml/modifiers/quantization/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sparseml/modifiers/quantization/__init__.py b/src/sparseml/modifiers/quantization/__init__.py index 2e1cdc0d24c..fe2676473d5 100644 --- a/src/sparseml/modifiers/quantization/__init__.py +++ b/src/sparseml/modifiers/quantization/__init__.py @@ -15,4 +15,3 @@ # flake8: noqa from .gptq import * -from .quantization import * From 2049a89f55cc1c2f4601303b96c588892e6b04f4 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 28 May 2024 20:52:01 +0000 Subject: [PATCH 5/5] fix imports --- src/sparseml/modifiers/quantization/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/sparseml/modifiers/quantization/__init__.py b/src/sparseml/modifiers/quantization/__init__.py index fe2676473d5..ebdf28a6d5b 100644 --- a/src/sparseml/modifiers/quantization/__init__.py +++ b/src/sparseml/modifiers/quantization/__init__.py @@ -13,5 +13,3 @@ # limitations under the License. # flake8: noqa - -from .gptq import *