Skip to content

Commit

Permalink
Delegate dtype functions to backend array API (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Mar 5, 2024
1 parent b59e281 commit 395ee54
Show file tree
Hide file tree
Showing 17 changed files with 76 additions and 287 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,14 @@ jobs:
array_api_tests/test_statistical_functions.py::test_std
array_api_tests/test_statistical_functions.py::test_var
# finfo(float32).eps returns float32 but should return float
array_api_tests/test_data_type_functions.py::test_finfo[float32]
# From https://github.com/data-apis/array-api-tests/blob/master/.github/workflows/numpy.yml
# https://github.com/numpy/numpy/issues/18881
array_api_tests/test_creation_functions.py::test_linspace
# https://github.com/numpy/numpy/issues/20870
array_api_tests/test_data_type_functions.py::test_can_cast
EOF
pytest -v -rxXfEA --max-examples=2 --disable-data-dependent-shapes --disable-extension linalg --ci --hypothesis-disable-deadline --cov=cubed.array_api --cov-report=term-missing
7 changes: 0 additions & 7 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
# Suppress numpy.array_api experimental warning
import sys
import warnings

if not sys.warnoptions:
warnings.filterwarnings("ignore", category=UserWarning)

from importlib.metadata import version as _version

try:
Expand Down
14 changes: 7 additions & 7 deletions cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,37 +225,37 @@ def __eq__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__eq__")
if other is NotImplemented:
return other
return elemwise(nxp.equal, self, other, dtype=np.bool_)
return elemwise(nxp.equal, self, other, dtype=nxp.bool)

def __ge__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__ge__")
if other is NotImplemented:
return other
return elemwise(nxp.greater_equal, self, other, dtype=np.bool_)
return elemwise(nxp.greater_equal, self, other, dtype=nxp.bool)

def __gt__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__gt__")
if other is NotImplemented:
return other
return elemwise(nxp.greater, self, other, dtype=np.bool_)
return elemwise(nxp.greater, self, other, dtype=nxp.bool)

def __le__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__le__")
if other is NotImplemented:
return other
return elemwise(nxp.less_equal, self, other, dtype=np.bool_)
return elemwise(nxp.less_equal, self, other, dtype=nxp.bool)

def __lt__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__lt__")
if other is NotImplemented:
return other
return elemwise(nxp.less, self, other, dtype=np.bool_)
return elemwise(nxp.less, self, other, dtype=nxp.bool)

def __ne__(self, other, /):
other = self._check_allowed_dtypes(other, "all", "__ne__")
if other is NotImplemented:
return other
return elemwise(nxp.not_equal, self, other, dtype=np.bool_)
return elemwise(nxp.not_equal, self, other, dtype=nxp.bool)

# Reflected Operators

Expand Down Expand Up @@ -425,7 +425,7 @@ def _promote_scalar(self, scalar):
"Python int scalars cannot be promoted with bool arrays"
)
if self.dtype in _integer_dtypes:
info = np.iinfo(self.dtype)
info = nxp.iinfo(self.dtype)
if not (info.min <= scalar <= info.max):
raise OverflowError(
"Python int scalars must be within the bounds of the dtype for integer arrays"
Expand Down
12 changes: 6 additions & 6 deletions cubed/array_api/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from cubed.backend_array_api import namespace as nxp

e = np.e
inf = np.inf
nan = np.nan
newaxis = None
pi = np.pi
e = nxp.e
inf = nxp.inf
nan = nxp.nan
newaxis = nxp.newaxis
pi = nxp.pi
17 changes: 8 additions & 9 deletions cubed/array_api/creation_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import math
from typing import TYPE_CHECKING, Iterable, List

import numpy as np
from zarr.util import normalize_shape

from cubed.backend_array_api import namespace as nxp
Expand Down Expand Up @@ -93,7 +92,7 @@ def empty_virtual_array(
shape, *, dtype=None, device=None, chunks="auto", spec=None, hidden=True
) -> "Array":
if dtype is None:
dtype = np.float64
dtype = nxp.float64

chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
name = gensym()
Expand All @@ -111,7 +110,7 @@ def eye(
if n_cols is None:
n_cols = n_rows
if dtype is None:
dtype = np.float64
dtype = nxp.float64

shape = (n_rows, n_cols)
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)
Expand Down Expand Up @@ -143,11 +142,11 @@ def full(
if dtype is None:
# check bool first since True/False are instances of int and float
if isinstance(fill_value, bool):
dtype = np.bool_
dtype = nxp.bool
elif isinstance(fill_value, int):
dtype = np.int64
dtype = nxp.int64
elif isinstance(fill_value, float):
dtype = np.float64
dtype = nxp.float64
else:
raise TypeError("Invalid input to full")
chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
Expand Down Expand Up @@ -194,7 +193,7 @@ def linspace(
div = 1
step = float(range_) / div
if dtype is None:
dtype = np.float64
dtype = nxp.float64
chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
chunksize = chunks[0][0]

Expand Down Expand Up @@ -258,7 +257,7 @@ def meshgrid(*arrays, indexing="xy") -> List["Array"]:

def ones(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
if dtype is None:
dtype = np.float64
dtype = nxp.float64
return full(shape, 1, dtype=dtype, device=device, chunks=chunks, spec=spec)


Expand Down Expand Up @@ -304,7 +303,7 @@ def _tri_mask(N, M, k, chunks, spec):

def zeros(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
if dtype is None:
dtype = np.float64
dtype = nxp.float64
return full(shape, 0, dtype=dtype, device=device, chunks=chunks, spec=spec)


Expand Down
125 changes: 7 additions & 118 deletions cubed/array_api/data_type_functions.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,6 @@
from dataclasses import dataclass

import numpy as np
from numpy.array_api._typing import Dtype

from cubed.backend_array_api import namespace as nxp
from cubed.core import CoreArray, map_blocks

from .dtypes import (
_all_dtypes,
_boolean_dtypes,
_complex_floating_dtypes,
_integer_dtypes,
_numeric_dtypes,
_real_floating_dtypes,
_result_type,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
)


def astype(x, dtype, /, *, copy=True):
if not copy and dtype == x.dtype:
Expand All @@ -30,118 +13,24 @@ def _astype(a, astype_dtype):


def can_cast(from_, to, /):
# Copied from numpy.array_api
# TODO: replace with `nxp.can_cast` when NumPy 1.25 is widely used (e.g. in Xarray)

if isinstance(from_, CoreArray):
from_ = from_.dtype
elif from_ not in _all_dtypes:
raise TypeError(f"{from_=}, but should be an array_api array or dtype")
if to not in _all_dtypes:
raise TypeError(f"{to=}, but should be a dtype")
try:
# We promote `from_` and `to` together. We then check if the promoted
# dtype is `to`, which indicates if `from_` can (up)cast to `to`.
dtype = _result_type(from_, to)
return to == dtype
except TypeError:
# _result_type() raises if the dtypes don't promote together
return False


@dataclass
class finfo_object:
bits: int
eps: float
max: float
min: float
smallest_normal: float
dtype: Dtype


@dataclass
class iinfo_object:
bits: int
max: int
min: int
dtype: Dtype
return nxp.can_cast(from_, to)


def finfo(type, /):
# Copied from numpy.array_api
# TODO: replace with `nxp.finfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray)

fi = np.finfo(type)
return finfo_object(
fi.bits,
float(fi.eps),
float(fi.max),
float(fi.min),
float(fi.smallest_normal),
fi.dtype,
)
return nxp.finfo(type)


def iinfo(type, /):
# Copied from numpy.array_api
# TODO: replace with `nxp.iinfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray)

ii = np.iinfo(type)
return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype)
return nxp.iinfo(type)


def isdtype(dtype, kind):
# Copied from numpy.array_api
# TODO: replace with `nxp.isdtype(dtype, kind)` when NumPy 1.25 is widely used (e.g. in Xarray)

if isinstance(kind, tuple):
# Disallow nested tuples
if any(isinstance(k, tuple) for k in kind):
raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs")
return any(isdtype(dtype, k) for k in kind)
elif isinstance(kind, str):
if kind == "bool":
return dtype in _boolean_dtypes
elif kind == "signed integer":
return dtype in _signed_integer_dtypes
elif kind == "unsigned integer":
return dtype in _unsigned_integer_dtypes
elif kind == "integral":
return dtype in _integer_dtypes
elif kind == "real floating":
return dtype in _real_floating_dtypes
elif kind == "complex floating":
return dtype in _complex_floating_dtypes
elif kind == "numeric":
return dtype in _numeric_dtypes
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
elif kind in _all_dtypes:
return dtype == kind
else:
raise TypeError(
f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}"
)
return nxp.isdtype(dtype, kind)


def result_type(*arrays_and_dtypes):
# Copied from numpy.array_api
# TODO: replace with `nxp.result_type` when NumPy 1.25 is widely used (e.g. in Xarray)

A = []
for a in arrays_and_dtypes:
if isinstance(a, CoreArray):
a = a.dtype
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
A.append(a)

if len(A) == 0:
raise ValueError("at least one array or dtype is required")
elif len(A) == 1:
return A[0]
else:
t = A[0]
for t2 in A[1:]:
t = _result_type(t, t2)
return t
return nxp.result_type(
*(a.dtype if isinstance(a, CoreArray) else a for a in arrays_and_dtypes)
)
Loading

0 comments on commit 395ee54

Please sign in to comment.