Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifier Refactor OBCQ Implementation #1737

Merged
merged 189 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
189 commits
Select commit Hold shift + click to select a range
bd729f3
First abstract structure with OPT, MPT
natuan Aug 18, 2023
0476a14
Merge branch 'main' into tuan/sparsegpt
Satrat Aug 24, 2023
e9c4e95
scaffolding
Satrat Aug 29, 2023
07a7273
First abstract structure with OPT, MPT
natuan Aug 18, 2023
73ccd2e
Updated
natuan Aug 29, 2023
f742f4a
Updated
natuan Aug 29, 2023
7c2d3d4
Initial implementation of Llama2 integration
anmarques Aug 30, 2023
3f4c5d8
Initial implementation of Llama2 integration
anmarques Aug 30, 2023
c181380
Fix llama-2 key
anmarques Aug 30, 2023
dd0f5f8
Add modelutils
anmarques Aug 30, 2023
99f8bbe
Make smoothquant optional
anmarques Aug 30, 2023
db901e4
probe sequence length from model
anmarques Aug 30, 2023
48ebe81
probe sequence length from model
anmarques Aug 30, 2023
2be7def
Fix typo
anmarques Aug 30, 2023
fb346ad
add in recipe and modifier logic
Satrat Aug 30, 2023
c4fc152
Catch additional arguments to load_model and load_data
anmarques Aug 30, 2023
2b50607
Redifinition of seqlen
anmarques Aug 30, 2023
15b2c10
Runable OPT main code path
natuan Aug 31, 2023
34c3d84
mpt modifier
Satrat Aug 31, 2023
4cac803
Merge branch 'tuan/sparsegpt' into sa/sgpt_modifier
Satrat Aug 31, 2023
a3934c2
Initial start implementation
markurtz Sep 1, 2023
6a35b7f
State before rebasing
anmarques Sep 1, 2023
e02abb2
State before rebasing
anmarques Sep 1, 2023
2fb9e66
Merge remote-tracking branch 'origin/tuan/sparsegpt' into research/sp…
anmarques Sep 1, 2023
466213c
Initial support for MPT
natuan Sep 4, 2023
d953be4
add in further completion state for session and events
markurtz Sep 5, 2023
7fc49ac
Initial commit for Llama2
natuan Sep 5, 2023
ff5b9a1
Return model after override Llama attn
natuan Sep 5, 2023
090af1c
working refactor
Satrat Sep 5, 2023
594339f
More memory-efficient layer compression
natuan Sep 5, 2023
5750edf
Make memory-efficient layer compressor default
natuan Sep 5, 2023
6c22fa1
finalization, clean up file saving
Satrat Sep 5, 2023
f3f9119
SmoothQuant enabled
natuan Sep 5, 2023
b2a6817
support for quantization, clean up
Satrat Sep 5, 2023
fa934c5
remove unsued files
Satrat Sep 5, 2023
a0deb44
remove unsued files
Satrat Sep 5, 2023
42aef3d
add in recipe helper functions for merging, loading, and running call…
markurtz Sep 6, 2023
6682784
minor fixes for new framework
markurtz Sep 6, 2023
59d1a72
move to transformers subdir
Satrat Sep 6, 2023
96b9248
quality
Satrat Sep 6, 2023
c2471e1
Merge branch 'main' into sa/sgpt_modifier
Satrat Sep 6, 2023
5b0f190
add constant pruning modifier
markurtz Sep 7, 2023
4e24413
fix modifier inheritance
Satrat Sep 7, 2023
836464d
PR comments, serialization
Satrat Sep 7, 2023
4b511c9
move compression to initializer
Satrat Sep 7, 2023
39bfbbf
clean up recipe
Satrat Sep 7, 2023
aced1d8
Rebase
anmarques Sep 7, 2023
6f3d188
Rebase
anmarques Sep 7, 2023
ae6c2e2
Rebase
anmarques Sep 7, 2023
ae2f8e0
Rebase
anmarques Sep 7, 2023
b52e0b4
Fix call to model eval
anmarques Sep 7, 2023
3647653
Fixes to caching
anmarques Sep 7, 2023
4652fcf
Add support for llama2
anmarques Sep 7, 2023
a81ba7b
Fix ptq-only option
anmarques Sep 7, 2023
f4f0a63
revert quant modifier
Satrat Sep 7, 2023
4e72fd6
Fixes
anmarques Sep 7, 2023
751a5ac
Rebase
anmarques Sep 7, 2023
71e8473
Rebase
anmarques Sep 7, 2023
48dd147
Bugs fixed
natuan Sep 8, 2023
924992c
Rebase
anmarques Sep 8, 2023
7b28a38
Rebase
anmarques Sep 8, 2023
ebf00e8
Rebase
anmarques Sep 8, 2023
4aa9e62
Rebase
anmarques Sep 8, 2023
c19d386
merge and move model loading to helper
Satrat Sep 8, 2023
6200f18
Rebase
anmarques Sep 8, 2023
6444a94
Rebase
anmarques Sep 8, 2023
b8452a5
add magntitude pruning modifier
markurtz Sep 9, 2023
f04ca6f
knowledge distillation implementation
markurtz Sep 10, 2023
83b23f7
docstrings
Satrat Sep 11, 2023
a32321c
docstrings
Satrat Sep 11, 2023
1bb260a
Rebase
anmarques Sep 12, 2023
2bb221f
Rebase
anmarques Sep 12, 2023
16cd0e5
basic llama modifier(not fully tested)
Satrat Sep 12, 2023
8967b20
working llama example
Satrat Sep 12, 2023
bba4cbb
Evaluate model in eval mode
anmarques Sep 12, 2023
19f8610
rebase
anmarques Sep 13, 2023
e89c17f
remove evaluate model for opt
anmarques Sep 13, 2023
3bd26e0
Fix model key for llama2
anmarques Sep 13, 2023
9307286
Fix model key for llama2
anmarques Sep 13, 2023
daf66cf
Fix dataset loading
anmarques Sep 13, 2023
d5f4b62
leave outputs as list
anmarques Sep 13, 2023
75539d5
Clean up
anmarques Sep 13, 2023
355f5c5
Fixes for input data as list
anmarques Sep 13, 2023
c745492
fix import errors and multiframework inits
Satrat Sep 14, 2023
bc73e15
fix import errors and multiframework inits
Satrat Sep 14, 2023
5438e05
initialization
Satrat Sep 14, 2023
4d0fdc3
First abstract structure with OPT, MPT
natuan Aug 18, 2023
0d00837
Updated
natuan Aug 29, 2023
e4a4878
Updated
natuan Aug 29, 2023
bb00c20
Runable OPT main code path
natuan Aug 31, 2023
277ee89
Initial support for MPT
natuan Sep 4, 2023
b454e26
Initial commit for Llama2
natuan Sep 5, 2023
e459c92
Return model after override Llama attn
natuan Sep 5, 2023
da511e9
More memory-efficient layer compression
natuan Sep 5, 2023
42a2845
Make memory-efficient layer compressor default
natuan Sep 5, 2023
55ddec5
SmoothQuant enabled
natuan Sep 5, 2023
aaa68a1
Bugs fixed
natuan Sep 8, 2023
cbc9360
Make load data more flexible
natuan Sep 11, 2023
f48cef6
Initialize scales in eval mode; clean up
natuan Sep 13, 2023
80f9208
Example script and recipe for OPT
natuan Sep 14, 2023
e61dc21
Formatting
natuan Sep 14, 2023
dff49e3
Copyright
natuan Sep 14, 2023
c63001d
Fix code styles
natuan Sep 15, 2023
bd213a5
Format
natuan Sep 15, 2023
32cf3c9
Fixes for channelwise quantization
anmarques Sep 15, 2023
c2afba0
Name fixes
anmarques Sep 15, 2023
996c533
RecipeModifiers working
Satrat Sep 15, 2023
aacdd54
Merge branch 'tuan/sparsegpt' into sa/sgpt_modifier
Satrat Sep 15, 2023
4dab25c
Merge branch 'sa/llama_modifiers' into sa/sgpt_modifier
Satrat Sep 15, 2023
5ae7a87
remove unused buffer
anmarques Sep 15, 2023
ef9da3a
Support for smoothquant
anmarques Sep 15, 2023
9635acb
fix import errors
markurtz Sep 17, 2023
5845359
Reformat smoothquant dict
anmarques Sep 18, 2023
e3166d0
Push smoothquant to a separate file
anmarques Sep 18, 2023
8e797c5
perplexity evaluation for opt
Satrat Sep 18, 2023
7ecd5c6
modifiers loading in stages
Satrat Sep 19, 2023
3e2954e
adding test files
Satrat Sep 19, 2023
5eed10d
merge with base and update
Satrat Sep 19, 2023
0807ba6
Rebase
anmarques Sep 19, 2023
3137424
Rebase
anmarques Sep 19, 2023
72f1d33
Add support to logarithmic activation equalization
anmarques Sep 19, 2023
69bb017
Add support to logarithmic activation equalization
anmarques Sep 19, 2023
7a08733
Add support to logarithmic activation equalization
anmarques Sep 19, 2023
69494db
Fix counter
anmarques Sep 19, 2023
b4fccfb
Rebase
anmarques Sep 19, 2023
ad9d7ea
Rebase
anmarques Sep 19, 2023
4bcf07b
Rebase
anmarques Sep 19, 2023
a981621
Rebase
anmarques Sep 19, 2023
5d9bbfb
Add license message
anmarques Sep 19, 2023
6b83b02
modifier factory implementation
markurtz Sep 19, 2023
e857729
running example, but sparsity not working correctly
Satrat Sep 19, 2023
b134431
Account for when keys are not matched
anmarques Sep 19, 2023
55027ce
Expand caching to include inputs. Move main ppl logic to utils
anmarques Sep 19, 2023
925fa61
Expand caching to include inputs. Move main ppl logic to utils
anmarques Sep 19, 2023
e88aa5d
Update opt integration
anmarques Sep 19, 2023
55eecc3
merge in factory, make it functional
Satrat Sep 19, 2023
bc5798d
fix polynomial scheduler, leave masks enabled on end
Satrat Sep 20, 2023
a35581d
remove e2e files
Satrat Sep 20, 2023
71869be
add on_event for modifier lifecycle and add initial integration for t…
markurtz Sep 20, 2023
2d04ea0
leave_enabled fixes
Satrat Sep 20, 2023
7b182e4
fixing evals and finalization
Satrat Sep 20, 2023
031c539
rebasing research and cleaning up perplexity
Satrat Sep 21, 2023
ddf35be
remove local files
Satrat Sep 21, 2023
4e3fd35
style and base obcq modifier
Satrat Sep 21, 2023
4c34fae
Merge branch 'sa/sgpt_modifier' into refactor_obcq
Satrat Sep 21, 2023
8015028
[untested] convert obcq to new framework
Satrat Sep 21, 2023
727928f
obcq working
Satrat Sep 21, 2023
6c2255f
Add test
rahul-tuli Sep 21, 2023
abeedb7
Add changes to allow accepting strings
rahul-tuli Sep 21, 2023
c7848e5
update llama example recipe
Satrat Sep 22, 2023
571d21d
fix recipe staging issue
Satrat Sep 22, 2023
952e4ee
style
Satrat Sep 22, 2023
ed8e0ba
style fixes
Satrat Sep 22, 2023
7236de7
Merge branch 'main' into sparsification-refactor
Satrat Sep 22, 2023
42a235b
reorg file structure
Satrat Sep 25, 2023
ef0ef18
quant modifier in new framework, opt tested
Satrat Sep 25, 2023
4d1c716
Merge branch 'sparsification-refactor' into refactor_obcq
Satrat Sep 25, 2023
8887b61
Removing custom smoothquant from this branch
anmarques Sep 26, 2023
4597497
Merge branch 'main' into research/sparsegpt/llama2
anmarques Sep 26, 2023
a1d15c8
post one shot calibration, recipe update
Satrat Sep 26, 2023
bfd7f84
bug fixes that came up during obcq implementation
Satrat Sep 26, 2023
05e0efb
Merge branch 'sparsification-refactor' into refactor_obcq
Satrat Sep 26, 2023
be06113
moving obcq script
Satrat Sep 26, 2023
629f9c5
quant fix, prunen and m
Satrat Sep 27, 2023
13ac2b9
fix experimental import paths, add comparison test
Satrat Sep 27, 2023
0cadcf3
return perplexity
Satrat Sep 27, 2023
76a8391
Merge branch 'research/sparsegpt/llama2' into refactor_obcq
Satrat Sep 29, 2023
d9e969c
style
Satrat Sep 29, 2023
eec247d
specify compressible layers in recipe
Satrat Sep 29, 2023
5a985d0
move attention cache to base class
Satrat Oct 3, 2023
bae96af
move attention cache to base class
Satrat Oct 3, 2023
6067521
clean up test scripts
Satrat Oct 3, 2023
27ba629
clean up bottom compressors and comments
Satrat Oct 4, 2023
0bdee59
documentation
Satrat Oct 4, 2023
ac29d68
fix typos
Satrat Oct 4, 2023
7617f9c
PR comments
Satrat Oct 5, 2023
f216d53
small bug fix on logger
Satrat Oct 5, 2023
dcbbbd4
Merge branch 'main' into refactor_obcq
Satrat Oct 5, 2023
de3dc2f
fixing transformer dependency, adding registry for dataset loaders
Satrat Oct 5, 2023
39ff676
fixing bugs in dataset loading and perplexity
Satrat Oct 6, 2023
7423f20
cleanup
Satrat Oct 6, 2023
25ba312
Merge branch 'main' into refactor_obcq
Satrat Oct 6, 2023
13c3a6c
fix memory issue in comparison
Satrat Oct 6, 2023
b5e9d6d
return perplexity
Satrat Oct 6, 2023
1decfe5
adding split to all datasets
Satrat Oct 6, 2023
0a0cff9
Merge branch 'refactor_obcq' of https://github.com/neuralmagic/sparse…
Satrat Oct 6, 2023
5fa7361
Update src/sparseml/modifiers/obcq/base.py
Satrat Oct 9, 2023
f1441cb
Merge branch 'main' into refactor_obcq
Satrat Oct 9, 2023
29271a3
fixing dataset issues
Satrat Oct 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/sparseml/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions src/sparseml/experimental/sparsegpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
38 changes: 22 additions & 16 deletions src/sparseml/experimental/sparsegpt/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

SUPPORTED_MODELS = ["opt", "mpt", "llama-2"]
SUPPORTED_MODELS = ["opt", "mpt", "llama"]


def load_model(args, model_key: str = None, *gargs, **kwargs):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import load_model as _load_model
from sparseml.experimental.sparsegpt.opt import load_model as _load_model
elif model_key == "mpt":
from mpt import load_model as _load_model
elif model_key == "llama-2":
from llama2 import load_model as _load_model
from sparseml.experimental.sparsegpt.mpt import load_model as _load_model
elif model_key == "llama":
from sparseml.experimental.sparsegpt.llama2 import load_model as _load_model
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _load_model(args, *gargs, **kwargs)
Expand All @@ -31,11 +31,11 @@ def load_model(args, model_key: str = None, *gargs, **kwargs):
def load_data(args, model_key: str = None, *gargs, **kwargs):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import load_data as _load_data
from sparseml.experimental.sparsegpt.opt import load_data as _load_data
elif model_key == "mpt":
from mpt import load_data as _load_data
elif model_key == "llama-2":
from llama2 import load_data as _load_data
from sparseml.experimental.sparsegpt.mpt import load_data as _load_data
elif model_key == "llama":
from sparseml.experimental.sparsegpt.llama2 import load_data as _load_data
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _load_data(args, *gargs, **kwargs)
Expand All @@ -46,9 +46,9 @@ def evaluate_perplexity(
):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import ppl_eval as _ppl_eval
elif model_key == "llama-2":
from llama2 import ppl_eval as _ppl_eval
from sparseml.experimental.sparsegpt.opt import ppl_eval as _ppl_eval
elif model_key == "llama":
from sparseml.experimental.sparsegpt.llama2 import ppl_eval as _ppl_eval
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _ppl_eval(args, model, dataloader, dev, *gargs, **kwargs)
Expand All @@ -57,11 +57,17 @@ def evaluate_perplexity(
def prepare_sparsegpt(model, dataloader, args, model_key: str = None, **kwargs):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import prepare_sparsegpt as _prepare_sparsegpt
from sparseml.experimental.sparsegpt.opt import (
prepare_sparsegpt as _prepare_sparsegpt,
)
elif model_key == "mpt":
from mpt import prepare_sparsegpt as _prepare_sparsegpt
elif model_key == "llama-2":
from llama2 import prepare_sparsegpt as _prepare_sparsegpt
from sparseml.experimental.sparsegpt.mpt import (
prepare_sparsegpt as _prepare_sparsegpt,
)
elif model_key == "llama":
from sparseml.experimental.sparsegpt.llama2 import (
prepare_sparsegpt as _prepare_sparsegpt,
)
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _prepare_sparsegpt(model, dataloader, args, **kwargs)
Expand Down
112 changes: 112 additions & 0 deletions src/sparseml/experimental/sparsegpt/examples/llama2/compare_obcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from sparseml.experimental.sparsegpt.dispatch import evaluate_perplexity, load_model
from sparseml.experimental.sparsegpt.llama2 import load_data
from sparseml.experimental.sparsegpt.main import sequential
from sparseml.modifiers.obcq.utils.helpers import ppl_eval_general
from sparseml.transformers.sparsification.obcq.obcq import one_shot
from sparseml.transformers.sparsification.obcq.utils.helpers import llama_forward


dataset = "open_platypus"
model_name = "/home/sadkins/ml-experiments/nlg-text_generation/"
model_name += "llama_chat-llama_7b_chat-base/dense/training"
sparsity = 0.5
nbits = 8
smooth_quant = 0
observer_batches = 128
nsamples = 128
data_sequence_length = 2048
sequential_hessian = 0
experimental_recipe = "src/sparseml/experimental/sparsegpt/examples/llama2/recipes/"
experimental_recipe += "llama_recipe.yaml"
prod_recipe = "src/sparseml/transformers/sparsification/obcq/example_llama.yaml"
device = "cuda:0"
seed = 0
prunen = 0
prunem = 0
percdamp = 0.01
blocksize = 128
ptq_only = 0


class ExperimentalArgs:
model = model_name
dataset = dataset
data_sequence_length = data_sequence_length
sequential_hessian_within_layer = sequential_hessian
recipe = experimental_recipe
sparsity = sparsity
wbits = nbits
observer_batches = observer_batches
nsamples = nsamples
smoothquant = smooth_quant
seed = seed
prunen = prunen
prunem = prunem
percdamp = percdamp
blocksize = blocksize
ptq_only = ptq_only


class ProdArgs:
model = model_name
dataset = dataset
nsamples = nsamples
device = device
recipe = prod_recipe
eval = False
save = False


def run_experimental_obcq(experimental_args):
model = load_model(experimental_args)
calibration_data, _, _ = load_data(experimental_args, data_sequence_length)
sequential(model, calibration_data, device, experimental_args)

del calibration_data
return model


if __name__ == "__main__":
experimental_args = ExperimentalArgs()
exp_model = run_experimental_obcq(experimental_args)
_, testloader, _ = load_data(experimental_args, data_sequence_length)
exp_perplexity = evaluate_perplexity(
experimental_args, exp_model, testloader, device, max_samples_per_iteration=8
)
del testloader
del exp_model
torch.cuda.empty_cache()

prod_args = ProdArgs()
prod_model = one_shot(
model_path=prod_args.model,
dataset_name=prod_args.dataset,
num_samples=prod_args.nsamples,
device=prod_args.device,
recipe_file=prod_args.recipe,
)
torch.cuda.empty_cache()

_, testloader, _ = load_data(experimental_args, data_sequence_length)
prod_perplexity = ppl_eval_general(
llama_forward, prod_model, testloader, device, max_samples_per_iteration=8
)
print(
f"Experimental Perplexity: {exp_perplexity}, "
f"Production Perplexity: {prod_perplexity}"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Quantization variables
observer_freeze_epoch: 1
bn_freeze_epoch: 1
qat_start_epoch: 0

quantization_modifiers:
- !QuantizationModifier
start_epoch: eval(qat_start_epoch)
disable_quantization_observer_epoch: eval(observer_freeze_epoch)
freeze_bn_stats_epoch: eval(bn_freeze_epoch)
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
- SiLUActivation
- model.layers.0.mlp.down_proj
- model.layers.1.mlp.down_proj
- model.layers.2.mlp.down_proj
- model.layers.3.mlp.down_proj
- model.layers.4.mlp.down_proj
- model.layers.5.mlp.down_proj
- model.layers.6.mlp.down_proj
- model.layers.7.mlp.down_proj
- model.layers.8.mlp.down_proj
- model.layers.9.mlp.down_proj
- model.layers.10.mlp.down_proj
- model.layers.11.mlp.down_proj
- model.layers.12.mlp.down_proj
- model.layers.13.mlp.down_proj
- model.layers.14.mlp.down_proj
- model.layers.15.mlp.down_proj
- model.layers.16.mlp.down_proj
- model.layers.17.mlp.down_proj
- model.layers.18.mlp.down_proj
- model.layers.19.mlp.down_proj
- model.layers.20.mlp.down_proj
- model.layers.21.mlp.down_proj
- model.layers.22.mlp.down_proj
- model.layers.23.mlp.down_proj
- model.layers.24.mlp.down_proj
- model.layers.25.mlp.down_proj
- model.layers.26.mlp.down_proj
- model.layers.27.mlp.down_proj
- model.layers.28.mlp.down_proj
- model.layers.29.mlp.down_proj
- model.layers.30.mlp.down_proj
- model.layers.31.mlp.down_proj
scheme_overrides:
Embedding:
input_activations: null
weights:
num_bits: 8
symmetric: False
112 changes: 112 additions & 0 deletions src/sparseml/experimental/sparsegpt/examples/opt/compare_obcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from sparseml.experimental.sparsegpt.dispatch import evaluate_perplexity, load_model
from sparseml.experimental.sparsegpt.main import sequential
from sparseml.experimental.sparsegpt.opt import load_data
from sparseml.modifiers.obcq.utils.helpers import ppl_eval_general
from sparseml.transformers.sparsification.obcq.obcq import one_shot
from sparseml.transformers.sparsification.obcq.utils.helpers import opt_forward


dataset = "c4"
model_name = "facebook/opt-1.3b"
sparsity = 0.5
nbits = 8
smooth_quant = 0
observer_batches = 128
nsamples = 128
data_sequence_length = 2048
sequential_hessian = 0
experimental_recipe = "src/sparseml/experimental/sparsegpt/examples/opt/recipes/"
experimental_recipe += "opt-1.3b-opt_pretrain-pruned50_quantW8A8.md"
prod_recipe = "src/sparseml/transformers/sparsification/obcq/example.yaml"
device = "cuda:0"
seed = 0
prunen = 0
prunem = 0
percdamp = 0.01
blocksize = 128
ptq_only = 0


class ExperimentalArgs:
model = model_name
dataset = dataset
data_sequence_length = data_sequence_length
sequential_hessian_within_layer = sequential_hessian
recipe = experimental_recipe
sparsity = sparsity
wbits = nbits
observer_batches = observer_batches
nsamples = nsamples
smoothquant = smooth_quant
seed = seed
prunen = prunen
prunem = prunem
percdamp = percdamp
blocksize = blocksize
ptq_only = ptq_only


class ProdArgs:
model = model_name
dataset = dataset
nsamples = nsamples
device = device
recipe = prod_recipe
save = False


def run_experimental_obcq(experimental_args):
model = load_model(experimental_args)
calibration_data, _, _ = load_data(experimental_args, data_sequence_length)
sequential(model, calibration_data, device, experimental_args)

del calibration_data
return model


if __name__ == "__main__":
experimental_args = ExperimentalArgs()
exp_model = run_experimental_obcq(experimental_args)
experimental_args.dataset = "wikitext2"
_, testloader, _ = load_data(experimental_args, data_sequence_length)
exp_perplexity = evaluate_perplexity(
experimental_args, exp_model, testloader, device, max_samples_per_iteration=8
)

del testloader
del exp_model
torch.cuda.empty_cache()

prod_args = ProdArgs()
prod_model = one_shot(
model_path=prod_args.model,
dataset_name=prod_args.dataset,
num_samples=prod_args.nsamples,
device=prod_args.device,
recipe_file=prod_args.recipe,
)
experimental_args.dataset = "wikitext2"
_, testloader, _ = load_data(experimental_args, data_sequence_length)
prod_perplexity = ppl_eval_general(
opt_forward, prod_model, testloader, device, max_samples_per_iteration=8
)
print(
f"Experimental Perplexity: {exp_perplexity}, "
f"Production Perplexity: {prod_perplexity}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

export CUDA_VISIBLE_DEVICES=0

ROOT=$HOME/src/neuralmagic/sparseml/src/sparseml/experimental/sparsegpt
ROOT=$HOME/sparseml/src/sparseml/experimental/sparsegpt

DATASET=c4

RECIPE_DIR=$ROOT/recipes
RECIPE_DIR=$ROOT/examples/opt/recipes
RECIPE_NAME=opt-1.3b-opt_pretrain-pruned50_quantW8A8

SRC_MODEL_ORG=facebook
Expand Down
Loading
Loading