Skip to content

Commit

Permalink
ENH: Add sort parameter to set operations for some Indexes and adjust… (
Browse files Browse the repository at this point in the history
  • Loading branch information
reidy-p authored and Pingviinituutti committed Feb 28, 2019
1 parent e5b643a commit 35da4eb
Show file tree
Hide file tree
Showing 15 changed files with 389 additions and 226 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ Other Enhancements
- :func:`read_fwf` now accepts keyword ``infer_nrows`` (:issue:`15138`).
- :func:`~DataFrame.to_parquet` now supports writing a ``DataFrame`` as a directory of parquet files partitioned by a subset of the columns when ``engine = 'pyarrow'`` (:issue:`23283`)
- :meth:`Timestamp.tz_localize`, :meth:`DatetimeIndex.tz_localize`, and :meth:`Series.tz_localize` have gained the ``nonexistent`` argument for alternative handling of nonexistent times. See :ref:`timeseries.timezone_nonexistent` (:issue:`8917`, :issue:`24466`)
- :meth:`Index.difference` now has an optional ``sort`` parameter to specify whether the results should be sorted if possible (:issue:`17839`)
- :meth:`Index.difference`, :meth:`Index.intersection`, :meth:`Index.union`, and :meth:`Index.symmetric_difference` now have an optional ``sort`` parameter to control whether the results should be sorted if possible (:issue:`17839`, :issue:`24471`)
- :meth:`read_excel()` now accepts ``usecols`` as a list of column names or callable (:issue:`18273`)
- :meth:`MultiIndex.to_flat_index` has been added to flatten multiple levels into a single-level :class:`Index` object.
- :meth:`DataFrame.to_stata` and :class:`pandas.io.stata.StataWriter117` can write mixed sting columns to Stata strl format (:issue:`23633`)
Expand Down
25 changes: 20 additions & 5 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,21 @@ def item_from_zerodim(val: object) -> object:

@cython.wraparound(False)
@cython.boundscheck(False)
def fast_unique_multiple(list arrays):
def fast_unique_multiple(list arrays, sort: bool=True):
"""
Generate a list of unique values from a list of arrays.
Parameters
----------
list : array-like
A list of array-like objects
sort : boolean
Whether or not to sort the resulting unique list
Returns
-------
unique_list : list of unique values
"""
cdef:
ndarray[object] buf
Py_ssize_t k = len(arrays)
Expand All @@ -217,10 +231,11 @@ def fast_unique_multiple(list arrays):
if val not in table:
table[val] = stub
uniques.append(val)
try:
uniques.sort()
except Exception:
pass
if sort:
try:
uniques.sort()
except Exception:
pass

return uniques

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _get_combined_index(indexes, intersect=False, sort=False):
elif intersect:
index = indexes[0]
for other in indexes[1:]:
index = index.intersection(other)
index = index.intersection(other, sort=sort)
else:
index = _union_indexes(indexes, sort=sort)
index = ensure_index(index)
Expand Down
67 changes: 46 additions & 21 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2241,13 +2241,17 @@ def _get_reconciled_name_object(self, other):
return self._shallow_copy(name=name)
return self

def union(self, other):
def union(self, other, sort=True):
"""
Form the union of two Index objects and sorts if possible.
Form the union of two Index objects.
Parameters
----------
other : Index or array-like
sort : bool, default True
Sort the resulting index if possible
.. versionadded:: 0.24.0
Returns
-------
Expand Down Expand Up @@ -2277,7 +2281,7 @@ def union(self, other):
if not is_dtype_union_equal(self.dtype, other.dtype):
this = self.astype('O')
other = other.astype('O')
return this.union(other)
return this.union(other, sort=sort)

# TODO(EA): setops-refactor, clean all this up
if is_period_dtype(self) or is_datetime64tz_dtype(self):
Expand Down Expand Up @@ -2311,29 +2315,33 @@ def union(self, other):
else:
result = lvals

try:
result = sorting.safe_sort(result)
except TypeError as e:
warnings.warn("%s, sort order is undefined for "
"incomparable objects" % e, RuntimeWarning,
stacklevel=3)
if sort:
try:
result = sorting.safe_sort(result)
except TypeError as e:
warnings.warn("{}, sort order is undefined for "
"incomparable objects".format(e),
RuntimeWarning, stacklevel=3)

# for subclasses
return self._wrap_setop_result(other, result)

def _wrap_setop_result(self, other, result):
return self._constructor(result, name=get_op_result_name(self, other))

def intersection(self, other):
def intersection(self, other, sort=True):
"""
Form the intersection of two Index objects.
This returns a new Index with elements common to the index and `other`,
preserving the order of the calling index.
This returns a new Index with elements common to the index and `other`.
Parameters
----------
other : Index or array-like
sort : bool, default True
Sort the resulting index if possible
.. versionadded:: 0.24.0
Returns
-------
Expand All @@ -2356,7 +2364,7 @@ def intersection(self, other):
if not is_dtype_equal(self.dtype, other.dtype):
this = self.astype('O')
other = other.astype('O')
return this.intersection(other)
return this.intersection(other, sort=sort)

# TODO(EA): setops-refactor, clean all this up
if is_period_dtype(self):
Expand Down Expand Up @@ -2385,8 +2393,18 @@ def intersection(self, other):
indexer = indexer[indexer != -1]

taken = other.take(indexer)

if sort:
taken = sorting.safe_sort(taken.values)
if self.name != other.name:
name = None
else:
name = self.name
return self._shallow_copy(taken, name=name)

if self.name != other.name:
taken.name = None

return taken

def difference(self, other, sort=True):
Expand Down Expand Up @@ -2442,16 +2460,18 @@ def difference(self, other, sort=True):

return this._shallow_copy(the_diff, name=result_name, freq=None)

def symmetric_difference(self, other, result_name=None):
def symmetric_difference(self, other, result_name=None, sort=True):
"""
Compute the symmetric difference of two Index objects.
It's sorted if sorting is possible.
Parameters
----------
other : Index or array-like
result_name : str
sort : bool, default True
Sort the resulting index if possible
.. versionadded:: 0.24.0
Returns
-------
Expand Down Expand Up @@ -2496,10 +2516,11 @@ def symmetric_difference(self, other, result_name=None):
right_diff = other.values.take(right_indexer)

the_diff = _concat._concat_compat([left_diff, right_diff])
try:
the_diff = sorting.safe_sort(the_diff)
except TypeError:
pass
if sort:
try:
the_diff = sorting.safe_sort(the_diff)
except TypeError:
pass

attribs = self._get_attributes_dict()
attribs['name'] = result_name
Expand Down Expand Up @@ -3226,8 +3247,12 @@ def join(self, other, how='left', level=None, return_indexers=False,
elif how == 'right':
join_index = other
elif how == 'inner':
join_index = self.intersection(other)
# TODO: sort=False here for backwards compat. It may
# be better to use the sort parameter passed into join
join_index = self.intersection(other, sort=False)
elif how == 'outer':
# TODO: sort=True here for backwards compat. It may
# be better to use the sort parameter passed into join
join_index = self.union(other)

if sort:
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def _wrap_setop_result(self, other, result):
name = get_op_result_name(self, other)
return self._shallow_copy(result, name=name, freq=None, tz=self.tz)

def intersection(self, other):
def intersection(self, other, sort=True):
"""
Specialized intersection for DatetimeIndex objects. May be much faster
than Index.intersection
Expand All @@ -617,7 +617,7 @@ def intersection(self, other):
other = DatetimeIndex(other)
except (TypeError, ValueError):
pass
result = Index.intersection(self, other)
result = Index.intersection(self, other, sort=sort)
if isinstance(result, DatetimeIndex):
if result.freq is None:
result.freq = to_offset(result.inferred_freq)
Expand All @@ -627,7 +627,7 @@ def intersection(self, other):
other.freq != self.freq or
not other.freq.isAnchored() or
(not self.is_monotonic or not other.is_monotonic)):
result = Index.intersection(self, other)
result = Index.intersection(self, other, sort=sort)
# Invalidate the freq of `result`, which may not be correct at
# this point, depending on the values.
result.freq = None
Expand Down
7 changes: 2 additions & 5 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,11 +1104,8 @@ def func(self, other, sort=True):
'objects that have compatible dtypes')
raise TypeError(msg.format(op=op_name))

if op_name == 'difference':
result = getattr(self._multiindex, op_name)(other._multiindex,
sort)
else:
result = getattr(self._multiindex, op_name)(other._multiindex)
result = getattr(self._multiindex, op_name)(other._multiindex,
sort=sort)
result_name = get_op_result_name(self, other)

# GH 19101: ensure empty results have correct dtype
Expand Down
28 changes: 21 additions & 7 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,13 +2879,17 @@ def equal_levels(self, other):
return False
return True

def union(self, other):
def union(self, other, sort=True):
"""
Form the union of two MultiIndex objects, sorting if possible
Form the union of two MultiIndex objects
Parameters
----------
other : MultiIndex or array / Index of tuples
sort : bool, default True
Sort the resulting MultiIndex if possible
.. versionadded:: 0.24.0
Returns
-------
Expand All @@ -2900,17 +2904,23 @@ def union(self, other):
return self

uniq_tuples = lib.fast_unique_multiple([self._ndarray_values,
other._ndarray_values])
other._ndarray_values],
sort=sort)

return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
names=result_names)

def intersection(self, other):
def intersection(self, other, sort=True):
"""
Form the intersection of two MultiIndex objects, sorting if possible
Form the intersection of two MultiIndex objects.
Parameters
----------
other : MultiIndex or array / Index of tuples
sort : bool, default True
Sort the resulting MultiIndex if possible
.. versionadded:: 0.24.0
Returns
-------
Expand All @@ -2924,7 +2934,11 @@ def intersection(self, other):

