Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in MultiStepLR lr scheduler #31828

Closed
Steve-Tod opened this issue Jan 3, 2020 · 7 comments
Closed

Bug in MultiStepLR lr scheduler #31828

Steve-Tod opened this issue Jan 3, 2020 · 7 comments
Assignees
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Steve-Tod
Copy link

Steve-Tod commented Jan 3, 2020

🐛 Bug

Adding epoch argument to step() function of MultiStepLR lead to false learning rate.

To Reproduce

from torch import nn
import torch
net = nn.Linear(30, 10)
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
s = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20, 30], gamma=0.1)
print(s.get_lr())
s.step(1)
print(s.get_lr())

Output

[0.001]
[1.0000000000000002e-06]

Expected behavior

[0.001]
[0.001]

Environment

PyTorch version: 1.4.0a0+d5bf51b
Is debug build: No
CUDA used to build PyTorch: 9.0

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.14.0

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: TITAN Xp
GPU 1: TITAN Xp
GPU 2: TITAN Xp
GPU 3: TITAN Xp

Nvidia driver version: 430.26
cuDNN version: Could not collect

Versions of relevant libraries:

[pip] numpy==1.17.3
[pip] torch==1.4.0a0+d5bf51b
[conda] blas 1.0 mkl
[conda] magma-cuda90 2.5.0 1 pytorch
[conda] mkl 2019.4 243
[conda] mkl-include 2019.4 243
[conda] mkl-service 2.3.0 py36he904b0f_0
[conda] mkl_fft 1.0.15 py36ha843d7b_0
[conda] mkl_random 1.1.0 py36hd6b4f25_0
[conda] torch 1.4.0a0+d5bf51b pypi_0 pypi

Additional context

Possible cause might be that the milestones of MultiStepLR is a counter rather then a list, which leads to false action of bisect in get_lr function.

cc @vincentqb

@ssnl
Copy link
Collaborator

ssnl commented Jan 3, 2020

you should use get_last_lr.

@Steve-Tod
Copy link
Author

you should use get_last_lr.

The result is the same.

image

@Steve-Tod
Copy link
Author

Steve-Tod commented Jan 5, 2020

Did anyone try this?

@jerryzh168 jerryzh168 added module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 8, 2020
@vincentqb
Copy link
Contributor

If you are compiling from master, please make sure to have the latest. Schedulers no longer take the epoch parameter in step. A warning against doing so will be raised with #31125. The code should be the following.

from torch import nn
import torch
net = nn.Linear(30, 10)
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
s = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20, 30], gamma=0.1)

print(s. get_last_lr())
s.step()
print(s. get_last_lr())

@vincentqb vincentqb self-assigned this Jan 9, 2020
@Steve-Tod
Copy link
Author

Thank you for your response!
But I'm curious why remove this epoch parameter. It is convenient if I want to resume from a checkpoint and continue using the same milestones.

@vincentqb
Copy link
Contributor

vincentqb commented Jan 10, 2020

Thank you for your response!
But I'm curious why remove this epoch parameter.

Not all schedulers support that parameter in the first place. Moreover, we made schedulers chainable, and the epoch parameter doesn't extend nicely. See also #26423

It is convenient if I want to resume from a checkpoint and continue using the same milestones.

For this, you can run the the scheduler over a loop, or you can save the state and load.

@AIRCAP
Copy link

AIRCAP commented Feb 25, 2020

this bug is a duplicate of #33229

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants