Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jaxopt 2 #402

Merged
merged 13 commits into from
Nov 7, 2023
Merged

Jaxopt 2 #402

merged 13 commits into from
Nov 7, 2023

Conversation

henrymoss
Copy link
Contributor

@henrymoss henrymoss commented Oct 22, 2023

Just so we make some progress and we support BFGS in some capacity (we really should!). I have done a quick PR where we have a separate fit function fit_scipy that uses scipy (BFGS, or L-BFGS depending on problem size), for those that are keen to do so. This is a very small non-breaking change unlike the other PR which was breaking and involved 10x the code.

Note, I also played around with optimistix, but it didnt really make things easier.

@daniel-dodd , by having it as a separate function, I only need to instantiate the jaxopt thingy once and so no immediate horrors pop up.

A couple of interesting observations when I update the notebooks with the scipy optimzier.

  • Adam was failing to optimize a kernel in the "intro to kernels" notebook. However, you never notice this with Adam. The scipy implementation threw an error, which helped me notice the bad fit we were getting with Adam. Turns out we just needed to set a more reasonable initial lengthscale and all is good.
  • @thomaspinder's circular kernel also wasnt fitting properly. If you look at the hosted notebooks, you see that the posterior was behaving really badly, with no uncertainty. This was because you had set up the bijector wrong that constrained tau. It wasnt actually stopping tau getting big. I've now fixed this, allowing our optimizers (adam and scipy) to converge.

One weird thing for @daniel-dodd. In the barycentre notebook, I have issues with a strange error popping up due to calling fit a few times in a for loop. Its entirely incomprehensible.

@henrymoss
Copy link
Contributor Author

Some Qs:

  • the logging for the scipy optimize is a bit loose atm. Would you rather I raise an error if opt doesnt work?

@daniel-dodd
Copy link
Member

daniel-dodd commented Oct 27, 2023

Thanks @henrymoss this PR looks great to me.

In the barycentres, it will actually break on:

fit_gp(x, ys[0])
fit_gp(x, ys[1])

Doing .step instead of relying on the .__call__ works i.e., objective=jax.jit(gpx.ConjugateMLL(negative=True).step) instead of objective=jax.jit(gpx.ConjugateMLL(negative=True)).

It would probably be safer to remove these following lines of code on AbstractObjective:

    ...
    def __hash__(self):
        return hash(tuple(jtu.tree_leaves(self)))  # Probably put this on the Module!


    def __call__(self, *args, **kwargs) -> ScalarFloat:
        return self.step(*args, **kwargs)

So that jax.jit(gpx.ConjugateMLL(negative=True)) errors --- its kinda hacky anyway -- as really you should just be jitting a function.

Objectives could still be passed as objective=gpx.ConjugateMLL(negative=True) without the jit which is not really needed in the first place, as code is traced with the lax.scan.

Copy link
Member

@daniel-dodd daniel-dodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super-duper. 🍾

@henrymoss henrymoss merged commit 5758238 into main Nov 7, 2023
14 checks passed
@st-- st-- deleted the jaxop_2 branch November 30, 2023 10:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants