Skip to content

Commit

Permalink
Modifier Refactor OBCQ Implementation (#1737)
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat committed Oct 11, 2023
1 parent 9a6e8e3 commit 6c3f054
Show file tree
Hide file tree
Showing 43 changed files with 4,277 additions and 35 deletions.
13 changes: 13 additions & 0 deletions src/sparseml/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions src/sparseml/experimental/sparsegpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
38 changes: 22 additions & 16 deletions src/sparseml/experimental/sparsegpt/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

SUPPORTED_MODELS = ["opt", "mpt", "llama-2"]
SUPPORTED_MODELS = ["opt", "mpt", "llama"]


def load_model(args, model_key: str = None, *gargs, **kwargs):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import load_model as _load_model
from sparseml.experimental.sparsegpt.opt import load_model as _load_model
elif model_key == "mpt":
from mpt import load_model as _load_model
elif model_key == "llama-2":
from llama2 import load_model as _load_model
from sparseml.experimental.sparsegpt.mpt import load_model as _load_model
elif model_key == "llama":
from sparseml.experimental.sparsegpt.llama2 import load_model as _load_model
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _load_model(args, *gargs, **kwargs)
Expand All @@ -31,11 +31,11 @@ def load_model(args, model_key: str = None, *gargs, **kwargs):
def load_data(args, model_key: str = None, *gargs, **kwargs):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import load_data as _load_data
from sparseml.experimental.sparsegpt.opt import load_data as _load_data
elif model_key == "mpt":
from mpt import load_data as _load_data
elif model_key == "llama-2":
from llama2 import load_data as _load_data
from sparseml.experimental.sparsegpt.mpt import load_data as _load_data
elif model_key == "llama":
from sparseml.experimental.sparsegpt.llama2 import load_data as _load_data
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _load_data(args, *gargs, **kwargs)
Expand All @@ -46,9 +46,9 @@ def evaluate_perplexity(
):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import ppl_eval as _ppl_eval
elif model_key == "llama-2":
from llama2 import ppl_eval as _ppl_eval
from sparseml.experimental.sparsegpt.opt import ppl_eval as _ppl_eval
elif model_key == "llama":
from sparseml.experimental.sparsegpt.llama2 import ppl_eval as _ppl_eval
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _ppl_eval(args, model, dataloader, dev, *gargs, **kwargs)
Expand All @@ -57,11 +57,17 @@ def evaluate_perplexity(
def prepare_sparsegpt(model, dataloader, args, model_key: str = None, **kwargs):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import prepare_sparsegpt as _prepare_sparsegpt
from sparseml.experimental.sparsegpt.opt import (
prepare_sparsegpt as _prepare_sparsegpt,
)
elif model_key == "mpt":
from mpt import prepare_sparsegpt as _prepare_sparsegpt
elif model_key == "llama-2":
from llama2 import prepare_sparsegpt as _prepare_sparsegpt
from sparseml.experimental.sparsegpt.mpt import (
prepare_sparsegpt as _prepare_sparsegpt,
)
elif model_key == "llama":
from sparseml.experimental.sparsegpt.llama2 import (
prepare_sparsegpt as _prepare_sparsegpt,
)
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _prepare_sparsegpt(model, dataloader, args, **kwargs)
Expand Down
112 changes: 112 additions & 0 deletions src/sparseml/experimental/sparsegpt/examples/llama2/compare_obcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from sparseml.experimental.sparsegpt.dispatch import evaluate_perplexity, load_model
from sparseml.experimental.sparsegpt.llama2 import load_data
from sparseml.experimental.sparsegpt.main import sequential
from sparseml.modifiers.obcq.utils.helpers import ppl_eval_general
from sparseml.transformers.sparsification.obcq.obcq import one_shot
from sparseml.transformers.sparsification.obcq.utils.helpers import llama_forward


dataset = "open_platypus"
model_name = "/home/sadkins/ml-experiments/nlg-text_generation/"
model_name += "llama_chat-llama_7b_chat-base/dense/training"
sparsity = 0.5
nbits = 8
smooth_quant = 0
observer_batches = 128
nsamples = 128
data_sequence_length = 2048
sequential_hessian = 0
experimental_recipe = "src/sparseml/experimental/sparsegpt/examples/llama2/recipes/"
experimental_recipe += "llama_recipe.yaml"
prod_recipe = "src/sparseml/transformers/sparsification/obcq/example_llama.yaml"
device = "cuda:0"
seed = 0
prunen = 0
prunem = 0
percdamp = 0.01
blocksize = 128
ptq_only = 0


class ExperimentalArgs:
model = model_name
dataset = dataset
data_sequence_length = data_sequence_length
sequential_hessian_within_layer = sequential_hessian
recipe = experimental_recipe
sparsity = sparsity
wbits = nbits
observer_batches = observer_batches
nsamples = nsamples
smoothquant = smooth_quant
seed = seed
prunen = prunen
prunem = prunem
percdamp = percdamp
blocksize = blocksize
ptq_only = ptq_only


class ProdArgs:
model = model_name
dataset = dataset
nsamples = nsamples
device = device
recipe = prod_recipe
eval = False
save = False


def run_experimental_obcq(experimental_args):
model = load_model(experimental_args)
calibration_data, _, _ = load_data(experimental_args, data_sequence_length)
sequential(model, calibration_data, device, experimental_args)

del calibration_data
return model


if __name__ == "__main__":
experimental_args = ExperimentalArgs()
exp_model = run_experimental_obcq(experimental_args)
_, testloader, _ = load_data(experimental_args, data_sequence_length)
exp_perplexity = evaluate_perplexity(
experimental_args, exp_model, testloader, device, max_samples_per_iteration=8
)
del testloader
del exp_model
torch.cuda.empty_cache()

prod_args = ProdArgs()
prod_model = one_shot(
model_path=prod_args.model,
dataset_name=prod_args.dataset,
num_samples=prod_args.nsamples,
device=prod_args.device,
recipe_file=prod_args.recipe,
)
torch.cuda.empty_cache()

_, testloader, _ = load_data(experimental_args, data_sequence_length)
prod_perplexity = ppl_eval_general(
llama_forward, prod_model, testloader, device, max_samples_per_iteration=8
)
print(
f"Experimental Perplexity: {exp_perplexity}, "
f"Production Perplexity: {prod_perplexity}"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Quantization variables
observer_freeze_epoch: 1
bn_freeze_epoch: 1
qat_start_epoch: 0

quantization_modifiers:
- !QuantizationModifier
start_epoch: eval(qat_start_epoch)
disable_quantization_observer_epoch: eval(observer_freeze_epoch)
freeze_bn_stats_epoch: eval(bn_freeze_epoch)
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
- SiLUActivation
- model.layers.0.mlp.down_proj
- model.layers.1.mlp.down_proj
- model.layers.2.mlp.down_proj
- model.layers.3.mlp.down_proj
- model.layers.4.mlp.down_proj
- model.layers.5.mlp.down_proj
- model.layers.6.mlp.down_proj
- model.layers.7.mlp.down_proj
- model.layers.8.mlp.down_proj
- model.layers.9.mlp.down_proj
- model.layers.10.mlp.down_proj
- model.layers.11.mlp.down_proj
- model.layers.12.mlp.down_proj
- model.layers.13.mlp.down_proj
- model.layers.14.mlp.down_proj
- model.layers.15.mlp.down_proj
- model.layers.16.mlp.down_proj
- model.layers.17.mlp.down_proj
- model.layers.18.mlp.down_proj
- model.layers.19.mlp.down_proj
- model.layers.20.mlp.down_proj
- model.layers.21.mlp.down_proj
- model.layers.22.mlp.down_proj
- model.layers.23.mlp.down_proj
- model.layers.24.mlp.down_proj
- model.layers.25.mlp.down_proj
- model.layers.26.mlp.down_proj
- model.layers.27.mlp.down_proj
- model.layers.28.mlp.down_proj
- model.layers.29.mlp.down_proj
- model.layers.30.mlp.down_proj
- model.layers.31.mlp.down_proj
scheme_overrides:
Embedding:
input_activations: null
weights:
num_bits: 8
symmetric: False
112 changes: 112 additions & 0 deletions src/sparseml/experimental/sparsegpt/examples/opt/compare_obcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from sparseml.experimental.sparsegpt.dispatch import evaluate_perplexity, load_model
from sparseml.experimental.sparsegpt.main import sequential
from sparseml.experimental.sparsegpt.opt import load_data
from sparseml.modifiers.obcq.utils.helpers import ppl_eval_general
from sparseml.transformers.sparsification.obcq.obcq import one_shot
from sparseml.transformers.sparsification.obcq.utils.helpers import opt_forward


dataset = "c4"
model_name = "facebook/opt-1.3b"
sparsity = 0.5
nbits = 8
smooth_quant = 0
observer_batches = 128
nsamples = 128
data_sequence_length = 2048
sequential_hessian = 0
experimental_recipe = "src/sparseml/experimental/sparsegpt/examples/opt/recipes/"
experimental_recipe += "opt-1.3b-opt_pretrain-pruned50_quantW8A8.md"
prod_recipe = "src/sparseml/transformers/sparsification/obcq/example.yaml"
device = "cuda:0"
seed = 0
prunen = 0
prunem = 0
percdamp = 0.01
blocksize = 128
ptq_only = 0


class ExperimentalArgs:
model = model_name
dataset = dataset
data_sequence_length = data_sequence_length
sequential_hessian_within_layer = sequential_hessian
recipe = experimental_recipe
sparsity = sparsity
wbits = nbits
observer_batches = observer_batches
nsamples = nsamples
smoothquant = smooth_quant
seed = seed
prunen = prunen
prunem = prunem
percdamp = percdamp
blocksize = blocksize
ptq_only = ptq_only


class ProdArgs:
model = model_name
dataset = dataset
nsamples = nsamples
device = device
recipe = prod_recipe
save = False


def run_experimental_obcq(experimental_args):
model = load_model(experimental_args)
calibration_data, _, _ = load_data(experimental_args, data_sequence_length)
sequential(model, calibration_data, device, experimental_args)

del calibration_data
return model


if __name__ == "__main__":
experimental_args = ExperimentalArgs()
exp_model = run_experimental_obcq(experimental_args)
experimental_args.dataset = "wikitext2"
_, testloader, _ = load_data(experimental_args, data_sequence_length)
exp_perplexity = evaluate_perplexity(
experimental_args, exp_model, testloader, device, max_samples_per_iteration=8
)

del testloader
del exp_model
torch.cuda.empty_cache()

prod_args = ProdArgs()
prod_model = one_shot(
model_path=prod_args.model,
dataset_name=prod_args.dataset,
num_samples=prod_args.nsamples,
device=prod_args.device,
recipe_file=prod_args.recipe,
)
experimental_args.dataset = "wikitext2"
_, testloader, _ = load_data(experimental_args, data_sequence_length)
prod_perplexity = ppl_eval_general(
opt_forward, prod_model, testloader, device, max_samples_per_iteration=8
)
print(
f"Experimental Perplexity: {exp_perplexity}, "
f"Production Perplexity: {prod_perplexity}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

export CUDA_VISIBLE_DEVICES=0

ROOT=$HOME/src/neuralmagic/sparseml/src/sparseml/experimental/sparsegpt
ROOT=$HOME/sparseml/src/sparseml/experimental/sparsegpt

DATASET=c4

RECIPE_DIR=$ROOT/recipes
RECIPE_DIR=$ROOT/examples/opt/recipes
RECIPE_NAME=opt-1.3b-opt_pretrain-pruned50_quantW8A8

SRC_MODEL_ORG=facebook
Expand Down
Loading

0 comments on commit 6c3f054

Please sign in to comment.