From eddd9f09a76628e1842cc19634b3f9b9f3b0fe83 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 20 Nov 2019 09:15:59 -0800 Subject: [PATCH] REF: ensure name and cname are always str (#29692) --- pandas/io/pytables.py | 108 ++++++++++++++++++++++++++++-------------- 1 file changed, 72 insertions(+), 36 deletions(-) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 9589832095474..4c9e10e0f4601 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -1710,29 +1710,37 @@ class IndexCol: is_data_indexable = True _info_fields = ["freq", "tz", "index_name"] + name: str + cname: str + kind_attr: str + def __init__( self, + name: str, values=None, kind=None, typ=None, - cname=None, + cname: Optional[str] = None, itemsize=None, - name=None, axis=None, - kind_attr=None, + kind_attr: Optional[str] = None, pos=None, freq=None, tz=None, index_name=None, **kwargs, ): + + if not isinstance(name, str): + raise ValueError("`name` must be a str.") + self.values = values self.kind = kind self.typ = typ self.itemsize = itemsize self.name = name - self.cname = cname - self.kind_attr = kind_attr + self.cname = cname or name + self.kind_attr = kind_attr or f"{name}_kind" self.axis = axis self.pos = pos self.freq = freq @@ -1742,19 +1750,14 @@ def __init__( self.meta = None self.metadata = None - if name is not None: - self.set_name(name, kind_attr) if pos is not None: self.set_pos(pos) - def set_name(self, name, kind_attr=None): - """ set the name of this indexer """ - self.name = name - self.kind_attr = kind_attr or "{name}_kind".format(name=name) - if self.cname is None: - self.cname = name - - return self + # These are ensured as long as the passed arguments match the + # constructor annotations. + assert isinstance(self.name, str) + assert isinstance(self.cname, str) + assert isinstance(self.kind_attr, str) def set_axis(self, axis: int): """ set the axis over which I index """ @@ -1771,7 +1774,6 @@ def set_pos(self, pos: int): def set_table(self, table): self.table = table - return self def __repr__(self) -> str: temp = tuple( @@ -1797,10 +1799,13 @@ def __ne__(self, other) -> bool: @property def is_indexed(self) -> bool: """ return whether I am an indexed column """ - try: - return getattr(self.table.cols, self.cname).is_indexed - except AttributeError: + if not hasattr(self.table, "cols"): + # e.g. if self.set_table hasn't been called yet, self.table + # will be None. return False + # GH#29692 mypy doesn't recognize self.table as having a "cols" attribute + # 'error: "None" has no attribute "cols"' + return getattr(self.table.cols, self.cname).is_indexed # type: ignore def copy(self): new_self = copy.copy(self) @@ -2508,6 +2513,7 @@ class DataIndexableCol(DataCol): def validate_names(self): if not Index(self.values).is_object(): + # TODO: should the message here be more specifically non-str? raise ValueError("cannot have non-object label DataIndexableCol") def get_atom_string(self, block, itemsize): @@ -2842,8 +2848,8 @@ def write_index(self, key, index): else: setattr(self.attrs, "{key}_variety".format(key=key), "regular") converted = _convert_index( - index, self.encoding, self.errors, self.format_type - ).set_name("index") + "index", index, self.encoding, self.errors, self.format_type + ) self.write_array(key, converted.values) @@ -2893,8 +2899,8 @@ def write_multi_index(self, key, index): ) level_key = "{key}_level{idx}".format(key=key, idx=i) conv_level = _convert_index( - lev, self.encoding, self.errors, self.format_type - ).set_name(level_key) + level_key, lev, self.encoding, self.errors, self.format_type + ) self.write_array(level_key, conv_level.values) node = getattr(self.group, level_key) node._v_attrs.kind = conv_level.kind @@ -3436,9 +3442,10 @@ def queryables(self): def index_cols(self): """ return a list of my index cols """ + # Note: each `i.cname` below is assured to be a str. return [(i.axis, i.cname) for i in self.index_axes] - def values_cols(self): + def values_cols(self) -> List[str]: """ return a list of my values cols """ return [i.cname for i in self.values_axes] @@ -3540,6 +3547,8 @@ def indexables(self): self._indexables = [] + # Note: each of the `name` kwargs below are str, ensured + # by the definition in index_cols. # index columns self._indexables.extend( [ @@ -3553,6 +3562,7 @@ def indexables(self): base_pos = len(self._indexables) def f(i, c): + assert isinstance(c, str) klass = DataCol if c in dc: klass = DataIndexableCol @@ -3560,6 +3570,8 @@ def f(i, c): i=i, name=c, pos=base_pos + i, version=self.version ) + # Note: the definition of `values_cols` ensures that each + # `c` below is a str. self._indexables.extend( [f(i, c) for i, c in enumerate(self.attrs.values_cols)] ) @@ -3797,11 +3809,9 @@ def create_axes( if i in axes: name = obj._AXIS_NAMES[i] - index_axes_map[i] = ( - _convert_index(a, self.encoding, self.errors, self.format_type) - .set_name(name) - .set_axis(i) - ) + index_axes_map[i] = _convert_index( + name, a, self.encoding, self.errors, self.format_type + ).set_axis(i) else: # we might be able to change the axes on the appending data if @@ -3900,6 +3910,9 @@ def get_blk_items(mgr, blocks): if data_columns and len(b_items) == 1 and b_items[0] in data_columns: klass = DataIndexableCol name = b_items[0] + if not (name is None or isinstance(name, str)): + # TODO: should the message here be more specifically non-str? + raise ValueError("cannot have non-object label DataIndexableCol") self.data_columns.append(name) # make sure that we match up the existing columns @@ -4582,6 +4595,7 @@ def indexables(self): self._indexables = [GenericIndexCol(name="index", axis=0)] for i, n in enumerate(d._v_names): + assert isinstance(n, str) dc = GenericDataIndexableCol( name=n, pos=i, values=[n], version=self.version @@ -4700,12 +4714,15 @@ def _set_tz(values, tz, preserve_UTC: bool = False, coerce: bool = False): return values -def _convert_index(index, encoding=None, errors="strict", format_type=None): +def _convert_index(name: str, index, encoding=None, errors="strict", format_type=None): + assert isinstance(name, str) + index_name = getattr(index, "name", None) if isinstance(index, DatetimeIndex): converted = index.asi8 return IndexCol( + name, converted, "datetime64", _tables().Int64Col(), @@ -4716,6 +4733,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None): elif isinstance(index, TimedeltaIndex): converted = index.asi8 return IndexCol( + name, converted, "timedelta64", _tables().Int64Col(), @@ -4726,6 +4744,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None): atom = _tables().Int64Col() # avoid to store ndarray of Period objects return IndexCol( + name, index._ndarray_values, "integer", atom, @@ -4743,6 +4762,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None): if inferred_type == "datetime64": converted = values.view("i8") return IndexCol( + name, converted, "datetime64", _tables().Int64Col(), @@ -4753,6 +4773,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None): elif inferred_type == "timedelta64": converted = values.view("i8") return IndexCol( + name, converted, "timedelta64", _tables().Int64Col(), @@ -4765,11 +4786,13 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None): dtype=np.float64, ) return IndexCol( - converted, "datetime", _tables().Time64Col(), index_name=index_name + name, converted, "datetime", _tables().Time64Col(), index_name=index_name ) elif inferred_type == "date": converted = np.asarray([v.toordinal() for v in values], dtype=np.int32) - return IndexCol(converted, "date", _tables().Time32Col(), index_name=index_name) + return IndexCol( + name, converted, "date", _tables().Time32Col(), index_name=index_name, + ) elif inferred_type == "string": # atom = _tables().ObjectAtom() # return np.asarray(values, dtype='O'), 'object', atom @@ -4777,6 +4800,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None): converted = _convert_string_array(values, encoding, errors) itemsize = converted.dtype.itemsize return IndexCol( + name, converted, "string", _tables().StringCol(itemsize), @@ -4787,7 +4811,11 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None): if format_type == "fixed": atom = _tables().ObjectAtom() return IndexCol( - np.asarray(values, dtype="O"), "object", atom, index_name=index_name + name, + np.asarray(values, dtype="O"), + "object", + atom, + index_name=index_name, ) raise TypeError( "[unicode] is not supported as a in index type for [{0}] formats".format( @@ -4799,17 +4827,25 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None): # take a guess for now, hope the values fit atom = _tables().Int64Col() return IndexCol( - np.asarray(values, dtype=np.int64), "integer", atom, index_name=index_name + name, + np.asarray(values, dtype=np.int64), + "integer", + atom, + index_name=index_name, ) elif inferred_type == "floating": atom = _tables().Float64Col() return IndexCol( - np.asarray(values, dtype=np.float64), "float", atom, index_name=index_name + name, + np.asarray(values, dtype=np.float64), + "float", + atom, + index_name=index_name, ) else: # pragma: no cover atom = _tables().ObjectAtom() return IndexCol( - np.asarray(values, dtype="O"), "object", atom, index_name=index_name + name, np.asarray(values, dtype="O"), "object", atom, index_name=index_name, )