From 640e3b75928ef913b3d2dbf621bfbcaa43ccf424 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 13 Oct 2022 19:51:07 +0800 Subject: [PATCH 01/37] feat(torchopt): adagrad optimizer support Co-authored-by: Benjamin-eecs --- torchopt/alias/__init__.py | 3 +- torchopt/alias/adagrad.py | 102 ++++++++++++++++++++ torchopt/schedule/__init__.py | 3 +- torchopt/schedule/exponential_decay.py | 92 ++++++++++++++++++ torchopt/transform/__init__.py | 2 + torchopt/transform/scale_by_rss.py | 127 +++++++++++++++++++++++++ 6 files changed, 327 insertions(+), 2 deletions(-) create mode 100644 torchopt/alias/adagrad.py create mode 100644 torchopt/schedule/exponential_decay.py create mode 100644 torchopt/transform/scale_by_rss.py diff --git a/torchopt/alias/__init__.py b/torchopt/alias/__init__.py index b00b3c35..058ac5db 100644 --- a/torchopt/alias/__init__.py +++ b/torchopt/alias/__init__.py @@ -31,10 +31,11 @@ # ============================================================================== r"""The aliases of preset :class:`GradientTransformation`\s for optimizers.""" +from torchopt.alias.adagrad import adagrad from torchopt.alias.adam import adam from torchopt.alias.adamw import adamw from torchopt.alias.rmsprop import rmsprop from torchopt.alias.sgd import sgd -__all__ = ['adam', 'adamw', 'rmsprop', 'sgd'] +__all__ = ['adagrad', 'adam', 'adamw', 'rmsprop', 'sgd'] diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py new file mode 100644 index 00000000..4c9e9c07 --- /dev/null +++ b/torchopt/alias/adagrad.py @@ -0,0 +1,102 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset :class:`GradientTransformation` for the AdaGrad optimizer.""" + +from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr +from torchopt.combine import chain_flat +from torchopt.transform import scale_by_rss +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['adagrad'] + + +# pylint: disable-next=too-many-arguments +def adagrad( + lr: ScalarOrSchedule = 1e-2, + lr_decay: float = 0.0, + weight_decay: float = 0.0, + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, + *, + maximize: bool = False, +) -> GradientTransformation: + """The functional AdaGrad optimizer. + + AdaGrad is an algorithm for gradient based optimization that anneals the learning rate for each + parameter during the course of training. + WARNING: AdaGrad's main limit is the monotonic accumulation of squared gradients in the + denominator: since all terms are >0, the sum keeps growing during training and the learning rate + eventually becomes vanishingly small. + + References: + Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html + + Args: + lr: (default: :const:`1e-3`) + This is a fixed global scaling factor. + lr_decay: (default: :const:`0.0`) + Learning rate decay. + weight_decay: (default: :const:`0.0`) + Weight decay, add L2 penalty to parameters. + initial_accumulator_value: (default: :const:`0.0`) + Initial value for the accumulator. + eps: (default: :const:`1e-8`) + A small constant applied to denominator outside of the square root (as in the Adam + paper) to avoid dividing by zero when rescaling. + maximize: (default: :data:`False`) + Maximize the params based on the objective, instead of minimizing. + use_accelerated_op: (default: :data:`False`) + If :data:`True` use our implemented fused operator. + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + # pylint: disable=unneeded-not + if not (callable(lr) or 0.0 <= lr): + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= lr_decay: + raise ValueError(f'Invalid lr_decay value: {lr_decay}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + return chain_flat( + flip_sign_and_add_weight_decay(weight_decay=weight_decay, maximize=maximize), + scale_by_rss.flat(initial_accumulator_value=initial_accumulator_value, eps=eps), # type: ignore[attr-defined] + scale_by_neg_lr(lr), + ) diff --git a/torchopt/schedule/__init__.py b/torchopt/schedule/__init__.py index 46f59550..35121c1a 100644 --- a/torchopt/schedule/__init__.py +++ b/torchopt/schedule/__init__.py @@ -31,7 +31,8 @@ # ============================================================================== """Learning rate schedules.""" +from torchopt.schedule.exponential_decay import exponential_decay from torchopt.schedule.polynomial import linear_schedule, polynomial_schedule -__all__ = ['polynomial_schedule', 'linear_schedule'] +__all__ = ['exponential_decay', 'polynomial_schedule', 'linear_schedule'] diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py new file mode 100644 index 00000000..9c607b65 --- /dev/null +++ b/torchopt/schedule/exponential_decay.py @@ -0,0 +1,92 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Exponential learning rate decay.""" + +import logging +from typing import Optional + +from torchopt.typing import Numeric, Scalar, Schedule + + +__all__ = ['exponential_decay'] + + +def exponential_decay( + init_value: Scalar, + decay_rate: Scalar, + transition_begin: int = 0, + transition_steps: Optional[int] = None, + end_value: Optional[float] = None, +) -> Schedule: + """Constructs a schedule with either continuous or discrete exponential decay. + Args: + value: value to be held constant throughout. + Returns: + schedule: A function that maps step counts to values. + """ + if transition_steps is not None and transition_steps <= 0: + logging.info( + 'An linear schedule was set with a non-positive `transition_steps`' + ' value; this will result in a constant schedule with value ' + '`init_value`.' + ) + return lambda count: init_value + + if decay_rate == 0: + logging.info( + 'An linear schedule was set with a zero `decay_rate` value; ' + 'this will result in a constant schedule with value `init_value`.' + ) + return lambda count: init_value + + if transition_begin < 0: + logging.info( + 'An linear schedule was set with a negative `transition_begin` ' + 'value; this will result in `transition_begin` falling back to `0`.' + ) + transition_begin = 0 + + if end_value is not None: + clip_fn = max if decay_rate < 1.0 else min + + def schedule(count: Numeric) -> Numeric: + decreased_count = count - transition_begin + decayed_value = ( + init_value / (1 + (decreased_count - 1) * decay_rate) + if decreased_count > 0 + else init_value + ) + if end_value is not None: + decayed_value = clip_fn(decayed_value, end_value) + return decayed_value + + return schedule diff --git a/torchopt/transform/__init__.py b/torchopt/transform/__init__.py index 5db2be48..40b3c673 100644 --- a/torchopt/transform/__init__.py +++ b/torchopt/transform/__init__.py @@ -35,6 +35,7 @@ from torchopt.transform.scale import scale from torchopt.transform.scale_by_adam import scale_by_accelerated_adam, scale_by_adam from torchopt.transform.scale_by_rms import scale_by_rms +from torchopt.transform.scale_by_rss import scale_by_rss from torchopt.transform.scale_by_schedule import scale_by_schedule from torchopt.transform.scale_by_stddev import scale_by_stddev from torchopt.transform.trace import trace @@ -47,6 +48,7 @@ 'add_decayed_weights', 'scale_by_adam', 'scale_by_accelerated_adam', + 'scale_by_rss', 'scale_by_rms', 'scale_by_stddev', ] diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py new file mode 100644 index 00000000..37b44535 --- /dev/null +++ b/torchopt/transform/scale_by_rss.py @@ -0,0 +1,127 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file is modified from: +# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py +# ============================================================================== +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for scaling updates by the root of the sum of all squared gradients.""" + +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import tree_map_flat +from torchopt.typing import Updates + + +__all__ = ['scale_by_rss'] + + +class ScaleByRssState(NamedTuple): + """State holding the sum of gradient squares to date.""" + + sum_of_squares: Updates + + +def scale_by_rss( + initial_accumulator_value: float = 0.1, + eps: float = 1e-7, +) -> GradientTransformation: + """Rescale updates by the root of the sum of all squared gradients to date. + + References: + [Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) + [McMahan et al., 2010](https://arxiv.org/abs/1002.4908) + + Args: + initial_accumulator_value: Starting value for accumulators, must be >= 0. + eps: A small floating point value to avoid zero denominator. + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_rss( + initial_accumulator_value=initial_accumulator_value, + eps=eps, + already_flattened=False, + ) + + +def _scale_by_rss_flat( + initial_accumulator_value: float = 0.1, + eps: float = 1e-7, +) -> GradientTransformation: + return _scale_by_rss( + initial_accumulator_value=initial_accumulator_value, + eps=eps, + already_flattened=True, + ) + + +def _scale_by_rss( + initial_accumulator_value: float = 0.1, + eps: float = 1e-7, + *, + already_flattened: bool = False, +) -> GradientTransformation: + + if already_flattened: + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map + + def init_fn(params): + sum_of_squares = tree_map(lambda t: torch.full_like(t, initial_accumulator_value), params) + return ScaleByRssState(sum_of_squares=sum_of_squares) + + def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument + sum_of_squares = tree_map( + lambda g, t: (g.conj() * g).real + t, updates, state.sum_of_squares + ) + + if inplace: + + def f(t): + return t.add_(eps).rsqrt_() if t > 0.0 else 0.0 + + else: + + def f(t): + return t.add(eps).rsqrt() if t > 0.0 else 0.0 + + inv_sqrt_g_square = tree_map(f, sum_of_squares) + updates = tree_map(lambda scale, g: scale * g, inv_sqrt_g_square, updates) + return updates, ScaleByRssState(sum_of_squares=sum_of_squares) + + return GradientTransformation(init_fn, update_fn) + + +scale_by_rss.flat = _scale_by_rss_flat # type: ignore[attr-defined] +scale_by_rss.impl = _scale_by_rss # type: ignore[attr-defined] From 21102bf26466e70239a6e958bbea993866d93ccc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Feb 2023 10:12:50 +0000 Subject: [PATCH 02/37] fix: [pre-commit.ci] auto fixes [...] --- torchopt/transform/scale_by_rss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 37b44535..3947211c 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -91,7 +91,6 @@ def _scale_by_rss( *, already_flattened: bool = False, ) -> GradientTransformation: - if already_flattened: tree_map = tree_map_flat else: From bf029ae938b6321f05497c1ed242dbcb00a9e520 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Feb 2023 09:28:27 +0000 Subject: [PATCH 03/37] fix: [pre-commit.ci] auto fixes [...] --- torchopt/alias/adagrad.py | 8 ++++---- torchopt/transform/scale_by_rss.py | 5 +---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 4c9e9c07..7ee50774 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -85,13 +85,13 @@ def adagrad( The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. """ # pylint: disable=unneeded-not - if not (callable(lr) or 0.0 <= lr): + if not (callable(lr) or lr >= 0.0): raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: + if not eps >= 0.0: raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= lr_decay: + if not lr_decay >= 0.0: raise ValueError(f'Invalid lr_decay value: {lr_decay}') - if not 0.0 <= weight_decay: + if not weight_decay >= 0.0: raise ValueError(f'Invalid weight_decay value: {weight_decay}') # pylint: enable=unneeded-not diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 3947211c..0e3cb82b 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -91,10 +91,7 @@ def _scale_by_rss( *, already_flattened: bool = False, ) -> GradientTransformation: - if already_flattened: - tree_map = tree_map_flat - else: - tree_map = pytree.tree_map + tree_map = tree_map_flat if already_flattened else pytree.tree_map def init_fn(params): sum_of_squares = tree_map(lambda t: torch.full_like(t, initial_accumulator_value), params) From eb31c43095369e047f1ae1630cb270c459303033 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Fri, 3 Mar 2023 18:12:13 +0800 Subject: [PATCH 04/37] feat: adagrad integration --- Makefile | 2 +- tests/test_alias.py | 68 ++++++++++++++++++++++++++ torchopt/schedule/exponential_decay.py | 21 +++++--- torchopt/transform/scale_by_rss.py | 16 ++++-- 4 files changed, 95 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index 750b9d9f..ac367359 100644 --- a/Makefile +++ b/Makefile @@ -117,7 +117,7 @@ addlicense-install: go-install pytest: test-install cd tests && $(PYTHON) -c 'import $(PROJECT_NAME)' && \ - $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ + $(PYTHON) -m pytest -k "test_adagrad" --verbose --color=yes --durations=0 \ --cov="$(PROJECT_NAME)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . diff --git a/tests/test_alias.py b/tests/test_alias.py index b609cf58..ce37af76 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -161,6 +161,74 @@ def test_adam( _set_use_chain_flat(True) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + lr_decay=[0.0, 1e-2], + initial_accumulator_value=[0.0, 1e-2], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + maximize=[False, True], + use_accelerated_op=[False, True], + use_chain_flat=[True, False], +) +def test_adagrad( + dtype: torch.dtype, + lr: float, + lr_decay: float, + initial_accumulator_value: float, + eps: float, + inplace: bool, + weight_decay: float, + maximize: bool, + use_accelerated_op: bool, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adagrad( + lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.Adagrad( + model_ref.parameters(), + lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + @helpers.parametrize( dtype=[torch.float64], outer_lr=[1e-2, 1e-3, 1e-4], diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index 9c607b65..2cd9ff7a 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -34,6 +34,8 @@ import logging from typing import Optional +import torch + from torchopt.typing import Numeric, Scalar, Schedule @@ -45,6 +47,7 @@ def exponential_decay( decay_rate: Scalar, transition_begin: int = 0, transition_steps: Optional[int] = None, + staircase: bool = False, end_value: Optional[float] = None, ) -> Schedule: """Constructs a schedule with either continuous or discrete exponential decay. @@ -76,17 +79,23 @@ def exponential_decay( transition_begin = 0 if end_value is not None: - clip_fn = max if decay_rate < 1.0 else min + pass def schedule(count: Numeric) -> Numeric: decreased_count = count - transition_begin - decayed_value = ( - init_value / (1 + (decreased_count - 1) * decay_rate) - if decreased_count > 0 - else init_value + p = decreased_count / transition_steps + + if staircase: + p = torch.floor(p) + + decayed_value = torch.where( + decreased_count <= 0, + torch.tensor(init_value), + torch.tensor(init_value) * torch.pow(torch.tensor(decay_rate), p), ) + if end_value is not None: - decayed_value = clip_fn(decayed_value, end_value) + return torch.clamp(decayed_value, max=end_value) return decayed_value return schedule diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 0e3cb82b..e890ec50 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -38,7 +38,7 @@ from torchopt import pytree from torchopt.base import GradientTransformation from torchopt.transform.utils import tree_map_flat -from torchopt.typing import Updates +from torchopt.typing import OptState, Params, Updates __all__ = ['scale_by_rss'] @@ -93,23 +93,29 @@ def _scale_by_rss( ) -> GradientTransformation: tree_map = tree_map_flat if already_flattened else pytree.tree_map - def init_fn(params): + def init_fn(params: Params) -> OptState: sum_of_squares = tree_map(lambda t: torch.full_like(t, initial_accumulator_value), params) return ScaleByRssState(sum_of_squares=sum_of_squares) - def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument + def update_fn( + updates: Updates, + state: OptState, + params: Params | None = None, + inplace: bool = True, + ) -> tuple[Updates, OptState]: # pylint: disable=unused-argument + del params sum_of_squares = tree_map( lambda g, t: (g.conj() * g).real + t, updates, state.sum_of_squares ) if inplace: - def f(t): + def f(t: torch.Tensor) -> torch.Tensor: return t.add_(eps).rsqrt_() if t > 0.0 else 0.0 else: - def f(t): + def f(t: torch.Tensor) -> torch.Tensor: return t.add(eps).rsqrt() if t > 0.0 else 0.0 inv_sqrt_g_square = tree_map(f, sum_of_squares) From dac67fb561501e918f75de37142c7a4c554478ea Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Fri, 3 Mar 2023 18:24:10 +0800 Subject: [PATCH 05/37] feat: adagrad integration --- docs/source/spelling_wordlist.txt | 1 + torchopt/alias/adagrad.py | 2 +- torchopt/schedule/exponential_decay.py | 22 ++++++++++++++++++++-- torchopt/transform/scale_by_rss.py | 2 ++ 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index aac17046..021871a1 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -172,3 +172,4 @@ abc ABCMeta subclasscheck ctx +Duchi diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 7ee50774..a3c079c1 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -56,7 +56,7 @@ def adagrad( parameter during the course of training. WARNING: AdaGrad's main limit is the monotonic accumulation of squared gradients in the denominator: since all terms are >0, the sum keeps growing during training and the learning rate - eventually becomes vanishingly small. + eventually becomes very small. References: Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index 2cd9ff7a..e642c409 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -42,6 +42,7 @@ __all__ = ['exponential_decay'] +# pylint: disable-next=too-many-arguments def exponential_decay( init_value: Scalar, decay_rate: Scalar, @@ -51,10 +52,27 @@ def exponential_decay( end_value: Optional[float] = None, ) -> Schedule: """Constructs a schedule with either continuous or discrete exponential decay. + + This function applies an exponential decay function to a provided initial + value. The function returns the decayed value as follows: + ``` + decayed_value = init_value * decay_rate ^ (count / transition_steps) + ``` + If the argument `staircase` is `True`, then `count / transition_steps` is + an integer division and the decayed value follows a staircase function. Args: - value: value to be held constant throughout. + init_value: the initial learning rate. + decay_rate: must not be zero. The decay rate. + transition_begin: must be positive. After how many steps to start annealing + (before this many steps the scalar value is held fixed at `init_value`). + transition_steps: must be positive. See the decay computation above. + staircase: if `True`, decay the values at discrete intervals. + end_value: the value at which the exponential decay stops. When + `decay_rate` < 1, `end_value` is treated as a lower bound, otherwise as + an upper bound. Has no effect when `decay_rate` = 0. + Returns: - schedule: A function that maps step counts to values. + schedule: A function that maps step counts to values. """ if transition_steps is not None and transition_steps <= 0: logging.info( diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index e890ec50..245e0e31 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -31,6 +31,8 @@ # ============================================================================== """Preset transformations for scaling updates by the root of the sum of all squared gradients.""" +from __future__ import annotations + from typing import NamedTuple import torch From a953329b1fa5cc24f981ded8c84e51ce5413755c Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Fri, 3 Mar 2023 18:51:29 +0800 Subject: [PATCH 06/37] feat: adagrad integration --- tests/test_alias.py | 9 +++------ torchopt/__init__.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/test_alias.py b/tests/test_alias.py index ce37af76..7e619a34 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -163,14 +163,13 @@ def test_adam( @helpers.parametrize( dtype=[torch.float64], - lr=[1e-2, 1e-3, 1e-4], - lr_decay=[0.0, 1e-2], - initial_accumulator_value=[0.0, 1e-2], + lr=[1e-2], + lr_decay=[0.0], + initial_accumulator_value=[0.0], eps=[1e-8], inplace=[True, False], weight_decay=[0.0, 1e-2], maximize=[False, True], - use_accelerated_op=[False, True], use_chain_flat=[True, False], ) def test_adagrad( @@ -182,7 +181,6 @@ def test_adagrad( inplace: bool, weight_decay: float, maximize: bool, - use_accelerated_op: bool, use_chain_flat: bool, ) -> None: _set_use_chain_flat(use_chain_flat) @@ -197,7 +195,6 @@ def test_adagrad( initial_accumulator_value=initial_accumulator_value, eps=eps, maximize=maximize, - use_accelerated_op=use_accelerated_op, ) optim_state = optim.init(params) optim_ref = torch.optim.Adagrad( diff --git a/torchopt/__init__.py b/torchopt/__init__.py index 0c36ac07..0374c3bf 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -33,7 +33,7 @@ visual, ) from torchopt.accelerated_op import is_available as accelerated_op_available -from torchopt.alias import adam, adamw, rmsprop, sgd +from torchopt.alias import adagrad, adam, adamw, rmsprop, sgd from torchopt.clip import clip_grad_norm from torchopt.combine import chain from torchopt.hook import register_hook From 9786565f3baa4a8eaac6e5eb14b9349055c1ca5d Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Fri, 3 Mar 2023 19:18:58 +0800 Subject: [PATCH 07/37] feat: adagrad integration --- torchopt/transform/scale_by_rss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 245e0e31..ba480829 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -113,12 +113,12 @@ def update_fn( if inplace: def f(t: torch.Tensor) -> torch.Tensor: - return t.add_(eps).rsqrt_() if t > 0.0 else 0.0 + return torch.where(t > 0.0, t.add_(eps).rsqrt_(), torch.tensor(0.0)) else: def f(t: torch.Tensor) -> torch.Tensor: - return t.add(eps).rsqrt() if t > 0.0 else 0.0 + return torch.where(t > 0.0, t.add(eps).rsqrt(), torch.tensor(0.0)) inv_sqrt_g_square = tree_map(f, sum_of_squares) updates = tree_map(lambda scale, g: scale * g, inv_sqrt_g_square, updates) From 449bdb059dc765672e36e668198af8dd394044be Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Fri, 3 Mar 2023 21:51:17 +0800 Subject: [PATCH 08/37] feat: adagrad integration --- tests/test_alias.py | 6 +++--- torchopt/alias/adagrad.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_alias.py b/tests/test_alias.py index 7e619a34..26b1c12b 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -166,10 +166,10 @@ def test_adam( lr=[1e-2], lr_decay=[0.0], initial_accumulator_value=[0.0], - eps=[1e-8], + eps=[1e-10], inplace=[True, False], - weight_decay=[0.0, 1e-2], - maximize=[False, True], + weight_decay=[0.0], + maximize=[False], use_chain_flat=[True, False], ) def test_adagrad( diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index a3c079c1..1367a885 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -33,8 +33,9 @@ from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr from torchopt.combine import chain_flat +from torchopt.schedule import exponential_decay from torchopt.transform import scale_by_rss -from torchopt.typing import GradientTransformation, ScalarOrSchedule +from torchopt.typing import GradientTransformation, Scalar __all__ = ['adagrad'] @@ -42,7 +43,7 @@ # pylint: disable-next=too-many-arguments def adagrad( - lr: ScalarOrSchedule = 1e-2, + lr: Scalar = 1e-2, lr_decay: float = 0.0, weight_decay: float = 0.0, initial_accumulator_value: float = 0.0, @@ -98,5 +99,5 @@ def adagrad( return chain_flat( flip_sign_and_add_weight_decay(weight_decay=weight_decay, maximize=maximize), scale_by_rss.flat(initial_accumulator_value=initial_accumulator_value, eps=eps), # type: ignore[attr-defined] - scale_by_neg_lr(lr), + scale_by_neg_lr(exponential_decay(init_value=lr, decay_rate=lr_decay)), ) From 75b2bfb49bc8e86d44db2f6419a0f67b977bc24f Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Fri, 3 Mar 2023 21:57:51 +0800 Subject: [PATCH 09/37] feat: adagrad integration --- torchopt/alias/adagrad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 1367a885..d699b096 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -33,7 +33,6 @@ from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr from torchopt.combine import chain_flat -from torchopt.schedule import exponential_decay from torchopt.transform import scale_by_rss from torchopt.typing import GradientTransformation, Scalar @@ -99,5 +98,6 @@ def adagrad( return chain_flat( flip_sign_and_add_weight_decay(weight_decay=weight_decay, maximize=maximize), scale_by_rss.flat(initial_accumulator_value=initial_accumulator_value, eps=eps), # type: ignore[attr-defined] - scale_by_neg_lr(exponential_decay(init_value=lr, decay_rate=lr_decay)), + # scale_by_neg_lr(exponential_decay(init_value=lr, decay_rate=lr_decay)), + scale_by_neg_lr(lr), ) From 3f28f989257fb234b08376fbce997c57a83ba5dd Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Fri, 3 Mar 2023 22:36:05 +0800 Subject: [PATCH 10/37] feat: adagrad integration --- torchopt/alias/adagrad.py | 2 +- torchopt/transform/scale_by_rss.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index d699b096..54763593 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -70,7 +70,7 @@ def adagrad( Weight decay, add L2 penalty to parameters. initial_accumulator_value: (default: :const:`0.0`) Initial value for the accumulator. - eps: (default: :const:`1e-8`) + eps: (default: :const:`1e-10`) A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. maximize: (default: :data:`False`) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index ba480829..b43ca3bb 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -53,8 +53,8 @@ class ScaleByRssState(NamedTuple): def scale_by_rss( - initial_accumulator_value: float = 0.1, - eps: float = 1e-7, + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, ) -> GradientTransformation: """Rescale updates by the root of the sum of all squared gradients to date. @@ -77,8 +77,8 @@ def scale_by_rss( def _scale_by_rss_flat( - initial_accumulator_value: float = 0.1, - eps: float = 1e-7, + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, ) -> GradientTransformation: return _scale_by_rss( initial_accumulator_value=initial_accumulator_value, @@ -88,8 +88,8 @@ def _scale_by_rss_flat( def _scale_by_rss( - initial_accumulator_value: float = 0.1, - eps: float = 1e-7, + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, *, already_flattened: bool = False, ) -> GradientTransformation: @@ -113,12 +113,16 @@ def update_fn( if inplace: def f(t: torch.Tensor) -> torch.Tensor: - return torch.where(t > 0.0, t.add_(eps).rsqrt_(), torch.tensor(0.0)) + return torch.where( + t > 0.0, torch.ones_like(t).div_(t.sqrt_().add(eps)), torch.tensor(0.0) + ) else: def f(t: torch.Tensor) -> torch.Tensor: - return torch.where(t > 0.0, t.add(eps).rsqrt(), torch.tensor(0.0)) + return torch.where( + t > 0.0, torch.ones_like(t).div(t.sqrt_().add(eps)), torch.tensor(0.0) + ) inv_sqrt_g_square = tree_map(f, sum_of_squares) updates = tree_map(lambda scale, g: scale * g, inv_sqrt_g_square, updates) From ae56e25a244f152581e0205333cfe3b66bf0f5bd Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sat, 4 Mar 2023 00:24:51 +0800 Subject: [PATCH 11/37] feat: adagrad integration --- tests/test_alias.py | 2 +- torchopt/alias/adagrad.py | 2 +- torchopt/transform/scale_by_rss.py | 7 ++++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_alias.py b/tests/test_alias.py index 26b1c12b..91f24933 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -163,7 +163,7 @@ def test_adam( @helpers.parametrize( dtype=[torch.float64], - lr=[1e-2], + lr=[1e-3], lr_decay=[0.0], initial_accumulator_value=[0.0], eps=[1e-10], diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 54763593..620b0e81 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -62,7 +62,7 @@ def adagrad( Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html Args: - lr: (default: :const:`1e-3`) + lr: (default: :const:`1e-2`) This is a fixed global scaling factor. lr_decay: (default: :const:`0.0`) Learning rate decay. diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index b43ca3bb..1853e8ef 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -106,8 +106,13 @@ def update_fn( inplace: bool = True, ) -> tuple[Updates, OptState]: # pylint: disable=unused-argument del params + # sum_of_squares = tree_map( + # lambda g, t: t + (g.conj() * g).real , updates, state.sum_of_squares + # ) + # sum_of_squares = torch.addcmul(state.sum_of_squares, updates, updates, value=1) + sum_of_squares = tree_map( - lambda g, t: (g.conj() * g).real + t, updates, state.sum_of_squares + lambda g, t: t.addcmul(g, g, value=1.0), updates, state.sum_of_squares ) if inplace: From 91c708648bf3482ce1ac3cb06043d58583d07632 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sat, 4 Mar 2023 07:30:33 +0800 Subject: [PATCH 12/37] feat: adagrad integration --- torchopt/schedule/exponential_decay.py | 26 +++++++++++++++----------- torchopt/transform/scale_by_rss.py | 12 ++++++++++-- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index e642c409..9b5be493 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -101,17 +101,21 @@ def exponential_decay( def schedule(count: Numeric) -> Numeric: decreased_count = count - transition_begin - p = decreased_count / transition_steps - - if staircase: - p = torch.floor(p) - - decayed_value = torch.where( - decreased_count <= 0, - torch.tensor(init_value), - torch.tensor(init_value) * torch.pow(torch.tensor(decay_rate), p), - ) - + if transition_steps is not None: + p = decreased_count / transition_steps + + if staircase: + p = torch.floor(torch.tensor(p)) + + decayed_value = torch.where( + torch.tensor(decreased_count) <= 0, + torch.tensor(init_value), + torch.tensor(init_value) * torch.pow(torch.tensor(decay_rate), p), + ) + else: + decayed_value = torch.tensor(init_value) * torch.pow( + torch.tensor(decay_rate), decreased_count + ) if end_value is not None: return torch.clamp(decayed_value, max=end_value) return decayed_value diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 1853e8ef..55798f98 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -93,10 +93,18 @@ def _scale_by_rss( *, already_flattened: bool = False, ) -> GradientTransformation: - tree_map = tree_map_flat if already_flattened else pytree.tree_map + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] def init_fn(params: Params) -> OptState: - sum_of_squares = tree_map(lambda t: torch.full_like(t, initial_accumulator_value), params) + sum_of_squares = tree_map( + lambda t: torch.full_like( + t, initial_accumulator_value, memory_format=torch.preserve_format + ), + params, + ) return ScaleByRssState(sum_of_squares=sum_of_squares) def update_fn( From 2f78e60c57b265cc48ea19c9381d59cd81dd5948 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sat, 4 Mar 2023 23:54:31 +0800 Subject: [PATCH 13/37] feat: adagrad integration --- tests/test_alias.py | 7 ++++++- torchopt/transform/scale_by_rss.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_alias.py b/tests/test_alias.py index 91f24933..88f6e1a0 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -206,8 +206,11 @@ def test_adagrad( eps=eps, maximize=maximize, ) - + t = 0 for xs, ys in loader: + t = t + 1 + if t == 1: + break xs = xs.to(dtype=dtype) pred = fmodel(params, buffers, xs) pred_ref = model_ref(xs) @@ -222,6 +225,8 @@ def test_adagrad( loss_ref.backward() optim_ref.step() + # _, params_ref, buffers_ref = functorch.make_functional_with_buffers(model_ref) + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) _set_use_chain_flat(True) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 55798f98..d7af9127 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -138,7 +138,7 @@ def f(t: torch.Tensor) -> torch.Tensor: ) inv_sqrt_g_square = tree_map(f, sum_of_squares) - updates = tree_map(lambda scale, g: scale * g, inv_sqrt_g_square, updates) + updates = tree_map(lambda scale, g: g * scale, inv_sqrt_g_square, updates) return updates, ScaleByRssState(sum_of_squares=sum_of_squares) return GradientTransformation(init_fn, update_fn) From c8e74f4eee03429bdabd076e3fb9af501dcde215 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sun, 5 Mar 2023 00:45:55 +0800 Subject: [PATCH 14/37] feat: adagrad integration --- tests/test_alias.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_alias.py b/tests/test_alias.py index 88f6e1a0..b5de7eb4 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -209,7 +209,7 @@ def test_adagrad( t = 0 for xs, ys in loader: t = t + 1 - if t == 1: + if t == 2: break xs = xs.to(dtype=dtype) pred = fmodel(params, buffers, xs) From 7e76a7e75a63f38417f192833e5577e98b450f6f Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sat, 11 Mar 2023 19:39:08 +0800 Subject: [PATCH 15/37] feat: adagrad integration --- torchopt/alias/adagrad.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 620b0e81..c1b4882c 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -32,7 +32,7 @@ """Preset :class:`GradientTransformation` for the AdaGrad optimizer.""" from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr -from torchopt.combine import chain_flat +from torchopt.combine import chain from torchopt.transform import scale_by_rss from torchopt.typing import GradientTransformation, Scalar @@ -95,9 +95,17 @@ def adagrad( raise ValueError(f'Invalid weight_decay value: {weight_decay}') # pylint: enable=unneeded-not - return chain_flat( - flip_sign_and_add_weight_decay(weight_decay=weight_decay, maximize=maximize), - scale_by_rss.flat(initial_accumulator_value=initial_accumulator_value, eps=eps), # type: ignore[attr-defined] + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + adagrad_scaler_fn = scale_by_rss + scale_by_neg_lr_fn = scale_by_neg_lr + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), + adagrad_scaler_fn( + initial_accumulator_value=initial_accumulator_value, + eps=eps, + ), # scale_by_neg_lr(exponential_decay(init_value=lr, decay_rate=lr_decay)), - scale_by_neg_lr(lr), + scale_by_neg_lr_fn(lr), ) From 1077916fd0956195172165b31b67131121e27704 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sat, 11 Mar 2023 19:51:19 +0800 Subject: [PATCH 16/37] feat: adagrad integration --- tests/test_alias.py | 12 ++++++++---- torchopt/transform/scale_by_rss.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test_alias.py b/tests/test_alias.py index b5de7eb4..a8e93444 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -166,7 +166,7 @@ def test_adam( lr=[1e-3], lr_decay=[0.0], initial_accumulator_value=[0.0], - eps=[1e-10], + eps=[1e-5], inplace=[True, False], weight_decay=[0.0], maximize=[False], @@ -209,8 +209,8 @@ def test_adagrad( t = 0 for xs, ys in loader: t = t + 1 - if t == 2: - break + # if t == 2: + # break xs = xs.to(dtype=dtype) pred = fmodel(params, buffers, xs) pred_ref = model_ref(xs) @@ -225,7 +225,11 @@ def test_adagrad( loss_ref.backward() optim_ref.step() - # _, params_ref, buffers_ref = functorch.make_functional_with_buffers(model_ref) + print(t) + + _, params_ref, buffers_ref = functorch.make_functional_with_buffers(model_ref) + + print(t) helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) _set_use_chain_flat(True) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index d7af9127..0292e9d0 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -120,7 +120,7 @@ def update_fn( # sum_of_squares = torch.addcmul(state.sum_of_squares, updates, updates, value=1) sum_of_squares = tree_map( - lambda g, t: t.addcmul(g, g, value=1.0), updates, state.sum_of_squares + lambda g, t: t.addcmul_(g, g, value=1.0), updates, state.sum_of_squares ) if inplace: From adf641e7f82bc00283a60000b33037e0e2319a1c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Mar 2023 05:10:02 +0000 Subject: [PATCH 17/37] fix: [pre-commit.ci] auto fixes [...] --- torchopt/schedule/exponential_decay.py | 9 +++++---- torchopt/transform/scale_by_rss.py | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index 9b5be493..ea204905 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -78,21 +78,21 @@ def exponential_decay( logging.info( 'An linear schedule was set with a non-positive `transition_steps`' ' value; this will result in a constant schedule with value ' - '`init_value`.' + '`init_value`.', ) return lambda count: init_value if decay_rate == 0: logging.info( 'An linear schedule was set with a zero `decay_rate` value; ' - 'this will result in a constant schedule with value `init_value`.' + 'this will result in a constant schedule with value `init_value`.', ) return lambda count: init_value if transition_begin < 0: logging.info( 'An linear schedule was set with a negative `transition_begin` ' - 'value; this will result in `transition_begin` falling back to `0`.' + 'value; this will result in `transition_begin` falling back to `0`.', ) transition_begin = 0 @@ -114,7 +114,8 @@ def schedule(count: Numeric) -> Numeric: ) else: decayed_value = torch.tensor(init_value) * torch.pow( - torch.tensor(decay_rate), decreased_count + torch.tensor(decay_rate), + decreased_count, ) if end_value is not None: return torch.clamp(decayed_value, max=end_value) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 0292e9d0..affa17e4 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -101,7 +101,9 @@ def _scale_by_rss( def init_fn(params: Params) -> OptState: sum_of_squares = tree_map( lambda t: torch.full_like( - t, initial_accumulator_value, memory_format=torch.preserve_format + t, + initial_accumulator_value, + memory_format=torch.preserve_format, ), params, ) @@ -120,21 +122,27 @@ def update_fn( # sum_of_squares = torch.addcmul(state.sum_of_squares, updates, updates, value=1) sum_of_squares = tree_map( - lambda g, t: t.addcmul_(g, g, value=1.0), updates, state.sum_of_squares + lambda g, t: t.addcmul_(g, g, value=1.0), + updates, + state.sum_of_squares, ) if inplace: def f(t: torch.Tensor) -> torch.Tensor: return torch.where( - t > 0.0, torch.ones_like(t).div_(t.sqrt_().add(eps)), torch.tensor(0.0) + t > 0.0, + torch.ones_like(t).div_(t.sqrt_().add(eps)), + torch.tensor(0.0), ) else: def f(t: torch.Tensor) -> torch.Tensor: return torch.where( - t > 0.0, torch.ones_like(t).div(t.sqrt_().add(eps)), torch.tensor(0.0) + t > 0.0, + torch.ones_like(t).div(t.sqrt_().add(eps)), + torch.tensor(0.0), ) inv_sqrt_g_square = tree_map(f, sum_of_squares) From 3ca005cc7f0d3a1f8a0dca7cf22172e59463f0a4 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sat, 18 Mar 2023 19:12:34 +0800 Subject: [PATCH 18/37] feat: adagrad integration --- tests/test_alias.py | 4 ++-- torchopt/transform/scale_by_rss.py | 11 +++-------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/test_alias.py b/tests/test_alias.py index 7cf794eb..8d1c2aaf 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -238,7 +238,7 @@ def test_adagrad( fmodel, params, buffers = functorch.make_functional_with_buffers(model) optim = torchopt.adagrad( - lr, + lr=lr, lr_decay=lr_decay, weight_decay=weight_decay, initial_accumulator_value=initial_accumulator_value, @@ -248,7 +248,7 @@ def test_adagrad( optim_state = optim.init(params) optim_ref = torch.optim.Adagrad( model_ref.parameters(), - lr, + lr=lr, lr_decay=lr_decay, weight_decay=weight_decay, initial_accumulator_value=initial_accumulator_value, diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index affa17e4..8fbe3b84 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -116,13 +116,8 @@ def update_fn( inplace: bool = True, ) -> tuple[Updates, OptState]: # pylint: disable=unused-argument del params - # sum_of_squares = tree_map( - # lambda g, t: t + (g.conj() * g).real , updates, state.sum_of_squares - # ) - # sum_of_squares = torch.addcmul(state.sum_of_squares, updates, updates, value=1) - sum_of_squares = tree_map( - lambda g, t: t.addcmul_(g, g, value=1.0), + lambda g, t: t + (g.conj() * g).real, updates, state.sum_of_squares, ) @@ -132,7 +127,7 @@ def update_fn( def f(t: torch.Tensor) -> torch.Tensor: return torch.where( t > 0.0, - torch.ones_like(t).div_(t.sqrt_().add(eps)), + torch.ones_like(t).div_(t.sqrt().add_(eps)), torch.tensor(0.0), ) @@ -141,7 +136,7 @@ def f(t: torch.Tensor) -> torch.Tensor: def f(t: torch.Tensor) -> torch.Tensor: return torch.where( t > 0.0, - torch.ones_like(t).div(t.sqrt_().add(eps)), + torch.ones_like(t).div(t.sqrt().add_(eps)), torch.tensor(0.0), ) From 9a17c10ba07ac3a4e068bf98b7d1d85ea21c50ed Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sat, 18 Mar 2023 21:46:05 +0800 Subject: [PATCH 19/37] feat: adagrad integration --- tests/test_alias.py | 22 ++++---------- torchopt/alias/adagrad.py | 9 ++++-- torchopt/schedule/exponential_decay.py | 10 +++---- torchopt/schedule/polynomial.py | 40 ++++++++++++++++++++++++++ torchopt/transform/scale_by_rss.py | 2 +- 5 files changed, 59 insertions(+), 24 deletions(-) diff --git a/tests/test_alias.py b/tests/test_alias.py index 8d1c2aaf..90cd6354 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -212,13 +212,13 @@ def test_adam( @helpers.parametrize( dtype=[torch.float64], - lr=[1e-3], - lr_decay=[0.0], - initial_accumulator_value=[0.0], - eps=[1e-5], + lr=[1e-2], + lr_decay=[0.0, 1e-2], + initial_accumulator_value=[0.0, 1e-1], + eps=[1e-10, 1e-7], inplace=[True, False], - weight_decay=[0.0], - maximize=[False], + weight_decay=[1e-2], + maximize=[False, True], use_chain_flat=[True, False], ) def test_adagrad( @@ -255,11 +255,7 @@ def test_adagrad( eps=eps, maximize=maximize, ) - t = 0 for xs, ys in loader: - t = t + 1 - # if t == 2: - # break xs = xs.to(dtype=dtype) pred = fmodel(params, buffers, xs) pred_ref = model_ref(xs) @@ -274,12 +270,6 @@ def test_adagrad( loss_ref.backward() optim_ref.step() - print(t) - - _, params_ref, buffers_ref = functorch.make_functional_with_buffers(model_ref) - - print(t) - helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) _set_use_chain_flat(True) diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index c1b4882c..8038866b 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -33,6 +33,7 @@ from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr from torchopt.combine import chain +from torchopt.schedule import polynomial_schedule from torchopt.transform import scale_by_rss from torchopt.typing import GradientTransformation, Scalar @@ -93,12 +94,15 @@ def adagrad( raise ValueError(f'Invalid lr_decay value: {lr_decay}') if not weight_decay >= 0.0: raise ValueError(f'Invalid weight_decay value: {weight_decay}') + if not initial_accumulator_value >= 0.0: + raise ValueError(f'Invalid initial_accumulator_value value: {initial_accumulator_value}') # pylint: enable=unneeded-not chain_fn = chain flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay adagrad_scaler_fn = scale_by_rss scale_by_neg_lr_fn = scale_by_neg_lr + schedule_fn = polynomial_schedule.adagrad return chain_fn( flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), @@ -106,6 +110,7 @@ def adagrad( initial_accumulator_value=initial_accumulator_value, eps=eps, ), - # scale_by_neg_lr(exponential_decay(init_value=lr, decay_rate=lr_decay)), - scale_by_neg_lr_fn(lr), + scale_by_neg_lr_fn( + schedule_fn(init_value=lr, decay_rate=lr_decay, transition_begin=0), + ), ) diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index ea204905..666fcd9c 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -76,7 +76,7 @@ def exponential_decay( """ if transition_steps is not None and transition_steps <= 0: logging.info( - 'An linear schedule was set with a non-positive `transition_steps`' + 'An exponential schedule was set with a non-positive `transition_steps`' ' value; this will result in a constant schedule with value ' '`init_value`.', ) @@ -84,20 +84,20 @@ def exponential_decay( if decay_rate == 0: logging.info( - 'An linear schedule was set with a zero `decay_rate` value; ' + 'An exponential schedule was set with a zero `decay_rate` value; ' 'this will result in a constant schedule with value `init_value`.', ) return lambda count: init_value if transition_begin < 0: logging.info( - 'An linear schedule was set with a negative `transition_begin` ' + 'An exponential schedule was set with a negative `transition_begin` ' 'value; this will result in `transition_begin` falling back to `0`.', ) transition_begin = 0 if end_value is not None: - pass + clip_fn = torch.maximum if decay_rate < 1.0 else torch.minimum def schedule(count: Numeric) -> Numeric: decreased_count = count - transition_begin @@ -118,7 +118,7 @@ def schedule(count: Numeric) -> Numeric: decreased_count, ) if end_value is not None: - return torch.clamp(decayed_value, max=end_value) + return clip_fn(decayed_value, end_value) return decayed_value return schedule diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index f7eb4a6a..e0a888e7 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -42,6 +42,43 @@ __all__ = ['polynomial_schedule', 'linear_schedule'] +# pylint: disable-next=too-many-arguments +def _adagrad_lr_decay( + init_value: Scalar, + decay_rate: Scalar, + transition_begin: int = 0, +) -> Schedule: + """Constructs a schedule dedicated to AdaGrad optimizer. + + This function applies an learning rate decay function to a provided initial + value. The function returns the decayed value as follows: + ``` + decayed_value = init_value / 1 + count * decay_rate + ``` + + Args: + init_value: the initial learning rate. + decay_rate: The decay rate. + transition_begin: must be positive. After how many steps to start annealing + (before this many steps the scalar value is held fixed at `init_value`). + + Returns: + schedule: A function that maps step counts to values. + """ + if transition_begin < 0: + logging.info( + 'An exponential schedule was set with a negative `transition_begin` ' + 'value; this will result in `transition_begin` falling back to `0`.', + ) + transition_begin = 0 + + def schedule(count: Numeric) -> Numeric: + decreased_count = count - transition_begin + return init_value / (1 + decay_rate * decreased_count) + + return schedule + + def polynomial_schedule( init_value: Scalar, end_value: Scalar, @@ -106,3 +143,6 @@ def linear_schedule( transition_steps=transition_steps, transition_begin=transition_begin, ) + + +polynomial_schedule.adagrad = _adagrad_lr_decay diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 8fbe3b84..ad23c923 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -136,7 +136,7 @@ def f(t: torch.Tensor) -> torch.Tensor: def f(t: torch.Tensor) -> torch.Tensor: return torch.where( t > 0.0, - torch.ones_like(t).div(t.sqrt().add_(eps)), + torch.ones_like(t).div(t.sqrt().add(eps)), torch.tensor(0.0), ) From 79036edcfd560179c77560afb7fb0bd2da3997df Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sun, 19 Mar 2023 03:34:35 +0800 Subject: [PATCH 20/37] feat: adagrad integration --- CHANGELOG.md | 1 + docs/source/api/api.rst | 63 +++++++---- docs/source/explicit_diff/explicit_diff.rst | 5 +- docs/source/implicit_diff/implicit_diff.rst | 2 +- docs/source/optimizer/optim.rst | 11 +- docs/source/spelling_wordlist.txt | 2 + tests/test_alias.py | 116 ++++++++++---------- tests/test_optim.py | 59 ++++++++++ tests/test_schedule.py | 30 +++++ torchopt/__init__.py | 6 +- torchopt/alias/adagrad.py | 4 +- torchopt/optim/__init__.py | 1 + torchopt/optim/adagrad.py | 80 ++++++++++++++ torchopt/optim/meta/__init__.py | 1 + torchopt/optim/meta/adagrad.py | 77 +++++++++++++ torchopt/schedule/exponential_decay.py | 20 +--- torchopt/schedule/polynomial.py | 2 +- 17 files changed, 375 insertions(+), 105 deletions(-) create mode 100644 torchopt/optim/adagrad.py create mode 100644 torchopt/optim/meta/adagrad.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e9ace446..cf64375d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Implement AdaGrad optimizer and exponential learning rate schedule by [@Benjamin-eecs](https://github.com/Benjamin-eecs). - Enable tests on Windows by [@XuehaiPan](https://github.com/XuehaiPan) in [#140](https://github.com/metaopt/torchopt/pull/140). - Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/torchopt/pull/138) and [#139](https://github.com/metaopt/torchopt/pull/139). - Add more documentation on implicit differentiation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan). diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index b2866407..a9937f57 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -30,10 +30,14 @@ Functional Optimizers .. autosummary:: FuncOptimizer + adagrad adam - sgd - rmsprop adamw + rmsprop + sgd + + + Wrapper for Function Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -41,6 +45,11 @@ Wrapper for Function Optimizer .. autoclass:: FuncOptimizer :members: +Functional AdaGrad Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: adagrad + Functional Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -51,16 +60,16 @@ Functional AdamW Optimizer .. autofunction:: adamw -Functional SGD Optimizer -~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: sgd - Functional RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: rmsprop +Functional SGD Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: sgd + ------ Classic Optimizers @@ -70,10 +79,16 @@ Classic Optimizers .. autosummary:: + Adagrad Adam - SGD - RMSProp AdamW + RMSProp + SGD + +Classic AdaGrad Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: Adagrad Classic Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~ @@ -85,16 +100,16 @@ Classic AdamW Optimizer .. autoclass:: AdamW -Classic SGD Optimizer -~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: SGD - Classic RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: RMSProp +Classic SGD Optimizer +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: SGD + ------ Differentiable Meta-Optimizers @@ -104,10 +119,16 @@ Differentiable Meta-Optimizers .. autosummary:: + MetaAdagrad MetaAdam - MetaSGD - MetaRMSProp MetaAdamW + MetaRMSProp + MetaSGD + +Differentiable Meta-AdaGrad Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaAdagrad Differentiable Meta-Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -119,16 +140,16 @@ Differentiable Meta-AdamW Optimizer .. autoclass:: MetaAdamW -Differentiable Meta-SGD Optimizer -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: MetaSGD - Differentiable Meta-RMSProp Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MetaRMSProp +Differentiable Meta-SGD Optimizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MetaSGD + ------ Implicit Differentiation diff --git a/docs/source/explicit_diff/explicit_diff.rst b/docs/source/explicit_diff/explicit_diff.rst index 89c38df6..57e153d7 100644 --- a/docs/source/explicit_diff/explicit_diff.rst +++ b/docs/source/explicit_diff/explicit_diff.rst @@ -53,10 +53,11 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho .. autosummary:: torchopt.MetaOptimizer + torchopt.MetaAdagrad torchopt.MetaAdam - torchopt.MetaSGD - torchopt.MetaRMSProp torchopt.MetaAdamW + torchopt.MetaRMSProp + torchopt.MetaSGD By combining low-level API :class:`torchopt.MetaOptimizer` with the previous functional optimizer, we can achieve high-level API: diff --git a/docs/source/implicit_diff/implicit_diff.rst b/docs/source/implicit_diff/implicit_diff.rst index 294d58d1..5544c25f 100644 --- a/docs/source/implicit_diff/implicit_diff.rst +++ b/docs/source/implicit_diff/implicit_diff.rst @@ -38,7 +38,7 @@ In `IMAML `_, the function :math:`F` in the fi Fixed-point Iteration ~~~~~~~~~~~~~~~~~~~~~ -Sometimes the inner-level optimal solution can also be achieved by fixed point where the optionality :math:`T` takes the form: +Sometimes the inner-level optimal solution can also be achieved by fixed point where the optimality :math:`T` takes the form: .. math:: diff --git a/docs/source/optimizer/optim.rst b/docs/source/optimizer/optim.rst index 850bc8c7..a96707f1 100644 --- a/docs/source/optimizer/optim.rst +++ b/docs/source/optimizer/optim.rst @@ -18,10 +18,11 @@ Currently, TorchOpt supports 4 functional optimizers: :func:`sgd`, :func:`adam`, .. autosummary:: torchopt.FuncOptimizer + torchopt.adagrad torchopt.adam - torchopt.sgd - torchopt.rmsprop torchopt.adamw + torchopt.rmsprop + torchopt.sgd Apply Parameter Updates ----------------------- @@ -84,10 +85,12 @@ We offer original PyTorch APIs (e.g., ``zero_grad()`` or ``step()``) for traditi .. autosummary:: torchopt.Optimizer + torchopt.Adagrad torchopt.Adam - torchopt.SGD - torchopt.RMSProp torchopt.AdamW + torchopt.RMSProp + torchopt.SGD + By combining low-level API :class:`torchopt.Optimizer` with the previous functional optimizer, we can achieve high-level API: diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index 021871a1..49fdbb69 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -173,3 +173,5 @@ ABCMeta subclasscheck ctx Duchi +invertible +AdaGrad diff --git a/tests/test_alias.py b/tests/test_alias.py index 90cd6354..d4274e97 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -77,54 +77,50 @@ def test_empty( @helpers.parametrize( - dtype=[torch.float64, torch.float32], - lr=[1e-2, 1e-3, 1e-4], - momentum=[0.0, 0.1], - dampening=[0.0, 0.5], - nesterov=[False, True], + dtype=[torch.float64], + lr=[1e-2], + lr_decay=[0.0, 1e-2], + initial_accumulator_value=[0.0, 1e-1], + eps=[1e-8], inplace=[True, False], weight_decay=[0.0, 1e-2], maximize=[False, True], use_chain_flat=[True, False], ) -def test_sgd( +def test_adagrad( dtype: torch.dtype, lr: float, - momentum: float, - dampening: float, - nesterov: bool, + lr_decay: float, + initial_accumulator_value: float, + eps: float, inplace: bool, weight_decay: float, maximize: bool, use_chain_flat: bool, ) -> None: - if nesterov and (momentum <= 0.0 or dampening != 0.0): - pytest.skip('Nesterov momentum requires a momentum and zero dampening.') - _set_use_chain_flat(use_chain_flat) model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) - optim = torchopt.sgd( - lr, - momentum=momentum, - dampening=dampening, - nesterov=nesterov, + optim = torchopt.adagrad( + lr=lr, + lr_decay=lr_decay, weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, maximize=maximize, ) optim_state = optim.init(params) - optim_ref = torch.optim.SGD( + optim_ref = torch.optim.Adagrad( model_ref.parameters(), - lr, - momentum=momentum, - dampening=dampening, - nesterov=nesterov, + lr=lr, + lr_decay=lr_decay, weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, maximize=maximize, ) - for xs, ys in loader: xs = xs.to(dtype=dtype) pred = fmodel(params, buffers, xs) @@ -145,48 +141,50 @@ def test_sgd( @helpers.parametrize( - dtype=[torch.float64], + dtype=[torch.float64, torch.float32], lr=[1e-2, 1e-3, 1e-4], - betas=[(0.9, 0.999), (0.95, 0.9995)], - eps=[1e-8], + momentum=[0.0, 0.1], + dampening=[0.0, 0.5], + nesterov=[False, True], inplace=[True, False], weight_decay=[0.0, 1e-2], maximize=[False, True], - use_accelerated_op=[False, True], use_chain_flat=[True, False], ) -def test_adam( +def test_sgd( dtype: torch.dtype, lr: float, - betas: tuple[float, float], - eps: float, + momentum: float, + dampening: float, + nesterov: bool, inplace: bool, weight_decay: float, maximize: bool, - use_accelerated_op: bool, use_chain_flat: bool, ) -> None: + if nesterov and (momentum <= 0.0 or dampening != 0.0): + pytest.skip('Nesterov momentum requires a momentum and zero dampening.') + _set_use_chain_flat(use_chain_flat) model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) - optim = torchopt.adam( + optim = torchopt.sgd( lr, - betas=betas, - eps=eps, - eps_root=0.0, + momentum=momentum, + dampening=dampening, + nesterov=nesterov, weight_decay=weight_decay, maximize=maximize, - use_accelerated_op=use_accelerated_op, ) optim_state = optim.init(params) - optim_ref = torch.optim.Adam( + optim_ref = torch.optim.SGD( model_ref.parameters(), lr, - betas=betas, - eps=eps, - amsgrad=False, + momentum=momentum, + dampening=dampening, + nesterov=nesterov, weight_decay=weight_decay, maximize=maximize, ) @@ -212,24 +210,24 @@ def test_adam( @helpers.parametrize( dtype=[torch.float64], - lr=[1e-2], - lr_decay=[0.0, 1e-2], - initial_accumulator_value=[0.0, 1e-1], - eps=[1e-10, 1e-7], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], inplace=[True, False], - weight_decay=[1e-2], + weight_decay=[0.0, 1e-2], maximize=[False, True], + use_accelerated_op=[False, True], use_chain_flat=[True, False], ) -def test_adagrad( +def test_adam( dtype: torch.dtype, lr: float, - lr_decay: float, - initial_accumulator_value: float, + betas: tuple[float, float], eps: float, inplace: bool, weight_decay: float, maximize: bool, + use_accelerated_op: bool, use_chain_flat: bool, ) -> None: _set_use_chain_flat(use_chain_flat) @@ -237,24 +235,26 @@ def test_adagrad( model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) fmodel, params, buffers = functorch.make_functional_with_buffers(model) - optim = torchopt.adagrad( - lr=lr, - lr_decay=lr_decay, - weight_decay=weight_decay, - initial_accumulator_value=initial_accumulator_value, + optim = torchopt.adam( + lr, + betas=betas, eps=eps, + eps_root=0.0, + weight_decay=weight_decay, maximize=maximize, + use_accelerated_op=use_accelerated_op, ) optim_state = optim.init(params) - optim_ref = torch.optim.Adagrad( + optim_ref = torch.optim.Adam( model_ref.parameters(), - lr=lr, - lr_decay=lr_decay, - weight_decay=weight_decay, - initial_accumulator_value=initial_accumulator_value, + lr, + betas=betas, eps=eps, + amsgrad=False, + weight_decay=weight_decay, maximize=maximize, ) + for xs, ys in loader: xs = xs.to(dtype=dtype) pred = fmodel(params, buffers, xs) diff --git a/tests/test_optim.py b/tests/test_optim.py index b2be7500..c2131e37 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -86,6 +86,65 @@ def test_SGD( helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2], + lr_decay=[0.0, 1e-2], + initial_accumulator_value=[0.0, 1e-1], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + maximize=[False, True], +) +def test_Adagrad( + dtype: torch.dtype, + lr: float, + lr_decay: float, + initial_accumulator_value: float, + eps: float, + inplace: bool, + weight_decay: float, + maximize: bool, +) -> None: + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + optim = torchopt.Adagrad( + model.parameters(), + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ) + optim_ref = torch.optim.Adagrad( + model_ref.parameters(), + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = model(xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + optim.zero_grad() + loss.backward() + optim.step() + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) + + @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], diff --git a/tests/test_schedule.py b/tests/test_schedule.py index ae714875..9daa49fe 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -27,6 +27,36 @@ from torchopt.alias.utils import _set_use_chain_flat +@helpers.parametrize( + init_value=[], + decay_rate=[], + transition_begin=[], + transition_steps=[], + staircase=[False, True], + end_value=[], +) +def test_exponential_decay( + init_value: float, + decay_rate: float, + transition_begin: int, + transition_steps: int, + staircase: bool, + end_value: float, +) -> None: + schedule = torchopt.schedule.exponential_decay( + init_value=init_value, + decay_rate=decay_rate, + transition_steps=transition_steps, + transition_begin=transition_begin, + staircase=staircase, + end_value=end_value, + ) + for i in range(transition_begin, transition_steps): + lr = schedule(i) + lr_gt = init_value * (decay_rate ** ((i - transition_begin) / transition_steps)) + assert np.allclose(lr, lr_gt) + + def test_linear_schedule() -> None: init_value = 1.0 end_value = 0.0 diff --git a/torchopt/__init__.py b/torchopt/__init__.py index 0374c3bf..30fc7945 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -37,9 +37,10 @@ from torchopt.clip import clip_grad_norm from torchopt.combine import chain from torchopt.hook import register_hook -from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop +from torchopt.optim import SGD, Adagrad, Adam, AdamW, Optimizer, RMSProp, RMSprop from torchopt.optim.func import FuncOptimizer from torchopt.optim.meta import ( + MetaAdagrad, MetaAdam, MetaAdamW, MetaOptimizer, @@ -61,6 +62,7 @@ __all__ = [ 'accelerated_op_available', + 'adagrad', 'adam', 'adamw', 'rmsprop', @@ -73,12 +75,14 @@ 'SGD', 'Adam', 'AdamW', + 'Adagrad', 'RMSProp', 'RMSprop', 'MetaOptimizer', 'MetaSGD', 'MetaAdam', 'MetaAdamW', + 'MetaAdagrad', 'MetaRMSProp', 'MetaRMSprop', 'FuncOptimizer', diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 8038866b..306635a4 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -72,7 +72,7 @@ def adagrad( initial_accumulator_value: (default: :const:`0.0`) Initial value for the accumulator. eps: (default: :const:`1e-10`) - A small constant applied to denominator outside of the square root (as in the Adam + A small constant applied to denominator outside of the square root (as in the AdaGrad paper) to avoid dividing by zero when rescaling. maximize: (default: :data:`False`) Maximize the params based on the objective, instead of minimizing. @@ -102,7 +102,7 @@ def adagrad( flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay adagrad_scaler_fn = scale_by_rss scale_by_neg_lr_fn = scale_by_neg_lr - schedule_fn = polynomial_schedule.adagrad + schedule_fn = polynomial_schedule.adagrad # type: ignore[attr-defined] return chain_fn( flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), diff --git a/torchopt/optim/__init__.py b/torchopt/optim/__init__.py index b75da23c..f2344464 100644 --- a/torchopt/optim/__init__.py +++ b/torchopt/optim/__init__.py @@ -15,6 +15,7 @@ """object oriented optimizer implementations.""" from torchopt.optim import meta +from torchopt.optim.adagrad import Adagrad from torchopt.optim.adam import Adam from torchopt.optim.adamw import AdamW from torchopt.optim.base import Optimizer diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py new file mode 100644 index 00000000..95bb34ec --- /dev/null +++ b/torchopt/optim/adagrad.py @@ -0,0 +1,80 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AdaGrad optimizer.""" + +from __future__ import annotations + +from typing import Iterable + +import torch + +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import Scalar + + +__all__ = ['Adagrad'] + + +class Adagrad(Optimizer): + """The classic AdaGrad optimizer. + + See Also: + - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. + - The differentiable meta AdaGrad optimizer: :class:`torchopt.MetaAdagrad`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: Scalar = 1e-2, + lr_decay: float = 0.0, + weight_decay: float = 0.0, + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, + *, + maximize: bool = False, + ) -> None: + r"""Initialize the AdaGrad optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr: (default: :const:`1e-2`) + This is a fixed global scaling factor. + lr_decay: (default: :const:`0.0`) + Learning rate decay. + weight_decay: (default: :const:`0.0`) + Weight decay, add L2 penalty to parameters. + initial_accumulator_value: (default: :const:`0.0`) + Initial value for the accumulator. + eps: (default: :const:`1e-10`) + A small constant applied to denominator outside of the square root (as in the AdaGrad + paper) to avoid dividing by zero when rescaling. + maximize: (default: :data:`False`) + Maximize the params based on the objective, instead of minimizing. + """ + super().__init__( + params, + alias.adagrad( + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ), + ) diff --git a/torchopt/optim/meta/__init__.py b/torchopt/optim/meta/__init__.py index ba486d6d..74f881b7 100644 --- a/torchopt/optim/meta/__init__.py +++ b/torchopt/optim/meta/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Differentiable Meta-Optimizers.""" +from torchopt.optim.meta.adagrad import MetaAdagrad from torchopt.optim.meta.adam import MetaAdam from torchopt.optim.meta.adamw import MetaAdamW from torchopt.optim.meta.base import MetaOptimizer diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py new file mode 100644 index 00000000..5785a9c0 --- /dev/null +++ b/torchopt/optim/meta/adagrad.py @@ -0,0 +1,77 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentiable AdaGrad optimizer.""" + +from __future__ import annotations + +import torch.nn as nn + +from torchopt import alias +from torchopt.optim.meta.base import MetaOptimizer +from torchopt.typing import Scalar + + +__all__ = ['MetaAdagrad'] + + +class MetaAdagrad(MetaOptimizer): + """The differentiable AdaGrad optimizer. + + See Also: + - The functional AdaGrad optimizer: :func:`torchopt.adagrad`.m + - The classic AdaGrad optimizer: :class:`torchopt.Adagrad`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + module: nn.Module, + lr: Scalar = 1e-2, + lr_decay: float = 0.0, + weight_decay: float = 0.0, + initial_accumulator_value: float = 0.0, + eps: float = 1e-10, + *, + maximize: bool = False, + ) -> None: + """Initialize the meta AdaGrad optimizer. + + Args: + module (nn.Module): A network whose parameters should be optimized. + lr: (default: :const:`1e-2`) + This is a fixed global scaling factor. + lr_decay: (default: :const:`0.0`) + Learning rate decay. + weight_decay: (default: :const:`0.0`) + Weight decay, add L2 penalty to parameters. + initial_accumulator_value: (default: :const:`0.0`) + Initial value for the accumulator. + eps: (default: :const:`1e-10`) + A small constant applied to denominator outside of the square root (as in the AdaGrad + paper) to avoid dividing by zero when rescaling. + maximize: (default: :data:`False`) + Maximize the params based on the objective, instead of minimizing. + """ + super().__init__( + module, + alias.adagrad( + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ), + ) diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index 666fcd9c..07dc812d 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -32,10 +32,9 @@ """Exponential learning rate decay.""" import logging +import math from typing import Optional -import torch - from torchopt.typing import Numeric, Scalar, Schedule @@ -97,26 +96,17 @@ def exponential_decay( transition_begin = 0 if end_value is not None: - clip_fn = torch.maximum if decay_rate < 1.0 else torch.minimum + clip_fn = max if decay_rate < 1.0 else min def schedule(count: Numeric) -> Numeric: decreased_count = count - transition_begin if transition_steps is not None: p = decreased_count / transition_steps - if staircase: - p = torch.floor(torch.tensor(p)) - - decayed_value = torch.where( - torch.tensor(decreased_count) <= 0, - torch.tensor(init_value), - torch.tensor(init_value) * torch.pow(torch.tensor(decay_rate), p), - ) + p = math.floor(p) + decayed_value = init_value if decreased_count <= 0.0 else init_value * (decay_rate**p) else: - decayed_value = torch.tensor(init_value) * torch.pow( - torch.tensor(decay_rate), - decreased_count, - ) + decayed_value = init_value * (decay_rate**decreased_count) if end_value is not None: return clip_fn(decayed_value, end_value) return decayed_value diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index e0a888e7..664c5693 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -145,4 +145,4 @@ def linear_schedule( ) -polynomial_schedule.adagrad = _adagrad_lr_decay +polynomial_schedule.adagrad = _adagrad_lr_decay # type: ignore[attr-defined] From 85709e34522fa3ae4b2b10f693296877447a65e4 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sun, 19 Mar 2023 03:51:16 +0800 Subject: [PATCH 21/37] feat: adagrad integration --- Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index b6ee6ffe..de7ba8dd 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,7 @@ COPYRIGHT = "MetaOPT Team. All Rights Reserved." PROJECT_PATH = $(PROJECT_NAME) SHELL = /bin/bash SOURCE_FOLDERS = $(PROJECT_PATH) examples include src tests docs +RST_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.rst" -o -name "*.rst") PYTHON_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.py" -o -name "*.pyi") CXX_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.h" -o -name "*.cpp") CUDA_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.cuh" -o -name "*.cu") @@ -113,7 +114,7 @@ addlicense-install: go-install pytest: test-install cd tests && $(PYTHON) -c 'import $(PROJECT_NAME)' && \ - $(PYTHON) -m pytest -k "test_adagrad" --verbose --color=yes --durations=0 \ + $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ --cov="$(PROJECT_NAME)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . From 56369922c8fe5e9b6094179ccea6004a023d90e6 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sun, 19 Mar 2023 03:59:11 +0800 Subject: [PATCH 22/37] feat: adagrad integration --- tests/test_schedule.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 9daa49fe..abe8b15f 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -15,6 +15,7 @@ from __future__ import annotations +import math from typing import Callable import functorch @@ -28,12 +29,12 @@ @helpers.parametrize( - init_value=[], - decay_rate=[], - transition_begin=[], - transition_steps=[], + init_value=[1.0], + decay_rate=[1e-2], + transition_begin=[1], + transition_steps=[10], staircase=[False, True], - end_value=[], + end_value=[0.0], ) def test_exponential_decay( init_value: float, @@ -53,7 +54,12 @@ def test_exponential_decay( ) for i in range(transition_begin, transition_steps): lr = schedule(i) - lr_gt = init_value * (decay_rate ** ((i - transition_begin) / transition_steps)) + if staircase: + lr_gt = init_value * ( + decay_rate ** math.floor((i - transition_begin) / transition_steps) + ) + else: + lr_gt = init_value * (decay_rate ** ((i - transition_begin) / transition_steps)) assert np.allclose(lr, lr_gt) From 3ede2b4fe402bf8e4719062c5e4e2df96a2ecc7c Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sun, 19 Mar 2023 04:26:42 +0800 Subject: [PATCH 23/37] feat: adagrad integration --- torchopt/alias/adagrad.py | 10 +++++----- torchopt/schedule/exponential_decay.py | 8 ++++---- torchopt/schedule/polynomial.py | 16 ++++++++-------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 306635a4..9c51659b 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -86,15 +86,15 @@ def adagrad( The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. """ # pylint: disable=unneeded-not - if not (callable(lr) or lr >= 0.0): + if not (callable(lr) or lr >= 0.0): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') - if not eps >= 0.0: + if not eps >= 0.0: # pragma: no cover raise ValueError(f'Invalid epsilon value: {eps}') - if not lr_decay >= 0.0: + if not lr_decay >= 0.0: # pragma: no cover raise ValueError(f'Invalid lr_decay value: {lr_decay}') - if not weight_decay >= 0.0: + if not weight_decay >= 0.0: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') - if not initial_accumulator_value >= 0.0: + if not initial_accumulator_value >= 0.0: # pragma: no cover raise ValueError(f'Invalid initial_accumulator_value value: {initial_accumulator_value}') # pylint: enable=unneeded-not diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index 07dc812d..4a689c4f 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -73,7 +73,7 @@ def exponential_decay( Returns: schedule: A function that maps step counts to values. """ - if transition_steps is not None and transition_steps <= 0: + if transition_steps is not None and transition_steps <= 0: # pragma: no cover logging.info( 'An exponential schedule was set with a non-positive `transition_steps`' ' value; this will result in a constant schedule with value ' @@ -81,21 +81,21 @@ def exponential_decay( ) return lambda count: init_value - if decay_rate == 0: + if decay_rate == 0: # pragma: no cover logging.info( 'An exponential schedule was set with a zero `decay_rate` value; ' 'this will result in a constant schedule with value `init_value`.', ) return lambda count: init_value - if transition_begin < 0: + if transition_begin < 0: # pragma: no cover logging.info( 'An exponential schedule was set with a negative `transition_begin` ' 'value; this will result in `transition_begin` falling back to `0`.', ) transition_begin = 0 - if end_value is not None: + if end_value is not None: # pragma: no cover clip_fn = max if decay_rate < 1.0 else min def schedule(count: Numeric) -> Numeric: diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index 664c5693..2f491aef 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -65,12 +65,12 @@ def _adagrad_lr_decay( Returns: schedule: A function that maps step counts to values. """ - if transition_begin < 0: - logging.info( - 'An exponential schedule was set with a negative `transition_begin` ' - 'value; this will result in `transition_begin` falling back to `0`.', - ) - transition_begin = 0 + if transition_begin < 0: # pragma: no cover + logging.info( # pragma: no cover + 'An exponential schedule was set with a negative `transition_begin` ' # pragma: no cover + 'value; this will result in `transition_begin` falling back to `0`.', # pragma: no cover + ) # pragma: no cover + transition_begin = 0 # pragma: no cover def schedule(count: Numeric) -> Numeric: decreased_count = count - transition_begin @@ -105,14 +105,14 @@ def polynomial_schedule( schedule: A function that maps step counts to values. """ - if transition_steps <= 0: + if transition_steps <= 0: # pragma: no cover logging.info( 'A polynomial schedule was set with a non-positive `transition_steps` value; this ' 'results in a constant schedule with value `init_value`.', ) return lambda count: init_value - if transition_begin < 0: + if transition_begin < 0: # pragma: no cover logging.info( 'An exponential schedule was set with a negative `transition_begin` value; this will ' 'result in `transition_begin` falling back to `0`.', From 4c4e1a3d40dc61e3ae66ade429d27e6844a06ecf Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sun, 19 Mar 2023 05:05:56 +0800 Subject: [PATCH 24/37] feat: adagrad integration --- tests/test_schedule.py | 8 ++++---- torchopt/schedule/exponential_decay.py | 13 +++++-------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/tests/test_schedule.py b/tests/test_schedule.py index abe8b15f..d2dcca9c 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -16,7 +16,7 @@ from __future__ import annotations import math -from typing import Callable +from typing import Callable, Optional import functorch import numpy as np @@ -34,15 +34,15 @@ transition_begin=[1], transition_steps=[10], staircase=[False, True], - end_value=[0.0], + end_value=[0.0, None], ) def test_exponential_decay( init_value: float, decay_rate: float, transition_begin: int, - transition_steps: int, + transition_steps: int | None, staircase: bool, - end_value: float, + end_value: float | None, ) -> None: schedule = torchopt.schedule.exponential_decay( init_value=init_value, diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index 4a689c4f..06351a09 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -46,7 +46,7 @@ def exponential_decay( init_value: Scalar, decay_rate: Scalar, transition_begin: int = 0, - transition_steps: Optional[int] = None, + transition_steps: int = 1, staircase: bool = False, end_value: Optional[float] = None, ) -> Schedule: @@ -100,13 +100,10 @@ def exponential_decay( def schedule(count: Numeric) -> Numeric: decreased_count = count - transition_begin - if transition_steps is not None: - p = decreased_count / transition_steps - if staircase: - p = math.floor(p) - decayed_value = init_value if decreased_count <= 0.0 else init_value * (decay_rate**p) - else: - decayed_value = init_value * (decay_rate**decreased_count) + p = decreased_count / transition_steps + if staircase: + p = math.floor(p) + decayed_value = init_value if decreased_count <= 0.0 else init_value * (decay_rate**p) if end_value is not None: return clip_fn(decayed_value, end_value) return decayed_value From d431749009e3172a6b9c59c5f32dd9db2ef70495 Mon Sep 17 00:00:00 2001 From: Bo Liu Date: Mon, 20 Mar 2023 21:52:42 +0900 Subject: [PATCH 25/37] feat(torchopt.optim): update torchopt/alias/adagrad.py Co-authored-by: Xuehai Pan --- torchopt/alias/adagrad.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 9c51659b..0171166a 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -55,9 +55,11 @@ def adagrad( AdaGrad is an algorithm for gradient based optimization that anneals the learning rate for each parameter during the course of training. - WARNING: AdaGrad's main limit is the monotonic accumulation of squared gradients in the - denominator: since all terms are >0, the sum keeps growing during training and the learning rate - eventually becomes very small. + + .. warning:: + AdaGrad's main limit is the monotonic accumulation of squared gradients in the denominator. + Since all terms are ``> 0``, the sum keeps growing during training, and the learning rate + eventually becomes very small. References: Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html From 03e5b42679eb12e9ef5d4977db17e44eaf98e912 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Mon, 20 Mar 2023 21:14:50 +0800 Subject: [PATCH 26/37] feat: adagrad integration --- Makefile | 2 +- tests/test_alias.py | 2 +- tests/test_optim.py | 2 +- tests/test_schedule.py | 14 +++-- torchopt/alias/__init__.py | 2 +- torchopt/alias/adagrad.py | 86 ++++++++++++++++++++------ torchopt/optim/__init__.py | 2 +- torchopt/optim/adagrad.py | 32 +++++----- torchopt/optim/meta/__init__.py | 2 +- torchopt/optim/meta/adagrad.py | 2 +- torchopt/schedule/__init__.py | 2 +- torchopt/schedule/exponential_decay.py | 23 ++++--- torchopt/schedule/polynomial.py | 42 +------------ torchopt/transform/scale_by_rss.py | 2 +- torchopt/version.py | 2 +- 15 files changed, 115 insertions(+), 102 deletions(-) diff --git a/Makefile b/Makefile index de7ba8dd..c387fd54 100644 --- a/Makefile +++ b/Makefile @@ -114,7 +114,7 @@ addlicense-install: go-install pytest: test-install cd tests && $(PYTHON) -c 'import $(PROJECT_NAME)' && \ - $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ + $(PYTHON) -m pytest -k "test_exponential_decay" --verbose --color=yes --durations=0 \ --cov="$(PROJECT_NAME)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . diff --git a/tests/test_alias.py b/tests/test_alias.py index d4274e97..ae1cfb70 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -78,7 +78,7 @@ def test_empty( @helpers.parametrize( dtype=[torch.float64], - lr=[1e-2], + lr=[1e-2, 1e-3, 1e-4], lr_decay=[0.0, 1e-2], initial_accumulator_value=[0.0, 1e-1], eps=[1e-8], diff --git a/tests/test_optim.py b/tests/test_optim.py index c2131e37..6f3e9208 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -88,7 +88,7 @@ def test_SGD( @helpers.parametrize( dtype=[torch.float64], - lr=[1e-2], + lr=[1e-2, 1e-3, 1e-4], lr_decay=[0.0, 1e-2], initial_accumulator_value=[0.0, 1e-1], eps=[1e-8], diff --git a/tests/test_schedule.py b/tests/test_schedule.py index d2dcca9c..c3a57f4c 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -29,12 +29,12 @@ @helpers.parametrize( - init_value=[1.0], - decay_rate=[1e-2], - transition_begin=[1], - transition_steps=[10], + init_value=[1.0, 1e-1], + decay_rate=[1e-2, 1e-3], + transition_begin=[1, 5], + transition_steps=[10, 100], staircase=[False, True], - end_value=[0.0, None], + end_value=[0.0, None, 8e-1], ) def test_exponential_decay( init_value: float, @@ -52,6 +52,8 @@ def test_exponential_decay( staircase=staircase, end_value=end_value, ) + if end_value is not None: # pragma: no cover + clip_fn = max if decay_rate < 1.0 else min for i in range(transition_begin, transition_steps): lr = schedule(i) if staircase: @@ -60,6 +62,8 @@ def test_exponential_decay( ) else: lr_gt = init_value * (decay_rate ** ((i - transition_begin) / transition_steps)) + if end_value is not None: + lr_gt = clip_fn(lr_gt, end_value) assert np.allclose(lr, lr_gt) diff --git a/torchopt/alias/__init__.py b/torchopt/alias/__init__.py index 058ac5db..ae7dd2b5 100644 --- a/torchopt/alias/__init__.py +++ b/torchopt/alias/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 9c51659b..546c2104 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,16 +31,58 @@ # ============================================================================== """Preset :class:`GradientTransformation` for the AdaGrad optimizer.""" -from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr +import logging + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) from torchopt.combine import chain -from torchopt.schedule import polynomial_schedule from torchopt.transform import scale_by_rss -from torchopt.typing import GradientTransformation, Scalar +from torchopt.typing import GradientTransformation, Numeric, Scalar, Schedule __all__ = ['adagrad'] +# pylint: disable-next=too-many-arguments +def _adagrad_lr_decay( + init_value: Scalar, + decay_rate: Scalar, + transition_begin: int = 0, +) -> Schedule: + """Constructs a schedule dedicated to AdaGrad optimizer. + + This function applies an learning rate decay function to a provided initial + value. The function returns the decayed value as follows: + ``` + decayed_value = init_value / 1 + count * decay_rate + ``` + + Args: + init_value: the initial learning rate. + decay_rate: The decay rate. + transition_begin: must be positive. After how many steps to start annealing + (before this many steps the scalar value is held fixed at `init_value`). + + Returns: + schedule: A function that maps step counts to values. + """ + if transition_begin < 0: # pragma: no cover + logging.info( # pragma: no cover + 'The AdaGrad learning rate schedule was set with a negative `transition_begin` ' # pragma: no cover + 'value; this will result in `transition_begin` falling back to `0`.', # pragma: no cover + ) # pragma: no cover + transition_begin = 0 # pragma: no cover + + def schedule(count: Numeric) -> Numeric: + decreased_count = count - transition_begin + return init_value / (1 + decay_rate * decreased_count) + + return schedule + + # pylint: disable-next=too-many-arguments def adagrad( lr: Scalar = 1e-2, @@ -63,21 +105,19 @@ def adagrad( Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html Args: - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - lr_decay: (default: :const:`0.0`) - Learning rate decay. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - initial_accumulator_value: (default: :const:`0.0`) - Initial value for the accumulator. - eps: (default: :const:`1e-10`) - A small constant applied to denominator outside of the square root (as in the AdaGrad - paper) to avoid dividing by zero when rescaling. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. - use_accelerated_op: (default: :data:`False`) - If :data:`True` use our implemented fused operator. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-2`) + lr_decay (float, optional): Learning rate decay. + (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + initial_accumulator_value (float, optional): Initial value for the accumulator. + (default: :const:`0.0`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-10`) + maximize (bool, optional): Maximize the params based on the objective, instead of minimizing. + (default: :data:`False`) Returns: The corresponding :class:`GradientTransformation` instance. @@ -102,7 +142,13 @@ def adagrad( flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay adagrad_scaler_fn = scale_by_rss scale_by_neg_lr_fn = scale_by_neg_lr - schedule_fn = polynomial_schedule.adagrad # type: ignore[attr-defined] + schedule_fn = _adagrad_lr_decay + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adagrad_scaler_fn = adagrad_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] return chain_fn( flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), diff --git a/torchopt/optim/__init__.py b/torchopt/optim/__init__.py index f2344464..038e1d95 100644 --- a/torchopt/optim/__init__.py +++ b/torchopt/optim/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py index 95bb34ec..4b71f558 100644 --- a/torchopt/optim/adagrad.py +++ b/torchopt/optim/adagrad.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ from torchopt import alias from torchopt.optim.base import Optimizer -from torchopt.typing import Scalar +from torchopt.typing import ScalarOrSchedule __all__ = ['Adagrad'] @@ -40,7 +40,7 @@ class Adagrad(Optimizer): def __init__( self, params: Iterable[torch.Tensor], - lr: Scalar = 1e-2, + lr: ScalarOrSchedule = 1e-2, lr_decay: float = 0.0, weight_decay: float = 0.0, initial_accumulator_value: float = 0.0, @@ -53,19 +53,19 @@ def __init__( Args: params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - lr_decay: (default: :const:`0.0`) - Learning rate decay. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - initial_accumulator_value: (default: :const:`0.0`) - Initial value for the accumulator. - eps: (default: :const:`1e-10`) - A small constant applied to denominator outside of the square root (as in the AdaGrad - paper) to avoid dividing by zero when rescaling. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-2`) + lr_decay (float, optional): Learning rate decay. + (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + initial_accumulator_value (float, optional): Initial value for the accumulator. + (default: :const:`0.0`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-10`) + maximize (bool, optional): Maximize the params based on the objective, instead of minimizing. + (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/optim/meta/__init__.py b/torchopt/optim/meta/__init__.py index 74f881b7..2a09a024 100644 --- a/torchopt/optim/meta/__init__.py +++ b/torchopt/optim/meta/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py index 5785a9c0..60051792 100644 --- a/torchopt/optim/meta/adagrad.py +++ b/torchopt/optim/meta/adagrad.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/schedule/__init__.py b/torchopt/schedule/__init__.py index 35121c1a..b9916783 100644 --- a/torchopt/schedule/__init__.py +++ b/torchopt/schedule/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index 06351a09..b2c3460f 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -60,15 +60,18 @@ def exponential_decay( If the argument `staircase` is `True`, then `count / transition_steps` is an integer division and the decayed value follows a staircase function. Args: - init_value: the initial learning rate. - decay_rate: must not be zero. The decay rate. - transition_begin: must be positive. After how many steps to start annealing - (before this many steps the scalar value is held fixed at `init_value`). - transition_steps: must be positive. See the decay computation above. - staircase: if `True`, decay the values at discrete intervals. - end_value: the value at which the exponential decay stops. When - `decay_rate` < 1, `end_value` is treated as a lower bound, otherwise as - an upper bound. Has no effect when `decay_rate` = 0. + init_value (float or Tensor): Initial value for the scalar to be annealed. + decay_rate (float or Tensor): The decay rate. + transition_begin (int, optional): Must be *positive*. After how many steps to start + annealing (before this many steps the scalar value is held fixed at ``init_value``). + (default: :const:`0`) + transition_steps (int): Number of steps over which annealing takes place, the scalar starts + changing at ``transition_begin`` steps and completes the transition by + ``transition_begin + transition_steps`` steps. If ``transition_steps <= 0``, then the + entire annealing process is disabled and the value is held fixed at ``init_value``. + (default: :const:`1`) + staircase (bool): If ``True``, decay the scalar at discrete intervals. + end_value (float or Tensor): End value of the scalar to be annealed. Returns: schedule: A function that maps step counts to values. diff --git a/torchopt/schedule/polynomial.py b/torchopt/schedule/polynomial.py index 2f491aef..39629c38 100644 --- a/torchopt/schedule/polynomial.py +++ b/torchopt/schedule/polynomial.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,43 +42,6 @@ __all__ = ['polynomial_schedule', 'linear_schedule'] -# pylint: disable-next=too-many-arguments -def _adagrad_lr_decay( - init_value: Scalar, - decay_rate: Scalar, - transition_begin: int = 0, -) -> Schedule: - """Constructs a schedule dedicated to AdaGrad optimizer. - - This function applies an learning rate decay function to a provided initial - value. The function returns the decayed value as follows: - ``` - decayed_value = init_value / 1 + count * decay_rate - ``` - - Args: - init_value: the initial learning rate. - decay_rate: The decay rate. - transition_begin: must be positive. After how many steps to start annealing - (before this many steps the scalar value is held fixed at `init_value`). - - Returns: - schedule: A function that maps step counts to values. - """ - if transition_begin < 0: # pragma: no cover - logging.info( # pragma: no cover - 'An exponential schedule was set with a negative `transition_begin` ' # pragma: no cover - 'value; this will result in `transition_begin` falling back to `0`.', # pragma: no cover - ) # pragma: no cover - transition_begin = 0 # pragma: no cover - - def schedule(count: Numeric) -> Numeric: - decreased_count = count - transition_begin - return init_value / (1 + decay_rate * decreased_count) - - return schedule - - def polynomial_schedule( init_value: Scalar, end_value: Scalar, @@ -143,6 +106,3 @@ def linear_schedule( transition_steps=transition_steps, transition_begin=transition_begin, ) - - -polynomial_schedule.adagrad = _adagrad_lr_decay # type: ignore[attr-defined] diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index ad23c923..a85713a9 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/torchopt/version.py b/torchopt/version.py index b8136a22..a65aad3e 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From ace6b5efa517c38b4b88737d08d2a9d349ac2f56 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Tue, 21 Mar 2023 23:06:49 +0800 Subject: [PATCH 27/37] feat: adagrad integration --- Makefile | 2 +- torchopt/alias/adagrad.py | 26 +++++++++++++------------- torchopt/schedule/exponential_decay.py | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Makefile b/Makefile index c387fd54..de7ba8dd 100644 --- a/Makefile +++ b/Makefile @@ -114,7 +114,7 @@ addlicense-install: go-install pytest: test-install cd tests && $(PYTHON) -c 'import $(PROJECT_NAME)' && \ - $(PYTHON) -m pytest -k "test_exponential_decay" --verbose --color=yes --durations=0 \ + $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ --cov="$(PROJECT_NAME)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 2e6fe980..62030dfc 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -39,16 +39,15 @@ scale_by_neg_lr, ) from torchopt.combine import chain -from torchopt.transform import scale_by_rss -from torchopt.typing import GradientTransformation, Numeric, Scalar, Schedule +from torchopt.transform import scale_by_rss, scale_by_schedule +from torchopt.typing import GradientTransformation, Numeric, Scalar, ScalarOrSchedule, Schedule __all__ = ['adagrad'] # pylint: disable-next=too-many-arguments -def _adagrad_lr_decay( - init_value: Scalar, +def _adagrad_lr_schedule( decay_rate: Scalar, transition_begin: int = 0, ) -> Schedule: @@ -61,10 +60,10 @@ def _adagrad_lr_decay( ``` Args: - init_value: the initial learning rate. - decay_rate: The decay rate. - transition_begin: must be positive. After how many steps to start annealing + decay_rate (float, optional): The decay rate. + transition_begin (int, optional): must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at `init_value`). + (default: :const:`1`) Returns: schedule: A function that maps step counts to values. @@ -78,14 +77,14 @@ def _adagrad_lr_decay( def schedule(count: Numeric) -> Numeric: decreased_count = count - transition_begin - return init_value / (1 + decay_rate * decreased_count) + return 1 / (1 + decay_rate * decreased_count) return schedule # pylint: disable-next=too-many-arguments def adagrad( - lr: Scalar = 1e-2, + lr: ScalarOrSchedule = 1e-2, lr_decay: float = 0.0, weight_decay: float = 0.0, initial_accumulator_value: float = 0.0, @@ -144,13 +143,15 @@ def adagrad( flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay adagrad_scaler_fn = scale_by_rss scale_by_neg_lr_fn = scale_by_neg_lr - schedule_fn = _adagrad_lr_decay + step_size_fn = _adagrad_lr_schedule + scale_by_schedule_fn = scale_by_schedule if _get_use_chain_flat(): # default behavior chain_fn = chain_fn.flat # type: ignore[attr-defined] flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] adagrad_scaler_fn = adagrad_scaler_fn.flat # type: ignore[attr-defined] scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + scale_by_schedule_fn = scale_by_schedule_fn.flat # type: ignore[attr-defined] return chain_fn( flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), @@ -158,7 +159,6 @@ def adagrad( initial_accumulator_value=initial_accumulator_value, eps=eps, ), - scale_by_neg_lr_fn( - schedule_fn(init_value=lr, decay_rate=lr_decay, transition_begin=0), - ), + scale_by_schedule_fn(step_size_fn=step_size_fn(decay_rate=lr_decay, transition_begin=0)), + scale_by_neg_lr_fn(lr), ) diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index b2c3460f..ae6683ce 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -71,7 +71,7 @@ def exponential_decay( entire annealing process is disabled and the value is held fixed at ``init_value``. (default: :const:`1`) staircase (bool): If ``True``, decay the scalar at discrete intervals. - end_value (float or Tensor): End value of the scalar to be annealed. + end_value (float or Tensor, optional): End value of the scalar to be annealed. Returns: schedule: A function that maps step counts to values. From a937d6bd595c2e88f0b94fb3034c00d4f3beb456 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Tue, 21 Mar 2023 23:14:39 +0800 Subject: [PATCH 28/37] feat: adagrad integration --- torchopt/alias/adagrad.py | 3 +-- torchopt/optim/meta/adagrad.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 62030dfc..6470d6ea 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -60,9 +60,8 @@ def _adagrad_lr_schedule( ``` Args: - decay_rate (float, optional): The decay rate. + decay_rate (float): The decay rate. transition_begin (int, optional): must be positive. After how many steps to start annealing - (before this many steps the scalar value is held fixed at `init_value`). (default: :const:`1`) Returns: diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py index 60051792..eb91a6a6 100644 --- a/torchopt/optim/meta/adagrad.py +++ b/torchopt/optim/meta/adagrad.py @@ -50,19 +50,19 @@ def __init__( Args: module (nn.Module): A network whose parameters should be optimized. - lr: (default: :const:`1e-2`) - This is a fixed global scaling factor. - lr_decay: (default: :const:`0.0`) - Learning rate decay. - weight_decay: (default: :const:`0.0`) - Weight decay, add L2 penalty to parameters. - initial_accumulator_value: (default: :const:`0.0`) - Initial value for the accumulator. - eps: (default: :const:`1e-10`) - A small constant applied to denominator outside of the square root (as in the AdaGrad - paper) to avoid dividing by zero when rescaling. - maximize: (default: :data:`False`) - Maximize the params based on the objective, instead of minimizing. + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-2`) + lr_decay (float, optional): Learning rate decay. + (default: :const:`0.0`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + initial_accumulator_value (float, optional): Initial value for the accumulator. + (default: :const:`0.0`) + eps (float, optional): A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-10`) + maximize (bool, optional): Maximize the params based on the objective, instead of minimizing. + (default: :data:`False`) """ super().__init__( module, From beab33996243c8cab58cc1280c02821f88eca3be Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Wed, 22 Mar 2023 00:06:59 +0800 Subject: [PATCH 29/37] feat: adagrad integration --- docs/source/api/api.rst | 8 ++++---- docs/source/explicit_diff/explicit_diff.rst | 2 +- docs/source/optimizer/optim.rst | 2 +- tests/test_optim.py | 2 +- torchopt/optim/adagrad.py | 9 ++++++--- torchopt/optim/meta/adagrad.py | 9 ++++++--- 6 files changed, 19 insertions(+), 13 deletions(-) diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index a9937f57..4aad53a7 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -79,7 +79,7 @@ Classic Optimizers .. autosummary:: - Adagrad + AdaGrad Adam AdamW RMSProp @@ -88,7 +88,7 @@ Classic Optimizers Classic AdaGrad Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: Adagrad +.. autoclass:: AdaGrad Classic Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~ @@ -119,7 +119,7 @@ Differentiable Meta-Optimizers .. autosummary:: - MetaAdagrad + MetaAdaGrad MetaAdam MetaAdamW MetaRMSProp @@ -128,7 +128,7 @@ Differentiable Meta-Optimizers Differentiable Meta-AdaGrad Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: MetaAdagrad +.. autoclass:: MetaAdaGrad Differentiable Meta-Adam Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/explicit_diff/explicit_diff.rst b/docs/source/explicit_diff/explicit_diff.rst index 57e153d7..f6b82826 100644 --- a/docs/source/explicit_diff/explicit_diff.rst +++ b/docs/source/explicit_diff/explicit_diff.rst @@ -53,7 +53,7 @@ For PyTorch-like API (e.g., ``step()``), we designed a base class :class:`torcho .. autosummary:: torchopt.MetaOptimizer - torchopt.MetaAdagrad + torchopt.MetaAdaGrad torchopt.MetaAdam torchopt.MetaAdamW torchopt.MetaRMSProp diff --git a/docs/source/optimizer/optim.rst b/docs/source/optimizer/optim.rst index a96707f1..54c8ef71 100644 --- a/docs/source/optimizer/optim.rst +++ b/docs/source/optimizer/optim.rst @@ -85,7 +85,7 @@ We offer original PyTorch APIs (e.g., ``zero_grad()`` or ``step()``) for traditi .. autosummary:: torchopt.Optimizer - torchopt.Adagrad + torchopt.AdaGrad torchopt.Adam torchopt.AdamW torchopt.RMSProp diff --git a/tests/test_optim.py b/tests/test_optim.py index 6f3e9208..c29385ea 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -96,7 +96,7 @@ def test_SGD( weight_decay=[0.0, 1e-2], maximize=[False, True], ) -def test_Adagrad( +def test_AdaGrad( dtype: torch.dtype, lr: float, lr_decay: float, diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py index 4b71f558..260b809c 100644 --- a/torchopt/optim/adagrad.py +++ b/torchopt/optim/adagrad.py @@ -25,15 +25,15 @@ from torchopt.typing import ScalarOrSchedule -__all__ = ['Adagrad'] +__all__ = ['AdaGrad', 'Adagrad'] -class Adagrad(Optimizer): +class AdaGrad(Optimizer): """The classic AdaGrad optimizer. See Also: - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. - - The differentiable meta AdaGrad optimizer: :class:`torchopt.MetaAdagrad`. + - The differentiable meta AdaGrad optimizer: :class:`torchopt.MetaAdaGrad`. """ # pylint: disable-next=too-many-arguments @@ -78,3 +78,6 @@ def __init__( maximize=maximize, ), ) + + +Adagrad = AdaGrad # alias for PyTorch compatibility diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py index eb91a6a6..09a5022e 100644 --- a/torchopt/optim/meta/adagrad.py +++ b/torchopt/optim/meta/adagrad.py @@ -23,15 +23,15 @@ from torchopt.typing import Scalar -__all__ = ['MetaAdagrad'] +__all__ = ['MetaAdaGrad', 'MetaAdagrad'] -class MetaAdagrad(MetaOptimizer): +class MetaAdaGrad(MetaOptimizer): """The differentiable AdaGrad optimizer. See Also: - The functional AdaGrad optimizer: :func:`torchopt.adagrad`.m - - The classic AdaGrad optimizer: :class:`torchopt.Adagrad`. + - The classic AdaGrad optimizer: :class:`torchopt.AdaGrad`. """ # pylint: disable-next=too-many-arguments @@ -75,3 +75,6 @@ def __init__( maximize=maximize, ), ) + + +MetaAdagrad = MetaAdaGrad # alias for PyTorch compatibility From 26afa122ed46a55bfec8ad787eba7da97f7fcd0f Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Wed, 22 Mar 2023 00:12:06 +0800 Subject: [PATCH 30/37] feat: adagrad integration --- tests/test_optim.py | 2 +- torchopt/optim/__init__.py | 2 +- torchopt/optim/meta/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index c29385ea..7d0f0919 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -108,7 +108,7 @@ def test_AdaGrad( ) -> None: model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) - optim = torchopt.Adagrad( + optim = torchopt.AdaGrad( model.parameters(), lr=lr, lr_decay=lr_decay, diff --git a/torchopt/optim/__init__.py b/torchopt/optim/__init__.py index 038e1d95..8e390a5c 100644 --- a/torchopt/optim/__init__.py +++ b/torchopt/optim/__init__.py @@ -15,7 +15,7 @@ """object oriented optimizer implementations.""" from torchopt.optim import meta -from torchopt.optim.adagrad import Adagrad +from torchopt.optim.adagrad import AdaGrad, Adagrad from torchopt.optim.adam import Adam from torchopt.optim.adamw import AdamW from torchopt.optim.base import Optimizer diff --git a/torchopt/optim/meta/__init__.py b/torchopt/optim/meta/__init__.py index 2a09a024..28f374cc 100644 --- a/torchopt/optim/meta/__init__.py +++ b/torchopt/optim/meta/__init__.py @@ -14,7 +14,7 @@ # ============================================================================== """Differentiable Meta-Optimizers.""" -from torchopt.optim.meta.adagrad import MetaAdagrad +from torchopt.optim.meta.adagrad import MetaAdaGrad, MetaAdagrad from torchopt.optim.meta.adam import MetaAdam from torchopt.optim.meta.adamw import MetaAdamW from torchopt.optim.meta.base import MetaOptimizer From 91bb5c235d8c777d243d57322f39ad03b3ffc9df Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Wed, 22 Mar 2023 00:31:46 +0800 Subject: [PATCH 31/37] feat: adagrad integration --- torchopt/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchopt/__init__.py b/torchopt/__init__.py index 30fc7945..c799d7fc 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -37,9 +37,10 @@ from torchopt.clip import clip_grad_norm from torchopt.combine import chain from torchopt.hook import register_hook -from torchopt.optim import SGD, Adagrad, Adam, AdamW, Optimizer, RMSProp, RMSprop +from torchopt.optim import SGD, AdaGrad, Adagrad, Adam, AdamW, Optimizer, RMSProp, RMSprop from torchopt.optim.func import FuncOptimizer from torchopt.optim.meta import ( + MetaAdaGrad, MetaAdagrad, MetaAdam, MetaAdamW, @@ -75,6 +76,7 @@ 'SGD', 'Adam', 'AdamW', + 'AdaGrad', 'Adagrad', 'RMSProp', 'RMSprop', @@ -82,6 +84,7 @@ 'MetaSGD', 'MetaAdam', 'MetaAdamW', + 'MetaAdaGrad', 'MetaAdagrad', 'MetaRMSProp', 'MetaRMSprop', From 38482bff733f802dbee874690896612a2f53db29 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 22 Mar 2023 06:16:07 +0000 Subject: [PATCH 32/37] =?UTF-8?q?fix:=20ca=20pi=20gu=20=F0=9F=92=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Makefile | 1 - docs/source/api/api.rst | 3 - tests/test_alias.py | 128 ++++++++++++------------- tests/test_import.py | 19 +++- tests/test_optim.py | 113 +++++++++++----------- tests/test_schedule.py | 10 +- torchopt/__init__.py | 2 +- torchopt/alias/adagrad.py | 45 +++++---- torchopt/alias/adam.py | 2 +- torchopt/alias/adamw.py | 4 +- torchopt/alias/sgd.py | 2 +- torchopt/optim/adagrad.py | 15 ++- torchopt/optim/adam.py | 4 +- torchopt/optim/adamw.py | 4 +- torchopt/optim/func/base.py | 3 +- torchopt/optim/meta/adagrad.py | 21 ++-- torchopt/optim/meta/adamw.py | 4 +- torchopt/schedule/exponential_decay.py | 29 +++--- torchopt/transform/scale_by_adam.py | 4 +- torchopt/transform/scale_by_rms.py | 2 +- torchopt/transform/scale_by_rss.py | 15 +-- torchopt/transform/scale_by_stddev.py | 2 +- 22 files changed, 223 insertions(+), 209 deletions(-) diff --git a/Makefile b/Makefile index de7ba8dd..bd839812 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,6 @@ COPYRIGHT = "MetaOPT Team. All Rights Reserved." PROJECT_PATH = $(PROJECT_NAME) SHELL = /bin/bash SOURCE_FOLDERS = $(PROJECT_PATH) examples include src tests docs -RST_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.rst" -o -name "*.rst") PYTHON_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.py" -o -name "*.pyi") CXX_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.h" -o -name "*.cpp") CUDA_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.cuh" -o -name "*.cu") diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index 4aad53a7..d00e2333 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -36,9 +36,6 @@ Functional Optimizers rmsprop sgd - - - Wrapper for Function Optimizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/test_alias.py b/tests/test_alias.py index 03ea9c71..a0a78129 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -76,70 +76,6 @@ def test_empty( _set_use_chain_flat(True) -@helpers.parametrize( - dtype=[torch.float64], - lr=[1e-2, 1e-3, 1e-4], - lr_decay=[0.0, 1e-2], - initial_accumulator_value=[0.0, 1e-1], - eps=[1e-8], - inplace=[True, False], - weight_decay=[0.0, 1e-2], - maximize=[False, True], - use_chain_flat=[True, False], -) -def test_adagrad( - dtype: torch.dtype, - lr: float, - lr_decay: float, - initial_accumulator_value: float, - eps: float, - inplace: bool, - weight_decay: float, - maximize: bool, - use_chain_flat: bool, -) -> None: - _set_use_chain_flat(use_chain_flat) - - model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) - - fmodel, params, buffers = functorch.make_functional_with_buffers(model) - optim = torchopt.adagrad( - lr=lr, - lr_decay=lr_decay, - weight_decay=weight_decay, - initial_accumulator_value=initial_accumulator_value, - eps=eps, - maximize=maximize, - ) - optim_state = optim.init(params) - optim_ref = torch.optim.Adagrad( - model_ref.parameters(), - lr=lr, - lr_decay=lr_decay, - weight_decay=weight_decay, - initial_accumulator_value=initial_accumulator_value, - eps=eps, - maximize=maximize, - ) - for xs, ys in loader: - xs = xs.to(dtype=dtype) - pred = fmodel(params, buffers, xs) - pred_ref = model_ref(xs) - loss = F.cross_entropy(pred, ys) - loss_ref = F.cross_entropy(pred_ref, ys) - - grads = torch.autograd.grad(loss, params, allow_unused=True) - updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) - params = torchopt.apply_updates(params, updates, inplace=inplace) - - optim_ref.zero_grad() - loss_ref.backward() - optim_ref.step() - - helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) - _set_use_chain_flat(True) - - @helpers.parametrize( dtype=[torch.float64, torch.float32], lr=[1e-2, 1e-3, 1e-4], @@ -514,6 +450,70 @@ def test_adam_accelerated_cuda( _set_use_chain_flat(True) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + lr_decay=[0.0, 1e-2], + initial_accumulator_value=[0.0, 1e-1], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + maximize=[False, True], + use_chain_flat=[True, False], +) +def test_adagrad( + dtype: torch.dtype, + lr: float, + lr_decay: float, + initial_accumulator_value: float, + eps: float, + inplace: bool, + weight_decay: float, + maximize: bool, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adagrad( + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.Adagrad( + model_ref.parameters(), + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, + maximize=maximize, + ) + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], diff --git a/tests/test_import.py b/tests/test_import.py index 30cf914e..1b6dea38 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -25,6 +25,7 @@ def test_accelerated_op_import() -> None: def test_alias_import() -> None: + torchopt.adagrad torchopt.adam torchopt.adamw torchopt.rmsprop @@ -33,8 +34,8 @@ def test_alias_import() -> None: torchopt.alias.adamw torchopt.alias.rmsprop torchopt.alias.sgd - from torchopt import adam, adamw, rmsprop, sgd - from torchopt.alias import adam, adamw, rmsprop, sgd + from torchopt import adagrad, adam, adamw, rmsprop, sgd + from torchopt.alias import adagrad, adam, adamw, rmsprop, sgd def test_diff_import() -> None: @@ -107,17 +108,23 @@ def test_nn_import() -> None: def test_optim_import() -> None: torchopt.FuncOptimizer + torchopt.MetaAdaGrad + torchopt.MetaAdagrad torchopt.MetaAdam torchopt.MetaAdamW torchopt.MetaRMSProp torchopt.MetaRMSprop torchopt.MetaSGD + torchopt.AdaGrad + torchopt.Adagrad torchopt.Adam torchopt.AdamW torchopt.Optimizer torchopt.RMSProp torchopt.RMSprop torchopt.SGD + torchopt.optim.meta.MetaAdaGrad + torchopt.optim.meta.MetaAdagrad torchopt.optim.meta.MetaAdam torchopt.optim.meta.MetaAdamW torchopt.optim.meta.MetaRMSProp @@ -132,14 +139,18 @@ def test_optim_import() -> None: torchopt.optim.func.FuncOptimizer from torchopt import ( SGD, + AdaGrad, + Adagrad, Adam, AdamW, FuncOptimizer, + MetaAdaGrad, + MetaAdagrad, MetaAdam, MetaAdamW, MetaOptimizer, - MetaRMSProp, MetaRMSprop, + MetaRMSProp, MetaSGD, Optimizer, RMSProp, @@ -147,6 +158,8 @@ def test_optim_import() -> None: from torchopt.optim import SGD, Adam, AdamW, FuncOptimizer, Optimizer, RMSProp from torchopt.optim.func import FuncOptimizer from torchopt.optim.meta import ( + MetaAdaGrad, + MetaAdagrad, MetaAdam, MetaAdamW, MetaOptimizer, diff --git a/tests/test_optim.py b/tests/test_optim.py index 7d0f0919..9fd0a072 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -89,41 +89,40 @@ def test_SGD( @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], - lr_decay=[0.0, 1e-2], - initial_accumulator_value=[0.0, 1e-1], + betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], - inplace=[True, False], weight_decay=[0.0, 1e-2], maximize=[False, True], + use_accelerated_op=[False, True], ) -def test_AdaGrad( +def test_Adam( dtype: torch.dtype, lr: float, - lr_decay: float, - initial_accumulator_value: float, + betas: tuple[float, float], eps: float, - inplace: bool, weight_decay: float, maximize: bool, + use_accelerated_op: bool, ) -> None: model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) - optim = torchopt.AdaGrad( + optim = torchopt.Adam( model.parameters(), - lr=lr, - lr_decay=lr_decay, - weight_decay=weight_decay, - initial_accumulator_value=initial_accumulator_value, + lr, + betas=betas, eps=eps, + eps_root=0.0, + weight_decay=weight_decay, maximize=maximize, + use_accelerated_op=use_accelerated_op, ) - optim_ref = torch.optim.Adagrad( + optim_ref = torch.optim.Adam( model_ref.parameters(), - lr=lr, - lr_decay=lr_decay, - weight_decay=weight_decay, - initial_accumulator_value=initial_accumulator_value, + lr, + betas=betas, eps=eps, + amsgrad=False, + weight_decay=weight_decay, maximize=maximize, ) @@ -150,11 +149,11 @@ def test_AdaGrad( lr=[1e-2, 1e-3, 1e-4], betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], - weight_decay=[0.0, 1e-2], + weight_decay=[1e-2, 1e-1], maximize=[False, True], use_accelerated_op=[False, True], ) -def test_Adam( +def test_AdamW( dtype: torch.dtype, lr: float, betas: tuple[float, float], @@ -165,7 +164,7 @@ def test_Adam( ) -> None: model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) - optim = torchopt.Adam( + optim = torchopt.AdamW( model.parameters(), lr, betas=betas, @@ -175,7 +174,7 @@ def test_Adam( maximize=maximize, use_accelerated_op=use_accelerated_op, ) - optim_ref = torch.optim.Adam( + optim_ref = torch.optim.AdamW( model_ref.parameters(), lr, betas=betas, @@ -203,27 +202,34 @@ def test_Adam( helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) +@pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.') @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], + optimizers=[ + (torchopt.Adam, torch.optim.Adam), + (torchopt.AdamW, torch.optim.AdamW), + ], betas=[(0.9, 0.999), (0.95, 0.9995)], eps=[1e-8], - weight_decay=[1e-2, 1e-1], + weight_decay=[0.0, 1e-2], maximize=[False, True], - use_accelerated_op=[False, True], ) -def test_AdamW( +def test_Adam_accelerated_cuda( dtype: torch.dtype, lr: float, + optimizers: tuple[torchopt.Optimizer, torch.optim.Optimizer], betas: tuple[float, float], eps: float, weight_decay: float, maximize: bool, - use_accelerated_op: bool, ) -> None: - model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + device = 'cuda' + model, model_ref, model_base, loader = helpers.get_models(device=device, dtype=dtype) - optim = torchopt.AdamW( + torchopt_optimizer, torch_optimizer = optimizers + + optim = torchopt_optimizer( model.parameters(), lr, betas=betas, @@ -231,9 +237,9 @@ def test_AdamW( eps_root=0.0, weight_decay=weight_decay, maximize=maximize, - use_accelerated_op=use_accelerated_op, + use_accelerated_op=True, ) - optim_ref = torch.optim.AdamW( + optim_ref = torch_optimizer( model_ref.parameters(), lr, betas=betas, @@ -244,7 +250,8 @@ def test_AdamW( ) for xs, ys in loader: - xs = xs.to(dtype=dtype) + xs = xs.to(device=device, dtype=dtype) + ys = ys.to(device=device) pred = model(xs) pred_ref = model_ref(xs) loss = F.cross_entropy(pred, ys) @@ -261,56 +268,47 @@ def test_AdamW( helpers.assert_model_all_close(model, model_ref, model_base, dtype=dtype) -@pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.') @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], - optimizers=[ - (torchopt.Adam, torch.optim.Adam), - (torchopt.AdamW, torch.optim.AdamW), - ], - betas=[(0.9, 0.999), (0.95, 0.9995)], - eps=[1e-8], + lr_decay=[0.0, 1e-2], + initial_accumulator_value=[0.0, 1e-1], + eps=[1e-10], weight_decay=[0.0, 1e-2], maximize=[False, True], ) -def test_Adam_accelerated_cuda( +def test_AdaGrad( dtype: torch.dtype, lr: float, - optimizers: tuple[torchopt.Optimizer, torch.optim.Optimizer], - betas: tuple[float, float], + lr_decay: float, + initial_accumulator_value: float, eps: float, weight_decay: float, maximize: bool, ) -> None: - device = 'cuda' - model, model_ref, model_base, loader = helpers.get_models(device=device, dtype=dtype) - - torchopt_optimizer, torch_optimizer = optimizers + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) - optim = torchopt_optimizer( + optim = torchopt.AdaGrad( model.parameters(), - lr, - betas=betas, - eps=eps, - eps_root=0.0, + lr=lr, + lr_decay=lr_decay, weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, maximize=maximize, - use_accelerated_op=True, ) - optim_ref = torch_optimizer( + optim_ref = torch.optim.Adagrad( model_ref.parameters(), - lr, - betas=betas, - eps=eps, - amsgrad=False, + lr=lr, + lr_decay=lr_decay, weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + eps=eps, maximize=maximize, ) for xs, ys in loader: - xs = xs.to(device=device, dtype=dtype) - ys = ys.to(device=device) + xs = xs.to(dtype=dtype) pred = model(xs) pred_ref = model_ref(xs) loss = F.cross_entropy(pred, ys) @@ -392,6 +390,7 @@ def test_RMSProp( (torchopt.sgd, torch.optim.SGD), (torchopt.adam, torch.optim.Adam), (torchopt.adamw, torch.optim.AdamW), + (torchopt.adagrad, torch.optim.Adagrad), (torchopt.rmsprop, torch.optim.RMSprop), ], inplace=[True, False], diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 38e9b1ac..c7ba8c5f 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -15,8 +15,7 @@ from __future__ import annotations -import math -from typing import Callable, Optional +from typing import Callable import functorch import numpy as np @@ -52,14 +51,12 @@ def test_exponential_decay( staircase=staircase, end_value=end_value, ) - if end_value is not None: # pragma: no cover + if end_value is not None: clip_fn = max if decay_rate < 1.0 else min for i in range(transition_begin, transition_steps): lr = schedule(i) if staircase: - lr_gt = init_value * ( - decay_rate ** math.floor((i - transition_begin) / transition_steps) - ) + lr_gt = init_value * (decay_rate ** np.floor((i - transition_begin) / transition_steps)) else: lr_gt = init_value * (decay_rate ** ((i - transition_begin) / transition_steps)) if end_value is not None: @@ -94,6 +91,7 @@ def test_linear_schedule() -> None: (torchopt.sgd, torch.optim.SGD), (torchopt.adam, torch.optim.Adam), (torchopt.adamw, torch.optim.AdamW), + (torchopt.adagrad, torch.optim.Adagrad), (torchopt.rmsprop, torch.optim.RMSprop), ], inplace=[True, False], diff --git a/torchopt/__init__.py b/torchopt/__init__.py index c799d7fc..a8c9fa1d 100644 --- a/torchopt/__init__.py +++ b/torchopt/__init__.py @@ -63,9 +63,9 @@ __all__ = [ 'accelerated_op_available', - 'adagrad', 'adam', 'adamw', + 'adagrad', 'rmsprop', 'sgd', 'clip_grad_norm', diff --git a/torchopt/alias/adagrad.py b/torchopt/alias/adagrad.py index 6470d6ea..25910abd 100644 --- a/torchopt/alias/adagrad.py +++ b/torchopt/alias/adagrad.py @@ -46,33 +46,33 @@ __all__ = ['adagrad'] -# pylint: disable-next=too-many-arguments def _adagrad_lr_schedule( decay_rate: Scalar, transition_begin: int = 0, ) -> Schedule: - """Constructs a schedule dedicated to AdaGrad optimizer. + """Construct a schedule dedicated to AdaGrad optimizer. + + This function applies an learning rate decay function to a provided initial value. The function + returns the decayed value as follows: + + .. code-block:: python - This function applies an learning rate decay function to a provided initial - value. The function returns the decayed value as follows: - ``` - decayed_value = init_value / 1 + count * decay_rate - ``` + decayed_value = init_value / (1 + count * decay_rate) Args: decay_rate (float): The decay rate. - transition_begin (int, optional): must be positive. After how many steps to start annealing - (default: :const:`1`) + transition_begin (int, optional): Must be *positive*. After how many steps to start + annealing. (default: :const:`0`) Returns: schedule: A function that maps step counts to values. """ if transition_begin < 0: # pragma: no cover - logging.info( # pragma: no cover - 'The AdaGrad learning rate schedule was set with a negative `transition_begin` ' # pragma: no cover - 'value; this will result in `transition_begin` falling back to `0`.', # pragma: no cover - ) # pragma: no cover - transition_begin = 0 # pragma: no cover + logging.info( + 'The AdaGrad learning rate schedule was set with a negative `transition_begin` ' + 'value; this will result in `transition_begin` falling back to `0`.', + ) + transition_begin = 0 def schedule(count: Numeric) -> Numeric: decreased_count = count - transition_begin @@ -102,13 +102,12 @@ def adagrad( eventually becomes very small. References: - Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html + Duchi et al., 2011: https://jmlr.org/papers/v12/duchi11a.html Args: lr (float or callable, optional): This is a fixed global scaling factor or a learning rate scheduler. (default: :const:`1e-2`) - lr_decay (float, optional): Learning rate decay. - (default: :const:`0.0`) + lr_decay (float, optional): Learning rate decay. (default: :const:`0.0`) weight_decay (float, optional): Weight decay, add L2 penalty to parameters. (default: :const:`0.0`) initial_accumulator_value (float, optional): Initial value for the accumulator. @@ -128,21 +127,20 @@ def adagrad( # pylint: disable=unneeded-not if not (callable(lr) or lr >= 0.0): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}') - if not eps >= 0.0: # pragma: no cover - raise ValueError(f'Invalid epsilon value: {eps}') if not lr_decay >= 0.0: # pragma: no cover raise ValueError(f'Invalid lr_decay value: {lr_decay}') if not weight_decay >= 0.0: # pragma: no cover raise ValueError(f'Invalid weight_decay value: {weight_decay}') if not initial_accumulator_value >= 0.0: # pragma: no cover raise ValueError(f'Invalid initial_accumulator_value value: {initial_accumulator_value}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') # pylint: enable=unneeded-not chain_fn = chain flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay adagrad_scaler_fn = scale_by_rss scale_by_neg_lr_fn = scale_by_neg_lr - step_size_fn = _adagrad_lr_schedule scale_by_schedule_fn = scale_by_schedule if _get_use_chain_flat(): # default behavior @@ -158,6 +156,11 @@ def adagrad( initial_accumulator_value=initial_accumulator_value, eps=eps, ), - scale_by_schedule_fn(step_size_fn=step_size_fn(decay_rate=lr_decay, transition_begin=0)), + scale_by_schedule_fn( + step_size_fn=_adagrad_lr_schedule( + decay_rate=lr_decay, + transition_begin=0, + ), + ), scale_by_neg_lr_fn(lr), ) diff --git a/torchopt/alias/adam.py b/torchopt/alias/adam.py index be58e49e..dc889285 100644 --- a/torchopt/alias/adam.py +++ b/torchopt/alias/adam.py @@ -65,7 +65,7 @@ def adam( exponential moving averages). References: - - Kingma et al, 2014: https://arxiv.org/abs/1412.6980 + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 Args: lr (float or callable, optional): This is a fixed global scaling factor or a learning rate diff --git a/torchopt/alias/adamw.py b/torchopt/alias/adamw.py index 4503381c..e8bed2ab 100644 --- a/torchopt/alias/adamw.py +++ b/torchopt/alias/adamw.py @@ -69,7 +69,7 @@ def adamw( does not behave as intended for adaptive gradient algorithms such as Adam. References: - - Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 + - Loshchilov et al., 2019: https://arxiv.org/abs/1711.05101 Args: lr (float or callable, optional): This is a fixed global scaling factor or a learning rate @@ -81,7 +81,7 @@ def adamw( (default: :const:`1e-8`) weight_decay (float, optional): Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other - frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight + frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight decay is only multiplied with the "schedule multiplier", but not the base learning rate. (default: :const:`1e-2`) eps_root (float, optional): A small constant applied to denominator inside the square root diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index c2d37292..6fb3c6db 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -61,7 +61,7 @@ def sgd( deep neural networks. References: - - Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf + - Sutskever et al., 2013: http://proceedings.mlr.press/v28/sutskever13.pdf Args: lr (float or callable): This is a fixed global scaling factor or a learning rate diff --git a/torchopt/optim/adagrad.py b/torchopt/optim/adagrad.py index 260b809c..055e0ad5 100644 --- a/torchopt/optim/adagrad.py +++ b/torchopt/optim/adagrad.py @@ -53,19 +53,18 @@ def __init__( Args: params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what tensors should be optimized. - lr (float or callable, optional): This is a fixed global scaling factor or a learning rate - scheduler. (default: :const:`1e-2`) - lr_decay (float, optional): Learning rate decay. - (default: :const:`0.0`) + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + lr_decay (float, optional): Learning rate decay. (default: :const:`0.0`) weight_decay (float, optional): Weight decay, add L2 penalty to parameters. (default: :const:`0.0`) initial_accumulator_value (float, optional): Initial value for the accumulator. (default: :const:`0.0`) - eps (float, optional): A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. (default: :const:`1e-10`) - maximize (bool, optional): Maximize the params based on the objective, instead of minimizing. - (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( params, diff --git a/torchopt/optim/adam.py b/torchopt/optim/adam.py index 640eea1d..5d85cbdc 100644 --- a/torchopt/optim/adam.py +++ b/torchopt/optim/adam.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ class Adam(Optimizer): def __init__( self, params: Iterable[torch.Tensor], - lr: ScalarOrSchedule, + lr: ScalarOrSchedule = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, diff --git a/torchopt/optim/adamw.py b/torchopt/optim/adamw.py index 7db5e750..be8c6727 100644 --- a/torchopt/optim/adamw.py +++ b/torchopt/optim/adamw.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -64,7 +64,7 @@ def __init__( (default: :const:`1e-8`) weight_decay (float, optional): Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where + other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight decay is only multiplied with the "schedule multiplier", but not the base learning rate. (default: :const:`1e-2`) eps_root (float, optional): A small constant applied to denominator inside the square diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 2ca50f6a..94038464 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ class FuncOptimizer: # pylint: disable=too-few-public-methods and update the parameters. See Also: + - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. - The functional Adam optimizer: :func:`torchopt.adam`. - The functional AdamW optimizer: :func:`torchopt.adamw`. - The functional RMSprop optimizer: :func:`torchopt.rmsprop`. diff --git a/torchopt/optim/meta/adagrad.py b/torchopt/optim/meta/adagrad.py index 09a5022e..079d76db 100644 --- a/torchopt/optim/meta/adagrad.py +++ b/torchopt/optim/meta/adagrad.py @@ -20,7 +20,7 @@ from torchopt import alias from torchopt.optim.meta.base import MetaOptimizer -from torchopt.typing import Scalar +from torchopt.typing import ScalarOrSchedule __all__ = ['MetaAdaGrad', 'MetaAdagrad'] @@ -30,7 +30,7 @@ class MetaAdaGrad(MetaOptimizer): """The differentiable AdaGrad optimizer. See Also: - - The functional AdaGrad optimizer: :func:`torchopt.adagrad`.m + - The functional AdaGrad optimizer: :func:`torchopt.adagrad`. - The classic AdaGrad optimizer: :class:`torchopt.AdaGrad`. """ @@ -38,7 +38,7 @@ class MetaAdaGrad(MetaOptimizer): def __init__( self, module: nn.Module, - lr: Scalar = 1e-2, + lr: ScalarOrSchedule = 1e-2, lr_decay: float = 0.0, weight_decay: float = 0.0, initial_accumulator_value: float = 0.0, @@ -50,19 +50,18 @@ def __init__( Args: module (nn.Module): A network whose parameters should be optimized. - lr (float or callable, optional): This is a fixed global scaling factor or a learning rate - scheduler. (default: :const:`1e-2`) - lr_decay (float, optional): Learning rate decay. - (default: :const:`0.0`) + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-2`) + lr_decay (float, optional): Learning rate decay. (default: :const:`0.0`) weight_decay (float, optional): Weight decay, add L2 penalty to parameters. (default: :const:`0.0`) initial_accumulator_value (float, optional): Initial value for the accumulator. (default: :const:`0.0`) - eps (float, optional): A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. (default: :const:`1e-10`) - maximize (bool, optional): Maximize the params based on the objective, instead of minimizing. - (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) """ super().__init__( module, diff --git a/torchopt/optim/meta/adamw.py b/torchopt/optim/meta/adamw.py index c8a8ef9c..204a5428 100644 --- a/torchopt/optim/meta/adamw.py +++ b/torchopt/optim/meta/adamw.py @@ -1,4 +1,4 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -64,7 +64,7 @@ def __init__( (default: :const:`1e-8`) weight_decay (float, optional): Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where + other frameworks such as PyTorch, but different from (Loshchilov et al., 2019) where the weight decay is only multiplied with the "schedule multiplier", but not the base learning rate. (default: :const:`1e-2`) eps_root (float, optional): A small constant applied to denominator inside the square diff --git a/torchopt/schedule/exponential_decay.py b/torchopt/schedule/exponential_decay.py index ae6683ce..8811b353 100644 --- a/torchopt/schedule/exponential_decay.py +++ b/torchopt/schedule/exponential_decay.py @@ -50,28 +50,33 @@ def exponential_decay( staircase: bool = False, end_value: Optional[float] = None, ) -> Schedule: - """Constructs a schedule with either continuous or discrete exponential decay. - - This function applies an exponential decay function to a provided initial - value. The function returns the decayed value as follows: - ``` - decayed_value = init_value * decay_rate ^ (count / transition_steps) - ``` - If the argument `staircase` is `True`, then `count / transition_steps` is - an integer division and the decayed value follows a staircase function. + """Construct a schedule with either continuous or discrete exponential decay. + + This function applies an exponential decay function to a provided initial value. The function + returns the decayed value as follows: + + .. code-block:: python + + decayed_value = init_value * decay_rate**(count / transition_steps) + + If the argument ``staircase`` is :data:`True`, then ``count / transition_steps`` is an integer + division and the decayed value follows a staircase function. + Args: init_value (float or Tensor): Initial value for the scalar to be annealed. decay_rate (float or Tensor): The decay rate. transition_begin (int, optional): Must be *positive*. After how many steps to start annealing (before this many steps the scalar value is held fixed at ``init_value``). (default: :const:`0`) - transition_steps (int): Number of steps over which annealing takes place, the scalar starts - changing at ``transition_begin`` steps and completes the transition by + transition_steps (int, optional): Number of steps over which annealing takes place, the + scalar starts changing at ``transition_begin`` steps and completes the transition by ``transition_begin + transition_steps`` steps. If ``transition_steps <= 0``, then the entire annealing process is disabled and the value is held fixed at ``init_value``. (default: :const:`1`) - staircase (bool): If ``True``, decay the scalar at discrete intervals. + staircase (bool, optional): If :data:`True`, decay the scalar at discrete intervals. + (default: :data:`False`) end_value (float or Tensor, optional): End value of the scalar to be annealed. + (default: :data:`None`) Returns: schedule: A function that maps step counts to values. diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index 4ea35f74..c3c6254e 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -87,7 +87,7 @@ def scale_by_adam( """Rescale updates according to the Adam algorithm. References: - [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 Args: b1 (float, optional): Decay rate for the exponentially weighted average of grads. @@ -238,7 +238,7 @@ def scale_by_accelerated_adam( This function is accelerated by using some fused accelerated operators. References: - [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 Args: b1 (float, optional): Decay rate for the exponentially weighted average of grads. diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index dbd6d621..ac2fef16 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -60,7 +60,7 @@ def scale_by_rms( """Rescale updates by the root of the exp. moving avg of the square. References: - [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) + - Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf Args: alpha (float, optional): Decay rate for the exponentially weighted average of squared grads. diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index a85713a9..a3d500ef 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -59,12 +59,14 @@ def scale_by_rss( """Rescale updates by the root of the sum of all squared gradients to date. References: - [Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) - [McMahan et al., 2010](https://arxiv.org/abs/1002.4908) + - Duchi et al., 2011: https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf + - McMahan et al., 2010: https://arxiv.org/abs/1002.4908 Args: - initial_accumulator_value: Starting value for accumulators, must be >= 0. - eps: A small floating point value to avoid zero denominator. + initial_accumulator_value (float, optional): Starting value for accumulators, must be + ``>= 0``. (default: :const:`0.0`) + eps (float, optional): A small floating point value to avoid zero denominator. + (default: :const:`1e-10`) Returns: An (init_fn, update_fn) tuple. @@ -112,10 +114,9 @@ def init_fn(params: Params) -> OptState: def update_fn( updates: Updates, state: OptState, - params: Params | None = None, + params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, - ) -> tuple[Updates, OptState]: # pylint: disable=unused-argument - del params + ) -> tuple[Updates, OptState]: sum_of_squares = tree_map( lambda g, t: t + (g.conj() * g).real, updates, diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index 8c9ab07d..bbbfb384 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -63,7 +63,7 @@ def scale_by_stddev( """Rescale updates by the root of the centered exponential moving average of squares. References: - [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) + - Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf Args: alpha (float, optional): Decay rate for the exponentially weighted average of squared grads. From 61234cd173e008f92d0342d5179e1950b6887d3e Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 22 Mar 2023 07:28:09 +0000 Subject: [PATCH 33/37] refactor: refactor scale_by_rss --- torchopt/transform/scale_by_rss.py | 28 +++++++++++---------- torchopt/transform/utils.py | 40 ++++++++++++++++++++++++------ 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index a3d500ef..68021e5e 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -39,7 +39,7 @@ from torchopt import pytree from torchopt.base import GradientTransformation -from torchopt.transform.utils import tree_map_flat +from torchopt.transform.utils import tree_map_flat, update_moment from torchopt.typing import OptState, Params, Updates @@ -117,32 +117,34 @@ def update_fn( params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, ) -> tuple[Updates, OptState]: - sum_of_squares = tree_map( - lambda g, t: t + (g.conj() * g).real, + sum_of_squares = update_moment.impl( # type: ignore[attr-defined] updates, state.sum_of_squares, + decay=1.0, + order=2, + inplace=inplace, + already_flattened=already_flattened, ) if inplace: - def f(t: torch.Tensor) -> torch.Tensor: + def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: return torch.where( - t > 0.0, - torch.ones_like(t).div_(t.sqrt().add_(eps)), - torch.tensor(0.0), + sos > 0.0, + g.div_(sos.sqrt().add_(eps)), + 0.0, ) else: - def f(t: torch.Tensor) -> torch.Tensor: + def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: return torch.where( - t > 0.0, - torch.ones_like(t).div(t.sqrt().add(eps)), - torch.tensor(0.0), + sos > 0.0, + g.div(sos.sqrt().add(eps)), + 0.0, ) - inv_sqrt_g_square = tree_map(f, sum_of_squares) - updates = tree_map(lambda scale, g: g * scale, inv_sqrt_g_square, updates) + updates = tree_map(f, updates, sum_of_squares) return updates, ScaleByRssState(sum_of_squares=sum_of_squares) return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index 57abe7ec..8c67fd7e 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -173,25 +173,49 @@ def _update_moment( if inplace: if order == 2: + if decay != 1.0: - def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: - return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.mul_(decay).addcmul_(g, g, value=1 - decay) if g is not None else t + + else: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.addcmul_(g, g) if g is not None else t else: + if decay != 1.0: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t + + else: - def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: - return t.mul_(decay).add_(g, alpha=1 - decay) if g is not None else t + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.add_(g) if g is not None else t else: if order == 2: + if decay != 1.0: - def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: - return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.mul(decay).addcmul_(g, g, value=1 - decay) if g is not None else t + + else: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.addcmul(g, g) if g is not None else t else: + if decay != 1.0: + + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t + + else: - def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: - return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t + def f(g: torch.Tensor | None, t: torch.Tensor) -> torch.Tensor: + return t.add(g) if g is not None else t if already_flattened: return tree_map_flat(f, updates, moments, none_is_leaf=True) From d34c534685d2b4e51f3e4006780e7c7947b1ba24 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 22 Mar 2023 08:46:00 +0000 Subject: [PATCH 34/37] test: fix eps value for AdaGrad --- tests/test_optim.py | 16 +++++++++------- tests/test_schedule.py | 18 ++++++++++-------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index 9fd0a072..6ec81918 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -273,7 +273,7 @@ def test_Adam_accelerated_cuda( lr=[1e-2, 1e-3, 1e-4], lr_decay=[0.0, 1e-2], initial_accumulator_value=[0.0, 1e-1], - eps=[1e-10], + eps=[1e-8], weight_decay=[0.0, 1e-2], maximize=[False, True], ) @@ -387,11 +387,11 @@ def test_RMSProp( dtype=[torch.float64, torch.float32], lr=[1e-2, 1e-3], optimizers=[ - (torchopt.sgd, torch.optim.SGD), - (torchopt.adam, torch.optim.Adam), - (torchopt.adamw, torch.optim.AdamW), - (torchopt.adagrad, torch.optim.Adagrad), - (torchopt.rmsprop, torch.optim.RMSprop), + (torchopt.sgd, torch.optim.SGD, {}), + (torchopt.adam, torch.optim.Adam, {}), + (torchopt.adamw, torch.optim.AdamW, {}), + (torchopt.adagrad, torch.optim.Adagrad, {'eps': 1e-8}), + (torchopt.rmsprop, torch.optim.RMSprop, {}), ], inplace=[True, False], weight_decay=[0.0, 1e-2], @@ -405,13 +405,14 @@ def test_FuncOptimizer( ) -> None: model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) - torchopt_optimizer, torch_optimizer = optimizers + torchopt_optimizer, torch_optimizer, optimizer_kwargs = optimizers fmodel, params, buffers = functorch.make_functional_with_buffers(model) optim = torchopt.FuncOptimizer( torchopt_optimizer( lr=lr, weight_decay=weight_decay, + **optimizer_kwargs, ), inplace=inplace, ) @@ -419,6 +420,7 @@ def test_FuncOptimizer( model_ref.parameters(), lr, weight_decay=weight_decay, + **optimizer_kwargs, ) for xs, ys in loader: diff --git a/tests/test_schedule.py b/tests/test_schedule.py index c7ba8c5f..1fdc4669 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import Callable +from typing import Any, Callable import functorch import numpy as np @@ -88,11 +88,11 @@ def test_linear_schedule() -> None: lr=[1e-2, 1e-3], total_iters=[helpers.NUM_UPDATES, helpers.NUM_UPDATES * 2], optimizers=[ - (torchopt.sgd, torch.optim.SGD), - (torchopt.adam, torch.optim.Adam), - (torchopt.adamw, torch.optim.AdamW), - (torchopt.adagrad, torch.optim.Adagrad), - (torchopt.rmsprop, torch.optim.RMSprop), + (torchopt.sgd, torch.optim.SGD, {}), + (torchopt.adam, torch.optim.Adam, {}), + (torchopt.adamw, torch.optim.AdamW, {}), + (torchopt.adagrad, torch.optim.Adagrad, {'eps': 1e-8}), + (torchopt.rmsprop, torch.optim.RMSprop, {}), ], inplace=[True, False], weight_decay=[0.0, 1e-2], @@ -102,7 +102,7 @@ def test_lr_linear_schedule( dtype: torch.dtype, lr: float, total_iters: int, - optimizers: tuple[Callable, torch.optim.Optimizer], + optimizers: tuple[Callable, torch.optim.Optimizer, dict[str, Any]], inplace: bool, weight_decay: float, use_chain_flat: bool, @@ -111,7 +111,7 @@ def test_lr_linear_schedule( model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) - torchopt_optimizer, torch_optimizer = optimizers + torchopt_optimizer, torch_optimizer, optimizer_kwargs = optimizers fmodel, params, buffers = functorch.make_functional_with_buffers(model) optim = torchopt_optimizer( @@ -122,12 +122,14 @@ def test_lr_linear_schedule( transition_begin=0, ), weight_decay=weight_decay, + **optimizer_kwargs, ) optim_state = optim.init(params) optim_ref = torch_optimizer( model_ref.parameters(), lr, weight_decay=weight_decay, + **optimizer_kwargs, ) torch_scheduler = torch.optim.lr_scheduler.LinearLR( optim_ref, From ccf87e226e1aa145f04f7b5055ebd125e4d08f8a Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 22 Mar 2023 08:56:25 +0000 Subject: [PATCH 35/37] docs(CHANGELOG): update CHANGELOG.md --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cf64375d..1afddb7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,10 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Implement AdaGrad optimizer and exponential learning rate schedule by [@Benjamin-eecs](https://github.com/Benjamin-eecs). +- Implement AdaGrad optimizer and exponential learning rate schedule by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#80](https://github.com/metaopt/torchopt/pull/80). - Enable tests on Windows by [@XuehaiPan](https://github.com/XuehaiPan) in [#140](https://github.com/metaopt/torchopt/pull/140). - Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/torchopt/pull/138) and [#139](https://github.com/metaopt/torchopt/pull/139). -- Add more documentation on implicit differentiation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan). +- Add more documentation on implicit differentiation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#143](https://github.com/metaopt/torchopt/pull/143). ### Changed From c1888393365551179c9c094ce6624e577c0598ea Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 22 Mar 2023 09:26:59 +0000 Subject: [PATCH 36/37] revert: revert change in torchopt/version.py --- torchopt/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchopt/version.py b/torchopt/version.py index a65aad3e..b8136a22 100644 --- a/torchopt/version.py +++ b/torchopt/version.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# Copyright 2022 MetaOPT Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From fbb68f891186eededa9a49926bec0b525ae41fb3 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Wed, 22 Mar 2023 21:25:41 +0800 Subject: [PATCH 37/37] chore: update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1afddb7a..58c4ae74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Implement AdaGrad optimizer and exponential learning rate schedule by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#80](https://github.com/metaopt/torchopt/pull/80). +- Implement AdaGrad optimizer and exponential learning rate decay schedule by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#80](https://github.com/metaopt/torchopt/pull/80). - Enable tests on Windows by [@XuehaiPan](https://github.com/XuehaiPan) in [#140](https://github.com/metaopt/torchopt/pull/140). - Add `ruff` and `flake8` plugins integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#138](https://github.com/metaopt/torchopt/pull/138) and [#139](https://github.com/metaopt/torchopt/pull/139). - Add more documentation on implicit differentiation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#143](https://github.com/metaopt/torchopt/pull/143).