Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(transform): fix momentum trace #58

Merged
merged 4 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
##### Project specific #####
!torchopt/_src/
!torchopt/_lib/

##### Python.gitignore #####
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Changelog

<!-- markdownlint-disable no-duplicate-header -->

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
Expand All @@ -19,10 +21,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix momentum tracing by [@XuehaiPan](https://github.com/XuehaiPan) in [#58](https://github.com/metaopt/TorchOpt/pull/58).
- Fix CUDA build for accelerated OP [@XuehaiPan](https://github.com/XuehaiPan) in [#53](https://github.com/metaopt/TorchOpt/pull/53).
- Fix gamma error in MAML-RL implementation [@Benjamin-eecs](https://github.com/Benjamin-eecs) [#47](https://github.com/metaopt/TorchOpt/pull/47).


### Removed

------
Expand Down
9 changes: 4 additions & 5 deletions torchopt/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@

from typing import Optional

import jax

from torchopt._src import base, combine, transform
from torchopt._src.typing import ScalarOrSchedule
from torchopt._src.utils import pytree


def _scale_by_lr(lr: ScalarOrSchedule, flip_sign=True):
Expand All @@ -48,13 +47,13 @@ def schedule_wrapper(count):
def f(scaled_lr):
return sign * scaled_lr

return jax.tree_map(f, lr(count)) # type: ignore
return pytree.tree_map(f, lr(count)) # type: ignore

return transform.scale_by_schedule(schedule_wrapper)
return transform.scale(sign * lr)


# pylint: disable=too-many-arguments
# pylint: disable-next=too-many-arguments
def adam(
lr: ScalarOrSchedule,
b1: float = 0.9,
Expand Down Expand Up @@ -151,7 +150,7 @@ def sgd(
)


# pylint: disable=too-many-arguments
# pylint: disable-next=too-many-arguments
def rmsprop(
lr: ScalarOrSchedule,
decay: float = 0.9,
Expand Down
4 changes: 2 additions & 2 deletions torchopt/_src/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/clip_grad.py
# ==============================================================================

import jax
import torch
from torch._six import inf

from torchopt._src import base
from torchopt._src.utils import pytree


ClipState = base.EmptyState
Expand Down Expand Up @@ -80,7 +80,7 @@ def f(g):
def f(g):
return g.mul(clip_coef_clamped) if g is not None else None

new_updates = jax.tree_map(f, updates)
new_updates = pytree.tree_map(f, updates)
return new_updates, state

return base.GradientTransformation(init_fn, update_fn)
4 changes: 2 additions & 2 deletions torchopt/_src/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.
# ==============================================================================

import jax
import torch

from torchopt._src.base import EmptyState, GradientTransformation
from torchopt._src.utils import pytree


def zero_nan_hook(g: torch.Tensor) -> torch.Tensor:
Expand All @@ -40,7 +40,7 @@ def update_fn(updates, state, inplace=False): # pylint: disable=unused-argument
def f(g):
return g.register_hook(hook) if g is not None else None

jax.tree_map(f, updates)
pytree.tree_map(f, updates)
return updates, state

return GradientTransformation(init_fn, update_fn)
2 changes: 1 addition & 1 deletion torchopt/_src/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Adam(Optimizer):
- The differentiable meta-Adam optimizer: :class:`torchopt.MetaAdam`.
"""

# pylint: disable=too-many-arguments
# pylint: disable-next=too-many-arguments
def __init__(
self,
params: Iterable[torch.Tensor],
Expand Down
10 changes: 5 additions & 5 deletions torchopt/_src/optimizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

from typing import Iterable

import jax
import torch

from torchopt._src.base import GradientTransformation
from torchopt._src.update import apply_updates
from torchopt._src.utils import pytree


class Optimizer:
Expand Down Expand Up @@ -70,7 +70,7 @@ def f(p):
p.grad.requires_grad_(False)
p.grad.zero_()

jax.tree_map(f, group)
pytree.tree_map(f, group)

def state_dict(self):
"""Returns the state of the optimizer."""
Expand Down Expand Up @@ -102,16 +102,16 @@ def f(p):
return p.grad

for param, state in zip(self.param_groups, self.state_groups):
grad = jax.tree_map(f, param)
grad = pytree.tree_map(f, param)
updates, _ = self.impl.update(grad, state)
apply_updates(param, updates)

return loss

def add_param_group(self, params):
"""Add a param group to the optimizer's :attr:`param_groups`."""
params, tree = jax.tree_flatten(params)
params, params_tree = pytree.tree_flatten(params)
params = tuple(params)
self.param_groups.append(params)
self.param_tree_groups.append(tree)
self.param_tree_groups.append(params_tree)
self.state_groups.append(self.impl.init(params))
2 changes: 1 addition & 1 deletion torchopt/_src/optimizer/meta/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MetaAdam(MetaOptimizer):
- The classic Adam optimizer: :class:`torchopt.Adam`.
"""

# pylint: disable=too-many-arguments
# pylint: disable-next=too-many-arguments
def __init__(
self,
net: nn.Module,
Expand Down
8 changes: 4 additions & 4 deletions torchopt/_src/optimizer/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
# ==============================================================================

import jax
import torch
import torch.nn as nn

from torchopt._src.base import GradientTransformation
from torchopt._src.update import apply_updates
from torchopt._src.utils import pytree


class MetaOptimizer:
Expand Down Expand Up @@ -57,7 +57,7 @@ def step(self, loss: torch.Tensor):
for idx, (state, param_containers) in enumerate(
zip(self.state_groups, self.param_containers_groups)
):
flatten_params, containers_tree = jax.tree_util.tree_flatten(param_containers)
flatten_params, containers_tree = pytree.tree_flatten(param_containers)
flatten_params = tuple(flatten_params)
grad = torch.autograd.grad(loss, flatten_params, create_graph=True, allow_unused=True)
updates, state = self.impl.update(grad, state, False)
Expand All @@ -69,11 +69,11 @@ def step(self, loss: torch.Tensor):

def add_param_group(self, net):
"""Add a param group to the optimizer's :attr:`state_groups`."""
# pylint: disable=import-outside-toplevel,cyclic-import
# pylint: disable-next=import-outside-toplevel,cyclic-import
from torchopt._src.utils import _extract_container

net_container = _extract_container(net, with_buffer=False)
flatten_param, _ = jax.tree_util.tree_flatten(net_container)
flatten_param, _ = pytree.tree_flatten(net_container)
flatten_param = tuple(flatten_param)
optim_state = self.impl.init(flatten_param)
self.state_groups.append(optim_state)
Expand Down
2 changes: 1 addition & 1 deletion torchopt/_src/optimizer/meta/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MetaRMSProp(MetaOptimizer):
- The classic RMSProp optimizer: :class:`torchopt.RMSProp`.
"""

# pylint: disable=too-many-arguments
# pylint: disable-next=too-many-arguments
def __init__(
self,
net: nn.Module,
Expand Down
2 changes: 1 addition & 1 deletion torchopt/_src/optimizer/meta/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MetaSGD(MetaOptimizer):
- The classic SGD optimizer: :class:`torchopt.SGD`.
"""

# pylint: disable=too-many-arguments
# pylint: disable-next=too-many-arguments
def __init__(
self,
net: nn.Module,
Expand Down
2 changes: 1 addition & 1 deletion torchopt/_src/optimizer/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class RMSProp(Optimizer):
- The differentiable meta-RMSProp optimizer: :class:`torchopt.MetaRMSProp`.
"""

# pylint: disable=too-many-arguments
# pylint: disable-next=too-many-arguments
def __init__(
self,
params: Iterable[torch.Tensor],
Expand Down
4 changes: 2 additions & 2 deletions torchopt/_src/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
# limitations under the License.
# ==============================================================================

import jax
import numpy as np
from absl import logging

from torchopt._src import base
from torchopt._src.typing import Scalar
from torchopt._src.utils import pytree


def polynomial_schedule(
Expand Down Expand Up @@ -85,7 +85,7 @@ def impl(count):
frac = 1 - count / transition_steps
return (init_value - end_value) * (frac**power) + end_value

return jax.tree_map(impl, count)
return pytree.tree_map(impl, count)

return schedule

Expand Down
Loading