Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MVP for Alternating Flow #1912

Merged
merged 44 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
d5abe8e
initial recipe re-loading
Satrat Nov 16, 2023
ec0e180
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 5, 2023
2d7b5b7
loading for input recipe
Satrat Dec 7, 2023
2cc9e16
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 7, 2023
356bd81
persist structure across recipe loads
Satrat Dec 7, 2023
1b67b6f
clean up fn names
Satrat Dec 7, 2023
f06ed8a
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 12, 2023
ab5a464
clean up duplicated code
Satrat Dec 12, 2023
11f4efe
delete extra file
Satrat Dec 12, 2023
7e960a3
unit tests
Satrat Dec 12, 2023
ebb5407
fix failing test
Satrat Dec 12, 2023
6a394d7
quantization edge cases
Satrat Dec 12, 2023
d7974bf
quant tests
Satrat Dec 13, 2023
4b9014d
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 13, 2023
701ab2c
fixes for stage name clashes
Satrat Dec 13, 2023
5812488
Merge branch 'sparse_auto_recipe' of github.com:neuralmagic/sparseml …
Satrat Dec 13, 2023
21473aa
clean up documentation
Satrat Dec 13, 2023
485501b
setup StageRunner class
Satrat Dec 13, 2023
2d536a3
running one_shot from text_gen script
Satrat Dec 13, 2023
a4406ae
cleanup helper fns
Satrat Dec 14, 2023
4576a80
precision support
Satrat Dec 14, 2023
27467e3
formatting
Satrat Dec 14, 2023
10a0fed
Merge branch 'main' into alternate_flows
Satrat Dec 14, 2023
7c754e0
WIP for alternating
Satrat Dec 15, 2023
0eb06bf
fixing device issue
Satrat Dec 15, 2023
f45326d
Merge branch 'sparse_auto_recipe' into alternating_flow_pt2
Satrat Dec 15, 2023
e46dd96
Merge branch 'main' into sparse_auto_recipe
Satrat Dec 15, 2023
d308987
MVP for alternating flows
Satrat Dec 15, 2023
fe9af83
add apply flag during finalization as well
Satrat Dec 15, 2023
5f6e854
clarity comments
Satrat Dec 15, 2023
4588eb2
Merge branch 'sparse_auto_recipe' of github.com:neuralmagic/sparseml …
Satrat Dec 15, 2023
391350d
clean up docstrings
Satrat Dec 15, 2023
f7fb65a
fix unit test
Satrat Dec 15, 2023
e429929
Merge branch 'main' into alternate_flows
Satrat Dec 15, 2023
7336453
Merge branch 'sparse_auto_recipe' into alternating_flow_pt2
Satrat Dec 15, 2023
2968171
Merge branch 'alternate_flows' into alternating_flow_pt2
Satrat Dec 15, 2023
ee1ee2d
add finetuning README
Satrat Dec 20, 2023
9004da6
Merge branch 'main' of github.com:neuralmagic/sparseml
Satrat Dec 21, 2023
180a24d
Merge branch 'main' into alternating_flow_pt2
Satrat Dec 21, 2023
a8760eb
cleaning up stage logic
Satrat Dec 21, 2023
8eba7dd
Merge branch 'main' into alternating_flow_pt2
Satrat Jan 2, 2024
9ef0d4c
quality
Satrat Jan 2, 2024
c4562c0
Merge branch 'main' into alternating_flow_pt2
Satrat Jan 8, 2024
797413a
Merge branch 'main' into alternating_flow_pt2
Satrat Jan 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/sparseml/core/lifecycle/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def pre_initialize_structure(
if data is not None:
mod_data.append(data)

# mark which modifiers have already had their structures initialized
# so when we consolidate the next recipe this info isn't lost
self.initialized_structure = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)

return mod_data

Expand Down Expand Up @@ -113,6 +117,8 @@ def finalize(self, **kwargs) -> List[Any]:
mod_data.append(data)

self.finalized = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)

return mod_data

Expand Down Expand Up @@ -169,6 +175,9 @@ def _check_compile_recipe(self):
self.modifiers = self.recipe_container.compiled_recipe.create_modifier(
self.state.framework
)
for mod in self.modifiers:
if mod.unique_id in self.recipe_container.applied_stages:
mod.applied = True

def _check_setup_event_lifecycle(self, event_type: EventType):
if self.event_lifecycle is not None:
Expand Down
30 changes: 28 additions & 2 deletions src/sparseml/core/modifier/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ class StageModifiers(ModifierInterface, BaseModel):
:param modifiers: The modifiers to apply as a stage
:param index: The index of the stage, if applicable
:param group: The group name of the stage, if applicable
:param applied: Flag for indicating if this stage has has already been
applied to the model, through structure initialization or finalization
"""

modifiers: List["Modifier"] = Field(default_factory=list)
index: Optional[int] = None
group: Optional[str] = None
applied: bool = False

@property
def initialized_structure(self) -> bool:
Expand All @@ -66,6 +69,13 @@ def finalized(self) -> bool:
"""
return all(mod.finalized for mod in self.modifiers)

