Skip to content

Commit

Permalink
Autoupdate min_lrs for ReduceLROnPlateau if possible, fixes pytorch#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
janeyx99 authored and pytorchmergebot committed Oct 10, 2024
1 parent d50d5df commit f9ed39c
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
54 changes: 54 additions & 0 deletions test/optim/test_lrscheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2405,6 +2405,60 @@ def test_lr_scheduler_state_dict_load(self, LRClass, weights_only):
scheduler2.load_state_dict(state_dict_loaded)
self.assertEqual(scheduler2.state_dict(), state_dict)

@parametrize("min_lr", ["scalar", "list"])
def test_add_param_group_does_not_break_reduce_lr_on_plateau(self, min_lr):
epochs = 20
for param_group in self.opt.param_groups:
param_group["lr"] = 0.5
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
metrics = [1] * 7 + [0.6] + [0.5] * 12
scheduler = ReduceLROnPlateau(
self.opt,
mode="min",
threshold_mode="rel",
threshold=0.1,
patience=5,
cooldown=5,
min_lr=0 if min_lr == "scalar" else [1e-5, 1e-4],
)
for epoch in range(epochs):
# Point is to test the use case in #104361
if epoch == 8:
param = torch.nn.Parameter(torch.rand(2, 3))
self.opt.add_param_group({"params": [param], "lr": 0.05})
if min_lr == "list":
scheduler.min_lrs.append(1e-6)
self.opt.step()
scheduler.step(metrics[epoch])
for param_group, target in zip(self.opt.param_groups, targets):
self.assertEqual(
target[epoch],
param_group["lr"],
msg="LR is wrong in epoch {}: expected {}, got {}".format(
epoch, target[epoch], param_group["lr"]
),
atol=1e-5,
rtol=0,
)

def test_add_param_group_errors_reduce_lr_on_plateau(self):
scheduler = ReduceLROnPlateau(
self.opt,
mode="min",
threshold_mode="rel",
threshold=1e-5,
patience=0,
cooldown=0,
min_lr=[1e-5, 1e-4],
)
param = torch.nn.Parameter(torch.rand(2, 3))
self.opt.add_param_group({"params": [param], "lr": 0.05})
self.opt.step()
scheduler.step(1)
with self.assertRaisesRegex(RuntimeError, "The number of param groups in the"):
self.opt.step()
scheduler.step(1.3)

@parametrize(
"LRClass",
[
Expand Down
16 changes: 16 additions & 0 deletions torch/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,8 +1318,10 @@ def __init__(
raise ValueError(
f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}"
)
self.default_min_lr = None
self.min_lrs = list(min_lr)
else:
self.default_min_lr = min_lr
self.min_lrs = [min_lr] * len(optimizer.param_groups)

self.patience = patience
Expand Down Expand Up @@ -1375,6 +1377,20 @@ def step(self, metrics: SupportsFloat, epoch=None): # type: ignore[override]
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]

def _reduce_lr(self, epoch):
if len(self.optimizer.param_groups) != len(self.min_lrs):
if self.default_min_lr is None:
raise RuntimeError(
"The number of param groups in the `optimizer` "
f"({len(self.optimizer.param_groups)}) differs "
f"from when `ReduceLROnPlateau` was initialized "
f"({len(self.min_lrs)}), usually due to a new "
"param group being added to the optimizer. Please "
"modify the `min_lrs` field to match the length "
"of the `optimizer` param groups."
)
else:
self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)

for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group["lr"])
new_lr = max(old_lr * self.factor, self.min_lrs[i])
Expand Down

0 comments on commit f9ed39c

Please sign in to comment.