Skip to content

Commit

Permalink
BUG: kind parameter on categorical argsort (pandas-dev#16834)
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-b1 authored and jreback committed Jul 7, 2017
1 parent 8d197ba commit 5cc1025
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 7 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.20.3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ Numeric
Categorical
^^^^^^^^^^^

- Bug in ``DataFrame.sort_values`` not respecting the ``kind`` with categorical data (:issue:`16793`)

Other
^^^^^
10 changes: 9 additions & 1 deletion pandas/compat/numpy/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def validate_argmax_with_skipna(skipna, args, kwargs):
validate_argsort = CompatValidator(ARGSORT_DEFAULTS, fname='argsort',
max_fname_arg_count=0, method='both')

# two different signatures of argsort, this second validation
# for when the `kind` param is supported
ARGSORT_DEFAULTS_KIND = OrderedDict()
ARGSORT_DEFAULTS_KIND['axis'] = -1
ARGSORT_DEFAULTS_KIND['order'] = None
validate_argsort_kind = CompatValidator(ARGSORT_DEFAULTS_KIND, fname='argsort',
max_fname_arg_count=0, method='both')


def validate_argsort_with_ascending(ascending, args, kwargs):
"""
Expand All @@ -121,7 +129,7 @@ def validate_argsort_with_ascending(ascending, args, kwargs):
args = (ascending,) + args
ascending = True

validate_argsort(args, kwargs, max_fname_arg_count=1)
validate_argsort_kind(args, kwargs, max_fname_arg_count=3)
return ascending


Expand Down
4 changes: 2 additions & 2 deletions pandas/core/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,7 @@ def check_for_ordered(self, op):
"you can use .as_ordered() to change the "
"Categorical to an ordered one\n".format(op=op))

def argsort(self, ascending=True, *args, **kwargs):
def argsort(self, ascending=True, kind='quicksort', *args, **kwargs):
"""
Returns the indices that would sort the Categorical instance if
'sort_values' was called. This function is implemented to provide
Expand All @@ -1309,7 +1309,7 @@ def argsort(self, ascending=True, *args, **kwargs):
numpy.ndarray.argsort
"""
ascending = nv.validate_argsort_with_ascending(ascending, args, kwargs)
result = np.argsort(self._codes.copy(), **kwargs)
result = np.argsort(self._codes.copy(), kind=kind, **kwargs)
if not ascending:
result = result[::-1]
return result
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def nargsort(items, kind='quicksort', ascending=True, na_position='last'):

# specially handle Categorical
if is_categorical_dtype(items):
return items.argsort(ascending=ascending)
return items.argsort(ascending=ascending, kind=kind)

items = np.asanyarray(items)
idx = np.arange(len(items))
Expand Down
9 changes: 9 additions & 0 deletions pandas/tests/frame/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,15 @@ def test_stable_descending_multicolumn_sort(self):
kind='mergesort')
assert_frame_equal(sorted_df, expected)

def test_stable_categorial(self):
# GH 16793
df = DataFrame({
'x': pd.Categorical(np.repeat([1, 2, 3, 4], 5), ordered=True)
})
expected = df.copy()
sorted_df = df.sort_values('x', kind='mergesort')
assert_frame_equal(sorted_df, expected)

def test_sort_datetimes(self):

# GH 3461, argsort / lexsort differences for a datetime column
Expand Down
5 changes: 2 additions & 3 deletions pandas/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,8 @@ def test_numpy_argsort(self):
tm.assert_numpy_array_equal(np.argsort(c), expected,
check_dtype=False)

msg = "the 'kind' parameter is not supported"
tm.assert_raises_regex(ValueError, msg, np.argsort,
c, kind='mergesort')
tm.assert_numpy_array_equal(np.argsort(c, kind='mergesort'), expected,
check_dtype=False)

msg = "the 'axis' parameter is not supported"
tm.assert_raises_regex(ValueError, msg, np.argsort,
Expand Down

0 comments on commit 5cc1025

Please sign in to comment.