Skip to content

Commit

Permalink
fix transformers and quantization imports (#1770)
Browse files Browse the repository at this point in the history
* fix transformers and quantization imports

* update GHA

* transformers dependency

* move import

* moving tests around

* fix torch version issue

* move all modifiers under pytorch

* fixing 1.9 tests

* fix key error

* bug fix for 1.9

* move torch imports out of base quant
  • Loading branch information
Satrat committed Oct 18, 2023
1 parent 5dd16ac commit 9c3ef78
Show file tree
Hide file tree
Showing 13 changed files with 142 additions and 48 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ifneq ($(findstring onnx,$(TARGETS)),onnx)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/onnx
endif
ifneq ($(findstring pytorch,$(TARGETS)),pytorch)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch --ignore tests/sparseml/modifiers
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch
endif
ifneq ($(findstring pytorch_models,$(TARGETS)),pytorch_models)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch/models
Expand Down
1 change: 0 additions & 1 deletion src/sparseml/modifiers/obcq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@
# flake8: noqa

from .base import *
from .pytorch import *
7 changes: 6 additions & 1 deletion src/sparseml/modifiers/obcq/utils/sparsegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
import torch.nn as nn
import transformers


DEBUG = False
Expand All @@ -42,6 +41,8 @@ class SparseGPT:
"""

def __init__(self, layer):
import transformers

self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
Expand All @@ -61,6 +62,8 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor):
:param inp: tensor containing layer input
:param out: tensor containing layer our
"""
import transformers

if DEBUG:
self._inp1 = inp
self.out1 = out
Expand Down Expand Up @@ -97,6 +100,8 @@ def fasterprune(
:param percdamp: Amount of dampening to apply to H, as a fraction of the
diagonal norm
"""
import transformers

W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
Expand Down
1 change: 0 additions & 1 deletion src/sparseml/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@
# flake8: noqa

from .base import *
from .pytorch import *
40 changes: 0 additions & 40 deletions src/sparseml/modifiers/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
from typing import Any, Dict, List, Optional

from sparseml.core import Event, Modifier, State
from sparseml.modifiers.quantization.utils.quantization_scheme import (
QuantizationScheme,
QuantizationSchemeLoadable,
)


__all__ = ["QuantizationModifier"]
Expand Down Expand Up @@ -57,15 +53,6 @@ class QuantizationModifier(Modifier):
| model_fuse_fn_name: 'fuse_module'
| strict: True
:param scheme: Default QuantizationScheme to use when enabling quantization
in a module. May also be a dictionary to be loaded into the QuantizationScheme
class. A string alias may also be used, supported aliases:
['default', 'deepsparse', 'tensorrt'].
If None, the default scheme (`QuantizationScheme()`) will be used.
Default is None
:param scheme_overrides: optional mapping of module type names or submodule type
names to quantization schemes to override them with. If a scheme is mapped to
'default', then it will use the scheme set in the modifier scheme property
:param ignore: optional list of module class names or submodule names
to not quantize. Default is None
:param disable_quantization_observer_epoch: Epoch to disable updates to the module
Expand All @@ -85,8 +72,6 @@ class QuantizationModifier(Modifier):
scheme_overrides or ignore are not found in a given module. Default True
"""

scheme: Optional[QuantizationSchemeLoadable] = None
scheme_overrides: Optional[Dict[str, QuantizationSchemeLoadable]] = None
ignore: Optional[List[str]] = None
disable_quantization_observer_epoch: Optional[float] = None
freeze_bn_stats_epoch: Optional[float] = None
Expand All @@ -98,10 +83,6 @@ class QuantizationModifier(Modifier):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.scheme = QuantizationScheme.load(self.scheme)
self.scheme_overrides = _load_quantization_schemes_dict(
self.scheme_overrides, self.scheme
)
if self.model_fuse_fn_kwargs is None:
self.model_fuse_fn_kwargs = {}
if self.ignore is None:
Expand Down Expand Up @@ -158,24 +139,3 @@ def check_should_disable_observer(self, event: Event) -> bool:

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier


class _QuantizationSchemesDict(dict):
# wrapper class for dict to override the __str__ method for yaml serialization

def __str__(self):
return str({submodule: scheme.dict() for submodule, scheme in self.items()})


def _load_quantization_schemes_dict(
schemes_dict: Optional[Dict[str, QuantizationSchemeLoadable]],
default_scheme: QuantizationScheme,
) -> Dict[str, QuantizationScheme]:
if schemes_dict is None:
return {}
return _QuantizationSchemesDict(
{
submodule: QuantizationScheme.load(scheme, default=default_scheme)
for submodule, scheme in schemes_dict.items()
}
)
50 changes: 49 additions & 1 deletion src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from itertools import cycle
from typing import Any, Callable
from typing import Any, Callable, Dict, Optional

import torch
from torch.nn import Module
Expand All @@ -26,6 +26,10 @@
freeze_bn_stats,
fuse_module_conv_bn_relus,
)
from sparseml.modifiers.quantization.utils.quantization_scheme import (
QuantizationScheme,
QuantizationSchemeLoadable,
)
from sparseml.modifiers.quantization.utils.quantize import (
convert_module_qat_from_schemes,
raise_if_torch_quantization_not_available,
Expand All @@ -38,12 +42,35 @@


class QuantizationModifierPyTorch(QuantizationModifier):
"""
Pytorch-specific implementation of quantization modifier
:param scheme: Default QuantizationScheme to use when enabling quantization
in a module. May also be a dictionary to be loaded into the QuantizationScheme
class. A string alias may also be used, supported aliases:
['default', 'deepsparse', 'tensorrt'].
If None, the default scheme (`QuantizationScheme()`) will be used.
Default is None
:param scheme_overrides: optional mapping of module type names or submodule type
names to quantization schemes to override them with. If a scheme is mapped to
'default', then it will use the scheme set in the modifier scheme property
"""

scheme: Optional[QuantizationSchemeLoadable] = None
scheme_overrides: Optional[Dict[str, QuantizationSchemeLoadable]] = None
calibration_dataloader_: Any = None
calibration_function_: Any = None
qat_enabled_: bool = False
quantization_observer_disabled_: bool = False
bn_stats_frozen_: bool = False

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.scheme = QuantizationScheme.load(self.scheme)
self.scheme_overrides = _load_quantization_schemes_dict(
self.scheme_overrides, self.scheme
)

def on_initialize(self, state: State, **kwargs) -> bool:
raise_if_torch_quantization_not_available()
if self.end and self.end != -1:
Expand Down Expand Up @@ -181,3 +208,24 @@ def _calibrate(self, module: Module):
module.train()
else:
self._disable_quantization_observer(module)


class _QuantizationSchemesDict(dict):
# wrapper class for dict to override the __str__ method for yaml serialization

def __str__(self):
return str({submodule: scheme.dict() for submodule, scheme in self.items()})


def _load_quantization_schemes_dict(
schemes_dict: Optional[Dict[str, QuantizationSchemeLoadable]],
default_scheme: QuantizationScheme,
) -> Dict[str, QuantizationScheme]:
if schemes_dict is None:
return {}
return _QuantizationSchemesDict(
{
submodule: QuantizationScheme.load(scheme, default=default_scheme)
for submodule, scheme in schemes_dict.items()
}
)
13 changes: 13 additions & 0 deletions src/sparseml/modifiers/quantization/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/sparseml/modifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/sparseml/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/sparseml/pytorch/modifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/sparseml/pytorch/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from sparseml.core.event import Event, EventType
from sparseml.core.factory import ModifierFactory
from sparseml.core.framework import Framework
from sparseml.modifiers.quantization import QuantizationModifierPyTorch
from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch
from sparseml.pytorch.sparsification.quantization.quantize import (
is_qat_helper_module,
is_quantizable_module,
)
from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory
from tests.sparseml.pytorch.helpers import ConvNet, LinearNet
from tests.sparseml.pytorch.sparsification.quantization.test_modifier_quantization import ( # noqa E501
_match_submodule_name_or_type,
_test_qat_wrapped_module,
_test_quantized_module,
)
Expand Down Expand Up @@ -56,7 +57,12 @@ def _test_qat_applied(modifier, model):
_test_qat_wrapped_module(model, name)
elif is_quantizable:
# check each target module is quantized
_test_quantized_module(model, modifier, module, name)
override_key = _match_submodule_name_or_type(
module,
name,
list(modifier.scheme_overrides.keys()),
)
_test_quantized_module(model, modifier, module, name, override_key)
else:
# check all non-target modules are not quantized
assert not hasattr(module, "quantization_scheme")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os

import pytest
import torch
from packaging import version
from torch.nn import Identity

from sparseml.pytorch.sparsification.quantization.helpers import QATWrapper
Expand Down Expand Up @@ -57,7 +59,17 @@ def _assert_observers_eq(observer_1, observer_2):

if hasattr(observer_1, "p"):
# assume observer is a partial, test by properties dict
assert observer_1.p.keywords == observer_2.p.keywords
observer_1_keywords = observer_1.p.keywords
observer_2_keywords = observer_2.p.keywords
TORCH_VERSION = version.parse(torch.__version__)
if (
TORCH_VERSION.major < 2
): # can't match observer class instances before 2.0
if "observer" in observer_1_keywords:
del observer_1_keywords["observer"]
if "observer" in observer_2_keywords:
del observer_2_keywords["observer"]
assert observer_1_keywords == observer_2_keywords
else:
# default to plain `==`
assert observer_1 == observer_2
Expand Down

0 comments on commit 9c3ef78

Please sign in to comment.