Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Support for functorch transforms #223

Open
2 tasks done
marvinfriede opened this issue May 27, 2024 · 0 comments
Open
2 tasks done

[Feature Request] Support for functorch transforms #223

marvinfriede opened this issue May 27, 2024 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@marvinfriede
Copy link

marvinfriede commented May 27, 2024

Required prerequisites

Motivation

I am interested in Jacobians and Hessians from implicitly differentiated root finding problems. This is something that regularly comes up in scientific computing. With jax, this is already possible out of the box using function transforms (e.g., jacrev). Is this something you plan to support in torchopt, too?

Solution

I already tried, but apparently, there are some pieces of code that prevent this:

  • missing setup_context for vmap rule in ImplicitMetaGradient (very easy to adapt)
  • .item() in _vdot_real_kernel
  • make_rmatvec in normal_cg
  • conditionals in _cg_solve
  • tree operations in _cg_solve

Alternatives

The jaxopt version
import jax
import jax.numpy as jnp
from jaxopt.implicit_diff import custom_root
from jaxopt import Bisection

jax.config.update("jax_platform_name", "cpu")


def F(x, factor):
  return factor * x ** 3 - x - 2


def bisection_root_solver(init_x, factor):
  bisec = Bisection(optimality_fun=F, lower=1, upper=2)
  return bisec.run(factor=factor).params


@custom_root(F)
def custom_root_solver(init_x, factor):
    """Root solver using gradient descent."""
    maxiter = 100
    lr = 1e-1

    x = init_x
    for _ in range(maxiter):
        grad = F(x, factor)
        x = x - lr * grad

    return x


x_init = jnp.array(3.0)
fac = jnp.array(2.0)

print(custom_root_solver(x_init, fac))
print(bisection_root_solver(x_init, fac))

print(jax.grad(custom_root_solver, argnums=1)(x_init, fac))
print(jax.grad(bisection_root_solver, argnums=1)(x_init, fac))

custom_jac_fcn = jax.jacrev(custom_root_solver, argnums=1)
print(jax.jacrev(custom_jac_fcn, argnums=1)(x_init, fac))
bisection_jac_fcn = jax.jacrev(bisection_root_solver, argnums=1)
print(jax.jacrev(bisection_jac_fcn, argnums=1)(x_init, fac))

Additional context

No response

@marvinfriede marvinfriede added the enhancement New feature or request label May 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants