Skip to content

Commit

Permalink
Support for Stacking Recipes (#1897)
Browse files Browse the repository at this point in the history
* initial recipe re-loading

* loading for input recipe

* persist structure across recipe loads

* clean up fn names

* clean up duplicated code

* delete extra file

* unit tests

* fix failing test

* quantization edge cases

* quant tests

* fixes for stage name clashes

* clean up documentation

* add apply flag during finalization as well

* clarity comments

* fix unit test
  • Loading branch information
Satrat committed Dec 21, 2023
1 parent f088321 commit 0eaf565
Show file tree
Hide file tree
Showing 8 changed files with 339 additions and 12 deletions.
9 changes: 9 additions & 0 deletions src/sparseml/core/lifecycle/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def pre_initialize_structure(
if data is not None:
mod_data.append(data)

# mark which modifiers have already had their structures initialized
# so when we consolidate the next recipe this info isn't lost
self.initialized_structure = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)

return mod_data

Expand Down Expand Up @@ -113,6 +117,8 @@ def finalize(self, **kwargs) -> List[Any]:
mod_data.append(data)

self.finalized = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)

return mod_data

Expand Down Expand Up @@ -169,6 +175,9 @@ def _check_compile_recipe(self):
self.modifiers = self.recipe_container.compiled_recipe.create_modifier(
self.state.framework
)
for mod in self.modifiers:
if mod.unique_id in self.recipe_container.applied_stages:
mod.applied = True

def _check_setup_event_lifecycle(self, event_type: EventType):
if self.event_lifecycle is not None:
Expand Down
30 changes: 28 additions & 2 deletions src/sparseml/core/modifier/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ class StageModifiers(ModifierInterface, BaseModel):
:param modifiers: The modifiers to apply as a stage
:param index: The index of the stage, if applicable
:param group: The group name of the stage, if applicable
:param applied: Flag for indicating if this stage has has already been
applied to the model, through structure initialization or finalization
"""

modifiers: List["Modifier"] = Field(default_factory=list)
index: Optional[int] = None
group: Optional[str] = None
applied: bool = False

@property
def initialized_structure(self) -> bool:
Expand All @@ -66,6 +69,13 @@ def finalized(self) -> bool:
"""
return all(mod.finalized for mod in self.modifiers)

@property
def unique_id(self) -> str:
"""
:return: ID for stage containing the name and index
"""
return self.group + "_" + str(self.index)

def check_initialized(self):
"""
Check if all of the stage modifiers have been initialized, and log a warning
Expand Down Expand Up @@ -103,7 +113,7 @@ def calculate_end(self) -> float:

def pre_initialize_structure(self, state: "State", **kwargs):
"""
Pre initialize the structure for all stage modifiers
Pre initialize the structure for all stage modifiers mark the stage applied
:param state: The current state of the training
:param kwargs: Additional kwargs to pass to the modifier(s)
Expand All @@ -112,6 +122,8 @@ def pre_initialize_structure(self, state: "State", **kwargs):
for modifier in self.modifiers:
modifier.pre_initialize_structure(state, **kwargs)

self.applied = True

def initialize(self, state: "State", **kwargs):
"""
Initialize all the stage modifiers
Expand All @@ -120,20 +132,30 @@ def initialize(self, state: "State", **kwargs):
:param kwargs: Additional kwargs to pass to the modifier(s)
initialize method
"""

if self.applied:
return

for modifier in self.modifiers:
modifier.initialize(state, **kwargs)

def finalize(self, state: "State", **kwargs):
"""
Finalize all the stage modifiers
Finalize all the stage modifiers and mark the stage as applied
:param state: The state of current session
:param kwargs: Additional kwargs to pass to the modifier(s)
finalize method
"""

if self.applied:
return

for modifier in self.modifiers:
modifier.finalize(state, **kwargs)

self.applied = True

def update_event(self, state: "State", event: "Event", **kwargs):
"""
Propagate the event to all the stage modifiers
Expand All @@ -143,5 +165,9 @@ def update_event(self, state: "State", event: "Event", **kwargs):
:param kwargs: Additional kwargs to pass to the modifier(s)
update_event method
"""

if self.applied:
return

