Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Silently changing JAX default config #496

Open
invemichele opened this issue Mar 31, 2023 · 10 comments
Open

Silently changing JAX default config #496

invemichele opened this issue Mar 31, 2023 · 10 comments

Comments

@invemichele
Copy link

I run into an error while trying to load a JAX NN model, and it took me a while to realize the problem was caused by import pymbar. Here JAX global default is changed to x64, which was incompatible with my stored model.

I solved by setting force_no_jax = True here, but probably it would be nice to have a warning somewhere about this global config change, or mention it somewhere in the doc.

@mikemhenry
Copy link
Contributor

mikemhenry commented Mar 31, 2023

Thanks for opening this bug report @invemichele ❤️
@mrshirts or @Lnaden you both are more familiar with JAX than me, do we need to set 64bit support?

(At first I thought this was about CPU architecture but it is about 64 bit floats https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision)

Testing to see what happens in our CI here: #497 if I remove 64 bit floats

If we do need them, we should document this behavior and also print a warning. This will be important because I am not sure what happens if someone has already started jax and then we try and change the config since the docs say:

To use double-precision numbers, you need to set the jax_enable_x64 configuration variable at startup.

@mikemhenry
Copy link
Contributor

Looks like we do need 64 bit floats (which is what I thought)

@invemichele I can add a warning and improve documentation, is that sufficient? It doesn't look like we can dynamically change Jax global configs after startup unfortunately

@invemichele
Copy link
Author

yes, that would be useful info for people that want to use JAX and pymbar in the same script

@Lnaden
Copy link
Contributor

Lnaden commented Jun 13, 2023

I've been tinkering with this some more, and I don't think there is a reasonable way to expect PyMBAR to operate in 32 bit mode. Implementing setting 32-bit mode is a bit tricky, but I have an implementation. The problem is there is no reliable way to expect useful outputs. You can get them, but its not reliably accurate or converging.

@invemichele (and others) Is a loud warning sufficient for you use case here, or would it be extremely useful to be able to force 32-bit mode for PyMBAR under the assumption that you're not guaranteed converged results (I would also issue a very loud warning about using 32-bit floats from PyMBAR if thats the case).

@invemichele
Copy link
Author

To be clear, I am not interested in pymbar using float32. The issue is that by changing the global JAX default, pymbar is in practice incompatible with other JAX code. My notebook with some JAX neural network, stopped working as soon as I imported pymbar and, since the error came from one of my JAX lines, it was not clear to me that the problem was pymbar until I went through its source code.

A warning about the changed JAX global setting would be very useful for debugging.
A solution to the issue would be to give the possibility of setting force_no_jax=True somehow from the import instead of having to recompile pymbar.
Another possible solution could be to automatically fall back to the non-jax implementation if JAX had already been imported outside of pymbar. This would be a safety-first approach, since according to the documentation jax_enable_x64 should be set at startup and could otherwise create problems.

@Lnaden
Copy link
Contributor

Lnaden commented Jun 14, 2023

Another possible solution could be to automatically fall back to the non-jax implementation if JAX had already been imported outside of pymbar

That would be a viable option. I think we can expand that to have the safest and most user-controllable approach. From what I can interpret; so long as the JIT'd functions haven't been called yet, we can still set the 64-bit mode. So how does this sound:

  • On import of pymbar (and JAX), check if 64bit is set, issue a warning that pymbar needs it set and running its code will set it unless disabled, but otherwise dont do anything yet.
  • Add a API features to the main MBAR code which will disable the JAX pathways so users can toggle that as they need.
  • Create a call in mbar_solvers which toggles JAX globally for its functions that the main MBAR code can toggle as well. Simplest implementation but relies on module-level "global" variable.
  • Issue another warning when/if 64bit jax is enabled.

@Lnaden
Copy link
Contributor

Lnaden commented Jun 15, 2023

@invemichele I've got the warning for this in #504. Functionally, the JAX config is not set until right before the first JIT call and will issue this pair of warnings:

On import (if 32-bit JAX):

****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************

On change to 64-bit mode:

******* JAX 64-bit mode is now on! *******
*     JAX is now set to 64-bit mode!     *
*   This MAY cause problems with other   *
*      uses of JAX in the same code.     *
******************************************

I realize I still haven't set the API call, but the warnings are what I wanted to do for this PR first so I don't break the API in testing on top of changing the import logic.

@Lnaden Lnaden closed this as completed in 72b18b0 Jun 15, 2023
@Lnaden Lnaden reopened this Jun 15, 2023
@Lnaden
Copy link
Contributor

Lnaden commented Jun 15, 2023

Magic word closed this, my mistake. Not ready to close until the API to toggle JAX is in.

@Lnaden
Copy link
Contributor

Lnaden commented Jun 15, 2023

In trying to develop the API side of this, I realize this warning doesn't do any real good because the JIT decorators all activate on import before any of the actual functions are called because of how they work. I can disable the jit of functions with a global parameter, but I don't know how to check each function once its called to set the x64 flag and then jit. I need to delay the actions of the decorator until execution. Re-thinking the code. now.

Even though the current merged version doesn't stop the 64-bit setting on import, it will very loudly warn you at least for now.

@Lnaden
Copy link
Contributor

Lnaden commented Jun 15, 2023

Got a fix in #505. Once in I can can carry this over to the API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants