Skip to content

Commit

Permalink
BUG: read_csv not applying dtype to index col (#44632)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Nov 28, 2021
1 parent d8068e5 commit 8ffa2a9
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 18 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ I/O
- Bug in :func:`json_normalize` where multi-character ``sep`` parameter is incorrectly prefixed to every key (:issue:`43831`)
- Bug in :func:`json_normalize` where reading data with missing multi-level metadata would not respect errors="ignore" (:issue:`44312`)
- Bug in :func:`read_csv` with :code:`float_precision="round_trip"` which did not skip initial/trailing whitespace (:issue:`43713`)
- Bug in :func:`read_csv` not applying dtype for ``index_col`` (:issue:`9435`)
- Bug in dumping/loading a :class:`DataFrame` with ``yaml.dump(frame)`` (:issue:`42748`)
- Bug in :class:`ExcelWriter`, where ``engine_kwargs`` were not passed through to all engines (:issue:`43442`)
- Bug in :func:`read_csv` raising ``ValueError`` when ``parse_dates`` was used with ``MultiIndex`` columns (:issue:`8991`)
Expand Down
28 changes: 27 additions & 1 deletion pandas/io/parsers/base_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from copy import copy
import csv
import datetime
from enum import Enum
Expand Down Expand Up @@ -149,6 +150,8 @@ def __init__(self, kwds):
self.na_filter = kwds.get("na_filter", False)
self.keep_default_na = kwds.get("keep_default_na", True)

self.dtype = copy(kwds.get("dtype", None))

self.true_values = kwds.get("true_values")
self.false_values = kwds.get("false_values")
self.mangle_dupe_cols = kwds.get("mangle_dupe_cols", True)
Expand Down Expand Up @@ -511,6 +514,19 @@ def _get_name(icol):

return index

def _clean_mapping(self, mapping):
"""converts col numbers to names"""
if not isinstance(mapping, dict):
return mapping
clean = {}
for col, v in mapping.items():
# for mypy
assert self.orig_names is not None
if isinstance(col, int) and col not in self.orig_names:
col = self.orig_names[col]
clean[col] = v
return clean

@final
def _agg_index(self, index, try_parse_dates: bool = True) -> Index:
arrays = []
Expand All @@ -535,7 +551,17 @@ def _agg_index(self, index, try_parse_dates: bool = True) -> Index:
col_name, self.na_values, self.na_fvalues, self.keep_default_na
)

arr, _ = self._infer_types(arr, col_na_values | col_na_fvalues)
clean_dtypes = self._clean_mapping(self.dtype)

cast_type = None
if isinstance(clean_dtypes, dict) and self.index_names is not None:
cast_type = clean_dtypes.get(self.index_names[i], None)

try_num_bool = not (cast_type and is_string_dtype(cast_type))

arr, _ = self._infer_types(
arr, col_na_values | col_na_fvalues, try_num_bool
)
arrays.append(arr)

names = self.index_names
Expand Down
19 changes: 2 additions & 17 deletions pandas/io/parsers/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
abc,
defaultdict,
)
from copy import copy
import csv
from io import StringIO
import re
Expand Down Expand Up @@ -89,7 +88,6 @@ def __init__(
self.verbose = kwds["verbose"]
self.converters = kwds["converters"]

self.dtype = copy(kwds["dtype"])
self.thousands = kwds["thousands"]
self.decimal = kwds["decimal"]

Expand Down Expand Up @@ -308,21 +306,8 @@ def get_chunk(self, size=None):

def _convert_data(self, data):
# apply converters
def _clean_mapping(mapping):
"""converts col numbers to names"""
clean = {}
for col, v in mapping.items():
if isinstance(col, int) and col not in self.orig_names:
col = self.orig_names[col]
clean[col] = v
return clean

clean_conv = _clean_mapping(self.converters)
if not isinstance(self.dtype, dict):
# handles single dtype applied to all columns
clean_dtypes = self.dtype
else:
clean_dtypes = _clean_mapping(self.dtype)
clean_conv = self._clean_mapping(self.converters)
clean_dtypes = self._clean_mapping(self.dtype)

# Apply NA values.
clean_na_values = {}
Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/io/parser/test_index_col.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,14 @@ def test_infer_types_boolean_sum(all_parsers):
# index column of dtype 'object', and the Python parser will return a
# index column of dtype 'int64'.
tm.assert_frame_equal(result, expected, check_index_type=False)


@skip_pyarrow
@pytest.mark.parametrize("dtype, val", [(object, "01"), ("int64", 1)])
def test_specify_dtype_for_index_col(all_parsers, dtype, val):
# GH#9435
data = "a,b\n01,2"
parser = all_parsers
result = parser.read_csv(StringIO(data), index_col="a", dtype={"a": dtype})
expected = DataFrame({"b": [2]}, index=Index([val], name="a"))
tm.assert_frame_equal(result, expected)

0 comments on commit 8ffa2a9

Please sign in to comment.