Skip to content

Commit

Permalink
Merge branch 'main' into prod_smooth_quant
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat committed Oct 31, 2023
2 parents 788c16b + b622bba commit 800d08b
Show file tree
Hide file tree
Showing 34 changed files with 2,192 additions and 114 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -795,3 +795,6 @@ fabric.properties
*.resources
test-results/
integrations/pytorch/pytorch_vision*

# local log files
nm_temp_test_logs/*
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ More information on installation such as optional dependencies and requirements

### Recipes

To enable flexibility, ease of use, and repeatability, SparseML uses a declarative interface called `recipes` for specifying the sparsity-related algorithms and hyperparamters that should be applied by SparseML.
To enable flexibility, ease of use, and repeatability, SparseML uses a declarative interface called `recipes` for specifying the sparsity-related algorithms and hyperparameters that should be applied by SparseML.

`Recipes` are YAML-files formatted as a list of `modifiers`, which encode the instructions for SparseML. Example `modifiers` can be anything from setting the learning rate to encoding the hyperparameters of the gradual magnitude pruning algorithm. The SparseML system parses the `recipes` into a native format for each framework and applies the modifications to the model and training pipeline.

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
_transformers_deps = _pytorch_deps + [
f"{'nm-transformers' if is_release else 'nm-transformers-nightly'}"
f"~={version_nm_deps}",
"datasets<=2.11",
"datasets<=2.14.6",
"scikit-learn",
"seqeval",
"einops",
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/core/lifecycle/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def optim_pre_step_events(self) -> List[Event]:
and self.type_ is not None
and self.type_ != EventType.OPTIM_POST_STEP
):
raise ValueError("optim pre step must be called after optim post step")
raise ValueError("optim pre step must be called before optim post step")

if (
self.type_first == EventType.LOSS_CALCULATED
Expand Down
3 changes: 3 additions & 0 deletions src/sparseml/core/lifecycle/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def reset(self):
except Exception:
pass

if self.state and self.state.data:
# reset data if it exists
self.state.data.reset()
self.state = None
self.recipe_container = RecipeContainer()
self.modifiers = []
Expand Down
8 changes: 8 additions & 0 deletions src/sparseml/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,11 @@ def get_matching_layer(
:param model: model to search for targets
"""
raise NotImplementedError()

def qat_active(self) -> bool:
"""
Checks if quantization aware training is set up in the model
:return: True if QAT is active in any layer, False otherwise
"""
raise NotImplementedError()
9 changes: 9 additions & 0 deletions src/sparseml/core/model/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_matching_layer,
get_param,
get_params,
qat_active,
set_layer,
set_param,
)
Expand Down Expand Up @@ -105,3 +106,11 @@ def get_matching_layer(
:param model: model to search for targets
"""
return get_matching_layer(target, name_to_match, model)

def qat_active(self) -> bool:
"""
Checks if quantization aware training is set up in the model
:return: True if QAT is active in any layer, False otherwise
"""
return qat_active(self.model)
2 changes: 1 addition & 1 deletion src/sparseml/core/recipe/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,4 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
:return: the dictionary representation of the modifier
"""
return {self.type: self.args}
return {self.type: self.args, "group": f"{self.group}_modifiers"}
64 changes: 59 additions & 5 deletions src/sparseml/core/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,18 +399,20 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:
... targets: ['re:.*weight']
... '''
>>> recipe = Recipe.create_instance(recipe_str)
>>> recipe.dict()
Traceback (most recent call last):
...
KeyError: 'group'
>>> recipe_dict = recipe.dict()
>>> stage = recipe_dict["stages"]["test"]
>>> pruning_mods = stage[0]['modifiers']['pruning']
>>> modifier_args = pruning_mods[0]['ConstantPruningModifier']
>>> modifier_args == {'start': 0.0, 'end': 2.0, 'targets': ['re:.*weight']}
True
:return: A dictionary representation of the recipe
"""
dict_ = super().dict(*args, **kwargs)
stages = {}

for stage in dict_["stages"]:
name = stage["group"]
name = f"{stage['group']}_stage"
del stage["group"]

if name not in stages:
Expand All @@ -422,6 +424,58 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:

return dict_

def yaml(self, file_path: Optional[str] = None) -> str:
"""
Return a yaml string representation of the recipe.
:param file_path: optional file path to save yaml to
:return: The yaml string representation of the recipe
"""
file_stream = None if file_path is None else open(file_path, "w")
yaml_dict = self._get_yaml_dict()

ret = yaml.dump(
yaml_dict, stream=file_stream, allow_unicode=True, sort_keys=False
)

if file_stream is not None:
file_stream.close()

return ret

def _get_yaml_dict(self) -> Dict[str, Any]:
"""
Get a dictionary representation of the recipe for yaml serialization
The returned dict will only contain information necessary for yaml
serialization (ignores metadata, version, etc), and must not be used
in place of the dict method
:return: A dictionary representation of the recipe for yaml serialization
"""

def _modifier_group_to_dict(modifier_group: List[Dict[str, Any]]):
# convert a list of modifiers to a dict of modifiers
return {
key: value
for modifier in modifier_group
for key, value in modifier.items()
}

def _stage_to_dict(stage: List[Dict[str, Any]]):
# convert a list of stages to a dict of modifiers
return {
modifier_group_name: _modifier_group_to_dict(modifier_group)
for stage_modifiers in stage
for modifier_group_name, modifier_group in stage_modifiers[
"modifiers"
].items()
}

return {
stage_name: _stage_to_dict(stage=stage)
for stage_name, stage in self.dict()["stages"].items()
}


@dataclass
class RecipeTuple:
Expand Down
8 changes: 8 additions & 0 deletions src/sparseml/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ class Data:
test: Optional[ModifiableData] = None
calib: Optional[ModifiableData] = None

def reset(self):
"""
Reset self to initial state
"""
attribs = Data().__dict__
for attrib_name, attrib_value in attribs.items():
setattr(self, attrib_name, attrib_value)


@dataclass
class Hardware:
Expand Down
16 changes: 16 additions & 0 deletions src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ class Config:
multiply_batch_by_num_att_heads=False,
)

# Mistral has a config/model definition "MistralForCausalLM" but is based off Llama2.
# It contains these additions to Llama2-7b:
# * Sliding Window Attention
# * GQA (Grouped Query Attention)
# * Byte-fallback BPE tokenizer
MISTRAL_CONFIG = KeyValueCacheConfig(
model_name="mistral",
additional_transforms=AdditionalTransformsLLAMA,
key_num_attention_heads="num_attention_heads",
key_num_embedding_hidden_size="hidden_size",
transpose_value_input=None,
transpose_key_input=None,
multiply_batch_by_num_att_heads=False,
)

# Reusing the CodeGen transforms because it happens to match what we need for GPTNeo
additional_transforms_gpt_neo = AdditionalTransformsCodeGen

Expand All @@ -160,6 +175,7 @@ def get_kv_cache_config(
BLOOM_CONFIG,
MPT_CONFIG,
LLAMA_CONFIG,
MISTRAL_CONFIG,
GPT_NEO_CONFIG,
],
) -> KeyValueCacheConfig:
Expand Down
111 changes: 104 additions & 7 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Optional, Union
import logging
from typing import Any, Dict, List, Optional, Union

from sparseml.core import Modifier
from sparseml.core.factory import ModifierFactory
from sparseml.core.state import State
from sparseml.utils import ALL_TOKEN


__all__ = ["SparseGPTModifier"]

_LOGGER = logging.getLogger(__name__)


class SparseGPTModifier(Modifier):
"""
Expand All @@ -34,7 +37,9 @@ class SparseGPTModifier(Modifier):
:param sparsity: Sparsity to compress model to
:param block_size: Used to determine number of columns to compress in one pass
:param quantize: Whether or not model is quantized (affects layer names)
:param quantize: Whether or not to quantize weights during SparseGPT. Set to True
to quantize using an existing quantization modifier, or pass in the configuration
for a quantization modifier if one does not already exist in the recipe
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param sequential_update: Whether or not to update weights sequentially by layer,
Expand All @@ -48,16 +53,108 @@ class SparseGPTModifier(Modifier):
model.decoder for OPT or just model for Llama
"""

sparsity: float
sparsity: Union[float, List[float]]
block_size: int
quantize: bool
quantize: Union[bool, Dict]
dampening_frac: Optional[float] = 0.01
sequential_update: Optional[bool] = True
prunen: Optional[int] = 0
prunem: Optional[int] = 0
targets: Union[str, List[str], None] = ALL_TOKEN
target_ids: Optional[List[str]] = None
layer_prefix: Optional[str] = None
compressible_layers_: List = None
quantization_modifier_: Any = None

def compressible_layers(self) -> List:
"""
Retrieves the modules corresponding to a list of compressible layer names
:return: list of Pytorch modules to compress
"""
compressible_dict = self.model.get_layers(self.targets)
return [v for _, v in compressible_dict.items()]

def on_initialize_structure(self, state: State, **kwargs):
quantization_already_active = state.model.qat_active()
if isinstance(self.quantize, bool):
if not self.quantize and quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to False, but a "
"quantization modifier is already active on the model "
"resetting quantize to True"
)
self.quantize = True
elif self.quantize and not quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to True without an "
"active quantization modifier. Creating a default "
"8-bit quantization modifier"
)
default_quant_config = {"QuantizationModifier": {}}
self._build_quant_modifier_from_dict(
default_quant_config, state.framework
)
return # use existing quantization modifier if there is one
else:
if not isinstance(self.quantize, Dict):
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"type {type(self.quantize)}"
)
if len(self.quantize) != 1:
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"{len(self.quantize)} modifiers"
)
if quantization_already_active:
_LOGGER.warning(
"Attempting to initialize quantization for SparseGPT "
"but a quantization modifier has already been applied. "
"The quantization configuration defined under the "
"SparseGPT modifier will be ignored."
)
self.quantize = True
return
self._build_quant_modifier_from_dict(self.quantize, state.framework)
self.quantize = True

if self.quantization_modifier_:
self.quantization_modifier_.on_initialize_structure(state, **kwargs)

def _build_quant_modifier_from_dict(self, quant_config, framework):
modifier_type = list(quant_config.keys())[0]
modifier_args = quant_config[modifier_type]
self.quantization_modifier_ = ModifierFactory.create(
modifier_type,
framework=framework,
allow_registered=True,
allow_experimental=True,
**modifier_args,
)

def _validate_layerwise_sparisity(self):
if isinstance(self.sparsity, float):
return # single sparsity will be applied to all layers

if not isinstance(self.targets, List):
raise ValueError(
"Layer targets must be a list when specifying layer-wise"
f" sparsity. Got {self.targets}"
)

if len(self.targets) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(self.targets)} layers and "
f"{len(self.sparsity)} sparsities"
)

def on_initialize_structure(self, state: "State", **kwargs):
pass # nothing needed for this modifier
for layer_name in self.targets:
if layer_name.startswith("re:"):
raise ValueError(
"Using regular expressions for layer-wise sparsity "
f"profiles is not permitted. Found {layer_name}"
)
Loading

0 comments on commit 800d08b

Please sign in to comment.