Skip to content

Commit

Permalink
Complete elemwise broadcasting and add tests recommended by @mrocklin
Browse files Browse the repository at this point in the history
…and @shoyer.
  • Loading branch information
hameerabbasi committed Feb 17, 2018
1 parent 84e4e6c commit a715e9c
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 50 deletions.
117 changes: 69 additions & 48 deletions sparse/coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2059,7 +2059,7 @@ def elemwise(func, *args, **kwargs):

# Filter out scalars as they are 'baked' into the function.
func = PositinalArgumentPartial(func, pos, posargs)
args = list(filter(lambda arg: not isscalar(arg), args))
args = [arg for arg in args if not isscalar(arg)]

if len(args) == 0:
return func(**kwargs)
Expand Down Expand Up @@ -2109,11 +2109,12 @@ def _elemwise_n_ary(func, *args, **kwargs):
data_list = []
coords_list = []

cache = {}
for mask in product([True, False], repeat=len(args)):
if not any(mask):
continue

ci, di = _unmatch_coo(func, args, mask, **kwargs)
ci, di = _unmatch_coo(func, args, mask, cache, **kwargs)

coords_list.extend(ci)
data_list.extend(di)
Expand All @@ -2132,7 +2133,7 @@ def _elemwise_n_ary(func, *args, **kwargs):
return COO(coords, data, shape=result_shape, has_duplicates=False)


def _match_coo(*args):
def _match_coo(*args, **kwargs):
"""
Matches the coordinates for any number of input :obj:`COO` arrays.
Equivalent to "sparse" broadcasting for all arrays.
Expand All @@ -2141,66 +2142,87 @@ def _match_coo(*args):
----------
args : Tuple[COO]
The input :obj:`COO` arrays.
return_midx : bool
Whether to return matched indices or matched arrays. Matching
only supported for two arrays. ``False`` by default.
cache : dict
Cache of things already matched. No cache by default.
Returns
-------
matched_idx : List[ndarray]
The indices of matched elements in the original arrays.
The indices of matched elements in the original arrays. Only returned if
``return_midx`` is ``True``.
matched_arrays : List[COO]
The expanded, matched :obj:`COO` objects.
The expanded, matched :obj:`COO` objects. Only returned if
``return_midx`` is ``False``.
"""
# If there's only a single input, return as-is.
if len(args) == 1:
return [np.arange(args[0].nnz)], [args[0]]

shapes = [arg.shape for arg in args]
matched_shape = _get_nary_broadcast_shape(*shapes)
broadcast_params = [_get_broadcast_parameters(shape, matched_shape)
for shape in shapes]

combined_params = [all(params) for params in zip(*broadcast_params)]

reduced_coords = [_get_reduced_coords(arg.coords,
combined_params[-arg.ndim:])
for arg in args]
reduced_shape = _get_reduced_shape(args[0].shape,
combined_params[-args[0].ndim:])

reduced_linear = [_linear_loc(coord, reduced_shape)
for coord in reduced_coords]
sorted_indices = [np.argsort(lin) for lin in reduced_linear]

reduced_linear = [lin[idx] for lin, idx in zip(reduced_linear, sorted_indices)]

coords = [arg.coords[:, idx] for arg, idx in zip(args, sorted_indices)]
data = [arg.data[idx] for arg, idx in zip(args, sorted_indices)]
return_midx = kwargs.pop('return_midx', False)
cache = kwargs.pop('cache', None)

if kwargs:
raise ValueError('Unknown kwargs %s' % kwargs.keys())

if return_midx and (len(args) != 2 or cache is not None):
raise NotImplementedError('Matching only supported for two args, and no cache.')

matched_arrays = [args[0]]
cache_key = [id(args[0])]
for arg2 in args[1:]:
cache_key.append(id(arg2))
key = tuple(cache_key)
if cache is not None and key in cache:
matched_arrays = cache[key]
continue

# Find matches between self.coords and other.coords
matched, matched_indices = _nary_match(*reduced_linear)
cargs = [matched_arrays[0], arg2]
current_shape = _get_broadcast_shape(matched_arrays[0].shape, arg2.shape)
params = [_get_broadcast_parameters(arg.shape, current_shape) for arg in cargs]
reduced_params = [all(p) for p in zip(*params)]
reduced_shape = _get_reduced_shape(arg2.shape,
reduced_params[-arg2.ndim:])

matched_coords = [c[:, idx] for c, idx in zip(coords, matched_indices)]
reduced_coords = [_get_reduced_coords(arg.coords, reduced_params[-arg.ndim:])
for arg in cargs]

# Add the matched part.
matched_coords = _get_matching_coords(matched_coords, broadcast_params, matched_shape)
matched_data = [d[idx] for d, idx in zip(data, matched_indices)]
linear = [_linear_loc(rc, reduced_shape) for rc in reduced_coords]
sorted_idx = [np.argsort(idx) for idx in linear]
linear = [idx[s] for idx, s in zip(linear, sorted_idx)]
coords = [arg.coords[:, s] for arg, s in zip(cargs, sorted_idx)]
matched_idx = _match_arrays(*linear)
mcoords = [c[:, idx] for c, idx in zip(coords, matched_idx)]
mcoords = _get_matching_coords(mcoords, params, current_shape)
mdata = [arg.data[sorted_idx[0]][matched_idx[0]] for arg in matched_arrays]
mdata.append(arg2.data[sorted_idx[1]][matched_idx[1]])
matched_arrays = [COO(mcoords, md, shape=current_shape) for md in mdata]

matched_indices = [sidx[midx] for sidx, midx in zip(sorted_indices, matched_indices)]
if cache is not None:
cache[key] = matched_arrays

