Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifier Refactor OBCQ Implementation #1737

Merged
merged 189 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 175 commits
Commits
Show all changes
189 commits
Select commit Hold shift + click to select a range
bd729f3
First abstract structure with OPT, MPT
natuan Aug 18, 2023
0476a14
Merge branch 'main' into tuan/sparsegpt
Satrat Aug 24, 2023
e9c4e95
scaffolding
Satrat Aug 29, 2023
07a7273
First abstract structure with OPT, MPT
natuan Aug 18, 2023
73ccd2e
Updated
natuan Aug 29, 2023
f742f4a
Updated
natuan Aug 29, 2023
7c2d3d4
Initial implementation of Llama2 integration
anmarques Aug 30, 2023
3f4c5d8
Initial implementation of Llama2 integration
anmarques Aug 30, 2023
c181380
Fix llama-2 key
anmarques Aug 30, 2023
dd0f5f8
Add modelutils
anmarques Aug 30, 2023
99f8bbe
Make smoothquant optional
anmarques Aug 30, 2023
db901e4
probe sequence length from model
anmarques Aug 30, 2023
48ebe81
probe sequence length from model
anmarques Aug 30, 2023
2be7def
Fix typo
anmarques Aug 30, 2023
fb346ad
add in recipe and modifier logic
Satrat Aug 30, 2023
c4fc152
Catch additional arguments to load_model and load_data
anmarques Aug 30, 2023
2b50607
Redifinition of seqlen
anmarques Aug 30, 2023
15b2c10
Runable OPT main code path
natuan Aug 31, 2023
34c3d84
mpt modifier
Satrat Aug 31, 2023
4cac803
Merge branch 'tuan/sparsegpt' into sa/sgpt_modifier
Satrat Aug 31, 2023
a3934c2
Initial start implementation
markurtz Sep 1, 2023
6a35b7f
State before rebasing
anmarques Sep 1, 2023
e02abb2
State before rebasing
anmarques Sep 1, 2023
2fb9e66
Merge remote-tracking branch 'origin/tuan/sparsegpt' into research/sp…
anmarques Sep 1, 2023
466213c
Initial support for MPT
natuan Sep 4, 2023
d953be4
add in further completion state for session and events
markurtz Sep 5, 2023
7fc49ac
Initial commit for Llama2
natuan Sep 5, 2023
ff5b9a1
Return model after override Llama attn
natuan Sep 5, 2023
090af1c
working refactor
Satrat Sep 5, 2023
594339f
More memory-efficient layer compression
natuan Sep 5, 2023
5750edf
Make memory-efficient layer compressor default
natuan Sep 5, 2023
6c22fa1
finalization, clean up file saving
Satrat Sep 5, 2023
f3f9119
SmoothQuant enabled
natuan Sep 5, 2023
b2a6817
support for quantization, clean up
Satrat Sep 5, 2023
fa934c5
remove unsued files
Satrat Sep 5, 2023
a0deb44
remove unsued files
Satrat Sep 5, 2023
42aef3d
add in recipe helper functions for merging, loading, and running call…
markurtz Sep 6, 2023
6682784
minor fixes for new framework
markurtz Sep 6, 2023
59d1a72
move to transformers subdir
Satrat Sep 6, 2023
96b9248
quality
Satrat Sep 6, 2023
c2471e1
Merge branch 'main' into sa/sgpt_modifier
Satrat Sep 6, 2023
5b0f190
add constant pruning modifier
markurtz Sep 7, 2023
4e24413
fix modifier inheritance
Satrat Sep 7, 2023
836464d
PR comments, serialization
Satrat Sep 7, 2023
4b511c9
move compression to initializer
Satrat Sep 7, 2023
39bfbbf
clean up recipe
Satrat Sep 7, 2023
aced1d8
Rebase
anmarques Sep 7, 2023
6f3d188
Rebase
anmarques Sep 7, 2023
ae6c2e2
Rebase
anmarques Sep 7, 2023
ae2f8e0
Rebase
anmarques Sep 7, 2023
b52e0b4
Fix call to model eval
anmarques Sep 7, 2023
3647653
Fixes to caching
anmarques Sep 7, 2023
4652fcf
Add support for llama2
anmarques Sep 7, 2023
a81ba7b
Fix ptq-only option
anmarques Sep 7, 2023
f4f0a63
revert quant modifier
Satrat Sep 7, 2023
4e72fd6
Fixes
anmarques Sep 7, 2023
751a5ac
Rebase
anmarques Sep 7, 2023
71e8473
Rebase
anmarques Sep 7, 2023
48dd147
Bugs fixed
natuan Sep 8, 2023
924992c
Rebase
anmarques Sep 8, 2023
7b28a38
Rebase
anmarques Sep 8, 2023
ebf00e8
Rebase
anmarques Sep 8, 2023
4aa9e62
Rebase
anmarques Sep 8, 2023
c19d386
merge and move model loading to helper
Satrat Sep 8, 2023
6200f18
Rebase
anmarques Sep 8, 2023
6444a94
Rebase
anmarques Sep 8, 2023
b8452a5
add magntitude pruning modifier
markurtz Sep 9, 2023
f04ca6f
knowledge distillation implementation
markurtz Sep 10, 2023
83b23f7
docstrings
Satrat Sep 11, 2023
a32321c
docstrings
Satrat Sep 11, 2023
1bb260a
Rebase
anmarques Sep 12, 2023
2bb221f
Rebase
anmarques Sep 12, 2023
16cd0e5
basic llama modifier(not fully tested)
Satrat Sep 12, 2023
8967b20
working llama example
Satrat Sep 12, 2023
bba4cbb
Evaluate model in eval mode
anmarques Sep 12, 2023
19f8610
rebase
anmarques Sep 13, 2023
e89c17f
remove evaluate model for opt
anmarques Sep 13, 2023
3bd26e0
Fix model key for llama2
anmarques Sep 13, 2023
9307286
Fix model key for llama2
anmarques Sep 13, 2023
daf66cf
Fix dataset loading
anmarques Sep 13, 2023
d5f4b62
leave outputs as list
anmarques Sep 13, 2023
75539d5
Clean up
anmarques Sep 13, 2023
355f5c5
Fixes for input data as list
anmarques Sep 13, 2023
c745492
fix import errors and multiframework inits
Satrat Sep 14, 2023
bc73e15
fix import errors and multiframework inits
Satrat Sep 14, 2023
5438e05
initialization
Satrat Sep 14, 2023
4d0fdc3
First abstract structure with OPT, MPT
natuan Aug 18, 2023
0d00837
Updated
natuan Aug 29, 2023
e4a4878
Updated
natuan Aug 29, 2023
bb00c20
Runable OPT main code path
natuan Aug 31, 2023
277ee89
Initial support for MPT
natuan Sep 4, 2023
b454e26
Initial commit for Llama2
natuan Sep 5, 2023
e459c92
Return model after override Llama attn
natuan Sep 5, 2023
da511e9
More memory-efficient layer compression
natuan Sep 5, 2023
42a2845
Make memory-efficient layer compressor default
natuan Sep 5, 2023
55ddec5
SmoothQuant enabled
natuan Sep 5, 2023
aaa68a1
Bugs fixed
natuan Sep 8, 2023
cbc9360
Make load data more flexible
natuan Sep 11, 2023
f48cef6
Initialize scales in eval mode; clean up
natuan Sep 13, 2023
80f9208
Example script and recipe for OPT
natuan Sep 14, 2023
e61dc21
Formatting
natuan Sep 14, 2023
dff49e3
Copyright
natuan Sep 14, 2023
c63001d
Fix code styles
natuan Sep 15, 2023
bd213a5
Format
natuan Sep 15, 2023
32cf3c9
Fixes for channelwise quantization
anmarques Sep 15, 2023
c2afba0
Name fixes
anmarques Sep 15, 2023
996c533
RecipeModifiers working
Satrat Sep 15, 2023
aacdd54
Merge branch 'tuan/sparsegpt' into sa/sgpt_modifier
Satrat Sep 15, 2023
4dab25c
Merge branch 'sa/llama_modifiers' into sa/sgpt_modifier
Satrat Sep 15, 2023
5ae7a87
remove unused buffer
anmarques Sep 15, 2023
ef9da3a
Support for smoothquant
anmarques Sep 15, 2023
9635acb
fix import errors
markurtz Sep 17, 2023
5845359
Reformat smoothquant dict
anmarques Sep 18, 2023
e3166d0
Push smoothquant to a separate file
anmarques Sep 18, 2023
8e797c5
perplexity evaluation for opt
Satrat Sep 18, 2023
7ecd5c6
modifiers loading in stages
Satrat Sep 19, 2023
3e2954e
adding test files
Satrat Sep 19, 2023
5eed10d
merge with base and update
Satrat Sep 19, 2023
0807ba6
Rebase
anmarques Sep 19, 2023
3137424
Rebase
anmarques Sep 19, 2023
72f1d33
Add support to logarithmic activation equalization
anmarques Sep 19, 2023
69bb017
Add support to logarithmic activation equalization
anmarques Sep 19, 2023
7a08733
Add support to logarithmic activation equalization
anmarques Sep 19, 2023
69494db
Fix counter
anmarques Sep 19, 2023
b4fccfb
Rebase
anmarques Sep 19, 2023
ad9d7ea
Rebase
anmarques Sep 19, 2023
4bcf07b
Rebase
anmarques Sep 19, 2023
a981621
Rebase
anmarques Sep 19, 2023
5d9bbfb
Add license message
anmarques Sep 19, 2023
6b83b02
modifier factory implementation
markurtz Sep 19, 2023
e857729
running example, but sparsity not working correctly
Satrat Sep 19, 2023
b134431
Account for when keys are not matched
anmarques Sep 19, 2023
55027ce
Expand caching to include inputs. Move main ppl logic to utils
anmarques Sep 19, 2023
925fa61
Expand caching to include inputs. Move main ppl logic to utils
anmarques Sep 19, 2023
e88aa5d
Update opt integration
anmarques Sep 19, 2023
55eecc3
merge in factory, make it functional
Satrat Sep 19, 2023
bc5798d
fix polynomial scheduler, leave masks enabled on end
Satrat Sep 20, 2023
a35581d
remove e2e files
Satrat Sep 20, 2023
71869be
add on_event for modifier lifecycle and add initial integration for t…
markurtz Sep 20, 2023
2d04ea0
leave_enabled fixes
Satrat Sep 20, 2023
7b182e4
fixing evals and finalization
Satrat Sep 20, 2023
031c539
rebasing research and cleaning up perplexity
Satrat Sep 21, 2023
ddf35be
remove local files
Satrat Sep 21, 2023
4e3fd35
style and base obcq modifier
Satrat Sep 21, 2023
4c34fae
Merge branch 'sa/sgpt_modifier' into refactor_obcq
Satrat Sep 21, 2023
8015028
[untested] convert obcq to new framework
Satrat Sep 21, 2023
727928f
obcq working
Satrat Sep 21, 2023
6c2255f
Add test
rahul-tuli Sep 21, 2023
abeedb7
Add changes to allow accepting strings
rahul-tuli Sep 21, 2023
c7848e5
update llama example recipe
Satrat Sep 22, 2023
571d21d
fix recipe staging issue
Satrat Sep 22, 2023
952e4ee
style
Satrat Sep 22, 2023
ed8e0ba
style fixes
Satrat Sep 22, 2023
7236de7
Merge branch 'main' into sparsification-refactor
Satrat Sep 22, 2023
42a235b
reorg file structure
Satrat Sep 25, 2023
ef0ef18
quant modifier in new framework, opt tested
Satrat Sep 25, 2023
4d1c716
Merge branch 'sparsification-refactor' into refactor_obcq
Satrat Sep 25, 2023
8887b61
Removing custom smoothquant from this branch
anmarques Sep 26, 2023
4597497
Merge branch 'main' into research/sparsegpt/llama2
anmarques Sep 26, 2023
a1d15c8
post one shot calibration, recipe update
Satrat Sep 26, 2023
bfd7f84
bug fixes that came up during obcq implementation
Satrat Sep 26, 2023
05e0efb
Merge branch 'sparsification-refactor' into refactor_obcq
Satrat Sep 26, 2023
be06113
moving obcq script
Satrat Sep 26, 2023
629f9c5
quant fix, prunen and m
Satrat Sep 27, 2023
13ac2b9
fix experimental import paths, add comparison test
Satrat Sep 27, 2023
0cadcf3
return perplexity
Satrat Sep 27, 2023
76a8391
Merge branch 'research/sparsegpt/llama2' into refactor_obcq
Satrat Sep 29, 2023
d9e969c
style
Satrat Sep 29, 2023
eec247d
specify compressible layers in recipe
Satrat Sep 29, 2023
5a985d0
move attention cache to base class
Satrat Oct 3, 2023
bae96af
move attention cache to base class
Satrat Oct 3, 2023
6067521
clean up test scripts
Satrat Oct 3, 2023
27ba629
clean up bottom compressors and comments
Satrat Oct 4, 2023
0bdee59
documentation
Satrat Oct 4, 2023
ac29d68
fix typos
Satrat Oct 4, 2023
7617f9c
PR comments
Satrat Oct 5, 2023
f216d53
small bug fix on logger
Satrat Oct 5, 2023
dcbbbd4
Merge branch 'main' into refactor_obcq
Satrat Oct 5, 2023
de3dc2f
fixing transformer dependency, adding registry for dataset loaders
Satrat Oct 5, 2023
39ff676
fixing bugs in dataset loading and perplexity
Satrat Oct 6, 2023
7423f20
cleanup
Satrat Oct 6, 2023
25ba312
Merge branch 'main' into refactor_obcq
Satrat Oct 6, 2023
13c3a6c
fix memory issue in comparison
Satrat Oct 6, 2023
b5e9d6d
return perplexity
Satrat Oct 6, 2023
1decfe5
adding split to all datasets
Satrat Oct 6, 2023
0a0cff9
Merge branch 'refactor_obcq' of https://github.com/neuralmagic/sparse…
Satrat Oct 6, 2023
5fa7361
Update src/sparseml/modifiers/obcq/base.py
Satrat Oct 9, 2023
f1441cb
Merge branch 'main' into refactor_obcq
Satrat Oct 9, 2023
29271a3
fixing dataset issues
Satrat Oct 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/sparseml/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .data import *
from .event import *
from .factory import *
from .framework import *
from .framework_object import *
from .lifecycle import *
from .model import *
from .modifier import *
from .optimizer import *
from .recipe import *
from .state import *
17 changes: 17 additions & 0 deletions src/sparseml/core/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .base import ModifiableData
38 changes: 38 additions & 0 deletions src/sparseml/core/data/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Generic, TypeVar

