Skip to content

Commit

Permalink
COMPAT: unique() should preserve the dtype of the input (#27874)
Browse files Browse the repository at this point in the history
  • Loading branch information
stuarteberg authored and jreback committed Oct 7, 2019
1 parent a1dba2c commit af498fe
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ Other API changes

- :meth:`pandas.api.types.infer_dtype` will now return "integer-na" for integer and ``np.nan`` mix (:issue:`27283`)
- :meth:`MultiIndex.from_arrays` will no longer infer names from arrays if ``names=None`` is explicitly provided (:issue:`27292`)
- The returned dtype of ::func:`pd.unique` now matches the input dtype. (:issue:`27874`)
-

.. _whatsnew_1000.api.documentation:
Expand Down
8 changes: 4 additions & 4 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@ def _reconstruct_data(values, dtype, original):
if is_extension_array_dtype(dtype):
values = dtype.construct_array_type()._from_sequence(values)
elif is_bool_dtype(dtype):
values = values.astype(dtype)
values = values.astype(dtype, copy=False)

# we only support object dtypes bool Index
if isinstance(original, ABCIndexClass):
values = values.astype(object)
values = values.astype(object, copy=False)
elif dtype is not None:
values = values.astype(dtype)
values = values.astype(dtype, copy=False)

return values

Expand Down Expand Up @@ -396,7 +396,7 @@ def unique(values):

table = htable(len(values))
uniques = table.unique(values)
uniques = _reconstruct_data(uniques, dtype, original)
uniques = _reconstruct_data(uniques, original.dtype, original)
return uniques


Expand Down
27 changes: 23 additions & 4 deletions pandas/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def test_memory_usage(self):
class Ops:
def _allow_na_ops(self, obj):
"""Whether to skip test cases including NaN"""
if isinstance(obj, Index) and (obj.is_boolean() or not obj._can_hold_na):
# don't test boolean / int64 index
if (isinstance(obj, Index) and obj.is_boolean()) or not obj._can_hold_na:
# don't test boolean / integer dtypes
return False
return True

Expand All @@ -187,7 +187,24 @@ def setup_method(self, method):
types = ["bool", "int", "float", "dt", "dt_tz", "period", "string", "unicode"]
self.indexes = [getattr(self, "{}_index".format(t)) for t in types]
self.series = [getattr(self, "{}_series".format(t)) for t in types]
self.objs = self.indexes + self.series

# To test narrow dtypes, we use narrower *data* elements, not *index* elements
index = self.int_index
self.float32_series = Series(arr.astype(np.float32), index=index, name="a")

arr_int = np.random.choice(10, size=10, replace=False)
self.int8_series = Series(arr_int.astype(np.int8), index=index, name="a")
self.int16_series = Series(arr_int.astype(np.int16), index=index, name="a")
self.int32_series = Series(arr_int.astype(np.int32), index=index, name="a")

self.uint8_series = Series(arr_int.astype(np.uint8), index=index, name="a")
self.uint16_series = Series(arr_int.astype(np.uint16), index=index, name="a")
self.uint32_series = Series(arr_int.astype(np.uint32), index=index, name="a")

nrw_types = ["float32", "int8", "int16", "int32", "uint8", "uint16", "uint32"]
self.narrow_series = [getattr(self, "{}_series".format(t)) for t in nrw_types]

self.objs = self.indexes + self.series + self.narrow_series

def check_ops_properties(self, props, filter=None, ignore_failures=False):
for op in props:
Expand Down Expand Up @@ -385,6 +402,7 @@ def test_value_counts_unique_nunique(self):
if isinstance(o, Index):
assert isinstance(result, o.__class__)
tm.assert_index_equal(result, orig)
assert result.dtype == orig.dtype
elif is_datetime64tz_dtype(o):
# datetimetz Series returns array of Timestamp
assert result[0] == orig[0]
Expand All @@ -396,6 +414,7 @@ def test_value_counts_unique_nunique(self):
)
else:
tm.assert_numpy_array_equal(result, orig.values)
assert result.dtype == orig.dtype

assert o.nunique() == len(np.unique(o.values))

Expand Down Expand Up @@ -904,7 +923,7 @@ def test_fillna(self):

expected = [fill_value] * 2 + list(values[2:])

expected = klass(expected)
expected = klass(expected, dtype=orig.dtype)
o = klass(values)

# check values has the same dtype as the original
Expand Down

0 comments on commit af498fe

Please sign in to comment.