-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Get sampling working using Apple Silicon GPU via jax backend #7332
Comments
It seems to fail at something very basing, trying to call |
Perhaps try with |
Adding import pytensor
pytensor.config.floatX="float32" results in a different error
|
This actually works fine: >>> jax.numpy.array(x, dtype="float32")
Array([ 0.42651013, 1.9349691 , 0.43221945, -0.24343772, 2.760918 ,
1.2610279 , -1.5116365 , 0.9801455 , 0.5613332 , 0.6750525 ], dtype=float32) |
Doesn't work for gives
|
Yes that's the error you were getting first |
You're now getting errors deep inside numpyro, and have left PyMC/PyTensor. Could you try something simpler first? |
That fails with
Well, there's more than that which I can give if you really want. |
|
@drbenvincent try something even simpler, a model just with |
import pymc as pm
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
model.compile_logp(sum=False, mode="JAX") Does not error out :) Gives me |
You need to eval it still with import pytensor
pytensor.config.floatX = "float32"
import pymc as pm
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
model.compile_logp(sum=False, mode="JAX")(model.initial_point()) |
It works
|
Great. So the next question is, can you use numpyro to sample from that very simple model? |
Looks like the answer is no, so far: with model:
idata = pm.sample(draws=10_000, nuts_sampler="numpyro", chains=1) gives
Full traceback
|
Can you try with |
Same error I'm afraid |
So it's still pretty broken |
Does installing |
I tried a bunch of things and no luck yet. Specifically on that suggestion, I tried
|
FYI I created a topic in the Pyro discourse so we'll see if there are any plans for support from the pyro side of things. |
Upgraded |
Tried it myself with that setting, and the newer JAX version does not produce the previous error, but I'm still getting the MPS error.
|
It does work with
But seems to sample just fine so far. |
Same error with blackjax. but using vectorization gives:
|
@junpenglao Is there a way to not have blackjax spawn/fork a new process? |
Current summary: it seems that running |
A quick recap of the steps I'm taking here. With a environment file: name: metal_test_env
channels:
- conda-forge
dependencies:
- jupyterlab
- numpy
- numpyro
- pip
- pymc
- python<3.11
- pip:
- blackjax
- jax==0.4.28
- jax-metal
- jaxlib==0.4.28
- ml-dtypes==0.2.0
Then run import os
import jax
import pytensor
import pymc as pm
import numpy as np
os.environ['ENABLE_PJRT_COMPATIBILITY'] = '1'
pytensor.config.floatX = "float32"
jax.print_environment_info() # gives positive output, metal device is recognised This works fine: with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
model.compile_logp(sum=False, mode="JAX")(model.initial_point()) # returns something like [Array(-0.9189385, dtype=float32)]
with model:
idata = pm.sample(draws=10_000, cores=1, chains=1) And sampling with
Trying with blackjax backend I get the same errors as you:
So as far as I understand, nobody so far has been able to use GPU to sample. |
Yeah that's how far I got. I don't think it's going to work.
…On Wed, May 29, 2024, 11:59 Benjamin T. Vincent ***@***.***> wrote:
Upgraded jax and jax-lib and set ENABLE_PJRT_COMPATIBILITY=1 allows using
more recent jax versions.
I'm not getting this to work. A quick recap of the steps I'm taking here.
With a environment file:
name: metal_test_envchannels:
- conda-forgedependencies:
- jupyterlab
- numpy
- numpyro
- pip
- pymc
- python<3.11
- pip:
- blackjax
- jax==0.4.28
- jax-metal
- jaxlib==0.4.28
- ml-dtypes==0.2.0
conda env create -f metal_test_env.yaml
conda activate metal_test_env
Then run
import osimport jaximport pytensorimport pymc as pmimport numpy as np
os.environ['ENABLE_PJRT_COMPATIBILITY'] = '1'pytensor.config.floatX = "float32"
jax.print_environment_info() # gives positive output, metal device is recognised
This works fine:
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
model.compile_logp(sum=False, mode="JAX")(model.initial_point()) # returns something like [Array(-0.9189385, dtype=float32)]
pm.sample works, but doesn't utilise GPU, using activity monitor
with model:
idata = pm.sample(draws=10_000, cores=1, chains=1)
And sampling with numpyro backend still gives MPS error
with model:
idata = pm.sample(draws=1_000, nuts_sampler="numpyro", chains=1, cores=1)
# XlaRuntimeError: INTERNAL: Unable to serialize MPS module
Trying with blackjax backend I get the same errors as you:
- XlaRuntimeError: INTERNAL: Unable to serialize MPS module
- ValueError: EmitPythonCallback not supported on METAL backend. when
setting nuts_sampler_kwargs={"chain_method": "vectorized"}
So as far as I understand, nobody so far has been able to use GPU to
sample.
—
Reply to this email directly, view it on GitHub
<#7332 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGB2WZ4ABWMARBAM3L3ZEWRJ3AVCNFSM6AAAAABIFMOH6KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZXGAZDINRZGY>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Fingers crossed it just starts working randomly as the metal support for jax improves. |
I wouldn't hold my breath.
But via pytorch backend might be possible.
…On Wed, May 29, 2024, 12:18 Benjamin T. Vincent ***@***.***> wrote:
Fingers crossed it just starts working randomly as the metal support for
jax improves.
—
Reply to this email directly, view it on GitHub
<#7332 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGCGLGCCBFT3EBXRDRTZEWTRDAVCNFSM6AAAAABIFMOH6KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZXGA3DGNJTGE>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Tried with the new jax-metal 0.1.0 and jax and jaxlib 0.4.30 Setting import pytensor
pytensor.config.floatX="float32"
import numpy as np
import pymc as pm
x = np.random.normal(size=10)
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
pm.Normal("x_obs", mu=mu, sigma=1, observed=x)
idata = pm.sample(nuts_sampler="blackjax", cores=1, chains=1, nuts_sampler_kwargs={"chain_method": "vectorized"}) Gives:
Running with numpyro:
Not sure why it's always trying to fork even when using vectorize. |
Updated to Sequioa 15 Beta 7 and now getting (with numpyro sampler):
|
It would be great to utilise the GPU on Apple Silicon chips. The lowest resistance way of doing this is probably through the jax backend, see https://jax.readthedocs.io/en/latest/installation.html#apple-silicon-gpu-arm-based and the Apple docs Accelerated JAX training on Mac
I don't have the stats, but some sizeable portion of PyMC users run code on hardware with Apple Silicon, and this will increase over time as more people upgrade from Intel to Apple Silicon. Full utilisation of those chips (i.e. the GPU component) would likely unlock some speed-ups in sampling.
So far I have partial progress (ht to @twiecki). I have the following environment file,
metal_test_env.yaml
NOTE: It seems that pinning
python<3.11
is a necessity at this point in time.I build that with:
Then in an ipython session we can confirm that jax has detected the Apple Silicon GPU
gives
The key line is:
jax.devices (1 total, 1 local): [METAL(id=0)]
So the next step is to see if we can do sampling:
which as of now results in this traceback
Traceback
For all I know the problem is on the jax side, and may require issues to be filled in that repo. But I think it makes sense to have a pymc issue to raise this goal as a priority and perhaps to coordinate any additional issues on the pymc or jax side.
The text was updated successfully, but these errors were encountered: