diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index 8ef442ef..b7ed20ea 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -244,9 +244,9 @@ __all__ += ["argmax", "argmin", "where"] -from .statistical_functions import max, mean, min, prod, sum +from .statistical_functions import max, mean, min, prod, sum, std, var -__all__ += ["max", "mean", "min", "prod", "sum"] +__all__ += ["max", "mean", "min", "prod", "sum", "std", "var"] from .utility_functions import all, any diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index b4ff7694..3cda825e 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -16,6 +16,7 @@ uint64, ) from cubed.core import reduction +from cubed.array_api.elementwise_functions import subtract, square, sqrt def max(x, /, *, axis=None, keepdims=False): @@ -152,3 +153,12 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False): keepdims=keepdims, extra_func_kwargs=extra_func_kwargs, ) + + +def var(x, /, *, axis=None, keepdims=False): + mu = mean(x, axis=axis, keepdims=True) + return mean(square(subtract(x, mu)), axis=axis, keepdims=keepdims) + + +def std(x, /, *, axis=None, keepdims=False): + return sqrt(var(x, axis=axis, keepdims=keepdims)) diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 365ada4b..98f87046 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -563,6 +563,72 @@ def test_sum_axis_0(spec, executor): assert_array_equal(b.compute(executor=executor), np.array([12, 15, 18])) +def test_var(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.var(a) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(), + ) + + +def test_var_axis_0(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.var(a, axis=0) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(axis=0), + ) + + +def test_var_axis_1(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.var(a, axis=1) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(axis=1), + ) + + +def test_var_keepdims_true(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.var(a, keepdims=True) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(keepdims=True), + ) + + +def test_std(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.std(a) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).std(), + ) + + +def test_std_axis_0(spec, executor): + a = xp.asarray( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec + ) + b = xp.std(a, axis=0) + assert_array_equal( + b.compute(executor=executor), + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).std(axis=0), + ) + + # Utility functions