matched_arrays = [COO(matched_coords, d, shape=matched_shape) for d in matched_data]
if return_midx:
matched_idx = [sidx[midx] for sidx, midx in zip(sorted_idx, matched_idx)]
return matched_idx

return matched_indices, matched_arrays
return matched_arrays


def _unmatch_coo(func, args, mask, **kwargs):
def _unmatch_coo(func, args, mask, cache, **kwargs):
"""
Matches the coordinates for any number of input :obj:`COO` arrays.
First computes the matches, then filters out the non-matches.
Parameters
----------
func : Callable
The function to compute matches
args : tuple[COO]
The input :obj:`COO` arrays.
mask : tuple[bool]
Specifies the inputs that are zero and the ones that are
nonzero.
kwargs: dict
Extra keyword arguments to pass to func.
Returns
-------
Expand All @@ -2214,33 +2236,32 @@ def _unmatch_coo(func, args, mask, **kwargs):
matched_args = [a for a, m in zip(args, mask) if m]
unmatched_args = [a for a, m in zip(args, mask) if not m]

matched_arrays = _match_coo(*matched_args)[1]
matched_arrays = _match_coo(*matched_args, cache=cache)

pos, = np.where([not m for m in mask])
pos = tuple(pos)
pos = tuple(i for i, m in enumerate(mask) if not m)
posargs = [_zero_of_dtype(arg.dtype)[()] for arg, m in zip(args, mask) if not m]
result_shape = _get_nary_broadcast_shape(*[arg.shape for arg in args])

partial = PositinalArgumentPartial(func, pos, posargs)
matched_func = partial(*[a.data for a in matched_arrays], **kwargs)

if (matched_func == _zero_of_dtype(matched_func.dtype)).all():
return [], []

unmatched_mask = matched_func != _zero_of_dtype(matched_func.dtype)

func_data = matched_func[unmatched_mask]
func_coords = matched_arrays[0].coords[:, unmatched_mask]

func_array = COO(func_coords, func_data, shape=matched_arrays[0].shape).broadcast_to(result_shape)

if func_array.nnz == 0:
return [], []

if all(mask):
return [func_array.coords], [func_array.data]

unmatched_mask = np.ones(func_array.nnz, dtype=np.bool)

for arg in unmatched_args:
matched_idx = _match_coo(func_array, arg)[0][0]
matched_idx = _match_coo(func_array, arg, return_midx=True)[0]
unmatched_mask[matched_idx] = False

coords = np.asarray(func_array.coords[:, unmatched_mask], order='C')
Expand Down Expand Up @@ -2499,7 +2520,7 @@ def _get_matching_coords(coords, params, shape):
The broacasted coordinates
"""
matching_coords = []
dims = np.zeros(len(params), dtype=np.uint8)
dims = np.zeros(len(coords), dtype=np.uint8)

for p_all in zip(*params):
for i, p in enumerate(p_all):
Expand Down
92 changes: 90 additions & 2 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def test_elemwise(func):
x = s.todense()

fs = func(s)

assert isinstance(fs, COO)
assert fs.nnz <= s.nnz

assert_eq(func(x), fs)

Expand Down Expand Up @@ -250,7 +250,94 @@ def test_elemwise_trinary(func, shape):
y = ys.todense()
z = zs.todense()

assert_eq(sparse.elemwise(func, xs, ys, zs), func(x, y, z))
fs = sparse.elemwise(func, xs, ys, zs)
assert isinstance(fs, COO)

assert_eq(fs, func(x, y, z))


@pytest.mark.parametrize('shapes, func', [
([
(2,),
(3, 2),
(4, 3, 2),
], lambda x, y, z: (x + y) * z),
([
(3,),
(2, 3),
(2, 2, 3),
], lambda x, y, z: x * (y + z)),
([
(2,),
(2, 2),
(2, 2, 2),
], lambda x, y, z: x * y * z),
([
(4,),
(4, 4),
(4, 4, 4),
], lambda x, y, z: x + y + z),
([
(4,),
(4, 4),
(4, 4, 4),
], lambda x, y, z: x + y - z),
([
(4,),
(4, 4),
(4, 4, 4),
], lambda x, y, z: x - y + z),
])
def test_nary_broadcasting(shapes, func):
args = [sparse.random(s, density=0.5) for s in shapes]
dense_args = [arg.todense() for arg in args]

fs = sparse.elemwise(func, *args)
assert isinstance(fs, COO)

assert_eq(fs, func(*dense_args))


@pytest.mark.parametrize('shapes, func', [
([
(2,),
(3, 2),
(4, 3, 2),
], lambda x, y, z: (x + y) * z),
([
(3,),
(2, 3),
(2, 2, 3),
], lambda x, y, z: x * (y + z)),
([
(2,),
(2, 2),
(2, 2, 2),
], lambda x, y, z: x * y * z),
([
(4,),
(4, 4),
(4, 4, 4),
], lambda x, y, z: x + y + z),
])
@pytest.mark.parametrize('value', [
np.nan,
np.inf,
-np.inf
])
def test_nary_broadcasting_pathological(shapes, func, value):
def value_array(n):
ar = np.empty((n,), dtype=np.float_)
ar[:] = value
return ar

args = [sparse.random(s, density=0.5, data_rvs=value_array) for s in shapes]
dense_args = [arg.todense() for arg in args]

fs = sparse.elemwise(func, *args)
assert isinstance(fs, COO)

assert_eq(fs, func(*dense_args), equal_nan=True)


@pytest.mark.parametrize('func', [
Expand Down Expand Up @@ -755,6 +842,7 @@ def test_broadcasting(func, shape1, shape2):
expected = func(x, y)
actual = func(xs, ys)

assert isinstance(actual, COO)
assert_eq(expected, actual)

assert np.count_nonzero(expected) == actual.nnz
Expand Down

0 comments on commit a715e9c

Please sign in to comment.