Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adagrad optimizer support #80

Merged
merged 54 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
640e3b7
feat(torchopt): adagrad optimizer support
XuehaiPan Oct 13, 2022
ba6be61
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Feb 15, 2023
21102bf
fix: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] Feb 15, 2023
035a429
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Feb 17, 2023
c4a1899
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Feb 20, 2023
bf029ae
fix: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] Feb 20, 2023
61552b2
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Feb 23, 2023
3830947
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Feb 27, 2023
5dcf35f
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 1, 2023
eb31c43
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
dac67fb
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
51bafa9
Merge branch 'feature/adagrad' of https://github.com/Benjamin-eecs/to…
Benjamin-eecs Mar 3, 2023
a953329
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
9786565
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
449bdb0
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
75b2bfb
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
3f28f98
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
ae56e25
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
91c7086
feat: adagrad integration
Benjamin-eecs Mar 3, 2023
2f78e60
feat: adagrad integration
Benjamin-eecs Mar 4, 2023
c8e74f4
feat: adagrad integration
Benjamin-eecs Mar 4, 2023
fd4e257
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 7, 2023
95be0cb
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 9, 2023
9718cc0
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 9, 2023
7e76a7e
feat: adagrad integration
Benjamin-eecs Mar 11, 2023
1077916
feat: adagrad integration
Benjamin-eecs Mar 11, 2023
fc43b03
Merge branch 'feature/adagrad' of https://github.com/Benjamin-eecs/to…
Benjamin-eecs Mar 11, 2023
93d9daf
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 13, 2023
11c99d7
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 14, 2023
5e64fe1
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 17, 2023
adf641e
fix: [pre-commit.ci] auto fixes [...]
pre-commit-ci[bot] Mar 17, 2023
bb50658
Merge branch 'metaopt:main' into feature/adagrad
Benjamin-eecs Mar 17, 2023
3ca005c
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
9a17c10
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
79036ed
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
85709e3
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
5636992
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
3ede2b4
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
4c4e1a3
feat: adagrad integration
Benjamin-eecs Mar 18, 2023
d431749
feat(torchopt.optim): update torchopt/alias/adagrad.py
Benjamin-eecs Mar 20, 2023
03e5b42
feat: adagrad integration
Benjamin-eecs Mar 20, 2023
11d1c1f
Merge branch 'feature/adagrad' of https://github.com/Benjamin-eecs/to…
Benjamin-eecs Mar 20, 2023
ace6b5e
feat: adagrad integration
Benjamin-eecs Mar 21, 2023
a937d6b
feat: adagrad integration
Benjamin-eecs Mar 21, 2023
beab339
feat: adagrad integration
Benjamin-eecs Mar 21, 2023
26afa12
feat: adagrad integration
Benjamin-eecs Mar 21, 2023
91bb5c2
feat: adagrad integration
Benjamin-eecs Mar 21, 2023
87b5219
Merge branch 'main' into feature/adagrad
XuehaiPan Mar 22, 2023
38482bf
fix: ca pi gu 💩
XuehaiPan Mar 22, 2023
61234cd
refactor: refactor scale_by_rss
XuehaiPan Mar 22, 2023
d34c534
test: fix eps value for AdaGrad
XuehaiPan Mar 22, 2023
ccf87e2
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Mar 22, 2023
c188839
revert: revert change in torchopt/version.py
XuehaiPan Mar 22, 2023
fbb68f8
chore: update CHANGELOG
Benjamin-eecs Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Implement AdaGrad optimizer and exponential learning rate decay schedule by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#80](https://github.com/metaopt/torchopt/pull/80).
- Enable tests on Windows by [@XuehaiPan](https://github.com/XuehaiPan) in [#140](https://github.com/metaopt/torchopt/pull/140).
- Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/torchopt/pull/138) and [#139](https://github.com/metaopt/torchopt/pull/139).
- Add more documentation on implicit differentiation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan).
- Add more documentation on implicit differentiation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#143](https://github.com/metaopt/torchopt/pull/143).

### Changed

Expand Down
60 changes: 39 additions & 21 deletions docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,23 @@ Functional Optimizers
.. autosummary::

FuncOptimizer
adagrad
adam
sgd
rmsprop
adamw
rmsprop
sgd

Wrapper for Function Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: FuncOptimizer
:members:

Functional AdaGrad Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: adagrad

Functional Adam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -51,16 +57,16 @@ Functional AdamW Optimizer

.. autofunction:: adamw

Functional SGD Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: sgd

Functional RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: rmsprop

Functional SGD Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: sgd

------

Classic Optimizers
Expand All @@ -70,10 +76,16 @@ Classic Optimizers

.. autosummary::

AdaGrad
Adam
SGD
RMSProp
AdamW
RMSProp
SGD

Classic AdaGrad Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: AdaGrad

Classic Adam Optimizer
~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -85,16 +97,16 @@ Classic AdamW Optimizer

.. autoclass:: AdamW

Classic SGD Optimizer
~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: SGD

Classic RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: RMSProp

Classic SGD Optimizer
~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: SGD

------

Differentiable Meta-Optimizers
Expand All @@ -104,10 +116,16 @@ Differentiable Meta-Optimizers

.. autosummary::

MetaAdaGrad
MetaAdam
MetaSGD
MetaRMSProp
MetaAdamW
MetaRMSProp
MetaSGD

Differentiable Meta-AdaGrad Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaAdaGrad

Differentiable Meta-Adam Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -119,16 +137,16 @@ Differentiable Meta-AdamW Optimizer

.. autoclass:: MetaAdamW

Differentiable Meta-SGD Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaSGD

Differentiable Meta-RMSProp Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaRMSProp

Differentiable Meta-SGD Optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MetaSGD

------

Implicit Differentiation
Expand Down
5 changes: 3 additions & 2 deletions docs/source/explicit_diff/explicit_diff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho
.. autosummary::

torchopt.MetaOptimizer
torchopt.MetaAdaGrad
torchopt.MetaAdam
torchopt.MetaSGD
torchopt.MetaRMSProp
torchopt.MetaAdamW
torchopt.MetaRMSProp
torchopt.MetaSGD

By combining low-level API :class:`torchopt.MetaOptimizer` with the previous functional optimizer, we can achieve high-level API:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/implicit_diff/implicit_diff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ In `IMAML <https://arxiv.org/abs/1909.04630>`_, the function :math:`F` in the fi
Fixed-point Iteration
~~~~~~~~~~~~~~~~~~~~~

Sometimes the inner-level optimal solution can also be achieved by fixed point where the optionality :math:`T` takes the form:
Sometimes the inner-level optimal solution can also be achieved by fixed point where the optimality :math:`T` takes the form:

.. math::

Expand Down
11 changes: 7 additions & 4 deletions docs/source/optimizer/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ Currently, TorchOpt supports 4 functional optimizers: :func:`sgd`, :func:`adam`,
.. autosummary::

torchopt.FuncOptimizer
torchopt.adagrad
torchopt.adam
torchopt.sgd
torchopt.rmsprop
torchopt.adamw
torchopt.rmsprop
torchopt.sgd

Apply Parameter Updates
-----------------------
Expand Down Expand Up @@ -84,10 +85,12 @@ We offer original PyTorch APIs (e.g., ``zero_grad()`` or ``step()``) for traditi
.. autosummary::

torchopt.Optimizer
torchopt.AdaGrad
torchopt.Adam
torchopt.SGD
torchopt.RMSProp
torchopt.AdamW
torchopt.RMSProp
torchopt.SGD


By combining low-level API :class:`torchopt.Optimizer` with the previous functional optimizer, we can achieve high-level API:

Expand Down
3 changes: 3 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,6 @@ abc
ABCMeta
subclasscheck
ctx
Duchi
invertible
AdaGrad
64 changes: 64 additions & 0 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,70 @@ def test_adam_accelerated_cuda(
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
lr_decay=[0.0, 1e-2],
initial_accumulator_value=[0.0, 1e-1],
eps=[1e-8],
inplace=[True, False],
weight_decay=[0.0, 1e-2],
maximize=[False, True],
use_chain_flat=[True, False],
)
def test_adagrad(
dtype: torch.dtype,
lr: float,
lr_decay: float,
initial_accumulator_value: float,
eps: float,
inplace: bool,
weight_decay: float,
maximize: bool,
use_chain_flat: bool,
) -> None:
_set_use_chain_flat(use_chain_flat)

model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt.adagrad(
lr=lr,
lr_decay=lr_decay,
weight_decay=weight_decay,
initial_accumulator_value=initial_accumulator_value,
eps=eps,
maximize=maximize,
)
optim_state = optim.init(params)
optim_ref = torch.optim.Adagrad(
model_ref.parameters(),
lr=lr,
lr_decay=lr_decay,
weight_decay=weight_decay,
initial_accumulator_value=initial_accumulator_value,
eps=eps,
maximize=maximize,
)
for xs, ys in loader:
xs = xs.to(dtype=dtype)
pred = fmodel(params, buffers, xs)
pred_ref = model_ref(xs)
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
loss_ref.backward()
optim_ref.step()

helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
Expand Down
19 changes: 16 additions & 3 deletions tests/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_accelerated_op_import() -> None:


def test_alias_import() -> None:
torchopt.adagrad
torchopt.adam
torchopt.adamw
torchopt.rmsprop
Expand All @@ -33,8 +34,8 @@ def test_alias_import() -> None:
torchopt.alias.adamw
torchopt.alias.rmsprop
torchopt.alias.sgd
from torchopt import adam, adamw, rmsprop, sgd
from torchopt.alias import adam, adamw, rmsprop, sgd
from torchopt import adagrad, adam, adamw, rmsprop, sgd
from torchopt.alias import adagrad, adam, adamw, rmsprop, sgd


def test_diff_import() -> None:
Expand Down Expand Up @@ -107,17 +108,23 @@ def test_nn_import() -> None:

def test_optim_import() -> None:
torchopt.FuncOptimizer
torchopt.MetaAdaGrad
torchopt.MetaAdagrad
torchopt.MetaAdam
torchopt.MetaAdamW
torchopt.MetaRMSProp
torchopt.MetaRMSprop
torchopt.MetaSGD
torchopt.AdaGrad
torchopt.Adagrad
torchopt.Adam
torchopt.AdamW
torchopt.Optimizer
torchopt.RMSProp
torchopt.RMSprop
torchopt.SGD
torchopt.optim.meta.MetaAdaGrad
torchopt.optim.meta.MetaAdagrad
torchopt.optim.meta.MetaAdam
torchopt.optim.meta.MetaAdamW
torchopt.optim.meta.MetaRMSProp
Expand All @@ -132,21 +139,27 @@ def test_optim_import() -> None:
torchopt.optim.func.FuncOptimizer
from torchopt import (
SGD,
AdaGrad,
Adagrad,
Adam,
AdamW,
FuncOptimizer,
MetaAdaGrad,
MetaAdagrad,
MetaAdam,
MetaAdamW,
MetaOptimizer,
MetaRMSProp,
MetaRMSprop,
MetaRMSProp,
MetaSGD,
Optimizer,
RMSProp,
)
from torchopt.optim import SGD, Adam, AdamW, FuncOptimizer, Optimizer, RMSProp
from torchopt.optim.func import FuncOptimizer
from torchopt.optim.meta import (
MetaAdaGrad,
MetaAdagrad,
MetaAdam,
MetaAdamW,
MetaOptimizer,
Expand Down
Loading