self_tuples = self._ndarray_values
other_tuples = other._ndarray_values
uniq_tuples = sorted(set(self_tuples) & set(other_tuples))
uniq_tuples = set(self_tuples) & set(other_tuples)

if sort:
uniq_tuples = sorted(uniq_tuples)

if len(uniq_tuples) == 0:
return MultiIndex(levels=self.levels,
codes=[[]] * self.nlevels,
Expand All @@ -2935,7 +2949,7 @@ def intersection(self, other):

def difference(self, other, sort=True):
"""
Compute sorted set difference of two MultiIndex objects
Compute set difference of two MultiIndex objects
Parameters
----------
Expand Down
13 changes: 9 additions & 4 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,17 @@ def equals(self, other):

return super(RangeIndex, self).equals(other)

def intersection(self, other):
def intersection(self, other, sort=True):
"""
Form the intersection of two Index objects. Sortedness of the result is
not guaranteed
Form the intersection of two Index objects.
Parameters
----------
other : Index or array-like
sort : bool, default True
Sort the resulting index if possible
.. versionadded:: 0.24.0
Returns
-------
Expand All @@ -361,7 +364,7 @@ def intersection(self, other):
return self._get_reconciled_name_object(other)

if not isinstance(other, RangeIndex):
return super(RangeIndex, self).intersection(other)
return super(RangeIndex, self).intersection(other, sort=sort)

if not len(self) or not len(other):
return RangeIndex._simple_new(None)
Expand Down Expand Up @@ -398,6 +401,8 @@ def intersection(self, other):

if (self._step < 0 and other._step < 0) is not (new_index._step < 0):
new_index = new_index[::-1]
if sort:
new_index = new_index.sort_values()
return new_index

def _min_fitting_element(self, lower_limit):
Expand Down
2 changes: 1 addition & 1 deletion pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4473,7 +4473,7 @@ def _reindex_axis(obj, axis, labels, other=None):

labels = ensure_index(labels.unique())
if other is not None:
labels = ensure_index(other.unique()) & labels
labels = ensure_index(other.unique()).intersection(labels, sort=False)
if not labels.equals(ax):
slicer = [slice(None, None)] * obj.ndim
slicer[axis] = labels
Expand Down
7 changes: 5 additions & 2 deletions pandas/tests/indexes/datetimes/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def test_intersection2(self):

@pytest.mark.parametrize("tz", [None, 'Asia/Tokyo', 'US/Eastern',
'dateutil/US/Pacific'])
def test_intersection(self, tz):
@pytest.mark.parametrize("sort", [True, False])
def test_intersection(self, tz, sort):
# GH 4690 (with tz)
base = date_range('6/1/2000', '6/30/2000', freq='D', name='idx')

Expand Down Expand Up @@ -185,7 +186,9 @@ def test_intersection(self, tz):

for (rng, expected) in [(rng2, expected2), (rng3, expected3),
(rng4, expected4)]:
result = base.intersection(rng)
result = base.intersection(rng, sort=sort)
if sort:
expected = expected.sort_values()
tm.assert_index_equal(result, expected)
assert result.name == expected.name
assert result.freq is None
Expand Down
Loading

0 comments on commit 35da4eb

Please sign in to comment.