Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SmoothQuant Modifier for OBCQ #1758

Merged
merged 219 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 207 commits
Commits
Show all changes
219 commits
Select commit Hold shift + click to select a range
bd729f3
First abstract structure with OPT, MPT
natuan Aug 18, 2023
0476a14
Merge branch 'main' into tuan/sparsegpt
Satrat Aug 24, 2023
e9c4e95
scaffolding
Satrat Aug 29, 2023
07a7273
First abstract structure with OPT, MPT
natuan Aug 18, 2023
73ccd2e
Updated
natuan Aug 29, 2023
f742f4a
Updated
natuan Aug 29, 2023
7c2d3d4
Initial implementation of Llama2 integration
anmarques Aug 30, 2023
3f4c5d8
Initial implementation of Llama2 integration
anmarques Aug 30, 2023
c181380
Fix llama-2 key
anmarques Aug 30, 2023
dd0f5f8
Add modelutils
anmarques Aug 30, 2023
99f8bbe
Make smoothquant optional
anmarques Aug 30, 2023
db901e4
probe sequence length from model
anmarques Aug 30, 2023
48ebe81
probe sequence length from model
anmarques Aug 30, 2023
2be7def
Fix typo
anmarques Aug 30, 2023
fb346ad
add in recipe and modifier logic
Satrat Aug 30, 2023
c4fc152
Catch additional arguments to load_model and load_data
anmarques Aug 30, 2023
2b50607
Redifinition of seqlen
anmarques Aug 30, 2023
15b2c10
Runable OPT main code path
natuan Aug 31, 2023
34c3d84
mpt modifier
Satrat Aug 31, 2023
4cac803
Merge branch 'tuan/sparsegpt' into sa/sgpt_modifier
Satrat Aug 31, 2023
a3934c2
Initial start implementation
markurtz Sep 1, 2023
6a35b7f
State before rebasing
anmarques Sep 1, 2023
e02abb2
State before rebasing
anmarques Sep 1, 2023
2fb9e66
Merge remote-tracking branch 'origin/tuan/sparsegpt' into research/sp…
anmarques Sep 1, 2023
466213c
Initial support for MPT
natuan Sep 4, 2023
d953be4
add in further completion state for session and events
markurtz Sep 5, 2023
7fc49ac
Initial commit for Llama2
natuan Sep 5, 2023
ff5b9a1
Return model after override Llama attn
natuan Sep 5, 2023
090af1c
working refactor
Satrat Sep 5, 2023
594339f
More memory-efficient layer compression
natuan Sep 5, 2023
5750edf
Make memory-efficient layer compressor default
natuan Sep 5, 2023
6c22fa1
finalization, clean up file saving
Satrat Sep 5, 2023
f3f9119
SmoothQuant enabled
natuan Sep 5, 2023
b2a6817
support for quantization, clean up
Satrat Sep 5, 2023
fa934c5
remove unsued files
Satrat Sep 5, 2023
a0deb44
remove unsued files
Satrat Sep 5, 2023
42aef3d
add in recipe helper functions for merging, loading, and running call…
markurtz Sep 6, 2023
6682784
minor fixes for new framework
markurtz Sep 6, 2023
59d1a72
move to transformers subdir
Satrat Sep 6, 2023
96b9248
quality
Satrat Sep 6, 2023
c2471e1
Merge branch 'main' into sa/sgpt_modifier
Satrat Sep 6, 2023
5b0f190
add constant pruning modifier
markurtz Sep 7, 2023
4e24413
fix modifier inheritance
Satrat Sep 7, 2023
836464d
PR comments, serialization
Satrat Sep 7, 2023
4b511c9
move compression to initializer
Satrat Sep 7, 2023
39bfbbf
clean up recipe
Satrat Sep 7, 2023
aced1d8
Rebase
anmarques Sep 7, 2023
6f3d188
Rebase
anmarques Sep 7, 2023
ae6c2e2
Rebase
anmarques Sep 7, 2023
ae2f8e0
Rebase
anmarques Sep 7, 2023
b52e0b4
Fix call to model eval
anmarques Sep 7, 2023
3647653
Fixes to caching
anmarques Sep 7, 2023
4652fcf
Add support for llama2
anmarques Sep 7, 2023
a81ba7b
Fix ptq-only option
anmarques Sep 7, 2023
f4f0a63
revert quant modifier
Satrat Sep 7, 2023
4e72fd6
Fixes
anmarques Sep 7, 2023
751a5ac
Rebase
anmarques Sep 7, 2023
71e8473
Rebase
anmarques Sep 7, 2023
48dd147
Bugs fixed
natuan Sep 8, 2023
924992c
Rebase
anmarques Sep 8, 2023
7b28a38
Rebase
anmarques Sep 8, 2023
ebf00e8
Rebase
anmarques Sep 8, 2023
4aa9e62
Rebase
anmarques Sep 8, 2023
c19d386
merge and move model loading to helper
Satrat Sep 8, 2023
6200f18
Rebase
anmarques Sep 8, 2023
6444a94
Rebase
anmarques Sep 8, 2023
b8452a5
add magntitude pruning modifier
markurtz Sep 9, 2023
f04ca6f
knowledge distillation implementation
markurtz Sep 10, 2023
83b23f7
docstrings
Satrat Sep 11, 2023
a32321c
docstrings
Satrat Sep 11, 2023
1bb260a
Rebase
anmarques Sep 12, 2023
2bb221f
Rebase
anmarques Sep 12, 2023
16cd0e5
basic llama modifier(not fully tested)
Satrat Sep 12, 2023
8967b20
working llama example
Satrat Sep 12, 2023
bba4cbb
Evaluate model in eval mode
anmarques Sep 12, 2023
19f8610
rebase
anmarques Sep 13, 2023
e89c17f
remove evaluate model for opt
anmarques Sep 13, 2023
3bd26e0
Fix model key for llama2
anmarques Sep 13, 2023
9307286
Fix model key for llama2
anmarques Sep 13, 2023
daf66cf
Fix dataset loading
anmarques Sep 13, 2023
d5f4b62
leave outputs as list
anmarques Sep 13, 2023
75539d5
Clean up
anmarques Sep 13, 2023
355f5c5
Fixes for input data as list
anmarques Sep 13, 2023
c745492
fix import errors and multiframework inits
Satrat Sep 14, 2023
bc73e15
fix import errors and multiframework inits
Satrat Sep 14, 2023
5438e05
initialization
Satrat Sep 14, 2023
4d0fdc3
First abstract structure with OPT, MPT
natuan Aug 18, 2023
0d00837
Updated
natuan Aug 29, 2023
e4a4878
Updated
natuan Aug 29, 2023
bb00c20
Runable OPT main code path
natuan Aug 31, 2023
277ee89
Initial support for MPT
natuan Sep 4, 2023
b454e26
Initial commit for Llama2
natuan Sep 5, 2023
e459c92
Return model after override Llama attn
natuan Sep 5, 2023
da511e9
More memory-efficient layer compression
natuan Sep 5, 2023
42a2845
Make memory-efficient layer compressor default
natuan Sep 5, 2023
55ddec5
SmoothQuant enabled
natuan Sep 5, 2023
aaa68a1
Bugs fixed
natuan Sep 8, 2023
cbc9360
Make load data more flexible
natuan Sep 11, 2023
f48cef6
Initialize scales in eval mode; clean up
natuan Sep 13, 2023
80f9208
Example script and recipe for OPT
natuan Sep 14, 2023
e61dc21
Formatting
natuan Sep 14, 2023
dff49e3
Copyright
natuan Sep 14, 2023
c63001d
Fix code styles
natuan Sep 15, 2023
bd213a5
Format
natuan Sep 15, 2023
32cf3c9
Fixes for channelwise quantization
anmarques Sep 15, 2023
c2afba0
Name fixes
anmarques Sep 15, 2023
996c533
RecipeModifiers working
Satrat Sep 15, 2023
aacdd54
Merge branch 'tuan/sparsegpt' into sa/sgpt_modifier
Satrat Sep 15, 2023
4dab25c
Merge branch 'sa/llama_modifiers' into sa/sgpt_modifier
Satrat Sep 15, 2023
5ae7a87
remove unused buffer
anmarques Sep 15, 2023
ef9da3a
Support for smoothquant
anmarques Sep 15, 2023
9635acb
fix import errors
markurtz Sep 17, 2023
5845359
Reformat smoothquant dict
anmarques Sep 18, 2023
e3166d0
Push smoothquant to a separate file
anmarques Sep 18, 2023
8e797c5
perplexity evaluation for opt
Satrat Sep 18, 2023
7ecd5c6
modifiers loading in stages
Satrat Sep 19, 2023
3e2954e
adding test files
Satrat Sep 19, 2023
5eed10d
merge with base and update
Satrat Sep 19, 2023
0807ba6
Rebase
anmarques Sep 19, 2023
3137424
Rebase
anmarques Sep 19, 2023
72f1d33
Add support to logarithmic activation equalization
anmarques Sep 19, 2023
69bb017
Add support to logarithmic activation equalization
anmarques Sep 19, 2023
7a08733
Add support to logarithmic activation equalization
anmarques Sep 19, 2023
69494db
Fix counter
anmarques Sep 19, 2023
b4fccfb
Rebase
anmarques Sep 19, 2023
ad9d7ea
Rebase
anmarques Sep 19, 2023
4bcf07b
Rebase
anmarques Sep 19, 2023
a981621
Rebase
anmarques Sep 19, 2023
5d9bbfb
Add license message
anmarques Sep 19, 2023
6b83b02
modifier factory implementation
markurtz Sep 19, 2023
e857729
running example, but sparsity not working correctly
Satrat Sep 19, 2023
b134431
Account for when keys are not matched
anmarques Sep 19, 2023
55027ce
Expand caching to include inputs. Move main ppl logic to utils
anmarques Sep 19, 2023
925fa61
Expand caching to include inputs. Move main ppl logic to utils
anmarques Sep 19, 2023
e88aa5d
Update opt integration
anmarques Sep 19, 2023
55eecc3
merge in factory, make it functional
Satrat Sep 19, 2023
bc5798d
fix polynomial scheduler, leave masks enabled on end
Satrat Sep 20, 2023
a35581d
remove e2e files
Satrat Sep 20, 2023
71869be
add on_event for modifier lifecycle and add initial integration for t…
markurtz Sep 20, 2023
2d04ea0
leave_enabled fixes
Satrat Sep 20, 2023
7b182e4
fixing evals and finalization
Satrat Sep 20, 2023
031c539
rebasing research and cleaning up perplexity
Satrat Sep 21, 2023
ddf35be
remove local files
Satrat Sep 21, 2023
4e3fd35
style and base obcq modifier
Satrat Sep 21, 2023
4c34fae
Merge branch 'sa/sgpt_modifier' into refactor_obcq
Satrat Sep 21, 2023
8015028
[untested] convert obcq to new framework
Satrat Sep 21, 2023
727928f
obcq working
Satrat Sep 21, 2023
6c2255f
Add test
rahul-tuli Sep 21, 2023
abeedb7
Add changes to allow accepting strings
rahul-tuli Sep 21, 2023
c7848e5
update llama example recipe
Satrat Sep 22, 2023
571d21d
fix recipe staging issue
Satrat Sep 22, 2023
952e4ee
style
Satrat Sep 22, 2023
ed8e0ba
style fixes
Satrat Sep 22, 2023
7236de7
Merge branch 'main' into sparsification-refactor
Satrat Sep 22, 2023
42a235b
reorg file structure
Satrat Sep 25, 2023
ef0ef18
quant modifier in new framework, opt tested
Satrat Sep 25, 2023
4d1c716
Merge branch 'sparsification-refactor' into refactor_obcq
Satrat Sep 25, 2023
8887b61
Removing custom smoothquant from this branch
anmarques Sep 26, 2023
4597497
Merge branch 'main' into research/sparsegpt/llama2
anmarques Sep 26, 2023
a1d15c8
post one shot calibration, recipe update
Satrat Sep 26, 2023
bfd7f84
bug fixes that came up during obcq implementation
Satrat Sep 26, 2023
05e0efb
Merge branch 'sparsification-refactor' into refactor_obcq
Satrat Sep 26, 2023
be06113
moving obcq script
Satrat Sep 26, 2023
629f9c5
quant fix, prunen and m
Satrat Sep 27, 2023
13ac2b9
fix experimental import paths, add comparison test
Satrat Sep 27, 2023
0cadcf3
return perplexity
Satrat Sep 27, 2023
76a8391
Merge branch 'research/sparsegpt/llama2' into refactor_obcq
Satrat Sep 29, 2023
d9e969c
style
Satrat Sep 29, 2023
eec247d
specify compressible layers in recipe
Satrat Sep 29, 2023
5a985d0
move attention cache to base class
Satrat Oct 3, 2023
bae96af
move attention cache to base class
Satrat Oct 3, 2023
6067521
clean up test scripts
Satrat Oct 3, 2023
27ba629
clean up bottom compressors and comments
Satrat Oct 4, 2023
0bdee59
documentation
Satrat Oct 4, 2023
ac29d68
fix typos
Satrat Oct 4, 2023
7617f9c
PR comments
Satrat Oct 5, 2023
f216d53
small bug fix on logger
Satrat Oct 5, 2023
dcbbbd4
Merge branch 'main' into refactor_obcq
Satrat Oct 5, 2023
de3dc2f
fixing transformer dependency, adding registry for dataset loaders
Satrat Oct 5, 2023
39ff676
fixing bugs in dataset loading and perplexity
Satrat Oct 6, 2023
7423f20
cleanup
Satrat Oct 6, 2023
25ba312
Merge branch 'main' into refactor_obcq
Satrat Oct 6, 2023
13c3a6c
fix memory issue in comparison
Satrat Oct 6, 2023
b5e9d6d
return perplexity
Satrat Oct 6, 2023
1decfe5
adding split to all datasets
Satrat Oct 6, 2023
0a0cff9
Merge branch 'refactor_obcq' of https://github.com/neuralmagic/sparse…
Satrat Oct 6, 2023
5fa7361
Update src/sparseml/modifiers/obcq/base.py
Satrat Oct 9, 2023
f1441cb
Merge branch 'main' into refactor_obcq
Satrat Oct 9, 2023
40cc9a5
outlines for smoothquant modifier
Satrat Oct 9, 2023
be99aff
Merge branch 'refactor_obcq' into prod_smooth_quant
Satrat Oct 9, 2023
a95d97a
smoothquant runs for OPT
Satrat Oct 10, 2023
459ab7b
recipe update, small bugfix
Satrat Oct 10, 2023
29271a3
fixing dataset issues
Satrat Oct 10, 2023
caed150
Merge branch 'refactor_obcq' into prod_smooth_quant
Satrat Oct 10, 2023
046276b
docstrings
Satrat Oct 11, 2023
e6bf7e9
fix hook deletion
Satrat Oct 11, 2023
4bb2b2c
fix llama recipe
Satrat Oct 11, 2023
4f2f166
Merge branch 'main' into prod_smooth_quant
Satrat Oct 11, 2023
c8be5b4
quality
Satrat Oct 11, 2023
525bed4
addressing PR comments
Satrat Oct 12, 2023
9b81f69
Merge branch 'main' into prod_smooth_quant
Satrat Oct 12, 2023
d488f84
make modifier more generic to framework
Satrat Oct 12, 2023
ded5f00
Merge branch 'prod_smooth_quant' of github.com:neuralmagic/sparseml i…
Satrat Oct 12, 2023
0ee5faa
rename alpha, adding alias
Satrat Oct 16, 2023
b03ab84
Merge branch 'main' into prod_smooth_quant
Satrat Oct 16, 2023
8244bc5
Merge branch 'main' into prod_smooth_quant
Satrat Oct 19, 2023
45e3630
remove pytorch import
Satrat Oct 19, 2023
6b0393d
PR comments on logarithmic equalization
Satrat Oct 20, 2023
3012189
Update src/sparseml/modifiers/smoothquant/base.py
Satrat Oct 20, 2023
12ab5e6
Update src/sparseml/modifiers/smoothquant/base.py
Satrat Oct 20, 2023
06b7f7d
Update src/sparseml/modifiers/smoothquant/base.py
Satrat Oct 20, 2023
042cb3f
PR comments
Satrat Oct 20, 2023
68b1841
Merge branch 'prod_smooth_quant' of github.com:neuralmagic/sparseml i…
Satrat Oct 20, 2023
11d1be3
clean up comments
Satrat Oct 20, 2023
7a83668
Merge branch 'main' into prod_smooth_quant
Satrat Oct 20, 2023
31b8bab
Merge branch 'main' into prod_smooth_quant
Satrat Oct 20, 2023
788c16b
fix calibration model reference, adding tiny integration test
Satrat Oct 23, 2023
800d08b
Merge branch 'main' into prod_smooth_quant
Satrat Oct 31, 2023
de2b0cd
style
Satrat Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/sparseml/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Dict, Generic, List, Optional, TypeVar, Union
from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union

from sparseml.core.framework import Framework
from sparseml.core.framework_object import MultiFrameworkObject
Expand Down Expand Up @@ -116,3 +116,13 @@ def set_param(self, target: str, param: PT):
:param param: the param instance to set
"""
raise NotImplementedError()

def get_matching_layer(
self, target: str, name_to_match: str, model: LT
) -> Optional[Tuple[str, LT]]:
"""
:param target: layer name to target when searching model
:param name_to_match: name to match targets to
:param model: model to search for targets
"""
raise NotImplementedError()
11 changes: 11 additions & 0 deletions src/sparseml/core/model/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_layer,
get_layers,
get_layers_params,
get_matching_layer,
get_param,
get_params,
set_layer,
Expand Down Expand Up @@ -94,3 +95,13 @@ def set_param(self, target: str, param: Parameter):
:param param: the parameter to set
"""
return set_param(target, param, self.model)

def get_matching_layer(
self, target: str, name_to_match: str, model: Module
) -> Optional[Tuple[str, Module]]:
"""
:param target: layer name to target when searching model
Satrat marked this conversation as resolved.
Show resolved Hide resolved
:param name_to_match: name to match targets to
:param model: model to search for targets
"""
return get_matching_layer(target, name_to_match, model)
1 change: 1 addition & 0 deletions src/sparseml/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .obcq import *
from .pruning import *
from .quantization import *
from .smoothquant import *
28 changes: 7 additions & 21 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# limitations under the License.

