Skip to content

Commit

Permalink
Sparsification Refactor for LLMs (#1713)
Browse files Browse the repository at this point in the history
* Initial start implementation

* add in further completion state for session and events

* add in recipe helper functions for merging, loading, and running callbacks

* minor fixes for new framework

* add constant pruning modifier

* add magntitude pruning modifier

* knowledge distillation implementation

* fix import errors and multiframework inits

* fix import errors and multiframework inits

* initialization

* RecipeModifiers working

* fix import errors

* modifiers loading in stages

* adding test files

* modifier factory implementation

* running example, but sparsity not working correctly

* fix polynomial scheduler, leave masks enabled on end

* remove e2e files

* add on_event for modifier lifecycle and add initial integration for torchvision

* leave_enabled fixes

* fixing evals and finalization

* Add test

* Add changes to allow accepting strings

* fix recipe staging issue

* style

* style fixes

* bug fixes that came up during obcq implementation

---------

Co-authored-by: Sara Adkins <sara@neuralmagic.com>
Co-authored-by: rahul-tuli <rahul@neuralmagic.com>
  • Loading branch information
3 people committed Oct 4, 2023
1 parent 2d0f8a0 commit b9d6b70
Show file tree
Hide file tree
Showing 67 changed files with 5,160 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/sparseml/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .data import *
from .event import *
from .factory import *
from .framework import *
from .framework_object import *
from .lifecycle import *
from .model import *
from .modifier import *
from .optimizer import *
from .recipe import *
from .state import *
17 changes: 17 additions & 0 deletions src/sparseml/core/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .base import ModifiableData
38 changes: 38 additions & 0 deletions src/sparseml/core/data/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Generic, TypeVar

from sparseml.core.framework_object import MultiFrameworkObject


__all__ = ["ModifiableData"]

DT = TypeVar("DT") # Dataset Type


@dataclass
class ModifiableData(Generic[DT], MultiFrameworkObject):
data: DT = None
num_samples: int = None

def get_num_batches(self) -> int:
raise NotImplementedError()

def set_batch_size(self, batch_size: int):
raise NotImplementedError()

def get_batch_size(self) -> int:
raise NotImplementedError()
140 changes: 140 additions & 0 deletions src/sparseml/core/data/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Mapping, Sequence

import torch
from torch.utils.data import DataLoader

from sparseml.core.data.base import ModifiableData


__all__ = ["ModifiableDataPyTorch", "DynamicBatchSizeDataLoader"]


class DynamicBatchSizeDataLoader:
def __init__(self, data_loader: DataLoader):
self.data_loader = data_loader
self.current_batch_size = data_loader.batch_size

def __iter__(self):
if self.current_batch_size == self.data_loader.batch_size:
yield from self.data_loader
elif self.current_batch_size < self.data_loader.batch_size:
yield from self._data_split_iter()
else:
yield from self._data_merge_iter()

def set_batch_size(self, batch_size: int):
self.current_batch_size = batch_size

def get_batch_size(self) -> int:
return self.current_batch_size

def _data_split_iter(self):
if self.current_batch_size >= self.data_loader.batch_size:
raise ValueError(
"Current batch size must be less than the original batch size"
)

for batch in self.data_loader:
num_splits = self.data_loader.batch_size // self.current_batch_size
for i in range(num_splits):
start_idx = i * self.current_batch_size
end_idx = (i + 1) * self.current_batch_size
yield DynamicBatchSizeDataLoader.split_batch(batch, start_idx, end_idx)

def _data_merge_iter(self):
if self.current_batch_size <= self.data_loader.batch_size:
raise ValueError(
"Current batch size must be greater than the original batch size"
)

buffer = []
buffer_size = 0
for batch in self.data_loader:
buffer.append(batch)
buffer_size += len(batch)
while buffer_size >= self.current_batch_size:
merged = DynamicBatchSizeDataLoader.merge_batches(buffer)
yield DynamicBatchSizeDataLoader.split_batch(
merged, 0, self.current_batch_size
)
buffer = [
DynamicBatchSizeDataLoader.split_batch(
merged, self.current_batch_size, buffer_size
)
]
buffer_size -= self.current_batch_size

@staticmethod
def split_batch(batch, start_idx, end_idx):
"""
Splits a batch based on its type (Tensor, Mapping, Sequence) and the provided
indices.
"""
if isinstance(batch, torch.Tensor):
return batch[start_idx:end_idx]
elif isinstance(batch, Mapping):
return {
key: DynamicBatchSizeDataLoader.split_batch(value, start_idx, end_idx)
for key, value in batch.items()
}
elif isinstance(batch, Sequence):
return [
DynamicBatchSizeDataLoader.split_batch(item, start_idx, end_idx)
for item in batch
]
else:
raise TypeError(f"Unsupported batch type: {type(batch)}")

@staticmethod
def merge_batches(batches):
"""
Merges a sequence of batches into a single batch.
"""
sample_batch = batches[0]
if isinstance(sample_batch, torch.Tensor):
return torch.cat(batches, dim=0)
elif isinstance(sample_batch, Mapping):
return {
key: DynamicBatchSizeDataLoader.merge_batches(
[batch[key] for batch in batches]
)
for key in sample_batch.keys()
}
elif isinstance(sample_batch, Sequence):
return [
DynamicBatchSizeDataLoader.merge_batches(
[batch[i] for batch in batches]
)
for i in range(len(sample_batch))
]
else:
raise TypeError(f"Unsupported batch type: {type(sample_batch)}")


class ModifiableDataPyTorch(ModifiableData[DynamicBatchSizeDataLoader]):
def __init__(self, data_loader: DataLoader, framework=None):
super().__init__()
self.data = DynamicBatchSizeDataLoader(data_loader)

def get_num_batches(self) -> int:
return self.num_samples // self.data.get_batch_size()

def set_batch_size(self, batch_size: int):
self.data.set_batch_size(batch_size)

def get_batch_size(self) -> int:
return self.data.get_batch_size()
146 changes: 146 additions & 0 deletions src/sparseml/core/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Optional


__all__ = [
"EventType",
"Event",
]


class EventType(Enum):
# training lifecycle
PRE_INIT = "pre_init"
INITIALIZE = "initialize"
FINALIZE = "finalize"

# batch lifecycle
BATCH_START = "batch_start"
LOSS_CALCULATED = "loss_calculated"
BATCH_END = "batch_end"

# step lifecycle
OPTIM_PRE_STEP = "optim_pre_step"
OPTIM_POST_STEP = "optim_post_step"

def order(self) -> int:
if self == EventType.PRE_INIT:
return 0
elif self == EventType.INITIALIZE:
return 10
elif self == EventType.FINALIZE:
return 20
elif self == EventType.BATCH_START:
return 100
elif self == EventType.LOSS_CALCULATED:
return 110
elif self == EventType.OPTIM_PRE_STEP:
return 120
elif self == EventType.OPTIM_POST_STEP:
return 130
elif self == EventType.BATCH_END:
return 140
else:
raise ValueError(f"invalid event type {self}")


@dataclass
class Event:
type_: EventType = None

steps_per_epoch: int = None
batches_per_step: int = None
invocations_per_step: int = None

global_step: int = 0
global_batch: int = 0

@property
def epoch_based(self) -> bool:
return self.steps_per_epoch is not None

@property
def epoch(self) -> int:
return self.global_step // self.steps_per_epoch

@property
def epoch_full(self) -> float:
return self.global_step / float(self.steps_per_epoch)

@property
def epoch_step(self) -> int:
return self.global_step % self.steps_per_epoch

@property
def epoch_batch(self) -> int:
batches_per_epoch = (
self.steps_per_epoch * self.batches_per_step
if self.batches_per_step
else self.steps_per_epoch
)

return self.global_batch % batches_per_epoch

@property
def current_index(self) -> float:
if not self.epoch_based:
return self.global_step

if self.epoch_full - self.epoch > 1.0:
raise ValueError("too many steps per epoch for epoch based event")

return self.epoch_full

@current_index.setter
def current_index(self, value: float):
if not self.epoch_based:
self.global_step = int(value)
self.global_batch = (
self.global_step
if self.batches_per_step is None or self.batches_per_step < 2
else self.global_step * self.batches_per_step
)
return

self.global_step = int(value * self.steps_per_epoch)
self.global_batch = (
self.global_step
if self.batches_per_step is None or self.batches_per_step < 2
else self.global_step * self.batches_per_step
)

def should_update(
self, start: Optional[float], end: Optional[float], update: float
):
current = self.current_index

if start is not None and current < start:
return False

if end is not None and current > end:
return False

return update is None or update <= 0.0 or current % update < 1e-10

def new_instance(self, **kwargs) -> "Event":
instance = deepcopy(self)
for key, value in kwargs.items():
setattr(instance, key, value)

return instance
Loading

0 comments on commit b9d6b70

Please sign in to comment.