diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b8ab81e1..45c5d9db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,10 +24,6 @@ repos: - id: detect-private-key - id: debug-statements - id: double-quote-string-fixer - - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.6 - hooks: - - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.0.278 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index cb158207..40725bab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Implement `Adadelta`, `RAdam`, `Adamax` optimizer by [@JieRen98](https://github.com/JieRen98) and [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#171](https://github.com/metaopt/torchopt/pull/171). ### Changed diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index d00e2333..0112e877 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -30,9 +30,12 @@ Functional Optimizers .. autosummary:: FuncOptimizer + adadelta adagrad adam adamw + adamax + radam rmsprop sgd @@ -42,6 +45,11 @@ Wrapper for Function Optimizer .. autoclass:: FuncOptimizer :members: +Functional AdaDelta Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: adadelta + Functional AdaGrad Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -57,6 +65,16 @@ Functional AdamW Optimizer .. autofunction:: adamw +Functional AdaMax Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: adamax + +Functional RAdam Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: radam + Functional RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -76,12 +94,23 @@ Classic Optimizers .. autosummary:: + AdaDelta + Adadelta AdaGrad + Adagrad Adam AdamW + AdaMax + Adamax + RAdam RMSProp SGD +Classic AdaDelta Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: AdaDelta + Classic AdaGrad Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -97,6 +126,16 @@ Classic AdamW Optimizer .. autoclass:: AdamW +Classic AdaMax Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: AdaMax + +Classic RAdam Optimizer +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: RAdam + Classic RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -116,12 +155,23 @@ Differentiable Meta-Optimizers .. autosummary:: + MetaAdaDelta + MetaAdadelta MetaAdaGrad + MetaAdagrad MetaAdam MetaAdamW + MetaAdaMax + MetaAdamax + MetaRAdam MetaRMSProp MetaSGD +Differentiable Meta-AdaDelta Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaAdaDelta + Differentiable Meta-AdaGrad Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -137,6 +187,16 @@ Differentiable Meta-AdamW Optimizer .. autoclass:: MetaAdamW +Differentiable Meta-AdaMax Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaAdaMax + +Differentiable Meta-RAdam Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaRAdam + Differentiable Meta-RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/explicit_diff/explicit_diff.rst b/docs/source/explicit_diff/explicit_diff.rst index f6b82826..9445adb8 100644 --- a/docs/source/explicit_diff/explicit_diff.rst +++ b/docs/source/explicit_diff/explicit_diff.rst @@ -53,9 +53,15 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho .. autosummary:: torchopt.MetaOptimizer + torchopt.MetaAdaDelta + torchopt.MetaAdadelta torchopt.MetaAdaGrad + torchopt.MetaAdagrad torchopt.MetaAdam torchopt.MetaAdamW + torchopt.AdaMax + torchopt.MetaAdamax + torchopt.MetaRAdam torchopt.MetaRMSProp torchopt.MetaSGD diff --git a/docs/source/optimizer/optim.rst b/docs/source/optimizer/optim.rst index 54c8ef71..4f2e17f8 100644 --- a/docs/source/optimizer/optim.rst +++ b/docs/source/optimizer/optim.rst @@ -18,9 +18,12 @@ Currently, TorchOpt supports 4 functional optimizers: :func:`sgd`, :func:`adam`, .. autosummary:: torchopt.FuncOptimizer + torchopt.adadelta torchopt.adagrad torchopt.adam torchopt.adamw + torchopt.adamax + torchopt.radam torchopt.rmsprop torchopt.sgd @@ -85,9 +88,15 @@ We offer original PyTorch APIs (e.g., ``zero_grad()`` or ``step()``) for traditi .. autosummary:: torchopt.Optimizer + torchopt.AdaDelta + torchopt.Adadelta torchopt.AdaGrad + torchopt.Adagrad torchopt.Adam torchopt.AdamW + torchopt.AdaMax + torchopt.Adamax + torchopt.RAdam torchopt.RMSProp torchopt.SGD diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index 49fdbb69..6e0cca78 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -175,3 +175,10 @@ ctx Duchi invertible AdaGrad +Adadelta +Zeiler +radam +adamax +RAdam +AdaDelta +AdaMax diff --git a/tests/test_alias.py b/tests/test_alias.py index a0a78129..aef35b96 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -144,6 +144,63 @@ def test_sgd( _set_use_chain_flat(True) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + rho=[0.9, 0.95], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + use_chain_flat=[True, False], +) +def test_adadelta( + dtype: torch.dtype, + lr: float, + rho: float, + eps: float, + inplace: bool, + weight_decay: float, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adadelta( + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.Adadelta( + model_ref.parameters(), + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], @@ -210,6 +267,120 @@ def test_adam( _set_use_chain_flat(True) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + use_chain_flat=[True, False], +) +def test_radam( + dtype: torch.dtype, + lr: float, + betas: tuple[float, float], + eps: float, + inplace: bool, + weight_decay: float, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.radam( + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.RAdam( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + use_chain_flat=[True, False], +) +def test_adamax( + dtype: torch.dtype, + lr: float, + betas: tuple[float, float], + eps: float, + inplace: bool, + weight_decay: float, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adamax( + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.Adamax( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + @helpers.parametrize( dtype=[torch.float64], outer_lr=[1e-2, 1e-3, 1e-4], diff --git a/tests/test_import.py b/tests/test_import.py index 1b6dea38..f7523756 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -25,17 +25,24 @@ def test_accelerated_op_import() -> None: def test_alias_import() -> None: + torchopt.adadelta torchopt.adagrad torchopt.adam torchopt.adamw + torchopt.adamax + torchopt.radam torchopt.rmsprop torchopt.sgd + torchopt.alias.adadelta + torchopt.alias.adagrad torchopt.alias.adam torchopt.alias.adamw + torchopt.alias.adamax + torchopt.alias.radam torchopt.alias.rmsprop torchopt.alias.sgd - from torchopt import adagrad, adam, adamw, rmsprop, sgd - from torchopt.alias import adagrad, adam, adamw, rmsprop, sgd + from torchopt import adadelta, adagrad, adam, adamax, adamw, radam, rmsprop, sgd + from torchopt.alias import adadelta, adagrad, adam, adamax, adamw, radam, rmsprop, sgd def test_diff_import() -> None: @@ -108,25 +115,38 @@ def test_nn_import() -> None: def test_optim_import() -> None: torchopt.FuncOptimizer + torchopt.MetaAdaDelta + torchopt.MetaAdadelta torchopt.MetaAdaGrad torchopt.MetaAdagrad torchopt.MetaAdam torchopt.MetaAdamW + torchopt.MetaAdaMax + torchopt.MetaAdamax + torchopt.MetaRAdam torchopt.MetaRMSProp torchopt.MetaRMSprop torchopt.MetaSGD + torchopt.AdaDelta + torchopt.Adadelta torchopt.AdaGrad torchopt.Adagrad torchopt.Adam torchopt.AdamW + torchopt.AdaMax + torchopt.Adamax torchopt.Optimizer torchopt.RMSProp torchopt.RMSprop torchopt.SGD + torchopt.optim.meta.MetaAdaDelta + torchopt.optim.meta.MetaAdadelta torchopt.optim.meta.MetaAdaGrad torchopt.optim.meta.MetaAdagrad torchopt.optim.meta.MetaAdam torchopt.optim.meta.MetaAdamW + torchopt.optim.meta.MetaAdaMax + torchopt.optim.meta.MetaAdamax torchopt.optim.meta.MetaRMSProp torchopt.optim.meta.MetaRMSprop torchopt.optim.meta.MetaSGD @@ -139,14 +159,22 @@ def test_optim_import() -> None: torchopt.optim.func.FuncOptimizer from torchopt import ( SGD, + AdaDelta, + Adadelta, AdaGrad, Adagrad, Adam, + AdaMax, + Adamax, AdamW, FuncOptimizer, + MetaAdaDelta, + MetaAdadelta, MetaAdaGrad, MetaAdagrad, MetaAdam, + MetaAdaMax, + MetaAdamax, MetaAdamW, MetaOptimizer, MetaRMSprop, @@ -158,11 +186,16 @@ def test_optim_import() -> None: from torchopt.optim import SGD, Adam, AdamW, FuncOptimizer, Optimizer, RMSProp from torchopt.optim.func import FuncOptimizer from torchopt.optim.meta import ( + MetaAdaDelta, + MetaAdadelta, MetaAdaGrad, MetaAdagrad, MetaAdam, + MetaAdaMax, + MetaAdamax, MetaAdamW, MetaOptimizer, + MetaRAdam, MetaRMSProp, MetaRMSprop, MetaSGD, diff --git a/tests/test_optim.py b/tests/test_optim.py index 6ec81918..dc3941d9 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -144,6 +144,153 @@ def test_Adam( helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + rho=[0.9, 0.95], + eps=[1e-8], + weight_decay=[0.0, 1e-2], +) +def test_Adadelta( + dtype: torch.dtype, + lr: float, + rho: float, + eps: float, + weight_decay: float, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + optim = torchopt.Adadelta( + model.parameters(), + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + optim_ref = torch.optim.Adadelta( + model_ref.parameters(), + lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = model(xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + optim.zero_grad() + loss.backward() + optim.step() + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + weight_decay=[0.0, 1e-2], +) +def test_RAdam( + dtype: torch.dtype, + lr: float, + betas: tuple[float, float], + eps: float, + weight_decay: float, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + optim = torchopt.RAdam( + model.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + optim_ref = torch.optim.RAdam( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = model(xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + optim.zero_grad() + loss.backward() + optim.step() + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + weight_decay=[0.0, 1e-2], +) +def test_Adamax( + dtype: torch.dtype, + lr: float, + betas: tuple[float, float], + eps: float, + weight_decay: float, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + optim = torchopt.Adamax( + model.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + optim_ref = torch.optim.Adamax( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = model(xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + optim.zero_grad() + loss.backward() + optim.step() + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], diff --git a/torchopt/__init__.py b/torchopt/__init__.py index a8c9fa1d..a089f3dc 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -33,18 +33,37 @@ visual, ) from torchopt.accelerated_op import is_available as accelerated_op_available -from torchopt.alias import adagrad, adam, adamw, rmsprop, sgd +from torchopt.alias import adadelta, adagrad, adam, adamax, adamw, radam, rmsprop, sgd from torchopt.clip import clip_grad_norm from torchopt.combine import chain from torchopt.hook import register_hook -from torchopt.optim import SGD, AdaGrad, Adagrad, Adam, AdamW, Optimizer, RMSProp, RMSprop +from torchopt.optim import ( + SGD, + AdaDelta, + Adadelta, + AdaGrad, + Adagrad, + Adam, + AdaMax, + Adamax, + AdamW, + Optimizer, + RAdam, + RMSProp, + RMSprop, +) from torchopt.optim.func import FuncOptimizer from torchopt.optim.meta import ( + MetaAdaDelta, + MetaAdadelta, MetaAdaGrad, MetaAdagrad, MetaAdam, + MetaAdaMax, + MetaAdamax, MetaAdamW, MetaOptimizer, + MetaRAdam, MetaRMSProp, MetaRMSprop, MetaSGD, @@ -64,6 +83,9 @@ __all__ = [ 'accelerated_op_available', 'adam', + 'adamax', + 'adadelta', + 'radam', 'adamw', 'adagrad', 'rmsprop', @@ -75,6 +97,11 @@ 'Optimizer', 'SGD', 'Adam', + 'AdaMax', + 'Adamax', + 'AdaDelta', + 'Adadelta', + 'RAdam', 'AdamW', 'AdaGrad', 'Adagrad', @@ -83,6 +110,11 @@ 'MetaOptimizer', 'MetaSGD', 'MetaAdam', + 'MetaAdaMax', + 'MetaAdamax', + 'MetaAdaDelta', + 'MetaAdadelta', + 'MetaRAdam', 'MetaAdamW', 'MetaAdaGrad', 'MetaAdagrad', diff --git a/torchopt/alias/__init__.py b/torchopt/alias/__init__.py index ae7dd2b5..3ea721c4 100644 --- a/torchopt/alias/__init__.py +++ b/torchopt/alias/__init__.py @@ -31,11 +31,14 @@ # ============================================================================== r"""The aliases of preset :class:`GradientTransformation`\s for optimizers.""" +from torchopt.alias.adadelta import adadelta from torchopt.alias.adagrad import adagrad from torchopt.alias.adam import adam +from torchopt.alias.adamax import adamax from torchopt.alias.adamw import adamw +from torchopt.alias.radam import radam from torchopt.alias.rmsprop import rmsprop from torchopt.alias.sgd import sgd -__all__ = ['adagrad', 'adam', 'adamw', 'rmsprop', 'sgd'] +__all__ = ['adagrad', 'radam', 'adam', 'adamax', 'adadelta', 'adamw', 'rmsprop', 'sgd'] diff --git a/torchopt/alias/adadelta.py b/torchopt/alias/adadelta.py new file mode 100644 index 00000000..2e3640f2 --- /dev/null +++ b/torchopt/alias/adadelta.py @@ -0,0 +1,98 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the Adadelta optimizer.""" + +from __future__ import annotations + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_adadelta +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['adadelta'] + + +# pylint: disable-next=too-many-arguments +def adadelta( + lr: ScalarOrSchedule = 1e-3, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Create a functional version of the AdaDelta optimizer. + + Adadelta is a per-dimension learning rate method for gradient descent. + + References: + - Zeiler, 2012: https://arxiv.org/abs/1212.5701 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + rho (float, optional): Coefficients used for computing running averages of gradient and its square. + (default: :const:`0.9`) + eps (float, optional): A small constant applied to the square root (as in the Adadelta paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not 0 <= rho <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {rho}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + adadelta_scaler_fn = scale_by_adadelta + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adadelta_scaler_fn = adadelta_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=False), + adadelta_scaler_fn( + rho=rho, + eps=eps, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 25910abd..3f983c38 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -91,7 +91,7 @@ def adagrad( *, maximize: bool = False, ) -> GradientTransformation: - """The functional AdaGrad optimizer. + """Create a functional version of the AdaGrad optimizer. AdaGrad is an algorithm for gradient based optimization that anneals the learning rate for each parameter during the course of training. diff --git a/torchopt/alias/adamax.py b/torchopt/alias/adamax.py new file mode 100644 index 00000000..ffa19e37 --- /dev/null +++ b/torchopt/alias/adamax.py @@ -0,0 +1,100 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the Adamax optimizer.""" + +from __future__ import annotations + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_adamax +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['adamax'] + + +# pylint: disable-next=too-many-arguments +def adamax( + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Create a functional version of the AdaMax optimizer. + + References: + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + b1, b2 = betas # pylint: disable=invalid-name + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not 0 <= b1 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b1}') + if not 0 <= b2 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b2}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + adamax_scaler_fn = scale_by_adamax + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adamax_scaler_fn = adamax_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=False), + adamax_scaler_fn( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/alias/radam.py b/torchopt/alias/radam.py new file mode 100644 index 00000000..230c1151 --- /dev/null +++ b/torchopt/alias/radam.py @@ -0,0 +1,102 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the RAdam optimizer.""" + +from __future__ import annotations + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_radam +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['radam'] + + +# pylint: disable-next=too-many-arguments +def radam( + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Create a functional version of the RAdam optimizer. + + RAdam is a variance of the adaptive learning rate rectified optimizer. + + References: + - Liu, 2019: https://arxiv.org/abs/1908.03265 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created with + flag ``requires_grad=True``, this flag is often used in Meta-Learning algorithms. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + b1, b2 = betas # pylint: disable=invalid-name + # pylint: disable=unneeded-not + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not 0 <= b1 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b1}') + if not 0 <= b2 <= 1: # pragma: no cover + raise ValueError(f'Invalid rho value: {b2}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + radam_scaler_fn = scale_by_radam + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + radam_scaler_fn = radam_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=False), + radam_scaler_fn( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) diff --git a/torchopt/optim/__init__.py b/torchopt/optim/__init__.py index 8e390a5c..20da5fca 100644 --- a/torchopt/optim/__init__.py +++ b/torchopt/optim/__init__.py @@ -15,10 +15,13 @@ """object oriented optimizer implementations.""" from torchopt.optim import meta +from torchopt.optim.adadelta import AdaDelta, Adadelta from torchopt.optim.adagrad import AdaGrad, Adagrad from torchopt.optim.adam import Adam +from torchopt.optim.adamax import AdaMax, Adamax from torchopt.optim.adamw import AdamW from torchopt.optim.base import Optimizer from torchopt.optim.func import FuncOptimizer +from torchopt.optim.radam import RAdam from torchopt.optim.rmsprop import RMSProp, RMSprop from torchopt.optim.sgd import SGD diff --git a/torchopt/optim/adadelta.py b/torchopt/optim/adadelta.py new file mode 100644 index 00000000..7c73cb58 --- /dev/null +++ b/torchopt/optim/adadelta.py @@ -0,0 +1,75 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adadelta optimizer.""" + +from __future__ import annotations + +from typing import Iterable + +import torch + +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['AdaDelta', 'Adadelta'] + + +class AdaDelta(Optimizer): + """The classic AdaDelta optimizer. + + See Also: + - The functional AdaDelta optimizer: :func:`torchopt.adadelta`. + - The differentiable meta-AdaDelta optimizer: :class:`torchopt.MetaAdaDetla`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + ) -> None: + r"""Initialize the AdaDelta optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + rho (float, optional): Coefficients used for computing running averages of gradient and its square. + (default: :const:`0.9`) + eps (float, optional): A small constant applied to the square root (as in the AdaDelta paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + """ + super().__init__( + params, + alias.adadelta( + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=False, + ), + ) + + +Adadelta = AdaDelta # alias for PyTorch compatibility diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py index 055e0ad5..a7e8c72b 100644 --- a/torchopt/optim/adagrad.py +++ b/torchopt/optim/adagrad.py @@ -33,7 +33,7 @@ class AdaGrad(Optimizer): See Also: - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. - - The differentiable meta AdaGrad optimizer: :class:`torchopt.MetaAdaGrad`. + - The differentiable meta-AdaGrad optimizer: :class:`torchopt.MetaAdaGrad`. """ # pylint: disable-next=too-many-arguments diff --git a/torchopt/optim/adamax.py b/torchopt/optim/adamax.py new file mode 100644 index 00000000..904c05a0 --- /dev/null +++ b/torchopt/optim/adamax.py @@ -0,0 +1,75 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adamax optimizer.""" + +from __future__ import annotations + +from typing import Iterable + +import torch + +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['AdaMax', 'Adamax'] + + +class AdaMax(Optimizer): + """The classic AdaMax optimizer. + + See Also: + - The functional AdaMax optimizer: :func:`torchopt.adamax`. + - The differentiable meta-AdaMax optimizer: :class:`torchopt.MetaAdaMax`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 2e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + ) -> None: + r"""Initialize the AdaMax optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the AdaMax paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + """ + super().__init__( + params, + alias.adamax( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=False, + ), + ) + + +Adamax = AdaMax # alias for PyTorch compatibility diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 94038464..7a7839a3 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -34,9 +34,12 @@ class FuncOptimizer: # pylint: disable=too-few-public-methods and update the parameters. See Also: + - The functional AdaDelta optimizer: :func:`torchopt.adadelta`. - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. - The functional Adam optimizer: :func:`torchopt.adam`. - The functional AdamW optimizer: :func:`torchopt.adamw`. + - The functional AdaMax optimizer: :func:`torchopt.adamax`. + - The functional RAdam optimizer: :func:`torchopt.radam`. - The functional RMSprop optimizer: :func:`torchopt.rmsprop`. - The functional SGD optimizer: :func:`torchopt.sgd`. """ diff --git a/torchopt/optim/meta/__init__.py b/torchopt/optim/meta/__init__.py index 28f374cc..516f2b5f 100644 --- a/torchopt/optim/meta/__init__.py +++ b/torchopt/optim/meta/__init__.py @@ -14,9 +14,12 @@ # ============================================================================== """Differentiable Meta-Optimizers.""" +from torchopt.optim.meta.adadelta import MetaAdaDelta, MetaAdadelta from torchopt.optim.meta.adagrad import MetaAdaGrad, MetaAdagrad from torchopt.optim.meta.adam import MetaAdam +from torchopt.optim.meta.adamax import MetaAdaMax, MetaAdamax from torchopt.optim.meta.adamw import MetaAdamW from torchopt.optim.meta.base import MetaOptimizer +from torchopt.optim.meta.radam import MetaRAdam from torchopt.optim.meta.rmsprop import MetaRMSProp, MetaRMSprop from torchopt.optim.meta.sgd import MetaSGD diff --git a/torchopt/optim/meta/adadelta.py b/torchopt/optim/meta/adadelta.py new file mode 100644 index 00000000..36d8d9ad --- /dev/null +++ b/torchopt/optim/meta/adadelta.py @@ -0,0 +1,77 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable Adadelta optimizer.""" + +from __future__ import annotations + +import torch.nn as nn + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaAdaDelta', 'MetaAdadelta'] + + +class MetaAdaDelta(MetaOptimizer): + """The differentiable AdaDelta optimizer. + + See Also: + - The functional AdaDelta optimizer: :func:`torchopt.adadetla`. + - The classic AdaDelta optimizer: :class:`torchopt.Adadelta`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = True, + ) -> None: + """Initialize the meta AdaDelta optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + rho (float, optional): Coefficients used for computing running averages of gradient and its square. + (default: :const:`0.9`) + eps (float, optional): A small constant applied to the square root (as in the AdaDelta paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + """ + super().__init__( + module, + alias.adadelta( + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=moment_requires_grad, + ), + ) + + +MetaAdadelta = MetaAdaDelta # alias for PyTorch compatibility diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py index 079d76db..4e8ef0eb 100644 --- a/torchopt/optim/meta/adagrad.py +++ b/torchopt/optim/meta/adagrad.py @@ -31,7 +31,7 @@ class MetaAdaGrad(MetaOptimizer): See Also: - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. - - The classic AdaGrad optimizer: :class:`torchopt.AdaGrad`. + - The classic AdaGrad optimizer: :class:`torchopt.Adagrad`. """ # pylint: disable-next=too-many-arguments diff --git a/torchopt/optim/meta/adamax.py b/torchopt/optim/meta/adamax.py new file mode 100644 index 00000000..01082af2 --- /dev/null +++ b/torchopt/optim/meta/adamax.py @@ -0,0 +1,77 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable Adamax optimizer.""" + +from __future__ import annotations + +import torch.nn as nn + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaAdaMax', 'MetaAdamax'] + + +class MetaAdaMax(MetaOptimizer): + """The differentiable AdaMax optimizer. + + See Also: + - The functional AdaMax optimizer: :func:`torchopt.adamax`. + - The classic AdaMax optimizer: :class:`torchopt.Adamax`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 2e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = True, + ) -> None: + """Initialize the meta AdaMax optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the AdaMax paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + """ + super().__init__( + module, + alias.adamax( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=moment_requires_grad, + ), + ) + + +MetaAdamax = MetaAdaMax # alias for PyTorch compatibility diff --git a/torchopt/optim/meta/radam.py b/torchopt/optim/meta/radam.py new file mode 100644 index 00000000..baf4cdd2 --- /dev/null +++ b/torchopt/optim/meta/radam.py @@ -0,0 +1,74 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable RAdam optimizer.""" + +from __future__ import annotations + +import torch.nn as nn + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['MetaRAdam'] + + +class MetaRAdam(MetaOptimizer): + """The differentiable RAdam optimizer. + + See Also: + - The functional RAdam optimizer: :func:`torchopt.radan`. + - The classic RAdam optimizer: :class:`torchopt.RAdam`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + moment_requires_grad: bool = True, + ) -> None: + """Initialize the meta-RAdam optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + """ + super().__init__( + module, + alias.radam( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=moment_requires_grad, + ), + ) diff --git a/torchopt/optim/radam.py b/torchopt/optim/radam.py new file mode 100644 index 00000000..c2f6a211 --- /dev/null +++ b/torchopt/optim/radam.py @@ -0,0 +1,72 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RAdam optimizer.""" + +from __future__ import annotations + +from typing import Iterable + +import torch + +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['RAdam'] + + +class RAdam(Optimizer): + """The classic RAdam optimizer. + + See Also: + - The functional Adam optimizer: :func:`torchopt.radam`. + - The differentiable meta-RAdam optimizer: :class:`torchopt.MetaRAdam`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + ) -> None: + r"""Initialize the RAdam optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to the square root (as in the RAdam paper) + to avoid dividing by zero when rescaling. + (default: :const:`1e-6`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + """ + super().__init__( + params, + alias.radam( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + moment_requires_grad=False, + ), + ) diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py index 47c49ea1..c75fcb5d 100644 --- a/torchopt/transform/__init__.py +++ b/torchopt/transform/__init__.py @@ -34,7 +34,10 @@ from torchopt.transform.add_decayed_weights import add_decayed_weights, masked from torchopt.transform.nan_to_num import nan_to_num from torchopt.transform.scale import scale +from torchopt.transform.scale_by_adadelta import scale_by_adadelta from torchopt.transform.scale_by_adam import scale_by_accelerated_adam, scale_by_adam +from torchopt.transform.scale_by_adamax import scale_by_adamax +from torchopt.transform.scale_by_radam import scale_by_radam from torchopt.transform.scale_by_rms import scale_by_rms from torchopt.transform.scale_by_rss import scale_by_rss from torchopt.transform.scale_by_schedule import scale_by_schedule @@ -49,6 +52,9 @@ 'add_decayed_weights', 'masked', 'scale_by_adam', + 'scale_by_adamax', + 'scale_by_adadelta', + 'scale_by_radam', 'scale_by_accelerated_adam', 'scale_by_rss', 'scale_by_rms', diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py new file mode 100644 index 00000000..fb5431a3 --- /dev/null +++ b/torchopt/transform/scale_by_adadelta.py @@ -0,0 +1,165 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Preset transformations for scaling updates by Adam.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment +from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_adadelta'] + + +class ScaleByAdadeltaState(NamedTuple): + """State for the Adadelta algorithm.""" + + mu: Updates + nu: Updates + + +def scale_by_adadelta( + rho: float = 0.9, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Rescale updates according to the Adadelta algorithm. + + References: + - Zeiler, 2012: https://arxiv.org/abs/1212.5701 + + Args: + rho (float, optional): Decay rate for the squared grads. + (default: :const:`0.9`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-6`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_adadelta( + rho=rho, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_adadelta_flat( + rho: float = 0.9, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_adadelta( + rho=rho, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_adadelta( + rho: float = 0.9, + eps: float = 1e-6, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= rho < 1.0: # pragma: no cover + raise ValueError(f'Invalid rho parameter at index 0: {rho}') + # pylint: enable=unneeded-not + + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + return ScaleByAdadeltaState(mu=mu, nu=nu) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.mu, + rho, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + if inplace: + + def f( + g: torch.Tensor, # pylint: disable=unused-argument + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) + + else: + + def f( + g: torch.Tensor, # pylint: disable=unused-argument + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) + + updates = tree_map(f, updates, mu, state.nu) + + nu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.nu, + rho, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + return updates, ScaleByAdadeltaState(mu=mu, nu=nu) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_adadelta.flat = _scale_by_adadelta_flat # type: ignore[attr-defined] +scale_by_adadelta.impl = _scale_by_adadelta # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py new file mode 100644 index 00000000..504e82cd --- /dev/null +++ b/torchopt/transform/scale_by_adamax.py @@ -0,0 +1,164 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Preset transformations for scaling updates by Adamax.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment +from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_adamax'] + + +class ScaleByAdamaxState(NamedTuple): + """State for the Adamax algorithm.""" + + mu: Updates + nu: Updates + t: int + + +def scale_by_adamax( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """A Adam algorithm variation. + + References: + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 + + Args: + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-6`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_adamax( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_adamax_flat( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_adamax( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_adamax( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b2}') + # pylint: enable=unneeded-not + + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + return ScaleByAdamaxState(mu=mu, nu=nu, t=1) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.mu, + b1, + order=1, + inplace=inplace, + already_flattened=already_flattened, + ) + + def update_nu( + g: torch.Tensor, + n: torch.Tensor, + ) -> torch.Tensor: + return torch.max(n.mul(b2), g.abs().add_(eps)) + + nu = tree_map(update_nu, updates, state.nu) + + one_minus_b1_pow_t = 1 - b1**state.t + + def f( + n: torch.Tensor, + m: torch.Tensor, + ) -> torch.Tensor: + return m.div(n).div_(one_minus_b1_pow_t) + + updates = tree_map(f, nu, mu) + + return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_adamax.flat = _scale_by_adamax_flat # type: ignore[attr-defined] +scale_by_adamax.impl = _scale_by_adamax # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_radam.py b/torchopt/transform/scale_by_radam.py new file mode 100644 index 00000000..acb85a82 --- /dev/null +++ b/torchopt/transform/scale_by_radam.py @@ -0,0 +1,204 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Preset transformations for scaling updates by RAdam.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +import math +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat, update_moment +from torchopt.typing import OptState, Params, Updates + + +__all__ = ['scale_by_radam'] + + +class ScaleByRAdamState(NamedTuple): + """State for the RAdam algorithm.""" + + mu: Updates + nu: Updates + t: int + + +def scale_by_radam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Rescale updates according to the RAdam algorithm. + + References: + - Liu, 2019: https://arxiv.org/abs/1908.03265 + + Args: + b1 (float, optional): Decay rate for the exponentially weighted average of grads. + (default: :const:`0.9`) + b2 (float, optional): Decay rate for the exponentially weighted average of squared grads. + (default: :const:`0.999`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-6`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_radam( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_radam_flat( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_radam( + b1=b1, + b2=b2, + eps=eps, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_radam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid b1 parameter at index 0: {b2}') + # pylint: enable=unneeded-not + + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + return ScaleByRAdamState(mu=mu, nu=nu, t=1) + + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: + mu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.mu, + b1, + order=1, + inplace=inplace, + already_flattened=already_flattened, + ) + + nu = update_moment.impl( # type: ignore[attr-defined] + updates, + state.nu, + b2, + order=2, + inplace=inplace, + already_flattened=already_flattened, + ) + + rho_inf = 2 / (1 - b2) - 1 + one_minus_b1_pow_t = 1 - b1**state.t + one_minus_b2_pow_t = 1 - b2**state.t + rho = rho_inf - 2 * state.t * b2**state.t / one_minus_b2_pow_t + + if rho > 5: + numerator = math.sqrt( + one_minus_b2_pow_t + * (rho - 4) + * (rho - 2) + * rho_inf + / ((rho_inf - 4) * (rho_inf - 2) * rho), + ) + if inplace: + + def f( + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + return m.mul(numerator / one_minus_b1_pow_t).div_(v.sqrt().add_(eps)) + + else: + + def f( + m: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + return m.mul(numerator / one_minus_b1_pow_t).div(v.sqrt().add(eps)) + + else: + if inplace: + + def f( + m: torch.Tensor, + v: torch.Tensor, # pylint: disable=unused-argument + ) -> torch.Tensor: + return m.div(one_minus_b1_pow_t) + + else: + + def f( + m: torch.Tensor, + v: torch.Tensor, # pylint: disable=unused-argument + ) -> torch.Tensor: + return m.div(one_minus_b1_pow_t) + + updates = tree_map(f, mu, nu) + + return updates, ScaleByRAdamState(mu=mu, nu=nu, t=state.t + 1) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_radam.flat = _scale_by_radam_flat # type: ignore[attr-defined] +scale_by_radam.impl = _scale_by_radam # type: ignore[attr-defined]