From 7c83fcd1aa28a659cbd03f8007727079b8905647 Mon Sep 17 00:00:00 2001 From: Ryan Soklaski Date: Fri, 12 Jul 2024 10:10:53 -0400 Subject: [PATCH 1/3] Fix numpy 2.0 compat --- .../structured_configs/_implementations.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/hydra_zen/structured_configs/_implementations.py b/src/hydra_zen/structured_configs/_implementations.py index cb88d1a9..a3e56967 100644 --- a/src/hydra_zen/structured_configs/_implementations.py +++ b/src/hydra_zen/structured_configs/_implementations.py @@ -520,14 +520,19 @@ class Just: def _is_ufunc(value: Any) -> bool: # checks without importing numpy - numpy = sys.modules.get("numpy") - if numpy is None: # pragma: no cover + if (numpy := sys.modules.get("numpy")) is None: # we do actually cover this branch some runs of our CI, # but our coverage job installs numpy return False return isinstance(value, numpy.ufunc) +def _is_numpy_array_func_dispatcher(value: Any) -> bool: + if (numpy := sys.modules.get("numpy")) is None: # pragma: no cover + return False + return isinstance(value, type(numpy.sum)) + + def _check_instance(*target_types: str, value: "Any", module: str): # pragma: no cover """Checks if value is an instance of any of the target types imported from the specified module. @@ -535,8 +540,7 @@ def _check_instance(*target_types: str, value: "Any", module: str): # pragma: n Returns `False` if module/target type doesn't exists (e.g. not installed). This is useful for gracefully handling specialized logic for optional dependencies. """ - mod = sys.modules.get(module) - if mod is None: + if (mod := sys.modules.get(module)) is None: return False types = [] @@ -555,10 +559,6 @@ def _check_instance(*target_types: str, value: "Any", module: str): # pragma: n return any(value is t for t in types) -_is_numpy_array_func_dispatcher = functools.partial( - _check_instance, "_ArrayFunctionDispatcher", module="numpy.core._multiarray_umath" -) - _is_jax_compiled_func = functools.partial( _check_instance, "CompiledFunction", "PjitFunction", module="jaxlib.xla_extension" ) From 7b5445e1c4818acf0d82749e29c5c8146d00bf51 Mon Sep 17 00:00:00 2001 From: Ryan Soklaski Date: Fri, 12 Jul 2024 10:16:45 -0400 Subject: [PATCH 2/3] coverage --- src/hydra_zen/structured_configs/_implementations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hydra_zen/structured_configs/_implementations.py b/src/hydra_zen/structured_configs/_implementations.py index a3e56967..fa0e8020 100644 --- a/src/hydra_zen/structured_configs/_implementations.py +++ b/src/hydra_zen/structured_configs/_implementations.py @@ -520,7 +520,7 @@ class Just: def _is_ufunc(value: Any) -> bool: # checks without importing numpy - if (numpy := sys.modules.get("numpy")) is None: + if (numpy := sys.modules.get("numpy")) is None: # pragma: no cover # we do actually cover this branch some runs of our CI, # but our coverage job installs numpy return False From 4e2576ae2b30f8364c6351c371f5026499ae231c Mon Sep 17 00:00:00 2001 From: Ryan Soklaski Date: Fri, 12 Jul 2024 10:23:42 -0400 Subject: [PATCH 3/3] pyright scan --- docs/source/tutorials/using_scikit_learn.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/tutorials/using_scikit_learn.rst b/docs/source/tutorials/using_scikit_learn.rst index 4ce05624..b6262913 100644 --- a/docs/source/tutorials/using_scikit_learn.rst +++ b/docs/source/tutorials/using_scikit_learn.rst @@ -417,13 +417,14 @@ stored in job directories and plot the results. row = i // 10 col = i % 10 # ax[row, col].set_axis_off() - ax[row, col].imshow(img) + _ax = ax[row, col] # type: ignore + _ax.imshow(img) if row == 0: - ax[row, col].set_title(cname) + _ax.set_title(cname) if col == 0: - ax[row, col].set_ylabel(dname) + _ax.set_ylabel(dname) The resulting figure should be: