Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Documentation] Add documentation and basic doctests to sparseml/core #1742

Merged
merged 15 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/sparseml/core/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,30 @@

@dataclass
class ModifiableData(Generic[DT], MultiFrameworkObject):
"""
A base class for data that can be modified by modifiers.

:param data: The data to be modified
:param num_samples: The number of samples in the data
"""

data: DT = None
num_samples: int = None

def get_num_batches(self) -> int:
"""
:return: The number of batches in the data
"""
raise NotImplementedError()

def set_batch_size(self, batch_size: int):
"""
:param batch_size: The new batch size to use
"""
raise NotImplementedError()

def get_batch_size(self) -> int:
"""
:return: The current batch size
"""
raise NotImplementedError()
40 changes: 40 additions & 0 deletions src/sparseml/core/data/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@


class DynamicBatchSizeDataLoader:
"""
A wrapper for a PyTorch data loader that allows for dynamic batch sizes.
This is useful for modifiers that need to change the batch size of a data loader

:param data_loader: The instantiated torch data loader to wrap
"""

def __init__(self, data_loader: DataLoader):
self.data_loader = data_loader
self.current_batch_size = data_loader.batch_size
Expand All @@ -37,9 +44,15 @@ def __iter__(self):
yield from self._data_merge_iter()

def set_batch_size(self, batch_size: int):
"""
:param batch_size: The new batch size to use
"""
self.current_batch_size = batch_size

def get_batch_size(self) -> int:
"""
:return: The current batch size
"""
return self.current_batch_size

def _data_split_iter(self):
Expand Down Expand Up @@ -83,6 +96,12 @@ def split_batch(batch, start_idx, end_idx):
"""
Splits a batch based on its type (Tensor, Mapping, Sequence) and the provided
indices.

:raises TypeError: If the batch type is not supported
:param batch: The batch to split
:param start_idx: The start index to split at
:param end_idx: The end index to split at
:return: The split batch as a Tensor, Mapping, or Sequence based on the type
"""
if isinstance(batch, torch.Tensor):
return batch[start_idx:end_idx]
Expand All @@ -103,6 +122,11 @@ def split_batch(batch, start_idx, end_idx):
def merge_batches(batches):
"""
Merges a sequence of batches into a single batch.

:raises TypeError: If the batch type is not supported
:param batches: The batches to merge
:return: The merged batch as a Tensor, Mapping, or Sequence
based on the type
"""
sample_batch = batches[0]
if isinstance(sample_batch, torch.Tensor):
Expand All @@ -126,15 +150,31 @@ def merge_batches(batches):


class ModifiableDataPyTorch(ModifiableData[DynamicBatchSizeDataLoader]):
"""
A ModifiableData implementation for PyTorch data loaders.

:param data_loader: The data loader to wrap
:param framework: The framework the data loader is for
"""

def __init__(self, data_loader: DataLoader, framework=None):
super().__init__()
self.data = DynamicBatchSizeDataLoader(data_loader)

def get_num_batches(self) -> int:
"""
:return: The number of batches in the data
"""
return self.num_samples // self.data.get_batch_size()

def set_batch_size(self, batch_size: int):
"""
:param batch_size: The new batch size to use
"""
self.data.set_batch_size(batch_size)

def get_batch_size(self) -> int:
"""
:return: The current batch size
"""
return self.data.get_batch_size()
84 changes: 78 additions & 6 deletions src/sparseml/core/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from enum import Enum, unique
from typing import Optional


Expand All @@ -24,7 +24,14 @@
]


@unique
class EventType(Enum):
"""
An Enum for defining the different types of events that can be triggered
Purpose of each EventType is to trigger coresponding Modifier callback
during sparsification
"""

# training lifecycle
PRE_INIT = "pre_init"
INITIALIZE = "initialize"
Expand All @@ -40,6 +47,14 @@ class EventType(Enum):
OPTIM_POST_STEP = "optim_post_step"

def order(self) -> int:
"""
Returns the priority order of the current EventType,
lower has higher priority

:raises ValueError: if the event type is invalid
:return: The order of the event type, lower has
higher priority
"""
if self == EventType.PRE_INIT:
return 0
elif self == EventType.INITIALIZE:
Expand All @@ -62,33 +77,64 @@ def order(self) -> int:

@dataclass
class Event:
type_: EventType = None
"""
A class for defining an event that can be triggered during
sparsification

:param type_: The type of event
:param steps_per_epoch: The number of steps per epoch
:param batches_per_step: The number of batches per step
:param invocations_per_step: The number of invocations per step
:param global_step: The current global step
:param global_batch: The current global batch
"""

type_: Optional[EventType] = None

steps_per_epoch: int = None
batches_per_step: int = None
invocations_per_step: int = None
steps_per_epoch: Optional[int] = None
batches_per_step: Optional[int] = None
invocations_per_step: Optional[int] = None

global_step: int = 0
global_batch: int = 0

@property
def epoch_based(self) -> bool:
"""
:return: True if the event is based on epochs, False otherwise
"""
return self.steps_per_epoch is not None
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved

@property
def epoch(self) -> int:
"""
:raises ValueError: if the event is not epoch based
:return: The current epoch
"""
return self.global_step // self.steps_per_epoch

@property
def epoch_full(self) -> float:
"""
:raises ValueError: if the event is not epoch based
:return: The current epoch with the fraction of the current step
"""
return self.global_step / float(self.steps_per_epoch)

@property
def epoch_step(self) -> int:
"""
:raises ValueError: if the event is not epoch based
:return: The current step within the current epoch
"""
return self.global_step % self.steps_per_epoch

@property
def epoch_batch(self) -> int:
"""
:raises ValueError: if the event is not epoch based
:return: The current batch within the current epoch
"""
batches_per_epoch = (
self.steps_per_epoch * self.batches_per_step
if self.batches_per_step
Expand All @@ -99,6 +145,11 @@ def epoch_batch(self) -> int:

@property
def current_index(self) -> float:
"""
:raises ValueError: if the event is not epoch based
:return: The current index of the event, which is either the global step
or the epoch with the fraction of the current step
"""
if not self.epoch_based:
return self.global_step

Expand All @@ -109,6 +160,9 @@ def current_index(self) -> float:

@current_index.setter
def current_index(self, value: float):
"""
Sets the current index of the event
"""
if not self.epoch_based:
self.global_step = int(value)
self.global_batch = (
Expand All @@ -126,8 +180,23 @@ def current_index(self, value: float):
)

def should_update(
self, start: Optional[float], end: Optional[float], update: float
self, start: Optional[float], end: Optional[float], update: Optional[float]
):
"""
Returns True if the event should trigger update, False otherwise.
Update should be triggered if the current index is within the start
and end and the current index is close (acceptable error is 1e-10)
to a multiple of the update interval


:param start: The start index to check against, set to None to
ignore start
:param end: The end index to check against, set to None to ignore
end
:param update: The update interval, set to None or 0.0 to always
update, otherwise must be greater than 0.0, defaults to None
:return: True if the event should be updated, False otherwise
"""
current = self.current_index

if start is not None and current < start:
Expand All @@ -139,6 +208,9 @@ def should_update(
return update is None or update <= 0.0 or current % update < 1e-10

def new_instance(self, **kwargs) -> "Event":
"""
:return: A new instance of the event with the provided kwargs
"""
instance = deepcopy(self)
for key, value in kwargs.items():
setattr(instance, key, value)
Expand Down
41 changes: 41 additions & 0 deletions src/sparseml/core/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@


class ModifierFactory:
"""
A factory for loading and registering modifiers

:param _MAIN_PACKAGE_PATH: The path to the main modifiers package
:param _EXPERIMENTAL_PACKAGE_PATH: The path to the experimental modifiers package
:param _loaded: Whether or not the factory has been loaded
:param _main_registry: The registry of main modifiers
:param _experimental_registry: The registry of experimental modifiers
:param _registered_registry: The registry of registered modifiers
:param _errors: The errors that occurred when loading the modifiers
"""

_MAIN_PACKAGE_PATH = "sparseml.modifiers"
_EXPERIMENTAL_PACKAGE_PATH = "sparseml.modifiers.experimental"

Expand All @@ -36,6 +48,10 @@ class ModifierFactory:

@staticmethod
def refresh():
"""
A method to refresh the factory by reloading the modifiers
Note: this will clear any previously registered modifiers
"""
ModifierFactory._main_registry = ModifierFactory.load_from_package(
ModifierFactory._MAIN_PACKAGE_PATH
)
Expand All @@ -46,6 +62,10 @@ def refresh():

@staticmethod
def load_from_package(package_path: str) -> Dict[str, Type[Modifier]]:
"""
:param package_path: The path to the package to load modifiers from
:return: The loaded modifiers, as a mapping of name to class
"""
loaded = {}
main_package = importlib.import_module(package_path)

Expand Down Expand Up @@ -93,6 +113,18 @@ def create(
allow_experimental: bool,
**kwargs,
) -> Modifier:
"""
Instantiate a modifier of the given type from registered modifiers.

:raises ValueError: If no modifier of the given type is found
:param type_: The type of modifier to create
:param framework: The framework the modifier is for
:param allow_registered: Whether or not to allow registered modifiers
:param allow_experimental: Whether or not to allow experimental modifiers
:param kwargs: Additional keyword arguments to pass to the modifier
during instantiation
:return: The instantiated modifier
"""
if type_ in ModifierFactory._errors:
raise ModifierFactory._errors[type_]

Expand Down Expand Up @@ -121,6 +153,15 @@ def create(

@staticmethod
def register(type_: str, modifier_class: Type[Modifier]):
"""
Register a modifier class to be used by the factory.

:raises ValueError: If the provided class does not subclass the Modifier
base class or is not a type
:param type_: The type of modifier to register
:param modifier_class: The class of the modifier to register, must subclass
the Modifier base class
"""
if not issubclass(modifier_class, Modifier):
raise ValueError(
"The provided class does not subclass the Modifier base class."
Expand Down
Loading
Loading