Skip to content

Commit

Permalink
REF: ensure name and cname are always str (#29692)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and jreback committed Nov 20, 2019
1 parent 84fcbb8 commit eddd9f0
Showing 1 changed file with 72 additions and 36 deletions.
108 changes: 72 additions & 36 deletions pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,29 +1710,37 @@ class IndexCol:
is_data_indexable = True
_info_fields = ["freq", "tz", "index_name"]

name: str
cname: str
kind_attr: str

def __init__(
self,
name: str,
values=None,
kind=None,
typ=None,
cname=None,
cname: Optional[str] = None,
itemsize=None,
name=None,
axis=None,
kind_attr=None,
kind_attr: Optional[str] = None,
pos=None,
freq=None,
tz=None,
index_name=None,
**kwargs,
):

if not isinstance(name, str):
raise ValueError("`name` must be a str.")

self.values = values
self.kind = kind
self.typ = typ
self.itemsize = itemsize
self.name = name
self.cname = cname
self.kind_attr = kind_attr
self.cname = cname or name
self.kind_attr = kind_attr or f"{name}_kind"
self.axis = axis
self.pos = pos
self.freq = freq
Expand All @@ -1742,19 +1750,14 @@ def __init__(
self.meta = None
self.metadata = None

if name is not None:
self.set_name(name, kind_attr)
if pos is not None:
self.set_pos(pos)

def set_name(self, name, kind_attr=None):
""" set the name of this indexer """
self.name = name
self.kind_attr = kind_attr or "{name}_kind".format(name=name)
if self.cname is None:
self.cname = name

return self
# These are ensured as long as the passed arguments match the
# constructor annotations.
assert isinstance(self.name, str)
assert isinstance(self.cname, str)
assert isinstance(self.kind_attr, str)

def set_axis(self, axis: int):
""" set the axis over which I index """
Expand All @@ -1771,7 +1774,6 @@ def set_pos(self, pos: int):

def set_table(self, table):
self.table = table
return self

def __repr__(self) -> str:
temp = tuple(
Expand All @@ -1797,10 +1799,13 @@ def __ne__(self, other) -> bool:
@property
def is_indexed(self) -> bool:
""" return whether I am an indexed column """
try:
return getattr(self.table.cols, self.cname).is_indexed
except AttributeError:
if not hasattr(self.table, "cols"):
# e.g. if self.set_table hasn't been called yet, self.table
# will be None.
return False
# GH#29692 mypy doesn't recognize self.table as having a "cols" attribute
# 'error: "None" has no attribute "cols"'
return getattr(self.table.cols, self.cname).is_indexed # type: ignore

def copy(self):
new_self = copy.copy(self)
Expand Down Expand Up @@ -2508,6 +2513,7 @@ class DataIndexableCol(DataCol):

def validate_names(self):
if not Index(self.values).is_object():
# TODO: should the message here be more specifically non-str?
raise ValueError("cannot have non-object label DataIndexableCol")

def get_atom_string(self, block, itemsize):
Expand Down Expand Up @@ -2842,8 +2848,8 @@ def write_index(self, key, index):
else:
setattr(self.attrs, "{key}_variety".format(key=key), "regular")
converted = _convert_index(
index, self.encoding, self.errors, self.format_type
).set_name("index")
"index", index, self.encoding, self.errors, self.format_type
)

self.write_array(key, converted.values)

Expand Down Expand Up @@ -2893,8 +2899,8 @@ def write_multi_index(self, key, index):
)
level_key = "{key}_level{idx}".format(key=key, idx=i)
conv_level = _convert_index(
lev, self.encoding, self.errors, self.format_type
).set_name(level_key)
level_key, lev, self.encoding, self.errors, self.format_type
)
self.write_array(level_key, conv_level.values)
node = getattr(self.group, level_key)
node._v_attrs.kind = conv_level.kind
Expand Down Expand Up @@ -3436,9 +3442,10 @@ def queryables(self):

def index_cols(self):
""" return a list of my index cols """
# Note: each `i.cname` below is assured to be a str.
return [(i.axis, i.cname) for i in self.index_axes]

def values_cols(self):
def values_cols(self) -> List[str]:
""" return a list of my values cols """
return [i.cname for i in self.values_axes]

Expand Down Expand Up @@ -3540,6 +3547,8 @@ def indexables(self):

self._indexables = []

# Note: each of the `name` kwargs below are str, ensured
# by the definition in index_cols.
# index columns
self._indexables.extend(
[
Expand All @@ -3553,13 +3562,16 @@ def indexables(self):
base_pos = len(self._indexables)

def f(i, c):
assert isinstance(c, str)
klass = DataCol
if c in dc:
klass = DataIndexableCol
return klass.create_for_block(
i=i, name=c, pos=base_pos + i, version=self.version
)

# Note: the definition of `values_cols` ensures that each
# `c` below is a str.
self._indexables.extend(
[f(i, c) for i, c in enumerate(self.attrs.values_cols)]
)
Expand Down Expand Up @@ -3797,11 +3809,9 @@ def create_axes(

if i in axes:
name = obj._AXIS_NAMES[i]
index_axes_map[i] = (
_convert_index(a, self.encoding, self.errors, self.format_type)
.set_name(name)
.set_axis(i)
)
index_axes_map[i] = _convert_index(
name, a, self.encoding, self.errors, self.format_type
).set_axis(i)
else:

# we might be able to change the axes on the appending data if
Expand Down Expand Up @@ -3900,6 +3910,9 @@ def get_blk_items(mgr, blocks):
if data_columns and len(b_items) == 1 and b_items[0] in data_columns:
klass = DataIndexableCol
name = b_items[0]
if not (name is None or isinstance(name, str)):
# TODO: should the message here be more specifically non-str?
raise ValueError("cannot have non-object label DataIndexableCol")
self.data_columns.append(name)

# make sure that we match up the existing columns
Expand Down Expand Up @@ -4582,6 +4595,7 @@ def indexables(self):
self._indexables = [GenericIndexCol(name="index", axis=0)]

for i, n in enumerate(d._v_names):
assert isinstance(n, str)

dc = GenericDataIndexableCol(
name=n, pos=i, values=[n], version=self.version
Expand Down Expand Up @@ -4700,12 +4714,15 @@ def _set_tz(values, tz, preserve_UTC: bool = False, coerce: bool = False):
return values


def _convert_index(index, encoding=None, errors="strict", format_type=None):
def _convert_index(name: str, index, encoding=None, errors="strict", format_type=None):
assert isinstance(name, str)

index_name = getattr(index, "name", None)

if isinstance(index, DatetimeIndex):
converted = index.asi8
return IndexCol(
name,
converted,
"datetime64",
_tables().Int64Col(),
Expand All @@ -4716,6 +4733,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
elif isinstance(index, TimedeltaIndex):
converted = index.asi8
return IndexCol(
name,
converted,
"timedelta64",
_tables().Int64Col(),
Expand All @@ -4726,6 +4744,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
atom = _tables().Int64Col()
# avoid to store ndarray of Period objects
return IndexCol(
name,
index._ndarray_values,
"integer",
atom,
Expand All @@ -4743,6 +4762,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
if inferred_type == "datetime64":
converted = values.view("i8")
return IndexCol(
name,
converted,
"datetime64",
_tables().Int64Col(),
Expand All @@ -4753,6 +4773,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
elif inferred_type == "timedelta64":
converted = values.view("i8")
return IndexCol(
name,
converted,
"timedelta64",
_tables().Int64Col(),
Expand All @@ -4765,18 +4786,21 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
dtype=np.float64,
)
return IndexCol(
converted, "datetime", _tables().Time64Col(), index_name=index_name
name, converted, "datetime", _tables().Time64Col(), index_name=index_name
)
elif inferred_type == "date":
converted = np.asarray([v.toordinal() for v in values], dtype=np.int32)
return IndexCol(converted, "date", _tables().Time32Col(), index_name=index_name)
return IndexCol(
name, converted, "date", _tables().Time32Col(), index_name=index_name,
)
elif inferred_type == "string":
# atom = _tables().ObjectAtom()
# return np.asarray(values, dtype='O'), 'object', atom

converted = _convert_string_array(values, encoding, errors)
itemsize = converted.dtype.itemsize
return IndexCol(
name,
converted,
"string",
_tables().StringCol(itemsize),
Expand All @@ -4787,7 +4811,11 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
if format_type == "fixed":
atom = _tables().ObjectAtom()
return IndexCol(
np.asarray(values, dtype="O"), "object", atom, index_name=index_name
name,
np.asarray(values, dtype="O"),
"object",
atom,
index_name=index_name,
)
raise TypeError(
"[unicode] is not supported as a in index type for [{0}] formats".format(
Expand All @@ -4799,17 +4827,25 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
# take a guess for now, hope the values fit
atom = _tables().Int64Col()
return IndexCol(
np.asarray(values, dtype=np.int64), "integer", atom, index_name=index_name
name,
np.asarray(values, dtype=np.int64),
"integer",
atom,
index_name=index_name,
)
elif inferred_type == "floating":
atom = _tables().Float64Col()
return IndexCol(
np.asarray(values, dtype=np.float64), "float", atom, index_name=index_name
name,
np.asarray(values, dtype=np.float64),
"float",
atom,
index_name=index_name,
)
else: # pragma: no cover
atom = _tables().ObjectAtom()
return IndexCol(
np.asarray(values, dtype="O"), "object", atom, index_name=index_name
name, np.asarray(values, dtype="O"), "object", atom, index_name=index_name,
)


Expand Down

0 comments on commit eddd9f0

Please sign in to comment.