Skip to content

Commit

Permalink
Merge pull request #707 from mit-ll-responsible-ai/numpy-fix
Browse files Browse the repository at this point in the history
Fix numpy 2.0 compat
  • Loading branch information
rsokl authored Jul 12, 2024
2 parents bb11ed4 + 4e2576a commit cdee868
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
7 changes: 4 additions & 3 deletions docs/source/tutorials/using_scikit_learn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,14 @@ stored in job directories and plot the results.
row = i // 10
col = i % 10
# ax[row, col].set_axis_off()
ax[row, col].imshow(img)
_ax = ax[row, col] # type: ignore
_ax.imshow(img)
if row == 0:
ax[row, col].set_title(cname)
_ax.set_title(cname)
if col == 0:
ax[row, col].set_ylabel(dname)
_ax.set_ylabel(dname)
The resulting figure should be:

Expand Down
16 changes: 8 additions & 8 deletions src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,23 +520,27 @@ class Just:

def _is_ufunc(value: Any) -> bool:
# checks without importing numpy
numpy = sys.modules.get("numpy")
if numpy is None: # pragma: no cover
if (numpy := sys.modules.get("numpy")) is None: # pragma: no cover
# we do actually cover this branch some runs of our CI,
# but our coverage job installs numpy
return False
return isinstance(value, numpy.ufunc)


def _is_numpy_array_func_dispatcher(value: Any) -> bool:
if (numpy := sys.modules.get("numpy")) is None: # pragma: no cover
return False
return isinstance(value, type(numpy.sum))


def _check_instance(*target_types: str, value: "Any", module: str): # pragma: no cover
"""Checks if value is an instance of any of the target types imported
from the specified module.
Returns `False` if module/target type doesn't exists (e.g. not installed).
This is useful for gracefully handling specialized logic for optional dependencies.
"""
mod = sys.modules.get(module)
if mod is None:
if (mod := sys.modules.get(module)) is None:
return False

types = []
Expand All @@ -555,10 +559,6 @@ def _check_instance(*target_types: str, value: "Any", module: str): # pragma: n
return any(value is t for t in types)


_is_numpy_array_func_dispatcher = functools.partial(
_check_instance, "_ArrayFunctionDispatcher", module="numpy.core._multiarray_umath"
)

_is_jax_compiled_func = functools.partial(
_check_instance, "CompiledFunction", "PjitFunction", module="jaxlib.xla_extension"
)
Expand Down

0 comments on commit cdee868

Please sign in to comment.