Skip to content

Commit

Permalink
Refactored SparseGPT (#1705)
Browse files Browse the repository at this point in the history
* First abstract structure with OPT, MPT

* Updated

* Updated

* Runable OPT main code path

* Initial support for MPT

* Initial commit for Llama2

* Return model after override Llama attn

* More memory-efficient layer compression

* Make memory-efficient layer compressor default

* SmoothQuant enabled

* Bugs fixed

* Make load data more flexible

* Initialize scales in eval mode; clean up

* Example script and recipe for OPT

* Formatting

* Copyright

* Fix code styles

* Format

---------

Co-authored-by: abhinavnmagic <121893843+abhinavnmagic@users.noreply.github.com>
  • Loading branch information
natuan and abhinavnmagic committed Sep 18, 2023
1 parent e5faef5 commit bd38892
Show file tree
Hide file tree
Showing 12 changed files with 2,206 additions and 0 deletions.
76 changes: 76 additions & 0 deletions src/sparseml/experimental/sparsegpt/dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

SUPPORTED_MODELS = ["opt", "mpt", "llama"]


def load_model(args, model_key: str = None):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import load_model as _load_model
elif model_key == "mpt":
from mpt import load_model as _load_model
elif model_key == "llama":
from llama import load_model as _load_model
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _load_model(args)


def load_data(args, model_key: str = None, dataset: str = None):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import load_data as _load_data
elif model_key == "mpt":
from mpt import load_data as _load_data
elif model_key == "llama":
from llama import load_data as _load_data
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _load_data(args, dataset=dataset)


def evaluate_perplexity(args, model, dataloader, dev, model_key: str = None):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import ppl_eval as _ppl_eval
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _ppl_eval(args, model, dataloader, dev)


def prepare_sparsegpt(model, dataloader, args, model_key: str = None, **kwargs):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import prepare_sparsegpt as _prepare_sparsegpt
elif model_key == "mpt":
from mpt import prepare_sparsegpt as _prepare_sparsegpt
elif model_key == "llama":
from llama import prepare_sparsegpt as _prepare_sparsegpt
else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _prepare_sparsegpt(model, dataloader, args, **kwargs)


def _get_model_key(args):
key = None
for k in SUPPORTED_MODELS:
if args.model.lower().find(k) >= 0:
key = k
break
if key is None:
raise ValueError(
f"Model {args.model} is not supported. Supported: {SUPPORTED_MODELS.keys()}"
)
return key
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---

quantization_modifiers:
- !QuantizationModifier
ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"]
scheme_overrides:
ReLU:
input_activations: null
output_activations: null
LayerNorm:
input_activations: null
output_activations: null

---
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/bin/bash

export CUDA_VISIBLE_DEVICES=0

ROOT=$HOME/src/neuralmagic/sparseml/src/sparseml/experimental/sparsegpt

DATASET=c4

RECIPE_DIR=$ROOT/recipes
RECIPE_NAME=opt-1.3b-opt_pretrain-pruned50_quantW8A8

SRC_MODEL_ORG=facebook
SRC_MODEL_NAME=opt-1.3b
SRC_MODEL=$SRC_MODEL_ORG/$SRC_MODEL_NAME

SP=0.5
WBITS=8

ID=$RANDOM

SMOOTH=0
SMOOTH_DIR=$HOME/src/smoothquant/act_scales
SMOOTH_FILE=$SMOOTH_DIR/$SRC_MODEL_NAME.pt

PTQ=1

TRUE_SEQ=0

DST_MODEL_DIR=$HOME/models/opt
DST_MODEL_NAME=sparsegpt@$SRC_MODEL_NAME@$DATASET@$RECIPE_NAME@SP$SP@SQ$SMOOTH@SEQ$TRUE_SEQ@PTQ$PTQ@ID$ID
DST_MODEL=$DST_MODEL_DIR/$DST_MODEL_NAME

EVAL_DENSE=0

OBSERVER_BATCHES=100

python $ROOT/main.py $SRC_MODEL $DATASET \
--data-sequence-length 2048 \
--sequential_hessian_within_layer $TRUE_SEQ \
--recipe $RECIPE_DIR/$RECIPE_NAME.md \
--sparsity $SP \
--eval-dense $EVAL_DENSE \
--wbits $WBITS \
--observer-batches $OBSERVER_BATCHES \
--ptq $PTQ \
--ptq-init 1 \
--smoothquant $SMOOTH \
--smooth-activation-file $SMOOTH_FILE \
--save $DST_MODEL

cp "$0" $DST_MODEL/command.sh
241 changes: 241 additions & 0 deletions src/sparseml/experimental/sparsegpt/layer_compressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Dict, List, Tuple

import torch
import torch.nn as nn

from quant import WeightFakeQuantizer
from sparsegpt import SparseGPT


DEFAULT_WBITS = 16


class BaseCompressor:
def __init__(self, model):
self.model = model

def pre_compress(self, dev: str = "cuda:0", **kwargs) -> Tuple[nn.Module, Dict]:
return self.model, {}

def compress(self, dev: str = "cuda:0", **kwargs) -> Tuple[nn.Module, Dict]:
return self.model, {}

def post_compress(self, dev: str = "cuda:0", **kwargs) -> Tuple[nn.Module, Dict]:
return self.model, {}


class LayerCompressor(BaseCompressor):
def __init__(self, model, layer, layer_index, inputs, manager, args):
super().__init__(model=model)
self.layer = layer
self.layer_index = layer_index
self.inputs = inputs
self.manager = manager
self.args = args

def compressible_modules(self, **kwargs):
if self.manager is not None and self.manager.quantization_modifiers:
# The layer names are changed due to quantization modifiers, therefore
# we need a slightly different func to retrieve layers
modules = _find_quant_layers(self.layer)
else:
modules = _find_layers(self.layer)
return modules

def pre_compress(self, **kwargs):
"""
Set up SparseGPT objects, compute Hessian
"""
if not self.args.sequential_hessian_within_layer:
subset = self.compressible_modules(**kwargs)

gpts = {}
for name in subset:
gpts[name] = SparseGPT(subset[name])
if (
self.args.wbits < 16
and self.manager is not None
and self.manager.quantization_modifiers
):
gpts[name].quantizer = WeightFakeQuantizer(subset[name])

def add_batch(name):
def tmp(_, inp, out):
gpts[name].add_batch(inp[0].data, out.data)

return tmp

handles = []
for name in gpts:
handles.append(subset[name].register_forward_hook(add_batch(name)))

# Run through the samples in order to compute Hessian matrix
nsamples = self.inputs.shape[0]
forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward)
passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs]
for j in range(nsamples):
passed_in_kwargs = {}
for arg in passed_in_args:
if isinstance(kwargs[arg], List):
passed_in_kwargs[arg] = kwargs[arg][j]
else:
passed_in_kwargs[arg] = kwargs[arg]
self.layer(self.inputs[j].unsqueeze(0), **passed_in_kwargs)[0]
for h in handles:
h.remove()

