diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 840e79c6c9ebe..9749297efd004 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -445,6 +445,8 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray: elif isinstance(values, ABCMultiIndex): # Avoid raising in extract_array values = np.array(values) + else: + values = extract_array(values, extract_numpy=True) comps = _ensure_arraylike(comps) comps = extract_array(comps, extract_numpy=True) @@ -459,11 +461,14 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray: elif needs_i8_conversion(values.dtype) and not is_object_dtype(comps.dtype): # e.g. comps are integers and values are datetime64s return np.zeros(comps.shape, dtype=bool) + # TODO: not quite right ... Sparse/Categorical + elif needs_i8_conversion(values.dtype): + return isin(comps, values.astype(object)) - comps, dtype = _ensure_data(comps) - values, _ = _ensure_data(values, dtype=dtype) - - f = htable.ismember_object + elif is_extension_array_dtype(comps.dtype) or is_extension_array_dtype( + values.dtype + ): + return isin(np.asarray(comps), np.asarray(values)) # GH16012 # Ensure np.in1d doesn't get object types or it *may* throw an exception @@ -476,23 +481,15 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray: f = lambda c, v: np.logical_or(np.in1d(c, v), np.isnan(c)) else: f = np.in1d - elif is_integer_dtype(comps.dtype): - try: - values = values.astype("int64", copy=False) - comps = comps.astype("int64", copy=False) - f = htable.ismember_int64 - except (TypeError, ValueError, OverflowError): - values = values.astype(object) - comps = comps.astype(object) - - elif is_float_dtype(comps.dtype): - try: - values = values.astype("float64", copy=False) - comps = comps.astype("float64", copy=False) - f = htable.ismember_float64 - except (TypeError, ValueError): - values = values.astype(object) - comps = comps.astype(object) + + else: + common = np.find_common_type([values.dtype, comps.dtype], []) + values = values.astype(common, copy=False) + comps = comps.astype(common, copy=False) + name = common.name + if name == "bool": + name = "uint8" + f = getattr(htable, f"ismember_{name}") return f(comps, values) diff --git a/pandas/tests/test_algos.py b/pandas/tests/test_algos.py index d836ca7a53249..35411d7e9cfb7 100644 --- a/pandas/tests/test_algos.py +++ b/pandas/tests/test_algos.py @@ -1044,7 +1044,6 @@ def test_different_nans_as_float64(self): expected = np.array([True, True]) tm.assert_numpy_array_equal(result, expected) - @pytest.mark.xfail(reason="problem related with issue #34125") def test_isin_int_df_string_search(self): """Comparing df with int`s (1,2) with a string at isin() ("1") -> should not match values because int 1 is not equal str 1""" @@ -1053,7 +1052,6 @@ def test_isin_int_df_string_search(self): expected_false = DataFrame({"values": [False, False]}) tm.assert_frame_equal(result, expected_false) - @pytest.mark.xfail(reason="problem related with issue #34125") def test_isin_nan_df_string_search(self): """Comparing df with nan value (np.nan,2) with a string at isin() ("NaN") -> should not match values because np.nan is not equal str NaN""" @@ -1062,7 +1060,6 @@ def test_isin_nan_df_string_search(self): expected_false = DataFrame({"values": [False, False]}) tm.assert_frame_equal(result, expected_false) - @pytest.mark.xfail(reason="problem related with issue #34125") def test_isin_float_df_string_search(self): """Comparing df with floats (1.4245,2.32441) with a string at isin() ("1.4245") -> should not match values because float 1.4245 is not equal str 1.4245"""