Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix transformers and quantization imports #1770

Merged
merged 12 commits into from
Oct 18, 2023
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ifneq ($(findstring onnx,$(TARGETS)),onnx)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/onnx
endif
ifneq ($(findstring pytorch,$(TARGETS)),pytorch)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch --ignore tests/sparseml/modifiers
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch
endif
ifneq ($(findstring pytorch_models,$(TARGETS)),pytorch_models)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch/models
Expand Down
1 change: 0 additions & 1 deletion src/sparseml/modifiers/obcq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@
# flake8: noqa

from .base import *
from .pytorch import *
7 changes: 6 additions & 1 deletion src/sparseml/modifiers/obcq/utils/sparsegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
import torch.nn as nn
import transformers


DEBUG = False
Expand All @@ -42,6 +41,8 @@ class SparseGPT:
"""

def __init__(self, layer):
import transformers

self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
Expand All @@ -61,6 +62,8 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor):
:param inp: tensor containing layer input
:param out: tensor containing layer our
"""
import transformers

if DEBUG:
self._inp1 = inp
self.out1 = out
Expand Down Expand Up @@ -97,6 +100,8 @@ def fasterprune(
:param percdamp: Amount of dampening to apply to H, as a fraction of the
diagonal norm
"""
import transformers

W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
Expand Down
1 change: 0 additions & 1 deletion src/sparseml/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@
# flake8: noqa

from .base import *
from .pytorch import *
13 changes: 13 additions & 0 deletions src/sparseml/modifiers/quantization/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/sparseml/pytorch/modifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/sparseml/pytorch/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sparseml.core.factory import ModifierFactory
from sparseml.core.framework import Framework
from sparseml.modifiers.quantization import QuantizationModifier
from tests.sparseml.modifiers.conf import setup_modifier_factory
from tests.sparseml.pytorch.modifiers.conf import setup_modifier_factory


def test_quantization_registered():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
from sparseml.core.event import Event, EventType
from sparseml.core.factory import ModifierFactory
from sparseml.core.framework import Framework
from sparseml.modifiers.quantization import QuantizationModifierPyTorch
from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch
from sparseml.pytorch.sparsification.quantization.quantize import (
is_qat_helper_module,
is_quantizable_module,
)
from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory
from tests.sparseml.pytorch.helpers import ConvNet, LinearNet
from tests.sparseml.pytorch.modifiers.conf import (
LifecyleTestingHarness,
setup_modifier_factory,
)
from tests.sparseml.pytorch.sparsification.quantization.test_modifier_quantization import ( # noqa E501
_match_submodule_name_or_type,
_test_qat_wrapped_module,
_test_quantized_module,
)
Expand Down Expand Up @@ -56,7 +60,12 @@ def _test_qat_applied(modifier, model):
_test_qat_wrapped_module(model, name)
elif is_quantizable:
# check each target module is quantized
_test_quantized_module(model, modifier, module, name)
override_key = _match_submodule_name_or_type(
module,
name,
list(modifier.scheme_overrides.keys()),
)
_test_quantized_module(model, modifier, module, name, override_key)
else:
# check all non-target modules are not quantized
assert not hasattr(module, "quantization_scheme")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os

import pytest
import torch
from packaging import version
from torch.nn import Identity

from sparseml.pytorch.sparsification.quantization.helpers import QATWrapper
Expand Down Expand Up @@ -57,7 +59,17 @@ def _assert_observers_eq(observer_1, observer_2):

if hasattr(observer_1, "p"):
# assume observer is a partial, test by properties dict
assert observer_1.p.keywords == observer_2.p.keywords
observer_1_keywords = observer_1.p.keywords
observer_2_keywords = observer_2.p.keywords
TORCH_VERSION = version.parse(torch.__version__)
if (
TORCH_VERSION.major < 2
): # can't match observer class instances before 2.0
if "observer" in observer_1_keywords:
del observer_1_keywords["observer"]
if "observer" in observer_2_keywords:
del observer_2_keywords["observer"]
assert observer_1_keywords == observer_2_keywords
else:
# default to plain `==`
assert observer_1 == observer_2
Expand Down
Loading