Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layerwise Sparsity Support for SparseGPT #1777

Merged
merged 12 commits into from
Oct 26, 2023
26 changes: 25 additions & 1 deletion src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SparseGPTModifier(Modifier):
model.decoder for OPT or just model for Llama
"""

sparsity: float
sparsity: Union[float, List[float]]
block_size: int
quantize: bool
dampening_frac: Optional[float] = 0.01
Expand All @@ -61,3 +61,27 @@ class SparseGPTModifier(Modifier):

def on_initialize_structure(self, state: "State", **kwargs):
pass # nothing needed for this modifier

def _validate_layerwise_sparisity(self):
if isinstance(self.sparsity, float):
return # single sparsity will be applied to all layers

if not isinstance(self.compress_layers, List):
raise ValueError(
"Layer targets must be a list when specifying layer-wise"
f" sparsity. Got {self.compress_layers}"
)

if len(self.compress_layers) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(self.compress_layers)} layers and "
f"{len(self.sparsity)} sparsities"
)

for layer_name in self.compress_layers:
if "re:" in layer_name:
Satrat marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Using regular expressions for layer-wise sparsity "
f"profiles is not permitted. Found {layer_name}"
)
25 changes: 23 additions & 2 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ def compressible_layers(self) -> List[Module]:
:return: list of Pytorch modules to compress
"""
compressible_dict = self.model.get_layers(self.compress_layers)

# Compare length to sparsities again in case one of the provided layers was
# invalid or not compressible
if isinstance(self.sparsity, List) and len(self.sparsity) != len(
compressible_dict
):
raise ValueError(
"Number of compressible layers must match the number of "
f"sparsities. Got {len(compressible_dict)} layers and "
f"{len(self.sparsity)} sparsities"
)

return [v for _, v in compressible_dict.items()]

def on_initialize(self, state: "State", **kwargs) -> bool:
Expand All @@ -65,6 +77,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool:

:param state: session state storing input model and calibration data
"""
self._validate_layerwise_sparisity()

self.finalization_kwargs_ = {}
modifiable_model = state.model
calibration_dataloader = state.data.calib
Expand Down Expand Up @@ -125,10 +139,17 @@ def apply_obcq(
"The 'outputs' key is expected but not found from the "
"return of the bottom compressor"
)

inputs = accum_kwargs["outputs"]
_LOGGER.info(f"\n===== Compressing layer {idx}/{num_layers-1} =====")
layer_sparsity = (
self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity
)
_LOGGER.info(
f"\n===== Compressing layer {idx+1}/{num_layers} "
f"to sparsity {layer_sparsity} ====="
)
args = {
"sparsity": self.sparsity,
"sparsity": layer_sparsity,
"prunen": self.prunen,
"prunem": self.prunem,
"blocksize": self.block_size,
Expand Down
Loading