return self.model, {"gpts": gpts}
else:
return self.model, {}

def compress(self, dev: str = "cuda:0", **kwargs):
self.layer.to(dev)
self.model, extras = self.pre_compress(**kwargs)
if not self.args.sequential_hessian_within_layer:
gpts = extras["gpts"]
for name in gpts:
print(f"Compressing {name}...")
sparsity = self.args.sparsity
gpts[name].fasterprune(
sparsity,
prunen=self.args.prunen,
prunem=self.args.prunem,
percdamp=self.args.percdamp,
blocksize=self.args.blocksize,
)
gpts[name].free()
else:
self._sequentially_compress(**kwargs)

self.model, extras = self.post_compress(**kwargs)

return self.model, {"outputs": extras["outputs"]}

def post_compress(self, **kwargs):
outputs = torch.zeros_like(self.inputs)
nsamples = self.inputs.shape[0]
attention_mask = kwargs.get("attention_mask", None)
for j in range(nsamples):
attn_mask = (
attention_mask[j]
if isinstance(attention_mask, List)
else attention_mask
)
outputs[j] = self.layer(
self.inputs[j].unsqueeze(0), attention_mask=attn_mask
)[0]
self.inputs = None
torch.cuda.empty_cache()

return self.model, {"outputs": outputs}

def _sequentially_compress(self, **kwargs):
subset = self.compressible_modules(**kwargs)

forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward)
passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs]

passed_in_kwargs = {}
for arg in passed_in_args:
if isinstance(kwargs[arg], List):
passed_in_kwargs[arg] = kwargs[arg][0]
else:
passed_in_kwargs[arg] = kwargs[arg]
order = _find_dependency_order(
self.layer, subset, self.inputs[0].unsqueeze(0), **passed_in_kwargs
)

nsamples = self.inputs.shape[0]
for name in order:
gpts = SparseGPT(subset[name])
if self.args.wbits < 16:
if self.manager is not None and self.manager.quantization_modifiers:
gpts.quantizer = WeightFakeQuantizer(subset[name])

def add_batch(name):
def tmp(_, inp, out):
gpts.add_batch(inp[0].data, out.data)

return tmp

handle = subset[name].register_forward_hook(add_batch(name))
for j in range(nsamples):
passed_in_kwargs = {}
for arg in passed_in_args:
if isinstance(kwargs[arg], List):
passed_in_kwargs[arg] = kwargs[arg][0]
else:
passed_in_kwargs[arg] = kwargs[arg]
self.layer(self.inputs[j].unsqueeze(0), **passed_in_kwargs)[0]
handle.remove()

print(f"Compressing module {name} of layer {self.layer_index}")
gpts.fasterprune(
self.args.sparsity,
prunen=self.args.prunen,
prunem=self.args.prunem,
percdamp=self.args.percdamp,
blocksize=self.args.blocksize,
)
gpts.free()


def _find_dependency_order(layer, subset, an_input, **kwargs):
order = []

def exe_input(name):
def _exe_input(_, inp, out):
if name in subset:
order.append(name)

return _exe_input

handles = [subset[name].register_forward_hook(exe_input(name)) for name in subset]
layer(an_input, **kwargs)
for h in handles:
h.remove()
return order


def _find_quant_layers(module, layers=[torch.nn.qat.Linear], name=""):
if type(module) in layers:
pieces = name.split(".")
if pieces[-1] == "module":
name = ".".join(pieces[:-1])
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(
_find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res


def _find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(
_find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res
Loading

0 comments on commit bd38892

Please sign in to comment.