Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add N-ary broadcasting operations. #98

Merged
merged 12 commits into from
Feb 18, 2018
149 changes: 69 additions & 80 deletions sparse/coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,38 +1894,6 @@ def tril(x, k=0):
return COO(coords, data, x.shape, x.has_duplicates, x.sorted)


def _nary_match(*arrays):
"""
Matches coordinates from N different 1-D arrays. Equivalent to
an SQL outer join.

Parameters
----------
arrays : tuple[numpy.ndarray]
Input arrays to match, must be sorted.

Returns
-------
matched : numpy.ndarray
The overall matched array.

matched_idx : list[numpy.ndarray]
The indices for the matched coordinates in each array.
"""

matched = arrays[0]
matched_idx = [np.arange(arrays[0].shape[0],
dtype=np.min_scalar_type(arrays[0].shape[0] - 1))]

for arr in arrays[1:]:
old_idx, new_idx = _match_arrays(matched, arr)
matched = matched[old_idx]
matched_idx = [idx[old_idx] for idx in matched_idx]
matched_idx.append(new_idx)

return matched, matched_idx


# (c) Paul Panzer
# Taken from https://stackoverflow.com/a/47833496/774273
# License: https://creativecommons.org/licenses/by-sa/3.0/
Expand Down Expand Up @@ -2059,7 +2027,7 @@ def elemwise(func, *args, **kwargs):

# Filter out scalars as they are 'baked' into the function.
func = PositinalArgumentPartial(func, pos, posargs)
args = list(filter(lambda arg: not isscalar(arg), args))
args = [arg for arg in args if not isscalar(arg)]

if len(args) == 0:
return func(**kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here. No test operates on no args

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added another small test for this.

Expand Down Expand Up @@ -2109,11 +2077,12 @@ def _elemwise_n_ary(func, *args, **kwargs):
data_list = []
coords_list = []

cache = {}
for mask in product([True, False], repeat=len(args)):
if not any(mask):
continue

ci, di = _unmatch_coo(func, args, mask, **kwargs)
ci, di = _unmatch_coo(func, args, mask, cache, **kwargs)

coords_list.extend(ci)
data_list.extend(di)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This confuses me and seems concerning. I see that this was a main point of your conversation with @shoyer earlier. I probably have some thinking to do on this problem before I'm able to reasonably comment on this.

Expand All @@ -2132,7 +2101,7 @@ def _elemwise_n_ary(func, *args, **kwargs):
return COO(coords, data, shape=result_shape, has_duplicates=False)


def _match_coo(*args):
def _match_coo(*args, **kwargs):
"""
Matches the coordinates for any number of input :obj:`COO` arrays.
Equivalent to "sparse" broadcasting for all arrays.
Expand All @@ -2141,66 +2110,87 @@ def _match_coo(*args):
----------
args : Tuple[COO]
The input :obj:`COO` arrays.
return_midx : bool
Whether to return matched indices or matched arrays. Matching
only supported for two arrays. ``False`` by default.
cache : dict
Cache of things already matched. No cache by default.

Returns
-------
matched_idx : List[ndarray]
The indices of matched elements in the original arrays.
The indices of matched elements in the original arrays. Only returned if
``return_midx`` is ``True``.
matched_arrays : List[COO]
The expanded, matched :obj:`COO` objects.
The expanded, matched :obj:`COO` objects. Only returned if
``return_midx`` is ``False``.
"""
# If there's only a single input, return as-is.
if len(args) == 1:
return [np.arange(args[0].nnz)], [args[0]]

shapes = [arg.shape for arg in args]
matched_shape = _get_nary_broadcast_shape(*shapes)
broadcast_params = [_get_broadcast_parameters(shape, matched_shape)
for shape in shapes]

combined_params = [all(params) for params in zip(*broadcast_params)]

reduced_coords = [_get_reduced_coords(arg.coords,
combined_params[-arg.ndim:])
for arg in args]
reduced_shape = _get_reduced_shape(args[0].shape,
combined_params[-args[0].ndim:])

reduced_linear = [_linear_loc(coord, reduced_shape)
for coord in reduced_coords]
sorted_indices = [np.argsort(lin) for lin in reduced_linear]

reduced_linear = [lin[idx] for lin, idx in zip(reduced_linear, sorted_indices)]

coords = [arg.coords[:, idx] for arg, idx in zip(args, sorted_indices)]
data = [arg.data[idx] for arg, idx in zip(args, sorted_indices)]
return_midx = kwargs.pop('return_midx', False)
cache = kwargs.pop('cache', None)

if kwargs:
raise ValueError('Unknown kwargs %s' % kwargs.keys())

if return_midx and (len(args) != 2 or cache is not None):
raise NotImplementedError('Matching only supported for two args, and no cache.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this option?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we don't. I'm not omniscient, so I went ahead and added this check in case someone tried to trigger caching on return_midx (which we don't cache, it's never repeated); or tried to match indices for len(args) != 2 (I'm not sure if we'll need this in the future, but we might, and it's useful to err rather than have it return incorrect results).


matched_arrays = [args[0]]
cache_key = [id(args[0])]
for arg2 in args[1:]:
cache_key.append(id(arg2))
key = tuple(cache_key)
if cache is not None and key in cache:
matched_arrays = cache[key]
continue

# Find matches between self.coords and other.coords
matched, matched_indices = _nary_match(*reduced_linear)
cargs = [matched_arrays[0], arg2]
current_shape = _get_broadcast_shape(matched_arrays[0].shape, arg2.shape)
params = [_get_broadcast_parameters(arg.shape, current_shape) for arg in cargs]
reduced_params = [all(p) for p in zip(*params)]
reduced_shape = _get_reduced_shape(arg2.shape,
reduced_params[-arg2.ndim:])

matched_coords = [c[:, idx] for c, idx in zip(coords, matched_indices)]
reduced_coords = [_get_reduced_coords(arg.coords, reduced_params[-arg.ndim:])
for arg in cargs]

# Add the matched part.
matched_coords = _get_matching_coords(matched_coords, broadcast_params, matched_shape)
matched_data = [d[idx] for d, idx in zip(data, matched_indices)]
linear = [_linear_loc(rc, reduced_shape) for rc in reduced_coords]
sorted_idx = [np.argsort(idx) for idx in linear]
linear = [idx[s] for idx, s in zip(linear, sorted_idx)]
coords = [arg.coords[:, s] for arg, s in zip(cargs, sorted_idx)]
matched_idx = _match_arrays(*linear)
mcoords = [c[:, idx] for c, idx in zip(coords, matched_idx)]
mcoords = _get_matching_coords(mcoords, params, current_shape)
mdata = [arg.data[sorted_idx[0]][matched_idx[0]] for arg in matched_arrays]
mdata.append(arg2.data[sorted_idx[1]][matched_idx[1]])
matched_arrays = [COO(mcoords, md, shape=current_shape) for md in mdata]

matched_indices = [sidx[midx] for sidx, midx in zip(sorted_indices, matched_indices)]
if cache is not None:
cache[key] = matched_arrays

matched_arrays = [COO(matched_coords, d, shape=matched_shape) for d in matched_data]
if return_midx:
matched_idx = [sidx[midx] for sidx, midx in zip(sorted_idx, matched_idx)]
return matched_idx

return matched_indices, matched_arrays
return matched_arrays


def _unmatch_coo(func, args, mask, **kwargs):
def _unmatch_coo(func, args, mask, cache, **kwargs):
"""
Matches the coordinates for any number of input :obj:`COO` arrays.

First computes the matches, then filters out the non-matches.

Parameters
----------
func : Callable
The function to compute matches
args : tuple[COO]
The input :obj:`COO` arrays.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add in func, mask and **kwargs to the docstring?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

mask : tuple[bool]
Specifies the inputs that are zero and the ones that are
nonzero.
kwargs: dict
Extra keyword arguments to pass to func.

Returns
-------
Expand All @@ -2214,33 +2204,32 @@ def _unmatch_coo(func, args, mask, **kwargs):
matched_args = [a for a, m in zip(args, mask) if m]
unmatched_args = [a for a, m in zip(args, mask) if not m]

matched_arrays = _match_coo(*matched_args)[1]
matched_arrays = _match_coo(*matched_args, cache=cache)

pos, = np.where([not m for m in mask])
pos = tuple(pos)
pos = tuple(i for i, m in enumerate(mask) if not m)
posargs = [_zero_of_dtype(arg.dtype)[()] for arg, m in zip(args, mask) if not m]
result_shape = _get_nary_broadcast_shape(*[arg.shape for arg in args])

partial = PositinalArgumentPartial(func, pos, posargs)
matched_func = partial(*[a.data for a in matched_arrays], **kwargs)

if (matched_func == _zero_of_dtype(matched_func.dtype)).all():
return [], []

unmatched_mask = matched_func != _zero_of_dtype(matched_func.dtype)

func_data = matched_func[unmatched_mask]
func_coords = matched_arrays[0].coords[:, unmatched_mask]

func_array = COO(func_coords, func_data, shape=matched_arrays[0].shape).broadcast_to(result_shape)

if func_array.nnz == 0:
return [], []

if all(mask):
return [func_array.coords], [func_array.data]

unmatched_mask = np.ones(func_array.nnz, dtype=np.bool)

for arg in unmatched_args:
matched_idx = _match_coo(func_array, arg)[0][0]
matched_idx = _match_coo(func_array, arg, return_midx=True)[0]
unmatched_mask[matched_idx] = False

coords = np.asarray(func_array.coords[:, unmatched_mask], order='C')
Expand Down Expand Up @@ -2499,7 +2488,7 @@ def _get_matching_coords(coords, params, shape):
The broacasted coordinates
"""
matching_coords = []
dims = np.zeros(len(params), dtype=np.uint8)
dims = np.zeros(len(coords), dtype=np.uint8)

for p_all in zip(*params):
for i, p in enumerate(p_all):
Expand Down
Loading