from sparseml.core.framework_object import MultiFrameworkObject


__all__ = ["ModifiableData"]

DT = TypeVar("DT") # Dataset Type


@dataclass
class ModifiableData(Generic[DT], MultiFrameworkObject):
data: DT = None
num_samples: int = None

def get_num_batches(self) -> int:
raise NotImplementedError()

def set_batch_size(self, batch_size: int):
raise NotImplementedError()

def get_batch_size(self) -> int:
raise NotImplementedError()
140 changes: 140 additions & 0 deletions src/sparseml/core/data/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Mapping, Sequence

import torch
from torch.utils.data import DataLoader

from sparseml.core.data.base import ModifiableData


__all__ = ["ModifiableDataPyTorch", "DynamicBatchSizeDataLoader"]


class DynamicBatchSizeDataLoader:
def __init__(self, data_loader: DataLoader):
self.data_loader = data_loader
self.current_batch_size = data_loader.batch_size

def __iter__(self):
if self.current_batch_size == self.data_loader.batch_size:
yield from self.data_loader
elif self.current_batch_size < self.data_loader.batch_size:
yield from self._data_split_iter()
else:
yield from self._data_merge_iter()

def set_batch_size(self, batch_size: int):
self.current_batch_size = batch_size

def get_batch_size(self) -> int:
return self.current_batch_size

