Skip to content

Commit

Permalink
Implementation of std and var.
Browse files Browse the repository at this point in the history
Fixes cubed-dev#29. Here, I use existing cubed operations to implement `var` and `std`. Please let me know if I should reimplement the primitives as pure reductions.
  • Loading branch information
alxmrs committed Sep 26, 2024
1 parent b4e94b0 commit 97514e3
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 2 deletions.
4 changes: 2 additions & 2 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@

__all__ += ["argmax", "argmin", "where"]

from .statistical_functions import max, mean, min, prod, sum
from .statistical_functions import max, mean, min, prod, sum, std, var

__all__ += ["max", "mean", "min", "prod", "sum"]
__all__ += ["max", "mean", "min", "prod", "sum", "std", "var"]

from .utility_functions import all, any

Expand Down
10 changes: 10 additions & 0 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from cubed.backend_array_api import namespace as nxp
from cubed.core import reduction
from cubed.array_api.elementwise_functions import subtract, square, sqrt


def max(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
Expand Down Expand Up @@ -184,3 +185,12 @@ def sum(
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)


def var(x, /, *, axis=None, keepdims=False):
mu = mean(x, axis=axis, keepdims=True)
return mean(square(subtract(x, mu)), axis=axis, keepdims=keepdims)


def std(x, /, *, axis=None, keepdims=False):
return sqrt(var(x, axis=axis, keepdims=keepdims))
66 changes: 66 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,72 @@ def test_sum_axis_0(spec, executor):
assert_array_equal(b.compute(executor=executor), np.array([12, 15, 18]))


def test_var(spec, executor):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
)
b = xp.var(a)
assert_array_equal(
b.compute(executor=executor),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(),
)


def test_var_axis_0(spec, executor):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
)
b = xp.var(a, axis=0)
assert_array_equal(
b.compute(executor=executor),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(axis=0),
)


def test_var_axis_1(spec, executor):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
)
b = xp.var(a, axis=1)
assert_array_equal(
b.compute(executor=executor),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(axis=1),
)


def test_var_keepdims_true(spec, executor):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
)
b = xp.var(a, keepdims=True)
assert_array_equal(
b.compute(executor=executor),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).var(keepdims=True),
)


def test_std(spec, executor):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
)
b = xp.std(a)
assert_array_equal(
b.compute(executor=executor),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).std(),
)


def test_std_axis_0(spec, executor):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
)
b = xp.std(a, axis=0)
assert_array_equal(
b.compute(executor=executor),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).std(axis=0),
)


# Utility functions


Expand Down

0 comments on commit 97514e3

Please sign in to comment.