Skip to content

Commit

Permalink
CLN: (re-)enable infer_dtype to catch complex (pandas-dev#25382)
Browse files Browse the repository at this point in the history
  • Loading branch information
h-vetinari authored and Pingviinituutti committed Feb 28, 2019
1 parent 1b76b2a commit e6bef98
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ _TYPE_MAP = {
'float32': 'floating',
'float64': 'floating',
'f': 'floating',
'complex64': 'complex',
'complex128': 'complex',
'c': 'complex',
'string': 'string' if PY2 else 'bytes',
Expand Down Expand Up @@ -1305,6 +1306,9 @@ def infer_dtype(value: object, skipna: object=None) -> str:
elif is_decimal(val):
return 'decimal'

elif is_complex(val):
return 'complex'

elif util.is_float_object(val):
if is_float_array(values):
return 'floating'
Expand Down
31 changes: 31 additions & 0 deletions pandas/tests/dtypes/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,37 @@ def test_decimals(self):
result = lib.infer_dtype(arr, skipna=True)
assert result == 'decimal'

# complex is compatible with nan, so skipna has no effect
@pytest.mark.parametrize('skipna', [True, False])
def test_complex(self, skipna):
# gets cast to complex on array construction
arr = np.array([1.0, 2.0, 1 + 1j])
result = lib.infer_dtype(arr, skipna=skipna)
assert result == 'complex'

arr = np.array([1.0, 2.0, 1 + 1j], dtype='O')
result = lib.infer_dtype(arr, skipna=skipna)
assert result == 'mixed'

# gets cast to complex on array construction
arr = np.array([1, np.nan, 1 + 1j])
result = lib.infer_dtype(arr, skipna=skipna)
assert result == 'complex'

arr = np.array([1.0, np.nan, 1 + 1j], dtype='O')
result = lib.infer_dtype(arr, skipna=skipna)
assert result == 'mixed'

# complex with nans stays complex
arr = np.array([1 + 1j, np.nan, 3 + 3j], dtype='O')
result = lib.infer_dtype(arr, skipna=skipna)
assert result == 'complex'

# test smaller complex dtype; will pass through _try_infer_map fastpath
arr = np.array([1 + 1j, np.nan, 3 + 3j], dtype=np.complex64)
result = lib.infer_dtype(arr, skipna=skipna)
assert result == 'complex'

def test_string(self):
pass

Expand Down

0 comments on commit e6bef98

Please sign in to comment.