def _data_split_iter(self):
if self.current_batch_size >= self.data_loader.batch_size:
raise ValueError(
"Current batch size must be less than the original batch size"
)

for batch in self.data_loader:
num_splits = self.data_loader.batch_size // self.current_batch_size
for i in range(num_splits):
start_idx = i * self.current_batch_size
end_idx = (i + 1) * self.current_batch_size
yield DynamicBatchSizeDataLoader.split_batch(batch, start_idx, end_idx)

def _data_merge_iter(self):
if self.current_batch_size <= self.data_loader.batch_size:
raise ValueError(
"Current batch size must be greater than the original batch size"
)

buffer = []
buffer_size = 0
for batch in self.data_loader:
buffer.append(batch)
buffer_size += len(batch)
while buffer_size >= self.current_batch_size:
merged = DynamicBatchSizeDataLoader.merge_batches(buffer)
yield DynamicBatchSizeDataLoader.split_batch(
merged, 0, self.current_batch_size
)
buffer = [
DynamicBatchSizeDataLoader.split_batch(
merged, self.current_batch_size, buffer_size
)
]
buffer_size -= self.current_batch_size

@staticmethod
def split_batch(batch, start_idx, end_idx):
"""
Splits a batch based on its type (Tensor, Mapping, Sequence) and the provided
indices.
"""
if isinstance(batch, torch.Tensor):
return batch[start_idx:end_idx]
elif isinstance(batch, Mapping):
return {
key: DynamicBatchSizeDataLoader.split_batch(value, start_idx, end_idx)
for key, value in batch.items()
}
elif isinstance(batch, Sequence):
return [
DynamicBatchSizeDataLoader.split_batch(item, start_idx, end_idx)
for item in batch
]
else:
raise TypeError(f"Unsupported batch type: {type(batch)}")

@staticmethod
def merge_batches(batches):
"""
Merges a sequence of batches into a single batch.
"""
sample_batch = batches[0]
if isinstance(sample_batch, torch.Tensor):
return torch.cat(batches, dim=0)
elif isinstance(sample_batch, Mapping):
return {
key: DynamicBatchSizeDataLoader.merge_batches(
[batch[key] for batch in batches]
)
for key in sample_batch.keys()
}
elif isinstance(sample_batch, Sequence):
return [
DynamicBatchSizeDataLoader.merge_batches(
[batch[i] for batch in batches]
)
for i in range(len(sample_batch))
]
else:
raise TypeError(f"Unsupported batch type: {type(sample_batch)}")


class ModifiableDataPyTorch(ModifiableData[DynamicBatchSizeDataLoader]):
def __init__(self, data_loader: DataLoader, framework=None):
super().__init__()
self.data = DynamicBatchSizeDataLoader(data_loader)

def get_num_batches(self) -> int:
return self.num_samples // self.data.get_batch_size()

def set_batch_size(self, batch_size: int):
self.data.set_batch_size(batch_size)

def get_batch_size(self) -> int:
return self.data.get_batch_size()
146 changes: 146 additions & 0 deletions src/sparseml/core/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Optional


__all__ = [
"EventType",
"Event",
]


