Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix transformers and quantization imports #1770

Merged
merged 12 commits into from
Oct 18, 2023
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ifneq ($(findstring onnx,$(TARGETS)),onnx)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/onnx
endif
ifneq ($(findstring pytorch,$(TARGETS)),pytorch)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch --ignore tests/sparseml/modifiers
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch
endif
ifneq ($(findstring pytorch_models,$(TARGETS)),pytorch_models)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch/models
Expand Down
1 change: 0 additions & 1 deletion src/sparseml/modifiers/obcq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@
# flake8: noqa

from .base import *
from .pytorch import *
7 changes: 6 additions & 1 deletion src/sparseml/modifiers/obcq/utils/sparsegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
import torch.nn as nn
import transformers


DEBUG = False
Expand All @@ -42,6 +41,8 @@ class SparseGPT:
"""

def __init__(self, layer):
import transformers

self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
Expand All @@ -61,6 +62,8 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor):
:param inp: tensor containing layer input
:param out: tensor containing layer our
"""
import transformers

if DEBUG:
self._inp1 = inp
self.out1 = out
Expand Down Expand Up @@ -97,6 +100,8 @@ def fasterprune(
:param percdamp: Amount of dampening to apply to H, as a fraction of the
diagonal norm
"""
import transformers

W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
Expand Down
1 change: 0 additions & 1 deletion src/sparseml/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@
# flake8: noqa

from .base import *
from .pytorch import *
40 changes: 0 additions & 40 deletions src/sparseml/modifiers/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
from typing import Any, Dict, List, Optional

from sparseml.core import Event, Modifier, State
from sparseml.modifiers.quantization.utils.quantization_scheme import (
QuantizationScheme,
QuantizationSchemeLoadable,
)


__all__ = ["QuantizationModifier"]
Expand Down Expand Up @@ -57,15 +53,6 @@ class QuantizationModifier(Modifier):
| model_fuse_fn_name: 'fuse_module'
| strict: True

:param scheme: Default QuantizationScheme to use when enabling quantization
in a module. May also be a dictionary to be loaded into the QuantizationScheme
class. A string alias may also be used, supported aliases:
['default', 'deepsparse', 'tensorrt'].
If None, the default scheme (`QuantizationScheme()`) will be used.
Default is None
:param scheme_overrides: optional mapping of module type names or submodule type
names to quantization schemes to override them with. If a scheme is mapped to
'default', then it will use the scheme set in the modifier scheme property
:param ignore: optional list of module class names or submodule names
to not quantize. Default is None
:param disable_quantization_observer_epoch: Epoch to disable updates to the module
Expand All @@ -85,8 +72,6 @@ class QuantizationModifier(Modifier):
scheme_overrides or ignore are not found in a given module. Default True
"""

scheme: Optional[QuantizationSchemeLoadable] = None
scheme_overrides: Optional[Dict[str, QuantizationSchemeLoadable]] = None
ignore: Optional[List[str]] = None
disable_quantization_observer_epoch: Optional[float] = None
freeze_bn_stats_epoch: Optional[float] = None
Expand All @@ -98,10 +83,6 @@ class QuantizationModifier(Modifier):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.scheme = QuantizationScheme.load(self.scheme)
self.scheme_overrides = _load_quantization_schemes_dict(
self.scheme_overrides, self.scheme
)
if self.model_fuse_fn_kwargs is None:
self.model_fuse_fn_kwargs = {}
if self.ignore is None:
Expand Down Expand Up @@ -158,24 +139,3 @@ def check_should_disable_observer(self, event: Event) -> bool:

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier


class _QuantizationSchemesDict(dict):
# wrapper class for dict to override the __str__ method for yaml serialization

def __str__(self):
return str({submodule: scheme.dict() for submodule, scheme in self.items()})


def _load_quantization_schemes_dict(
schemes_dict: Optional[Dict[str, QuantizationSchemeLoadable]],
default_scheme: QuantizationScheme,
) -> Dict[str, QuantizationScheme]:
if schemes_dict is None:
return {}
return _QuantizationSchemesDict(
{
submodule: QuantizationScheme.load(scheme, default=default_scheme)
for submodule, scheme in schemes_dict.items()
}
)
50 changes: 49 additions & 1 deletion src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from itertools import cycle
from typing import Any, Callable
from typing import Any, Callable, Dict, Optional

import torch
from torch.nn import Module
Expand All @@ -26,6 +26,10 @@
freeze_bn_stats,
fuse_module_conv_bn_relus,
)
from sparseml.modifiers.quantization.utils.quantization_scheme import (
QuantizationScheme,
QuantizationSchemeLoadable,
)
from sparseml.modifiers.quantization.utils.quantize import (
convert_module_qat_from_schemes,
raise_if_torch_quantization_not_available,
Expand All @@ -38,12 +42,35 @@


class QuantizationModifierPyTorch(QuantizationModifier):
"""
Pytorch-specific implementation of quantization modifier

:param scheme: Default QuantizationScheme to use when enabling quantization
in a module. May also be a dictionary to be loaded into the QuantizationScheme
class. A string alias may also be used, supported aliases:
['default', 'deepsparse', 'tensorrt'].
If None, the default scheme (`QuantizationScheme()`) will be used.
Default is None
:param scheme_overrides: optional mapping of module type names or submodule type
names to quantization schemes to override them with. If a scheme is mapped to
'default', then it will use the scheme set in the modifier scheme property
"""

scheme: Optional[QuantizationSchemeLoadable] = None
scheme_overrides: Optional[Dict[str, QuantizationSchemeLoadable]] = None
calibration_dataloader_: Any = None
calibration_function_: Any = None
qat_enabled_: bool = False
quantization_observer_disabled_: bool = False
bn_stats_frozen_: bool = False

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.scheme = QuantizationScheme.load(self.scheme)
self.scheme_overrides = _load_quantization_schemes_dict(
self.scheme_overrides, self.scheme
)

def on_initialize(self, state: State, **kwargs) -> bool:
raise_if_torch_quantization_not_available()
if self.end and self.end != -1:
Expand Down Expand Up @@ -181,3 +208,24 @@ def _calibrate(self, module: Module):
module.train()
else:
self._disable_quantization_observer(module)


class _QuantizationSchemesDict(dict):
# wrapper class for dict to override the __str__ method for yaml serialization

def __str__(self):
return str({submodule: scheme.dict() for submodule, scheme in self.items()})


def _load_quantization_schemes_dict(
schemes_dict: Optional[Dict[str, QuantizationSchemeLoadable]],
default_scheme: QuantizationScheme,
) -> Dict[str, QuantizationScheme]:
if schemes_dict is None:
return {}
return _QuantizationSchemesDict(
{
submodule: QuantizationScheme.load(scheme, default=default_scheme)
for submodule, scheme in schemes_dict.items()
}
)
13 changes: 13 additions & 0 deletions src/sparseml/modifiers/quantization/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/sparseml/modifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/sparseml/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/sparseml/pytorch/modifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/sparseml/pytorch/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from sparseml.core.event import Event, EventType
from sparseml.core.factory import ModifierFactory
from sparseml.core.framework import Framework
from sparseml.modifiers.quantization import QuantizationModifierPyTorch
from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch
from sparseml.pytorch.sparsification.quantization.quantize import (
is_qat_helper_module,
is_quantizable_module,
)
from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory
from tests.sparseml.pytorch.helpers import ConvNet, LinearNet
from tests.sparseml.pytorch.sparsification.quantization.test_modifier_quantization import ( # noqa E501
_match_submodule_name_or_type,
_test_qat_wrapped_module,
_test_quantized_module,
)
Expand Down Expand Up @@ -56,7 +57,12 @@ def _test_qat_applied(modifier, model):
_test_qat_wrapped_module(model, name)
elif is_quantizable:
# check each target module is quantized
_test_quantized_module(model, modifier, module, name)
override_key = _match_submodule_name_or_type(
module,
name,
list(modifier.scheme_overrides.keys()),
)
_test_quantized_module(model, modifier, module, name, override_key)
else:
# check all non-target modules are not quantized
assert not hasattr(module, "quantization_scheme")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os

import pytest
import torch
from packaging import version
from torch.nn import Identity

from sparseml.pytorch.sparsification.quantization.helpers import QATWrapper
Expand Down Expand Up @@ -57,7 +59,17 @@ def _assert_observers_eq(observer_1, observer_2):

if hasattr(observer_1, "p"):
# assume observer is a partial, test by properties dict
assert observer_1.p.keywords == observer_2.p.keywords
observer_1_keywords = observer_1.p.keywords
observer_2_keywords = observer_2.p.keywords
TORCH_VERSION = version.parse(torch.__version__)
if (
TORCH_VERSION.major < 2
): # can't match observer class instances before 2.0
if "observer" in observer_1_keywords:
del observer_1_keywords["observer"]
if "observer" in observer_2_keywords:
del observer_2_keywords["observer"]
assert observer_1_keywords == observer_2_keywords
else:
# default to plain `==`
assert observer_1 == observer_2
Expand Down
Loading