diff --git a/tests/test_series.py b/tests/test_series.py index 61e2b8bb..35e9ee06 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -3230,3 +3230,47 @@ def test_operator_constistency() -> None: pd.Series, pd.Timedelta, ) + + +def test_map() -> None: + s = pd.Series([1, 2, 3]) + + mapping = {1: "a", 2: "b", 3: "c"} + check( + assert_type(s.map(mapping, na_action="ignore"), "pd.Series[str]"), + pd.Series, + str, + ) + + def callable(x: int) -> str: + return str(x) + + check( + assert_type(s.map(callable, na_action="ignore"), "pd.Series[str]"), + pd.Series, + str, + ) + + series = pd.Series(["a", "b", "c"]) + check( + assert_type(s.map(series, na_action="ignore"), "pd.Series[str]"), pd.Series, str + ) + + +def test_map_na() -> None: + s: pd.Series[int] = pd.Series([1, pd.NA, 3]) + + mapping = {1: "a", 2: "b", 3: "c"} + check(assert_type(s.map(mapping, na_action=None), "pd.Series[str]"), pd.Series, str) + + def callable(x: int | NAType) -> str | NAType: + if isinstance(x, int): + return str(x) + return x + + check( + assert_type(s.map(callable, na_action=None), "pd.Series[str]"), pd.Series, str + ) + + series = pd.Series(["a", "b", "c"]) + check(assert_type(s.map(series, na_action=None), "pd.Series[str]"), pd.Series, str)