Skip to content

Commit

Permalink
Fixed issue with torch 1.12 where _scale_fn_ref is missing in CyclicLR (
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Oct 26, 2023
1 parent ffe6d25 commit 23b4f7a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,5 +1643,7 @@ def get_scheduler_state(scheduler) -> Dict[str, Tensor]:

state = scheduler.state_dict()
if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0):
del state["_scale_fn_ref"]
# A check is needed since torch 1.12 does not have the _scale_fn_ref attribute, while other versions do
if "_scale_fn_ref" in state:
del state["_scale_fn_ref"]
return state

0 comments on commit 23b4f7a

Please sign in to comment.