class EventType(Enum):
# training lifecycle
PRE_INIT = "pre_init"
INITIALIZE = "initialize"
FINALIZE = "finalize"

# batch lifecycle
BATCH_START = "batch_start"
LOSS_CALCULATED = "loss_calculated"
BATCH_END = "batch_end"

# step lifecycle
OPTIM_PRE_STEP = "optim_pre_step"
OPTIM_POST_STEP = "optim_post_step"

def order(self) -> int:
if self == EventType.PRE_INIT:
return 0
elif self == EventType.INITIALIZE:
return 10
elif self == EventType.FINALIZE:
return 20
elif self == EventType.BATCH_START:
return 100
elif self == EventType.LOSS_CALCULATED:
return 110
elif self == EventType.OPTIM_PRE_STEP:
return 120
elif self == EventType.OPTIM_POST_STEP:
return 130
elif self == EventType.BATCH_END:
return 140
else:
raise ValueError(f"invalid event type {self}")


@dataclass
class Event:
type_: EventType = None

steps_per_epoch: int = None
batches_per_step: int = None
invocations_per_step: int = None

global_step: int = 0
global_batch: int = 0

@property
def epoch_based(self) -> bool:
return self.steps_per_epoch is not None

@property
def epoch(self) -> int:
return self.global_step // self.steps_per_epoch

@property
def epoch_full(self) -> float:
return self.global_step / float(self.steps_per_epoch)

@property
def epoch_step(self) -> int:
return self.global_step % self.steps_per_epoch

@property
def epoch_batch(self) -> int:
batches_per_epoch = (
self.steps_per_epoch * self.batches_per_step
if self.batches_per_step
else self.steps_per_epoch
)

return self.global_batch % batches_per_epoch

@property
def current_index(self) -> float:
if not self.epoch_based:
return self.global_step

if self.epoch_full - self.epoch > 1.0:
raise ValueError("too many steps per epoch for epoch based event")

return self.epoch_full

@current_index.setter
def current_index(self, value: float):
if not self.epoch_based:
self.global_step = int(value)
self.global_batch = (
self.global_step
if self.batches_per_step is None or self.batches_per_step < 2
else self.global_step * self.batches_per_step
)
return

self.global_step = int(value * self.steps_per_epoch)
self.global_batch = (
self.global_step
if self.batches_per_step is None or self.batches_per_step < 2
else self.global_step * self.batches_per_step
)

def should_update(
self, start: Optional[float], end: Optional[float], update: float
):
current = self.current_index

if start is not None and current < start:
return False

if end is not None and current > end:
return False

return update is None or update <= 0.0 or current % update < 1e-10

def new_instance(self, **kwargs) -> "Event":
instance = deepcopy(self)
for key, value in kwargs.items():
setattr(instance, key, value)

return instance
Loading