Skip to content

Commit

Permalink
Updated
Browse files Browse the repository at this point in the history
  • Loading branch information
natuan committed Aug 29, 2023
1 parent 07a7273 commit 73ccd2e
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 39 deletions.
33 changes: 18 additions & 15 deletions src/sparseml/experimental/sparsegpt/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
SUPPORTED_MODELS = ["opt", "mpt"]


def load_model(args):
key = _get_model_key(args)
if key == "opt":
def load_model(args, model_key: str = None):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import load_model as _load_model
else:
elif model_key == "mpt":
from mpt import load_model as _load_model

else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _load_model(args)


def load_data(args):
key = _get_model_key(args)
if key == "opt":
def load_data(args, model_key: str = None):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import load_data as _load_data
else:
elif model_key == "mpt":
from mpt import load_data as _load_data

else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _load_data(args)


def prepare_sparsegpt(model, dataloader, args, **kwargs):
key = _get_model_key(args)
if key == "opt":
def prepare_sparsegpt(model, dataloader, args, model_key: str = None, **kwargs):
model_key = _get_model_key(args) if model_key is None else model_key
if model_key == "opt":
from opt import prepare_sparsegpt as _prepare_sparsegpt
else:
elif model_key == "mpt":
from mpt import prepare_sparsegpt as _prepare_sparsegpt

else:
raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}")
return _prepare_sparsegpt(model, dataloader, args, **kwargs)


Expand Down
3 changes: 2 additions & 1 deletion src/sparseml/experimental/sparsegpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _save(model, tokenizer, save_path):
choices=["wikitext2", "ptb", "c4"],
help="Where to extract calibration data from.",
)
parser.add_argument("--data-sequence-length", type=int, default=2048)
parser.add_argument("--recipe", type=str, default=None)
parser.add_argument("--observer-batches", type=int, default=100)
parser.add_argument(
Expand Down Expand Up @@ -135,8 +136,8 @@ def _save(model, tokenizer, save_path):
if args.log_wandb:
assert has_wandb, "wandb not installed try `pip install wandb`"
wandb.init(config=args)

model = load_model(args)
import pdb; pdb.set_trace()
dataloader, testloader, tokenizer = load_data(args)

if args.wbits < 16 or ((args.sparsity or args.prunen) and not args.gmp):
Expand Down
8 changes: 5 additions & 3 deletions src/sparseml/experimental/sparsegpt/model_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn

from smoothquant.smooth import smooth_lm
from sparseml.pytorch.optim.manager import ScheduledModifierManager


Expand All @@ -23,6 +22,7 @@ def __init__(self, model, smooth_activation_file, alpha: float = 0.5):
self.alpha = alpha

def __call__(self, dev: str = "cuda:0", **kwargs) -> Tuple[nn.Module, Dict]:
from smoothquant.smooth import smooth_lm
self.model.to(dev)
act_scales = torch.load(self.smooth_activation_file)
smooth_lm(self.model, act_scales, 0.5)
Expand All @@ -42,10 +42,12 @@ def __call__(self, model, dev: str = "cuda:0") -> Tuple[nn.Module, Dict]:
model.train()
manager.apply_structure(model, epoch=0.1)
model.eval()
model = self.initialize_scales_from_batches(model, dev)
model = self._initialize_scales_from_batches(model, dev)
return model, {"manager": manager}

def initialize_scales_from_batches_whole(self, model, dev):
def _initialize_scales_from_batches(self, model, dev):
# TODO: have another version with layer-wise execution to save memory on
# very large models
print("Collecting data statistics for quantization scales...")
model.train()
model.to(dev)
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/experimental/sparsegpt/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class SequentialSparseGPT_OPT(SequentialSparseGPT):
def compressible_layers(self):
return self.model.model.decoders.layers
return self.model.model.decoder.layers


class OPTBottomCompressor(BaseCompressor):
Expand Down Expand Up @@ -259,7 +259,7 @@ def skip(*args, **kwargs):
def load_data(args):
name = args.dataset
nsamples = args.nsamples
seqlen = args.max_seq_len
seqlen = args.data_sequence_length
model = args.model
seed = args.seed

Expand Down
50 changes: 32 additions & 18 deletions src/sparseml/experimental/sparsegpt/sequential.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional
from copy import deepcopy

import torch

Expand Down Expand Up @@ -36,52 +37,65 @@ def compressible_layers(self):

def pre_compress(self, dev: str = "cuda:0", **kwargs):
model = self.model
all_extras = {}
for processor in self.model_preprocessors:
model, extras = processor.pre_process(model, dev=dev, **kwargs)
kwargs.update(extras)
return model, kwargs
# We assume the processors are independent, and therefore
# pass in the initial kwargs into each of them
model, extras = processor(model, dev=dev, **kwargs)
all_extras.update(extras)
return model, all_extras

def compress(self, dev: str = "cuda:0", **kwargs):
accum_kwargs = deepcopy(kwargs)

self.model, kwargs = self.pre_compress(**kwargs)
import pdb; pdb.set_trace()
self.model, extras = self.pre_compress(**kwargs)

# Step 0: BottomCompressor accomplishes two things:
# 1) Compress the embedding if needed
# 2) Pass the calibration data through the (compressed) bottom part of the network, capturing the outputs
# which will become the inputs to the first decoder layer
# Also return attention_mask as part of kwargs
self.model, extras = self.bottom_compressor.compress(dev=dev, **kwargs)
kwargs.update(extras)
accum_kwargs.update(extras)
self.model, extras = self.bottom_compressor.compress(dev=dev, **accum_kwargs)
accum_kwargs.update(extras)

# Step 1: Sequentially prune/quantize decoder layers
inputs = kwargs["outputs"]

for idx, layer in enumerate(self.compressible_layers):
layer_kwargs = kwargs.deepcopy()
if "outputs" not in accum_kwargs:
raise RuntimeError("The 'outputs' key is expected but not found from the "
"return of the bottom compressor")
inputs = accum_kwargs["outputs"]

layer_compressor = LayerCompressor(
self.model, layer, idx, inputs, self.manager, **layer_kwargs
self.model, layer, idx, inputs, self.manager, **accum_kwargs
)

# Set up SparseGPT object, compute Hessian
self.model, layer_kwargs = layer_compressor.pre_compress(**layer_kwargs)
self.model, layer_kwargs = layer_compressor.pre_compress(**accum_kwargs)
accum_kwargs.update(layer_kwargs)

# Joinly prune/quantize using SparseGPT
self.model, layer_kwargs = layer_compressor.compress(**layer_kwargs)
self.model, layer_kwargs = layer_compressor.compress(**accum_kwargs)
accum_kwargs.update(layer_kwargs)

# Compute outputs given compressed layer, memory clean up etc
(
self.model,
layer_kwargs,
) = layer_compressor.post_compress(**layer_kwargs)
inputs = layer_kwargs["outputs"]
) = layer_compressor.post_compress(**accum_kwargs)
accum_kwargs.update(layer_kwargs)

# Step 2: Prune/quantize head
# TODO: Need update here -- see MPT for head quantization example
head_compressor = LayerCompressor(model.head, inputs)
head_compressor.pre_compress()
head_compressor.compress()
model, extras = head_compressor.post_compress()
if self.head_compressor is not None:
head_compressor = LayerCompressor(self.model.head, inputs)
head_compressor.pre_compress()
head_compressor.compress()
model, extras = head_compressor.post_compress()

return model, extras
return model, accum_kwargs

def post_compress(self, **kwargs):
use_cache = kwargs["use_cache"]
Expand Down

0 comments on commit 73ccd2e

Please sign in to comment.