for modifier in self.modifiers:
modifier.update_event(state, event, **kwargs)
13 changes: 13 additions & 0 deletions src/sparseml/core/recipe/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ class RecipeContainer:
:param compiled_recipe: the compiled recipe from the recipes list
:param recipes: the list of RecipeTuple instances to be compiled
:param applied_stages: list of recipe stages that have already been applied
"""

compiled_recipe: Optional[Recipe] = None
recipes: List[RecipeTuple] = field(default_factory=list)
applied_stages: List[str] = field(default_factory=list)

def update(
self,
Expand Down Expand Up @@ -118,6 +120,17 @@ def update(

return kwargs

def update_applied_stages(self, new_stages: List[str]):
"""
Updates the applied_stages list with new stages, indicating their structure
has already been applied
:param new_stages: new stage names to add
"""
for stage in new_stages:
if stage not in self.applied_stages:
self.applied_stages.append(stage)

def check_compile_recipe(self) -> bool:
"""
Check if the recipes need to be compiled into a single recipe and
Expand Down
22 changes: 12 additions & 10 deletions src/sparseml/core/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,20 +494,22 @@ def _modifier_group_to_dict(modifier_group: List[Dict[str, Any]]):
for key, value in modifier.items()
}

def _stage_to_dict(stage: List[Dict[str, Any]]):
# convert a list of stages to a dict of modifiers
def _stage_to_dict(stage: Dict[str, Any]):
# convert a stage to a dict of modifiers
return {
modifier_group_name: _modifier_group_to_dict(modifier_group)
for stage_modifiers in stage
for modifier_group_name, modifier_group in stage_modifiers[
"modifiers"
].items()
for modifier_group_name, modifier_group in stage["modifiers"].items()
}

return {
stage_name: _stage_to_dict(stage=stage)
for stage_name, stage in self.dict()["stages"].items()
}
final_dict = {}
for stage_name, stages in self.dict()["stages"].items():
if len(stages) == 1:
final_dict[stage_name] = _stage_to_dict(stages[0])
else:
for idx, stage in enumerate(stages):
final_dict[stage_name + "_" + str(idx)] = _stage_to_dict(stage)

return final_dict


@dataclass
Expand Down
43 changes: 43 additions & 0 deletions src/sparseml/modifiers/quantization/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"add_input_activation_quant_wrappers",
"add_output_activation_observers",
"raise_if_torch_quantization_not_available",
"raise_if_already_quantized",
"is_module_quantized",
]


Expand Down Expand Up @@ -148,6 +150,18 @@ def set_quantization_schemes(
# submodule type or graph section set to ignore, skip
continue

if isinstance(submodule, torch_quantization.QuantWrapper):
# special case to catch QuantizableMatMul children
if ignore and _match_submodule_name_or_type(
submodule.module, submodule_name, ignore
):
continue

if is_qat_helper_module(submodule):
# ignore children of an already quantized module, if there is a clash it
# will have been caught in the parent
continue

# override default scheme if necessary
override_key = _match_submodule_name_or_type(
submodule, submodule_name, scheme_overrides
Expand All @@ -162,6 +176,7 @@ def set_quantization_schemes(
wrap_qat_targets[submodule_name] = submodule_scheme
elif is_module_type_override or is_quantizable_module(submodule):
# is base quantizable module or user specifically targeted module type
raise_if_already_quantized(submodule_name, submodule)
submodule.quantization_scheme = submodule_scheme

# inject any targeted QATWrappers
Expand Down Expand Up @@ -351,6 +366,34 @@ def raise_if_torch_quantization_not_available():
)


def raise_if_already_quantized(module_name: str, module: Module):
"""
:param module_name: name of module to check for quantization
:param module: module to check for quantization
:raises: RuntimeError if module is already quantized, it cannot be re-quantized
"""
if is_module_quantized(module):
raise RuntimeError(
f"Unable to quantize module {module_name}, as it has already been "
"quantized. Ensure your input recipe does not contain multiple "
"QuantizationModifiers that act on the same module. "
)


def is_module_quantized(module: Module) -> bool:
"""
:param module: module to check for quantization
:return: True if the module is quantized, False otherwise
"""
if hasattr(module, "quantization_scheme") and isinstance(
module.quantization_scheme, QuantizationScheme
):
return True
if isinstance(module, torch_quantization.QuantWrapper):
return True
return False


def _match_submodule_name_or_type(
submodule: Module, submodule_name: str, names_or_types: List[str]
) -> Optional[str]:
Expand Down
3 changes: 3 additions & 0 deletions tests/sparseml/core/lifecycle/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def test_session_initialize_propagates_layer_prefix_to_model(

class ModifierMock(ModifierInterface):
initialized_ = False
applied = False
group = "test"
unique_id = "test_0"

def __init__(self, *args, **kwargs) -> None:
super().__init__()
Expand Down
14 changes: 14 additions & 0 deletions tests/sparseml/transformers/obcq/test_additional_sparsity.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
test_stage:
obcq_modifiers:
SparseGPTModifier:
sparsity: 0.7
block_size: 128
sequential_update: True
quantize: False
percdamp: 0.01
prunen: 0
prunem: 0
targets: [
"model.layers.0"
]
target_ids: ["attention_mask", "position_ids"]
Loading

0 comments on commit 0eaf565

Please sign in to comment.