Skip to content

Commit

Permalink
[OBCQ Recipe UX Change] Remove target_ids (#1804)
Browse files Browse the repository at this point in the history
* Remove target_ids

* Define target ids based on inputs

* Fix cache

* Remove remaining instances of target_ids

* Address review comments!
* Make default target ids a global
* use named arguments
  • Loading branch information
rahul-tuli authored and bfineran committed Nov 16, 2023
1 parent 5254cfb commit 920f513
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 23 deletions.
2 changes: 0 additions & 2 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ class SparseGPTModifier(Modifier):
shape. Defaults to 0:0 which represents an unstructured mask.
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model
:param target_ids: list of keys in model output to cache
"""

sparsity: Union[float, List[float]]
Expand All @@ -62,7 +61,6 @@ class SparseGPTModifier(Modifier):
prunen_: Optional[int] = None
prunem_: Optional[int] = None
targets: Union[str, List[str], None] = ALL_TOKEN
target_ids: Optional[List[str]] = None
layer_prefix: Optional[str] = None
compressible_layers_: List = None
quantization_modifier_: Any = None
Expand Down
8 changes: 5 additions & 3 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def apply_obcq(
# decoder layer. Also return attention_mask as part of kwargs
extras = self.compress_bottom(
dev=self.device_,
target_ids=self.target_ids,
layer_prefix=self.layer_prefix_,
**accum_kwargs,
)
Expand Down Expand Up @@ -166,7 +165,6 @@ def compress_bottom(
dataloader: List = None,
nsamples: int = None,
dev: str = "cuda:0",
target_ids: List[str] = None,
layer_prefix: Optional[str] = None,
) -> Dict:
"""
Expand All @@ -182,7 +180,11 @@ def compress_bottom(
"""
layer_prefix = layer_prefix or self.layer_prefix_
cached_inputs = cache_attention_inputs(
self.model, dataloader, dev, nsamples, target_ids, layer_prefix
model=self.model,
dataloader=dataloader,
device=dev,
nsamples=nsamples,
layer_prefix=layer_prefix,
)

outputs = cached_inputs.pop("inputs")
Expand Down
30 changes: 21 additions & 9 deletions src/sparseml/modifiers/obcq/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,46 @@
# limitations under the License.

import logging
from collections import defaultdict
from math import ceil

import torch


_LOGGER = logging.getLogger(__name__)
_DEFAULT_TARGET_IDS = [
"attention_mask",
"position_ids",
]


class Catcher(torch.nn.Module):
def __init__(self, module, target_keys):
def __init__(self, module):
super().__init__()
self.module = module
self.cache = {key: [] for key in target_keys}
self.target_keys = target_keys
self.target_keys = None
self.cache = defaultdict(list)
self.cache["inputs"] = []

def forward(self, *args, **kwargs):
self.cache["inputs"].append(args)
if self.target_keys is None:
self.target_keys = self._get_target_keys(kwargs.keys())

for key in self.target_keys:
self.cache[key].append(kwargs[key])
raise ValueError

def get_cache(self):
return self.cache

def _get_target_keys(self, input_keys):
target_keys = []
for key in _DEFAULT_TARGET_IDS:
if key in input_keys:
target_keys.append(key)
return target_keys


def replace_module(model, old_module, new_module):
for module_name, module in model.named_modules():
Expand All @@ -51,8 +66,8 @@ def replace_module(model, old_module, new_module):
setattr(current_module, module_name[-1], new_module)


def catch(model, attention_layer, target_keys, data_loader, nsamples):
catcher_module = Catcher(attention_layer, target_keys)
def catch(model, attention_layer, data_loader, nsamples):
catcher_module = Catcher(attention_layer)
replace_module(model, attention_layer, catcher_module)
device = next(attention_layer.parameters()).device
for input_id, inp in enumerate(data_loader):
Expand Down Expand Up @@ -108,9 +123,7 @@ def execute_offloaded_module(
return new_buffer


def cache_attention_inputs(
model, dataloader, device, nsamples, target_ids, layer_prefix
):
def cache_attention_inputs(model, dataloader, device, nsamples, layer_prefix):
if layer_prefix:
embed_tokens = getattr(model.model, layer_prefix).embed_tokens
first_layer = getattr(model.model, layer_prefix).layers[0]
Expand All @@ -122,7 +135,6 @@ def cache_attention_inputs(
cached_inputs = catch(
model,
first_layer,
target_ids, # ["attention_mask"],
dataloader,
nsamples,
)
Expand Down
1 change: 0 additions & 1 deletion src/sparseml/transformers/sparsification/obcq/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,3 @@ test_stage:
"model.decoder.layers.22",
"model.decoder.layers.23"
]
target_ids: ["attention_mask"]
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,4 @@ test_stage:
"model.layers.29",
"model.layers.30",
"model.layers.31",
]
target_ids: ["attention_mask", "position_ids"]
]
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def opt_forward(model: Module, data_loader: List, device: str, nsamples: int = N
:return: logits output of the model
"""
cached_inputs = cache_attention_inputs(
model, data_loader, device, nsamples, ["attention_mask"], "decoder"
model, data_loader, device, nsamples, "decoder"
)
buffer = [b[0] for b in cached_inputs.pop("inputs")]
for layer in model.model.decoder.layers:
Expand Down Expand Up @@ -86,9 +86,7 @@ def llama_forward(model: Module, data_loader: List, device: str, nsamples: int =
:return: logits output of the model
"""
cached_inputs = cache_attention_inputs(
model, data_loader, device, nsamples, ["attention_mask", "position_ids"], None
)
cached_inputs = cache_attention_inputs(model, data_loader, device, nsamples, None)
buffer = [b[0] for b in cached_inputs.pop("inputs")]
for layer in model.model.layers:
buffer = execute_offloaded_module(
Expand Down
3 changes: 1 addition & 2 deletions tests/sparseml/transformers/obcq/test_tiny.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,4 @@ test_stage:
"model.layers.3",
"model.layers.4",
"model.layers.5"
]
target_ids: ["attention_mask", "position_ids"]
]

0 comments on commit 920f513

Please sign in to comment.