Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Stacking Recipes #1897

Merged
merged 25 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d5abe8e
initial recipe re-loading
Satrat Nov 16, 2023
ec0e180
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 5, 2023
2d7b5b7
loading for input recipe
Satrat Dec 7, 2023
2cc9e16
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 7, 2023
356bd81
persist structure across recipe loads
Satrat Dec 7, 2023
1b67b6f
clean up fn names
Satrat Dec 7, 2023
f06ed8a
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 12, 2023
ab5a464
clean up duplicated code
Satrat Dec 12, 2023
11f4efe
delete extra file
Satrat Dec 12, 2023
7e960a3
unit tests
Satrat Dec 12, 2023
ebb5407
fix failing test
Satrat Dec 12, 2023
6a394d7
quantization edge cases
Satrat Dec 12, 2023
d7974bf
quant tests
Satrat Dec 13, 2023
4b9014d
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 13, 2023
701ab2c
fixes for stage name clashes
Satrat Dec 13, 2023
5812488
Merge branch 'sparse_auto_recipe' of github.com:neuralmagic/sparseml …
Satrat Dec 13, 2023
21473aa
clean up documentation
Satrat Dec 13, 2023
e46dd96
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 15, 2023
fe9af83
add apply flag during finalization as well
Satrat Dec 15, 2023
5f6e854
clarity comments
Satrat Dec 15, 2023
4588eb2
Merge branch 'sparse_auto_recipe' of github.com:neuralmagic/sparseml …
Satrat Dec 15, 2023
f7fb65a
fix unit test
Satrat Dec 15, 2023
04da6d6
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 21, 2023
7b5267d
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 21, 2023
fd11412
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/sparseml/core/lifecycle/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def pre_initialize_structure(
if data is not None:
mod_data.append(data)

# mark which modifiers have already had their structures initialized
# so when we consolidate the next recipe this info isn't lost
self.initialized_structure = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)

return mod_data

Expand Down Expand Up @@ -113,6 +117,8 @@ def finalize(self, **kwargs) -> List[Any]:
mod_data.append(data)

self.finalized = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)

return mod_data

Expand Down Expand Up @@ -169,6 +175,9 @@ def _check_compile_recipe(self):
self.modifiers = self.recipe_container.compiled_recipe.create_modifier(
self.state.framework
)
for mod in self.modifiers:
if mod.unique_id in self.recipe_container.applied_stages:
mod.applied = True

def _check_setup_event_lifecycle(self, event_type: EventType):
if self.event_lifecycle is not None:
Expand Down
30 changes: 28 additions & 2 deletions src/sparseml/core/modifier/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ class StageModifiers(ModifierInterface, BaseModel):
:param modifiers: The modifiers to apply as a stage
:param index: The index of the stage, if applicable
:param group: The group name of the stage, if applicable
:param applied: Flag for indicating if this stage has has already been
applied to the model, through structure initialization or finalization
"""

modifiers: List["Modifier"] = Field(default_factory=list)
index: Optional[int] = None
group: Optional[str] = None
applied: bool = False

@property
def initialized_structure(self) -> bool:
Expand All @@ -66,6 +69,13 @@ def finalized(self) -> bool:
"""
return all(mod.finalized for mod in self.modifiers)

@property
def unique_id(self) -> str:
"""
:return: ID for stage containing the name and index
"""
return self.group + "_" + str(self.index)

def check_initialized(self):
"""
Check if all of the stage modifiers have been initialized, and log a warning
Expand Down Expand Up @@ -103,7 +113,7 @@ def calculate_end(self) -> float:

def pre_initialize_structure(self, state: "State", **kwargs):
"""
Pre initialize the structure for all stage modifiers
Pre initialize the structure for all stage modifiers mark the stage applied

:param state: The current state of the training
:param kwargs: Additional kwargs to pass to the modifier(s)
Expand All @@ -112,6 +122,8 @@ def pre_initialize_structure(self, state: "State", **kwargs):
for modifier in self.modifiers:
modifier.pre_initialize_structure(state, **kwargs)

self.applied = True

def initialize(self, state: "State", **kwargs):
"""
Initialize all the stage modifiers
Expand All @@ -120,20 +132,30 @@ def initialize(self, state: "State", **kwargs):
:param kwargs: Additional kwargs to pass to the modifier(s)
initialize method
"""

if self.applied:
return

for modifier in self.modifiers:
modifier.initialize(state, **kwargs)

def finalize(self, state: "State", **kwargs):
"""
Finalize all the stage modifiers
Finalize all the stage modifiers and mark the stage as applied

:param state: The state of current session
:param kwargs: Additional kwargs to pass to the modifier(s)
finalize method
"""

if self.applied:
return

for modifier in self.modifiers:
modifier.finalize(state, **kwargs)

self.applied = True

def update_event(self, state: "State", event: "Event", **kwargs):
"""
Propagate the event to all the stage modifiers
Expand All @@ -143,5 +165,9 @@ def update_event(self, state: "State", event: "Event", **kwargs):
:param kwargs: Additional kwargs to pass to the modifier(s)
update_event method
"""

if self.applied:
return

for modifier in self.modifiers:
modifier.update_event(state, event, **kwargs)
13 changes: 13 additions & 0 deletions src/sparseml/core/recipe/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ class RecipeContainer:

:param compiled_recipe: the compiled recipe from the recipes list
:param recipes: the list of RecipeTuple instances to be compiled
:param applied_stages: list of recipe stages that have already been applied
"""

compiled_recipe: Optional[Recipe] = None
recipes: List[RecipeTuple] = field(default_factory=list)
applied_stages: List[str] = field(default_factory=list)

def update(
self,
Expand Down Expand Up @@ -118,6 +120,17 @@ def update(

return kwargs

def update_applied_stages(self, new_stages: List[str]):
"""
Updates the applied_stages list with new stages, indicating their structure
has already been applied

:param new_stages: new stage names to add
"""
for stage in new_stages:
if stage not in self.applied_stages:
self.applied_stages.append(stage)

def check_compile_recipe(self) -> bool:
"""
Check if the recipes need to be compiled into a single recipe and
Expand Down
22 changes: 12 additions & 10 deletions src/sparseml/core/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,20 +494,22 @@ def _modifier_group_to_dict(modifier_group: List[Dict[str, Any]]):
for key, value in modifier.items()
}

def _stage_to_dict(stage: List[Dict[str, Any]]):
# convert a list of stages to a dict of modifiers
def _stage_to_dict(stage: Dict[str, Any]):
# convert a stage to a dict of modifiers
return {
modifier_group_name: _modifier_group_to_dict(modifier_group)
for stage_modifiers in stage
for modifier_group_name, modifier_group in stage_modifiers[
"modifiers"
].items()
for modifier_group_name, modifier_group in stage["modifiers"].items()
}

return {
stage_name: _stage_to_dict(stage=stage)
for stage_name, stage in self.dict()["stages"].items()
}
final_dict = {}
for stage_name, stages in self.dict()["stages"].items():
if len(stages) == 1:
final_dict[stage_name] = _stage_to_dict(stages[0])
else:
for idx, stage in enumerate(stages):
final_dict[stage_name + "_" + str(idx)] = _stage_to_dict(stage)

return final_dict


@dataclass
Expand Down
43 changes: 43 additions & 0 deletions src/sparseml/modifiers/quantization/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"add_input_activation_quant_wrappers",
"add_output_activation_observers",
"raise_if_torch_quantization_not_available",
"raise_if_already_quantized",
"is_module_quantized",
]


Expand Down Expand Up @@ -148,6 +150,18 @@ def set_quantization_schemes(
# submodule type or graph section set to ignore, skip
continue

if isinstance(submodule, torch_quantization.QuantWrapper):
# special case to catch QuantizableMatMul children
if ignore and _match_submodule_name_or_type(
submodule.module, submodule_name, ignore
):
continue

if is_qat_helper_module(submodule):
# ignore children of an already quantized module, if there is a clash it
# will have been caught in the parent
continue

# override default scheme if necessary
override_key = _match_submodule_name_or_type(
submodule, submodule_name, scheme_overrides
Expand All @@ -162,6 +176,7 @@ def set_quantization_schemes(
wrap_qat_targets[submodule_name] = submodule_scheme
elif is_module_type_override or is_quantizable_module(submodule):
# is base quantizable module or user specifically targeted module type
raise_if_already_quantized(submodule_name, submodule)
submodule.quantization_scheme = submodule_scheme

# inject any targeted QATWrappers
Expand Down Expand Up @@ -351,6 +366,34 @@ def raise_if_torch_quantization_not_available():
)


def raise_if_already_quantized(module_name: str, module: Module):
"""
:param module_name: name of module to check for quantization
:param module: module to check for quantization
:raises: RuntimeError if module is already quantized, it cannot be re-quantized
"""
if is_module_quantized(module):
raise RuntimeError(
f"Unable to quantize module {module_name}, as it has already been "
"quantized. Ensure your input recipe does not contain multiple "
"QuantizationModifiers that act on the same module. "
)


def is_module_quantized(module: Module) -> bool:
"""
:param module: module to check for quantization
:return: True if the module is quantized, False otherwise
"""
if hasattr(module, "quantization_scheme") and isinstance(
module.quantization_scheme, QuantizationScheme
):
return True
if isinstance(module, torch_quantization.QuantWrapper):
return True
return False


def _match_submodule_name_or_type(
submodule: Module, submodule_name: str, names_or_types: List[str]
) -> Optional[str]:
Expand Down
3 changes: 3 additions & 0 deletions tests/sparseml/core/lifecycle/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def test_session_initialize_propagates_layer_prefix_to_model(

class ModifierMock(ModifierInterface):
initialized_ = False
applied = False
group = "test"
unique_id = "test_0"

def __init__(self, *args, **kwargs) -> None:
super().__init__()
Expand Down
14 changes: 14 additions & 0 deletions tests/sparseml/transformers/obcq/test_additional_sparsity.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
test_stage:
obcq_modifiers:
SparseGPTModifier:
sparsity: 0.7
block_size: 128
sequential_update: True
quantize: False
percdamp: 0.01
prunen: 0
prunem: 0
targets: [
"model.layers.0"
]
target_ids: ["attention_mask", "position_ids"]
Loading
Loading