Skip to content

Commit

Permalink
Add back test disabled for debugging.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Feb 10, 2018
1 parent 9ec04e6 commit 84e4e6c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 55 deletions.
91 changes: 37 additions & 54 deletions sparse/coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numpy.lib.mixins import NDArrayOperatorsMixin

from .slicing import normalize_index
from .utils import _zero_of_dtype, isscalar
from .utils import _zero_of_dtype, isscalar, PositinalArgumentPartial
from .sparse_array import SparseArray
from .compatibility import int, zip_longest, range, zip

Expand Down Expand Up @@ -1895,6 +1895,24 @@ def tril(x, k=0):


def _nary_match(*arrays):
"""
Matches coordinates from N different 1-D arrays. Equivalent to
an SQL outer join.
Parameters
----------
arrays : tuple[numpy.ndarray]
Input arrays to match, must be sorted.
Returns
-------
matched : numpy.ndarray
The overall matched array.
matched_idx : list[numpy.ndarray]
The indices for the matched coordinates in each array.
"""

matched = arrays[0]
matched_idx = [np.arange(arrays[0].shape[0],
dtype=np.min_scalar_type(arrays[0].shape[0] - 1))]
Expand Down Expand Up @@ -2022,28 +2040,25 @@ def elemwise(func, *args, **kwargs):
args = list(args)
posargs = []
pos = []
for i in range(len(args)):
if isinstance(args[i], scipy.sparse.spmatrix):
args[i] = COO.from_scipy_sparse(args[i])

if isscalar(args[i]) or (isinstance(args[i], np.ndarray)
and not args[i].shape):
for i, arg in enumerate(args):
if isinstance(arg, scipy.sparse.spmatrix):
args[i] = COO.from_scipy_sparse(arg)
elif isscalar(arg) or (isinstance(arg, np.ndarray)
and not arg.shape):
# Faster and more reliable to pass ()-shaped ndarrays as scalars.
if isinstance(args[i], np.ndarray):
args[i] = args[i][()]
if isinstance(arg, np.ndarray):
args[i] = arg[()]

# The -scalars factor is there because we need to account for already
# added scalars in the function.
pos.append(i)
posargs.append(args[i])
elif isinstance(args[i], SparseArray) and not isinstance(args[i], COO):
args[i] = COO(args[i])
elif not isinstance(args[i], COO):
elif isinstance(arg, SparseArray) and not isinstance(arg, COO):
args[i] = COO(arg)
elif not isinstance(arg, COO):
raise ValueError("Performing this operation would produce "
"a dense result: %s" % str(func))

# Filter out scalars as they are 'baked' into the function.
func = _posarg_partial(func, pos, posargs)
func = PositinalArgumentPartial(func, pos, posargs)
args = list(filter(lambda arg: not isscalar(arg), args))

if len(args) == 0:
Expand Down Expand Up @@ -2124,14 +2139,14 @@ def _match_coo(*args):
Parameters
----------
args : tuple[COO]
args : Tuple[COO]
The input :obj:`COO` arrays.
Returns
-------
matched_idx : list[ndarray]
matched_idx : List[ndarray]
The indices of matched elements in the original arrays.
matched_arrays : list[COO]
matched_arrays : List[COO]
The expanded, matched :obj:`COO` objects.
"""
# If there's only a single input, return as-is.
Expand Down Expand Up @@ -2166,7 +2181,7 @@ def _match_coo(*args):
matched_coords = [c[:, idx] for c, idx in zip(coords, matched_indices)]

# Add the matched part.
matched_coords = _get_nary_matching_coords(matched_coords, broadcast_params, matched_shape)
matched_coords = _get_matching_coords(matched_coords, broadcast_params, matched_shape)
matched_data = [d[idx] for d, idx in zip(data, matched_indices)]

matched_indices = [sidx[midx] for sidx, midx in zip(sorted_indices, matched_indices)]
Expand Down Expand Up @@ -2206,7 +2221,7 @@ def _unmatch_coo(func, args, mask, **kwargs):
posargs = [_zero_of_dtype(arg.dtype)[()] for arg, m in zip(args, mask) if not m]
result_shape = _get_nary_broadcast_shape(*[arg.shape for arg in args])

partial = _posarg_partial(func, pos, posargs)
partial = PositinalArgumentPartial(func, pos, posargs)
matched_func = partial(*[a.data for a in matched_arrays], **kwargs)

unmatched_mask = matched_func != _zero_of_dtype(matched_func.dtype)
Expand Down Expand Up @@ -2468,7 +2483,7 @@ def _elemwise_unary(func, self, *args, **kwargs):
sorted=self.sorted)


def _get_nary_matching_coords(coords, params, shape):
def _get_matching_coords(coords, params, shape):
"""
Get the matching coords across a number of broadcast operands.
Expand All @@ -2481,7 +2496,7 @@ def _get_nary_matching_coords(coords, params, shape):
Returns
-------
numpy.ndarray
The broacasted coordinates.
The broacasted coordinates
"""
matching_coords = []
dims = np.zeros(len(params), dtype=np.uint8)
Expand Down Expand Up @@ -2516,35 +2531,3 @@ def _linear_loc(coords, shape, signed=False):
np.add(tmp, out, out=out)
strides *= d
return out


def _posarg_partial(func, pos, posargs):
if not isinstance(pos, Iterable):
pos = (pos,)
posargs = (posargs,)

n_partial_args = len(pos)

class Partial(object):
def __call__(self, *args, **kwargs):
j = 0
totargs = []

for i in range(len(args) + n_partial_args):
if j >= n_partial_args or i != pos[j]:
totargs.append(args[i - j])
else:
totargs.append(posargs[j])
j += 1

return func(*totargs, **kwargs)

def __str__(self):
return str(func)

def __repr__(self):
return repr(func)

__doc__ = func.__doc__

return Partial()
2 changes: 1 addition & 1 deletion sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_elemwise_scalar(func, scalar, convert_to_np_number):
(operator.le, 3),
(operator.eq, 1),
])
@pytest.mark.parametrize('convert_to_np_number', [True])
@pytest.mark.parametrize('convert_to_np_number', [True, False])
def test_leftside_elemwise_scalar(func, scalar, convert_to_np_number):
xs = sparse.random((2, 3, 4), density=0.5)
if convert_to_np_number:
Expand Down
37 changes: 37 additions & 0 deletions sparse/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from numbers import Integral
from collections import Iterable


def assert_eq(x, y, **kwargs):
Expand Down Expand Up @@ -151,3 +152,39 @@ def random(
def isscalar(x):
from .sparse_array import SparseArray
return not isinstance(x, SparseArray) and np.isscalar(x)


class PositinalArgumentPartial(object):
def __init__(self, func, pos, posargs):
if not isinstance(pos, Iterable):
pos = (pos,)
posargs = (posargs,)

n_partial_args = len(pos)

self.pos = pos
self.posargs = posargs
self.func = func

self.n = n_partial_args

self.__doc__ = func.__doc__

def __call__(self, *args, **kwargs):
j = 0
totargs = []

for i in range(len(args) + self.n):
if j >= self.n or i != self.pos[j]:
totargs.append(args[i - j])
else:
totargs.append(self.posargs[j])
j += 1

return self.func(*totargs, **kwargs)

def __str__(self):
return str(self.func)

def __repr__(self):
return repr(self.func)

0 comments on commit 84e4e6c

Please sign in to comment.