diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index ea0a8c2b..7347418a 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -262,9 +262,9 @@ __all__ += ["argmax", "argmin", "where"] -from .statistical_functions import max, mean, min, prod, sum +from .statistical_functions import max, mean, min, prod, std, sum, var -__all__ += ["max", "mean", "min", "prod", "sum"] +__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"] from .utility_functions import all, any diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index 7ee6525e..46e960f9 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -15,6 +15,7 @@ uint64, ) from cubed.backend_array_api import namespace as nxp +from cubed.array_api.elementwise_functions import sqrt, square, subtract from cubed.core import reduction @@ -184,3 +185,12 @@ def sum( split_every=split_every, extra_func_kwargs=extra_func_kwargs, ) + + +def var(x, /, *, axis=None, keepdims=False): + mu = mean(x, axis=axis, keepdims=True) + return mean(square(subtract(x, mu)), axis=axis, keepdims=keepdims) + + +def std(x, /, *, axis=None, keepdims=False): + return sqrt(var(x, axis=axis, keepdims=keepdims)) diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index b7764caa..c4f0697b 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -728,6 +728,74 @@ def test_sum_axis_0(spec, executor): assert_array_equal(b.compute(executor=executor), np.array([12, 15, 18])) +def test_var(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.var(a) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(), + ) + + +def test_var_axis_0(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.var(a, axis=0) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(axis=0), + ) + + +def test_var_axis_1(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.var(a, axis=1) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(axis=1), + ) + + +def test_var_keepdims_true(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.var(a, keepdims=True) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var( + keepdims=True + ), + ) + + +def test_std(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.std(a) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).std(), + ) + + +def test_std_axis_0(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.std(a, axis=0) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).std(axis=0), + ) + + # Utility functions