Skip to content

Commit

Permalink
[Pallas TPU] Fix dtype_bitwidth for int in util.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675357560
  • Loading branch information
bythew3i authored and Google-ML-Automation committed Sep 17, 2024
1 parent 9408606 commit d27fce6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/pallas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def next_power_of_2(x: int) -> int:
return 1 if x == 0 else 2 ** (x - 1).bit_length()

def dtype_bitwidth(dtype: np.dtype | jnp.dtype) -> int:
if isinstance(dtype, jnp.integer):
if jnp.issubdtype(dtype, jnp.integer):
return jnp.iinfo(dtype).bits
return np.dtype(dtype).itemsize * 8

Expand Down

0 comments on commit d27fce6

Please sign in to comment.