Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Rmsprop optimiser #8

Merged
merged 10 commits into from
May 1, 2022
Merged

Implement Rmsprop optimiser #8

merged 10 commits into from
May 1, 2022

Conversation

future-xy
Copy link
Contributor

No description provided.

@future-xy future-xy requested a review from JieRen98 April 29, 2022 08:38
del params
nu = _update_moment_per_elem_norm(updates, state.nu, decay, 2, inplace)
if inplace:
def f(g, n): return g.mul_(torch.rsqrt_(n.add_(eps)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该不太对,n.add_会把n给inplace修改掉

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该只希望修改g,不希望修改n

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改了这里,同理还修正了其他会修改n的地方。另外修了一个test脚本的bug。

def f(g, n): return g.mul(torch.rsqrt(n.add(eps)))
# """The followings are pytorch style"""
# if inplace:
# def f(g, n): return g.div_(torch.sqrt_(n).add_(eps))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议把这个也修改一下,似乎应该是torch.sqrt(n).add_

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

358行吗?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的

@JieRen98
Copy link
Collaborator

是不是应该RMSprop -> RMSProp涅?

@future-xy
Copy link
Contributor Author

TF/PyTorch/Keras是RMSprop,Optax是RMSProp。那就按照Optax来吧

@JieRen98
Copy link
Collaborator

TF/PyTorch/Keras是RMSprop,Optax是RMSProp。那就按照Optax来吧

我看挺多论文都叫RMSProp

@future-xy
Copy link
Contributor Author

TF/PyTorch/Keras是RMSprop,Optax是RMSProp。那就按照Optax来吧

我看挺多论文都叫RMSProp

已提交修改

@JieRen98 JieRen98 merged commit 5d29a07 into metaopt:main May 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants