Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparsification Refactor for LLMs #1713

Merged
merged 31 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a3934c2
Initial start implementation
markurtz Sep 1, 2023
d953be4
add in further completion state for session and events
markurtz Sep 5, 2023
42aef3d
add in recipe helper functions for merging, loading, and running call…
markurtz Sep 6, 2023
6682784
minor fixes for new framework
markurtz Sep 6, 2023
5b0f190
add constant pruning modifier
markurtz Sep 7, 2023
b8452a5
add magntitude pruning modifier
markurtz Sep 9, 2023
f04ca6f
knowledge distillation implementation
markurtz Sep 10, 2023
c745492
fix import errors and multiframework inits
Satrat Sep 14, 2023
bc73e15
fix import errors and multiframework inits
Satrat Sep 14, 2023
5438e05
initialization
Satrat Sep 14, 2023
996c533
RecipeModifiers working
Satrat Sep 15, 2023
9635acb
fix import errors
markurtz Sep 17, 2023
7ecd5c6
modifiers loading in stages
Satrat Sep 19, 2023
3e2954e
adding test files
Satrat Sep 19, 2023
5eed10d
merge with base and update
Satrat Sep 19, 2023
6b83b02
modifier factory implementation
markurtz Sep 19, 2023
e857729
running example, but sparsity not working correctly
Satrat Sep 19, 2023
55eecc3
merge in factory, make it functional
Satrat Sep 19, 2023
bc5798d
fix polynomial scheduler, leave masks enabled on end
Satrat Sep 20, 2023
a35581d
remove e2e files
Satrat Sep 20, 2023
71869be
add on_event for modifier lifecycle and add initial integration for t…
markurtz Sep 20, 2023
2d04ea0
leave_enabled fixes
Satrat Sep 20, 2023
7b182e4
fixing evals and finalization
Satrat Sep 20, 2023
6c2255f
Add test
rahul-tuli Sep 21, 2023
abeedb7
Add changes to allow accepting strings
rahul-tuli Sep 21, 2023
571d21d
fix recipe staging issue
Satrat Sep 22, 2023
952e4ee
style
Satrat Sep 22, 2023
ed8e0ba
style fixes
Satrat Sep 22, 2023
7236de7
Merge branch 'main' into sparsification-refactor
Satrat Sep 22, 2023
bfd7f84
bug fixes that came up during obcq implementation
Satrat Sep 26, 2023
baf4b66
Merge branch 'main' into sparsification-refactor
bfineran Oct 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/sparseml/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .data import *
from .event import *
from .factory import *
from .framework import *
from .framework_object import *
from .lifecycle import *
from .model import *
from .modifier import *
from .optimizer import *
from .recipe import *
from .state import *
17 changes: 17 additions & 0 deletions src/sparseml/core/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .base import ModifiableData
38 changes: 38 additions & 0 deletions src/sparseml/core/data/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Generic, TypeVar

from sparseml.core.framework_object import MultiFrameworkObject


__all__ = ["ModifiableData"]

DT = TypeVar("DT") # Dataset Type


@dataclass
class ModifiableData(Generic[DT], MultiFrameworkObject):
data: DT = None
num_samples: int = None

def get_num_batches(self) -> int:
raise NotImplementedError()

def set_batch_size(self, batch_size: int):
raise NotImplementedError()

def get_batch_size(self) -> int:
raise NotImplementedError()
140 changes: 140 additions & 0 deletions src/sparseml/core/data/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Mapping, Sequence

import torch
from torch.utils.data import DataLoader

from sparseml.core.data.base import ModifiableData


__all__ = ["ModifiableDataPyTorch", "DynamicBatchSizeDataLoader"]


class DynamicBatchSizeDataLoader:
def __init__(self, data_loader: DataLoader):
self.data_loader = data_loader
self.current_batch_size = data_loader.batch_size

def __iter__(self):
if self.current_batch_size == self.data_loader.batch_size:
yield from self.data_loader
elif self.current_batch_size < self.data_loader.batch_size:
yield from self._data_split_iter()
else:
yield from self._data_merge_iter()

def set_batch_size(self, batch_size: int):
self.current_batch_size = batch_size

def get_batch_size(self) -> int:
return self.current_batch_size

def _data_split_iter(self):
if self.current_batch_size >= self.data_loader.batch_size:
raise ValueError(
"Current batch size must be less than the original batch size"
)

for batch in self.data_loader:
num_splits = self.data_loader.batch_size // self.current_batch_size
for i in range(num_splits):
start_idx = i * self.current_batch_size
end_idx = (i + 1) * self.current_batch_size
yield DynamicBatchSizeDataLoader.split_batch(batch, start_idx, end_idx)

def _data_merge_iter(self):
if self.current_batch_size <= self.data_loader.batch_size:
raise ValueError(
"Current batch size must be greater than the original batch size"
)

buffer = []
buffer_size = 0
for batch in self.data_loader:
buffer.append(batch)
buffer_size += len(batch)
while buffer_size >= self.current_batch_size:
merged = DynamicBatchSizeDataLoader.merge_batches(buffer)
yield DynamicBatchSizeDataLoader.split_batch(
merged, 0, self.current_batch_size
)
buffer = [
DynamicBatchSizeDataLoader.split_batch(
merged, self.current_batch_size, buffer_size
)
]
buffer_size -= self.current_batch_size

@staticmethod
def split_batch(batch, start_idx, end_idx):
"""
Splits a batch based on its type (Tensor, Mapping, Sequence) and the provided
indices.
"""
if isinstance(batch, torch.Tensor):
return batch[start_idx:end_idx]
elif isinstance(batch, Mapping):
return {
key: DynamicBatchSizeDataLoader.split_batch(value, start_idx, end_idx)
for key, value in batch.items()
}
elif isinstance(batch, Sequence):
return [
DynamicBatchSizeDataLoader.split_batch(item, start_idx, end_idx)
for item in batch
]
else:
raise TypeError(f"Unsupported batch type: {type(batch)}")

@staticmethod
def merge_batches(batches):
"""
Merges a sequence of batches into a single batch.
"""
sample_batch = batches[0]
if isinstance(sample_batch, torch.Tensor):
return torch.cat(batches, dim=0)
elif isinstance(sample_batch, Mapping):
return {
key: DynamicBatchSizeDataLoader.merge_batches(
[batch[key] for batch in batches]
)
for key in sample_batch.keys()
}
elif isinstance(sample_batch, Sequence):
return [
DynamicBatchSizeDataLoader.merge_batches(
[batch[i] for batch in batches]
)
for i in range(len(sample_batch))
]
else:
raise TypeError(f"Unsupported batch type: {type(sample_batch)}")


class ModifiableDataPyTorch(ModifiableData[DynamicBatchSizeDataLoader]):
def __init__(self, data_loader: DataLoader, framework=None):
super().__init__()
self.data = DynamicBatchSizeDataLoader(data_loader)

def get_num_batches(self) -> int:
return self.num_samples // self.data.get_batch_size()

def set_batch_size(self, batch_size: int):
self.data.set_batch_size(batch_size)

def get_batch_size(self) -> int:
return self.data.get_batch_size()
146 changes: 146 additions & 0 deletions src/sparseml/core/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Optional


__all__ = [
"EventType",
"Event",
]


class EventType(Enum):
# training lifecycle
PRE_INIT = "pre_init"
INITIALIZE = "initialize"
FINALIZE = "finalize"

# batch lifecycle
BATCH_START = "batch_start"
LOSS_CALCULATED = "loss_calculated"
BATCH_END = "batch_end"

# step lifecycle
OPTIM_PRE_STEP = "optim_pre_step"
OPTIM_POST_STEP = "optim_post_step"

def order(self) -> int:
if self == EventType.PRE_INIT:
return 0
elif self == EventType.INITIALIZE:
return 10
elif self == EventType.FINALIZE:
return 20
elif self == EventType.BATCH_START:
return 100
elif self == EventType.LOSS_CALCULATED:
return 110
elif self == EventType.OPTIM_PRE_STEP:
return 120
elif self == EventType.OPTIM_POST_STEP:
return 130
elif self == EventType.BATCH_END:
return 140
else:
raise ValueError(f"invalid event type {self}")


@dataclass
class Event:
type_: EventType = None

steps_per_epoch: int = None
batches_per_step: int = None
invocations_per_step: int = None

global_step: int = 0
global_batch: int = 0

@property
def epoch_based(self) -> bool:
return self.steps_per_epoch is not None

@property
def epoch(self) -> int:
return self.global_step // self.steps_per_epoch

@property
def epoch_full(self) -> float:
return self.global_step / float(self.steps_per_epoch)

@property
def epoch_step(self) -> int:
return self.global_step % self.steps_per_epoch

@property
def epoch_batch(self) -> int:
batches_per_epoch = (
self.steps_per_epoch * self.batches_per_step
if self.batches_per_step
else self.steps_per_epoch
)

return self.global_batch % batches_per_epoch

@property
def current_index(self) -> float:
if not self.epoch_based:
return self.global_step

if self.epoch_full - self.epoch > 1.0:
raise ValueError("too many steps per epoch for epoch based event")

return self.epoch_full

@current_index.setter
def current_index(self, value: float):
if not self.epoch_based:
self.global_step = int(value)
self.global_batch = (
self.global_step
if self.batches_per_step is None or self.batches_per_step < 2
else self.global_step * self.batches_per_step
)
return

self.global_step = int(value * self.steps_per_epoch)
self.global_batch = (
self.global_step
if self.batches_per_step is None or self.batches_per_step < 2
else self.global_step * self.batches_per_step
)

def should_update(
self, start: Optional[float], end: Optional[float], update: float
):
current = self.current_index

if start is not None and current < start:
return False

if end is not None and current > end:
return False

return update is None or update <= 0.0 or current % update < 1e-10

def new_instance(self, **kwargs) -> "Event":
instance = deepcopy(self)
for key, value in kwargs.items():
setattr(instance, key, value)

return instance
Loading
Loading