import logging
from itertools import cycle
from typing import Any, Callable, Dict, Optional
from typing import Any, Dict, Optional

import torch
from torch.nn import Module
Expand All @@ -35,7 +34,7 @@
raise_if_torch_quantization_not_available,
set_quantization_schemes,
)
from sparseml.pytorch.utils import tensors_module_forward, tensors_to_device
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -184,26 +183,13 @@ def _calibrate(self, module: Module):
module_training = module.training
module.eval()

forward_fn: Callable = (
self.calibration_function_
if self.calibration_function_
else tensors_module_forward
)

model_device = next(module.parameters()).device
_dataloader = (
self.calibration_dataloader_
if self.num_calibration_steps is None
else cycle(self.calibration_dataloader_)
run_calibration_forward(
module,
self.calibration_dataloader_,
self.num_calibration_steps,
self.calibration_function_,
)

for batch_idx, batch in enumerate(_dataloader):
if self.num_calibration_steps and batch_idx >= self.num_calibration_steps:
break
batch = tensors_to_device(batch, model_device)
with torch.no_grad():
forward_fn(batch, module=module)

if module_training:
module.train()
else:
Expand Down
17 changes: 17 additions & 0 deletions src/sparseml/modifiers/smoothquant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# flake8: noqa

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import *
172 changes: 172 additions & 0 deletions src/sparseml/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dataclasses import dataclass
from typing import Dict, Generic, List, Optional, Tuple, TypeVar

from pydantic import Field

from sparseml.core import Modifier
from sparseml.core.model import ModifiableModel
from sparseml.core.model.base import LT
from sparseml.core.state import Event, State


VT = TypeVar("VT") # represents a generic vector

__all__ = ["SmoothQuantScale", "SmoothQuantMapping", "SmoothQuantModifier"]


@dataclass
class SmoothQuantScale(Generic[VT]):
"""
Dataclass for storing the channel-wise minimum and maximum values for a layer. This
is updated each forward pass during calibration

:param min_channel_vals: minimum output value seen so far, per channel
:param max_channel_vals: maximum output value seen so far, per channel
"""

min_channel_vals: VT
max_channel_vals: VT


@dataclass
class SmoothQuantMapping(Generic[LT]):
"""
Dataclass for storing the mapping between an activation layer and the following
weights that must be balanced during smoothing

:param smooth_name: name of the activation layer
:param smooth_layer: PyTorch module storing the activation layer
:param balance_layers: list of PyTorch modules that smooth_layer feeds into, must be
balanced to offset the smoothing of smooth_layer
"""

smooth_name: str
smooth_layer: LT
balance_layers: List[LT]


class SmoothQuantModifier(Modifier):
"""
Implements the SmoothQuant algorithm from https://arxiv.org/abs/2211.10438. This
modifier performs a channel-wise smoothing of outliers in activations, making them
easier to quantize by reducing the dynamic range. The smoothing is offset by
applying the inverse operation to the next layer of weights, making the weights
slightly more difficult to quantize.

Because this modifier manipulates the weights of the model, it can only be used in
in one-shot and not during training. Activation ranges are determined by running a
small set of calibration data through the model.

Satrat marked this conversation as resolved.
Show resolved Hide resolved
:param smoothing_strength: alpha, intensity of smoothing to perform (0-1 range)
:param mappings: list activation layers to smooth, and the which layers to offset
the smoothing to for each activation
:param ignore: list of layers to ignore, even if they match a regex in mappings
:param logarithmic_equalization: Whether to use a logarithmic scale for smoothing
:param num_calibration_steps: number of samples to use for calibration, or None to
use the whole dataset
"""

smoothing_strength: float = Field(..., alias="alpha")
Satrat marked this conversation as resolved.
Show resolved Hide resolved
mappings: List[Tuple]
ignore: Optional[List[str]] = None
logarithmic_equalization: Optional[bool] = False
Satrat marked this conversation as resolved.
Show resolved Hide resolved
num_calibration_steps: Optional[int] = None

resolved_mappings_: Dict = None
Satrat marked this conversation as resolved.
Show resolved Hide resolved
scales_: Dict = None
Satrat marked this conversation as resolved.
Show resolved Hide resolved

def on_initialize_structure(self, state: "State", **kwargs):
Satrat marked this conversation as resolved.
Show resolved Hide resolved
pass # nothing needed for this modifier

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run SmoothQuant on the given state

:param state: state to run SmoothQuant on
:return: True on a successful run, False otherwise
"""
if self.end and self.end != -1:
raise ValueError(
"SmoothQuantModifier can only be applied during one-shot. Expected end"
" to be None or -1, got {}".format(self.end)
)
if self.start and self.start != -1:
raise ValueError(
"SmoothQuantModifier can only be applied during one-shot. Expected "
"start to be None or -1, got {}".format(self.start)
)

self.ignore = [] if not self.ignore else self.ignore
self.resolved_mappings_ = self._resolve_mappings(state.model)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
self.scales_ = {}

def _resolve_mappings(self, model: ModifiableModel):
"""
Transforms the list of activations to smooth and their corresponding weights
into SmoothQuantMapping objects, resolving regular expressions.

For each activation in the mapping list, we find the corresponding weight to
balance by searching for the longest substring. For instance, if our balance
weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
repeat for model.layer.1 and so on
"""
resolved_mappings = []
for to_balance, to_smooth in self.mappings:
to_smooth_layers = model.get_layers(to_smooth)
for layer_name, smooth_layer in to_smooth_layers.items():
if layer_name not in self.ignore:
balance_layers = []
for balance_suffix in to_balance:
# find the submodule that matches the activation layer
_, balance_layer = model.get_matching_layer(
balance_suffix, layer_name, model.model
)
if balance_layer:
balance_layers.append(balance_layer)
# each mapping can contain multiple layers to balance, but only
# one layer to smooth
mapping = SmoothQuantMapping(
layer_name, smooth_layer, balance_layers
)
resolved_mappings.append(mapping)
return resolved_mappings

def on_start(self, state: State, event: Event, **kwargs):
pass

def on_update(self, state: State, event: Event, **kwargs):
Satrat marked this conversation as resolved.
Show resolved Hide resolved
pass

def on_end(self, state: State, event: Event, **kwargs):
pass

def on_event(self, state: State, event: Event, **kwargs):
pass

def on_finalize(self, state: State, **kwargs) -> bool:
"""
Clean up by clearing the scale and mapping data

:param state: unused
:return: True
"""
self.scales_.clear()
self.resolved_mappings_.clear()
Satrat marked this conversation as resolved.
Show resolved Hide resolved

return True
Loading
Loading