Skip to content

Commit

Permalink
Force conversion of LR variables to float due to type coercion from m…
Browse files Browse the repository at this point in the history
…anager serialization (#384)
  • Loading branch information
markurtz committed Sep 10, 2021
1 parent 47d9472 commit 8dda5e7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/sparseml/optim/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def validate_lr_info(self):
else:
raise ValueError("unknown lr_class given of {}".format(self._lr_class))

if isinstance(self._init_lr, str):
self._init_lr = float(self._init_lr)

if self._init_lr <= 0.0:
raise ValueError("init_lr must be greater than 0")

Expand Down
6 changes: 6 additions & 0 deletions src/sparseml/pytorch/optim/modifier_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,9 @@ def validate(self):
if self.lr_func not in lr_funcs:
raise ValueError(f"lr_func must be one of {lr_funcs}")

if isinstance(self.init_lr, str):
self.init_lr = float(self.init_lr)

if (
(not self.init_lr and self.init_lr != 0)
or self.init_lr < 0.0
Expand All @@ -423,6 +426,9 @@ def validate(self):
f"init_lr must be within range [0.0, 1.0], given {self.init_lr}"
)

if isinstance(self.final_lr, str):
self.final_lr = float(self.final_lr)

if (
(not self.final_lr and self.final_lr != 0)
or self.final_lr < 0.0
Expand Down

0 comments on commit 8dda5e7

Please sign in to comment.