-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add N-ary broadcasting operations. #98
Changes from 5 commits
6c2ecf9
fa65902
9ec04e6
84e4e6c
a715e9c
ebd69f2
4ba99d1
2fabb72
0b3111f
031fb67
e10fd4a
97353c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1894,38 +1894,6 @@ def tril(x, k=0): | |
return COO(coords, data, x.shape, x.has_duplicates, x.sorted) | ||
|
||
|
||
def _nary_match(*arrays): | ||
""" | ||
Matches coordinates from N different 1-D arrays. Equivalent to | ||
an SQL outer join. | ||
|
||
Parameters | ||
---------- | ||
arrays : tuple[numpy.ndarray] | ||
Input arrays to match, must be sorted. | ||
|
||
Returns | ||
------- | ||
matched : numpy.ndarray | ||
The overall matched array. | ||
|
||
matched_idx : list[numpy.ndarray] | ||
The indices for the matched coordinates in each array. | ||
""" | ||
|
||
matched = arrays[0] | ||
matched_idx = [np.arange(arrays[0].shape[0], | ||
dtype=np.min_scalar_type(arrays[0].shape[0] - 1))] | ||
|
||
for arr in arrays[1:]: | ||
old_idx, new_idx = _match_arrays(matched, arr) | ||
matched = matched[old_idx] | ||
matched_idx = [idx[old_idx] for idx in matched_idx] | ||
matched_idx.append(new_idx) | ||
|
||
return matched, matched_idx | ||
|
||
|
||
# (c) Paul Panzer | ||
# Taken from https://stackoverflow.com/a/47833496/774273 | ||
# License: https://creativecommons.org/licenses/by-sa/3.0/ | ||
|
@@ -2059,7 +2027,7 @@ def elemwise(func, *args, **kwargs): | |
|
||
# Filter out scalars as they are 'baked' into the function. | ||
func = PositinalArgumentPartial(func, pos, posargs) | ||
args = list(filter(lambda arg: not isscalar(arg), args)) | ||
args = [arg for arg in args if not isscalar(arg)] | ||
|
||
if len(args) == 0: | ||
return func(**kwargs) | ||
|
@@ -2109,11 +2077,12 @@ def _elemwise_n_ary(func, *args, **kwargs): | |
data_list = [] | ||
coords_list = [] | ||
|
||
cache = {} | ||
for mask in product([True, False], repeat=len(args)): | ||
if not any(mask): | ||
continue | ||
|
||
ci, di = _unmatch_coo(func, args, mask, **kwargs) | ||
ci, di = _unmatch_coo(func, args, mask, cache, **kwargs) | ||
|
||
coords_list.extend(ci) | ||
data_list.extend(di) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This confuses me and seems concerning. I see that this was a main point of your conversation with @shoyer earlier. I probably have some thinking to do on this problem before I'm able to reasonably comment on this. |
||
|
@@ -2132,7 +2101,7 @@ def _elemwise_n_ary(func, *args, **kwargs): | |
return COO(coords, data, shape=result_shape, has_duplicates=False) | ||
|
||
|
||
def _match_coo(*args): | ||
def _match_coo(*args, **kwargs): | ||
""" | ||
Matches the coordinates for any number of input :obj:`COO` arrays. | ||
Equivalent to "sparse" broadcasting for all arrays. | ||
|
@@ -2141,66 +2110,87 @@ def _match_coo(*args): | |
---------- | ||
args : Tuple[COO] | ||
The input :obj:`COO` arrays. | ||
return_midx : bool | ||
Whether to return matched indices or matched arrays. Matching | ||
only supported for two arrays. ``False`` by default. | ||
cache : dict | ||
Cache of things already matched. No cache by default. | ||
|
||
Returns | ||
------- | ||
matched_idx : List[ndarray] | ||
The indices of matched elements in the original arrays. | ||
The indices of matched elements in the original arrays. Only returned if | ||
``return_midx`` is ``True``. | ||
matched_arrays : List[COO] | ||
The expanded, matched :obj:`COO` objects. | ||
The expanded, matched :obj:`COO` objects. Only returned if | ||
``return_midx`` is ``False``. | ||
""" | ||
# If there's only a single input, return as-is. | ||
if len(args) == 1: | ||
return [np.arange(args[0].nnz)], [args[0]] | ||
|
||
shapes = [arg.shape for arg in args] | ||
matched_shape = _get_nary_broadcast_shape(*shapes) | ||
broadcast_params = [_get_broadcast_parameters(shape, matched_shape) | ||
for shape in shapes] | ||
|
||
combined_params = [all(params) for params in zip(*broadcast_params)] | ||
|
||
reduced_coords = [_get_reduced_coords(arg.coords, | ||
combined_params[-arg.ndim:]) | ||
for arg in args] | ||
reduced_shape = _get_reduced_shape(args[0].shape, | ||
combined_params[-args[0].ndim:]) | ||
|
||
reduced_linear = [_linear_loc(coord, reduced_shape) | ||
for coord in reduced_coords] | ||
sorted_indices = [np.argsort(lin) for lin in reduced_linear] | ||
|
||
reduced_linear = [lin[idx] for lin, idx in zip(reduced_linear, sorted_indices)] | ||
|
||
coords = [arg.coords[:, idx] for arg, idx in zip(args, sorted_indices)] | ||
data = [arg.data[idx] for arg, idx in zip(args, sorted_indices)] | ||
return_midx = kwargs.pop('return_midx', False) | ||
cache = kwargs.pop('cache', None) | ||
|
||
if kwargs: | ||
raise ValueError('Unknown kwargs %s' % kwargs.keys()) | ||
|
||
if return_midx and (len(args) != 2 or cache is not None): | ||
raise NotImplementedError('Matching only supported for two args, and no cache.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need this option? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, we don't. I'm not omniscient, so I went ahead and added this check in case someone tried to trigger caching on |
||
|
||
matched_arrays = [args[0]] | ||
cache_key = [id(args[0])] | ||
for arg2 in args[1:]: | ||
cache_key.append(id(arg2)) | ||
key = tuple(cache_key) | ||
if cache is not None and key in cache: | ||
matched_arrays = cache[key] | ||
continue | ||
|
||
# Find matches between self.coords and other.coords | ||
matched, matched_indices = _nary_match(*reduced_linear) | ||
cargs = [matched_arrays[0], arg2] | ||
current_shape = _get_broadcast_shape(matched_arrays[0].shape, arg2.shape) | ||
params = [_get_broadcast_parameters(arg.shape, current_shape) for arg in cargs] | ||
reduced_params = [all(p) for p in zip(*params)] | ||
reduced_shape = _get_reduced_shape(arg2.shape, | ||
reduced_params[-arg2.ndim:]) | ||
|
||
matched_coords = [c[:, idx] for c, idx in zip(coords, matched_indices)] | ||
reduced_coords = [_get_reduced_coords(arg.coords, reduced_params[-arg.ndim:]) | ||
for arg in cargs] | ||
|
||
# Add the matched part. | ||
matched_coords = _get_matching_coords(matched_coords, broadcast_params, matched_shape) | ||
matched_data = [d[idx] for d, idx in zip(data, matched_indices)] | ||
linear = [_linear_loc(rc, reduced_shape) for rc in reduced_coords] | ||
sorted_idx = [np.argsort(idx) for idx in linear] | ||
linear = [idx[s] for idx, s in zip(linear, sorted_idx)] | ||
coords = [arg.coords[:, s] for arg, s in zip(cargs, sorted_idx)] | ||
matched_idx = _match_arrays(*linear) | ||
mcoords = [c[:, idx] for c, idx in zip(coords, matched_idx)] | ||
mcoords = _get_matching_coords(mcoords, params, current_shape) | ||
mdata = [arg.data[sorted_idx[0]][matched_idx[0]] for arg in matched_arrays] | ||
mdata.append(arg2.data[sorted_idx[1]][matched_idx[1]]) | ||
matched_arrays = [COO(mcoords, md, shape=current_shape) for md in mdata] | ||
|
||
matched_indices = [sidx[midx] for sidx, midx in zip(sorted_indices, matched_indices)] | ||
if cache is not None: | ||
cache[key] = matched_arrays | ||
|
||
matched_arrays = [COO(matched_coords, d, shape=matched_shape) for d in matched_data] | ||
if return_midx: | ||
matched_idx = [sidx[midx] for sidx, midx in zip(sorted_idx, matched_idx)] | ||
return matched_idx | ||
|
||
return matched_indices, matched_arrays | ||
return matched_arrays | ||
|
||
|
||
def _unmatch_coo(func, args, mask, **kwargs): | ||
def _unmatch_coo(func, args, mask, cache, **kwargs): | ||
""" | ||
Matches the coordinates for any number of input :obj:`COO` arrays. | ||
|
||
First computes the matches, then filters out the non-matches. | ||
|
||
Parameters | ||
---------- | ||
func : Callable | ||
The function to compute matches | ||
args : tuple[COO] | ||
The input :obj:`COO` arrays. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
mask : tuple[bool] | ||
Specifies the inputs that are zero and the ones that are | ||
nonzero. | ||
kwargs: dict | ||
Extra keyword arguments to pass to func. | ||
|
||
Returns | ||
------- | ||
|
@@ -2214,33 +2204,32 @@ def _unmatch_coo(func, args, mask, **kwargs): | |
matched_args = [a for a, m in zip(args, mask) if m] | ||
unmatched_args = [a for a, m in zip(args, mask) if not m] | ||
|
||
matched_arrays = _match_coo(*matched_args)[1] | ||
matched_arrays = _match_coo(*matched_args, cache=cache) | ||
|
||
pos, = np.where([not m for m in mask]) | ||
pos = tuple(pos) | ||
pos = tuple(i for i, m in enumerate(mask) if not m) | ||
posargs = [_zero_of_dtype(arg.dtype)[()] for arg, m in zip(args, mask) if not m] | ||
result_shape = _get_nary_broadcast_shape(*[arg.shape for arg in args]) | ||
|
||
partial = PositinalArgumentPartial(func, pos, posargs) | ||
matched_func = partial(*[a.data for a in matched_arrays], **kwargs) | ||
|
||
if (matched_func == _zero_of_dtype(matched_func.dtype)).all(): | ||
return [], [] | ||
|
||
unmatched_mask = matched_func != _zero_of_dtype(matched_func.dtype) | ||
|
||
func_data = matched_func[unmatched_mask] | ||
func_coords = matched_arrays[0].coords[:, unmatched_mask] | ||
|
||
func_array = COO(func_coords, func_data, shape=matched_arrays[0].shape).broadcast_to(result_shape) | ||
|
||
if func_array.nnz == 0: | ||
return [], [] | ||
|
||
if all(mask): | ||
return [func_array.coords], [func_array.data] | ||
|
||
unmatched_mask = np.ones(func_array.nnz, dtype=np.bool) | ||
|
||
for arg in unmatched_args: | ||
matched_idx = _match_coo(func_array, arg)[0][0] | ||
matched_idx = _match_coo(func_array, arg, return_midx=True)[0] | ||
unmatched_mask[matched_idx] = False | ||
|
||
coords = np.asarray(func_array.coords[:, unmatched_mask], order='C') | ||
|
@@ -2499,7 +2488,7 @@ def _get_matching_coords(coords, params, shape): | |
The broacasted coordinates | ||
""" | ||
matching_coords = [] | ||
dims = np.zeros(len(params), dtype=np.uint8) | ||
dims = np.zeros(len(coords), dtype=np.uint8) | ||
|
||
for p_all in zip(*params): | ||
for i, p in enumerate(p_all): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here. No test operates on no args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added another small test for this.