@property
def unique_id(self) -> str:
"""
:return: ID for stage containing the name and index
"""
return self.group + "_" + str(self.index)

def check_initialized(self):
"""
Check if all of the stage modifiers have been initialized, and log a warning
Expand Down Expand Up @@ -103,7 +113,7 @@ def calculate_end(self) -> float:

def pre_initialize_structure(self, state: "State", **kwargs):
"""
Pre initialize the structure for all stage modifiers
Pre initialize the structure for all stage modifiers mark the stage applied

:param state: The current state of the training
:param kwargs: Additional kwargs to pass to the modifier(s)
Expand All @@ -112,6 +122,8 @@ def pre_initialize_structure(self, state: "State", **kwargs):
for modifier in self.modifiers:
modifier.pre_initialize_structure(state, **kwargs)

self.applied = True

def initialize(self, state: "State", **kwargs):
"""
Initialize all the stage modifiers
Expand All @@ -120,20 +132,30 @@ def initialize(self, state: "State", **kwargs):
:param kwargs: Additional kwargs to pass to the modifier(s)
initialize method
"""

if self.applied:
return

for modifier in self.modifiers:
modifier.initialize(state, **kwargs)

def finalize(self, state: "State", **kwargs):
"""
Finalize all the stage modifiers
Finalize all the stage modifiers and mark the stage as applied

:param state: The state of current session
:param kwargs: Additional kwargs to pass to the modifier(s)
finalize method
"""

if self.applied:
return

for modifier in self.modifiers:
modifier.finalize(state, **kwargs)

self.applied = True

def update_event(self, state: "State", event: "Event", **kwargs):
"""
Propagate the event to all the stage modifiers
Expand All @@ -143,5 +165,9 @@ def update_event(self, state: "State", event: "Event", **kwargs):
:param kwargs: Additional kwargs to pass to the modifier(s)
update_event method
"""

if self.applied:
return

for modifier in self.modifiers:
modifier.update_event(state, event, **kwargs)
13 changes: 13 additions & 0 deletions src/sparseml/core/recipe/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ class RecipeContainer:

:param compiled_recipe: the compiled recipe from the recipes list
:param recipes: the list of RecipeTuple instances to be compiled
:param applied_stages: list of recipe stages that have already been applied
"""

compiled_recipe: Optional[Recipe] = None
recipes: List[RecipeTuple] = field(default_factory=list)
applied_stages: List[str] = field(default_factory=list)

def update(
self,
Expand Down Expand Up @@ -118,6 +120,17 @@ def update(

return kwargs

def update_applied_stages(self, new_stages: List[str]):
"""
Updates the applied_stages list with new stages, indicating their structure
has already been applied

:param new_stages: new stage names to add
"""
for stage in new_stages:
if stage not in self.applied_stages:
self.applied_stages.append(stage)

def check_compile_recipe(self) -> bool:
"""
Check if the recipes need to be compiled into a single recipe and
Expand Down
22 changes: 12 additions & 10 deletions src/sparseml/core/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,20 +484,22 @@ def _modifier_group_to_dict(modifier_group: List[Dict[str, Any]]):
for key, value in modifier.items()
}

def _stage_to_dict(stage: List[Dict[str, Any]]):
# convert a list of stages to a dict of modifiers
def _stage_to_dict(stage: Dict[str, Any]):
# convert a stage to a dict of modifiers
return {
modifier_group_name: _modifier_group_to_dict(modifier_group)
for stage_modifiers in stage
for modifier_group_name, modifier_group in stage_modifiers[
"modifiers"
].items()
for modifier_group_name, modifier_group in stage["modifiers"].items()
}

return {
stage_name: _stage_to_dict(stage=stage)
for stage_name, stage in self.dict()["stages"].items()
}
final_dict = {}
for stage_name, stages in self.dict()["stages"].items():
if len(stages) == 1:
final_dict[stage_name] = _stage_to_dict(stages[0])
else:
for idx, stage in enumerate(stages):
final_dict[stage_name + "_" + str(idx)] = _stage_to_dict(stage)

return final_dict


@dataclass
Expand Down
43 changes: 43 additions & 0 deletions src/sparseml/modifiers/quantization/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"add_input_activation_quant_wrappers",
"add_output_activation_observers",
"raise_if_torch_quantization_not_available",
"raise_if_already_quantized",
"is_module_quantized",
]


Expand Down Expand Up @@ -148,6 +150,18 @@ def set_quantization_schemes(
# submodule type or graph section set to ignore, skip
continue

if isinstance(submodule, torch_quantization.QuantWrapper):
# special case to catch QuantizableMatMul children
if ignore and _match_submodule_name_or_type(
submodule.module, submodule_name, ignore
):
continue

if is_qat_helper_module(submodule):
# ignore children of an already quantized module, if there is a clash it
# will have been caught in the parent
continue

# override default scheme if necessary
override_key = _match_submodule_name_or_type(
submodule, submodule_name, scheme_overrides
Expand All @@ -162,6 +176,7 @@ def set_quantization_schemes(
wrap_qat_targets[submodule_name] = submodule_scheme
elif is_module_type_override or is_quantizable_module(submodule):
# is base quantizable module or user specifically targeted module type
raise_if_already_quantized(submodule_name, submodule)
submodule.quantization_scheme = submodule_scheme

# inject any targeted QATWrappers
Expand Down Expand Up @@ -351,6 +366,34 @@ def raise_if_torch_quantization_not_available():
)


def raise_if_already_quantized(module_name: str, module: Module):
"""
:param module_name: name of module to check for quantization
:param module: module to check for quantization
:raises: RuntimeError if module is already quantized, it cannot be re-quantized
"""
if is_module_quantized(module):
raise RuntimeError(
f"Unable to quantize module {module_name}, as it has already been "
"quantized. Ensure your input recipe does not contain multiple "
"QuantizationModifiers that act on the same module. "
)


def is_module_quantized(module: Module) -> bool:
"""
:param module: module to check for quantization
:return: True if the module is quantized, False otherwise
"""
if hasattr(module, "quantization_scheme") and isinstance(
module.quantization_scheme, QuantizationScheme
):
return True
if isinstance(module, torch_quantization.QuantWrapper):
return True
return False


def _match_submodule_name_or_type(
submodule: Module, submodule_name: str, names_or_types: List[str]
) -> Optional[str]:
Expand Down
78 changes: 78 additions & 0 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
"apply_recipe_structure_to_model",
"reload_model_state",
"reload_model_from_checkpoint",
"save_model_and_recipe",
"fallback_to_cpu",
"parse_dtype",
"get_session_model",
]

_LOGGER = logging.getLogger(__name__)
Expand All @@ -57,6 +61,10 @@ def apply_recipe_structure_to_model(model: Module, recipe_path: str, model_path:
model=model, recipe=recipe_path, framework=Framework.pytorch
)

# no need to reload if no recipe was applied
if recipe_path is None:
return

session = session_manager.active_session()
num_stages = len(session.lifecycle.recipe_container.compiled_recipe.stages)
msg = (
Expand Down Expand Up @@ -173,3 +181,73 @@ def reload_model_from_checkpoint(model: Module, checkpoint: Optional[str] = None
# reload the state dict for the model from the checkpoint
if reload_model_state(model, checkpoint, orig_state_dict):
_LOGGER.info(f"Reloaded model state from checkpoint {checkpoint}")


def save_model_and_recipe(
model: Module,
save_path: str,
tokenizer: Optional[Any] = None,
):
"""
Save a model, tokenizer and the currently loaded recipe to file

:param model: pytorch model to save
:param save_path: path to save output to
:param tokenizer: model tokenizer to save
"""
model.save_pretrained(save_path)
if tokenizer is not None:
tokenizer.save_pretrained(save_path)

_LOGGER.info("Saving output to {}".format(os.path.abspath(save_path)))

recipe_path = os.path.join(save_path, RECIPE_FILE_NAME)
session = session_manager.active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)


def fallback_to_cpu(device: str) -> str:
"""
Takes in a device string and forces it to cpu if cuda is not available

:param device: device id to check
:return: device modified for CUDA status
"""
if "cuda" in device and not torch.cuda.is_available():
_LOGGER.warning(
f"Requested {device} but CUDA is not available, falling back to CPU"
)
return "cpu"

return device


def parse_dtype(dtype_arg: str) -> torch.dtype:
"""
:param dtype_arg: dtype string to parse
:return: torch.dtype parsed from input string
"""
dtype = "auto" # get precision from model by default
if dtype_arg == "half" or dtype_arg == "float16":
dtype = torch.float16
elif dtype_arg == "bfloat16":
dtype = torch.bfloat16
elif dtype_arg == "full" or dtype_arg == "float32":
dtype = torch.float32

return dtype


def get_session_model() -> Module:
"""
:return: pytorch module stored by the active SparseSession, or None if no session
is active
"""
session = session_manager.active_session()
if not session:
return None

active_model = session.state.model.model
return active_model
4 changes: 4 additions & 0 deletions src/sparseml/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class DataTrainingArguments:
default=None,
metadata={"help": "Optional percentages of each split to download"},
)
num_calibration_samples: Optional[int] = field(
default=512,
metadata={"help": "Number of samples to use for one-shot calibration"},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached preprocessed datasets or not."},
Expand Down
Loading
Loading