Skip to content

Commit

Permalink
TYP: more annotations for io.pytables (pandas-dev#29703)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and jacobaustin123 committed Nov 20, 2019
1 parent 4d89693 commit 0379fcd
Showing 1 changed file with 94 additions and 37 deletions.
131 changes: 94 additions & 37 deletions pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,16 +520,16 @@ def root(self):
def filename(self):
return self._path

def __getitem__(self, key):
def __getitem__(self, key: str):
return self.get(key)

def __setitem__(self, key, value):
def __setitem__(self, key: str, value):
self.put(key, value)

def __delitem__(self, key):
def __delitem__(self, key: str):
return self.remove(key)

def __getattr__(self, name):
def __getattr__(self, name: str):
""" allow attribute access to get stores """
try:
return self.get(name)
Expand Down Expand Up @@ -791,7 +791,12 @@ def func(_start, _stop, _where):
return it.get_result()

def select_as_coordinates(
self, key: str, where=None, start=None, stop=None, **kwargs
self,
key: str,
where=None,
start: Optional[int] = None,
stop: Optional[int] = None,
**kwargs,
):
"""
return the selection as an Index
Expand Down Expand Up @@ -943,13 +948,13 @@ def func(_start, _stop, _where):

return it.get_result(coordinates=True)

def put(self, key, value, format=None, append=False, **kwargs):
def put(self, key: str, value, format=None, append=False, **kwargs):
"""
Store object in HDFStore.
Parameters
----------
key : object
key : str
value : {Series, DataFrame}
format : 'fixed(f)|table(t)', default is 'fixed'
fixed(f) : Fixed format
Expand Down Expand Up @@ -1028,15 +1033,22 @@ def remove(self, key: str, where=None, start=None, stop=None):
return s.delete(where=where, start=start, stop=stop)

def append(
self, key, value, format=None, append=True, columns=None, dropna=None, **kwargs
self,
key: str,
value,
format=None,
append=True,
columns=None,
dropna=None,
**kwargs,
):
"""
Append to Table in file. Node must already exist and be Table
format.
Parameters
----------
key : object
key : str
value : {Series, DataFrame}
format : 'table' is the default
table(t) : table format
Expand Down Expand Up @@ -1077,7 +1089,14 @@ def append(
self._write_to_group(key, value, append=append, dropna=dropna, **kwargs)

def append_to_multiple(
self, d, value, selector, data_columns=None, axes=None, dropna=False, **kwargs
self,
d: Dict,
value,
selector,
data_columns=None,
axes=None,
dropna=False,
**kwargs,
):
"""
Append to multiple tables
Expand Down Expand Up @@ -1123,7 +1142,7 @@ def append_to_multiple(

# figure out how to split the value
remain_key = None
remain_values = []
remain_values: List = []
for k, v in d.items():
if v is None:
if remain_key is not None:
Expand Down Expand Up @@ -1871,7 +1890,7 @@ def validate(self, handler, append):
def validate_names(self):
pass

def validate_and_set(self, handler, append):
def validate_and_set(self, handler: "AppendableTable", append: bool):
self.set_table(handler.table)
self.validate_col()
self.validate_attr(append)
Expand Down Expand Up @@ -1901,7 +1920,7 @@ def validate_col(self, itemsize=None):

return None

def validate_attr(self, append):
def validate_attr(self, append: bool):
# check for backwards incompatibility
if append:
existing_kind = getattr(self.attrs, self.kind_attr, None)
Expand Down Expand Up @@ -1967,7 +1986,7 @@ def read_metadata(self, handler):
""" retrieve the metadata for this columns """
self.metadata = handler.read_metadata(self.cname)

def validate_metadata(self, handler):
def validate_metadata(self, handler: "AppendableTable"):
""" validate that kind=category does not change the categories """
if self.meta == "category":
new_metadata = self.metadata
Expand All @@ -1982,7 +2001,7 @@ def validate_metadata(self, handler):
"different categories to the existing"
)

def write_metadata(self, handler):
def write_metadata(self, handler: "AppendableTable"):
""" set the meta data """
if self.metadata is not None:
handler.write_metadata(self.cname, self.metadata)
Expand All @@ -1995,7 +2014,15 @@ class GenericIndexCol(IndexCol):
def is_indexed(self) -> bool:
return False

def convert(self, values, nan_rep, encoding, errors, start=None, stop=None):
def convert(
self,
values,
nan_rep,
encoding,
errors,
start: Optional[int] = None,
stop: Optional[int] = None,
):
""" set the values from this selection: take = take ownership
Parameters
Expand All @@ -2012,9 +2039,9 @@ def convert(self, values, nan_rep, encoding, errors, start=None, stop=None):
the underlying table's row count are normalized to that.
"""

start = start if start is not None else 0
stop = min(stop, self.table.nrows) if stop is not None else self.table.nrows
self.values = Int64Index(np.arange(stop - start))
_start = start if start is not None else 0
_stop = min(stop, self.table.nrows) if stop is not None else self.table.nrows
self.values = Int64Index(np.arange(_stop - _start))

return self

Expand Down Expand Up @@ -2749,7 +2776,9 @@ def get_attrs(self):
def write(self, obj, **kwargs):
self.set_attrs()

def read_array(self, key: str, start=None, stop=None):
def read_array(
self, key: str, start: Optional[int] = None, stop: Optional[int] = None
):
""" read an array for the specified node (off of group """
import tables

Expand Down Expand Up @@ -2836,7 +2865,7 @@ def write_block_index(self, key, index):
self.write_array("{key}_blengths".format(key=key), index.blengths)
setattr(self.attrs, "{key}_length".format(key=key), index.length)

def read_block_index(self, key, **kwargs):
def read_block_index(self, key, **kwargs) -> BlockIndex:
length = getattr(self.attrs, "{key}_length".format(key=key))
blocs = self.read_array("{key}_blocs".format(key=key), **kwargs)
blengths = self.read_array("{key}_blengths".format(key=key), **kwargs)
Expand All @@ -2846,7 +2875,7 @@ def write_sparse_intindex(self, key, index):
self.write_array("{key}_indices".format(key=key), index.indices)
setattr(self.attrs, "{key}_length".format(key=key), index.length)

def read_sparse_intindex(self, key, **kwargs):
def read_sparse_intindex(self, key, **kwargs) -> IntIndex:
length = getattr(self.attrs, "{key}_length".format(key=key))
indices = self.read_array("{key}_indices".format(key=key), **kwargs)
return IntIndex(length, indices)
Expand Down Expand Up @@ -2878,7 +2907,7 @@ def write_multi_index(self, key, index):
label_key = "{key}_label{idx}".format(key=key, idx=i)
self.write_array(label_key, level_codes)

def read_multi_index(self, key, **kwargs):
def read_multi_index(self, key, **kwargs) -> MultiIndex:
nlevels = getattr(self.attrs, "{key}_nlevels".format(key=key))

levels = []
Expand All @@ -2898,7 +2927,9 @@ def read_multi_index(self, key, **kwargs):
levels=levels, codes=codes, names=names, verify_integrity=True
)

def read_index_node(self, node, start=None, stop=None):
def read_index_node(
self, node, start: Optional[int] = None, stop: Optional[int] = None
):
data = node[start:stop]
# If the index was an empty array write_array_empty() will
# have written a sentinel. Here we relace it with the original.
Expand Down Expand Up @@ -2953,7 +2984,7 @@ def read_index_node(self, node, start=None, stop=None):

return name, index

def write_array_empty(self, key, value):
def write_array_empty(self, key: str, value):
""" write a 0-len array """

# ugly hack for length 0 axes
Expand All @@ -2966,7 +2997,7 @@ def _is_empty_array(self, shape) -> bool:
"""Returns true if any axis is zero length."""
return any(x == 0 for x in shape)

def write_array(self, key, value, items=None):
def write_array(self, key: str, value, items=None):
if key in self.group:
self._handle.remove_node(self.group, key)

Expand Down Expand Up @@ -3052,7 +3083,9 @@ def write_array(self, key, value, items=None):


class LegacyFixed(GenericFixed):
def read_index_legacy(self, key, start=None, stop=None):
def read_index_legacy(
self, key: str, start: Optional[int] = None, stop: Optional[int] = None
):
node = getattr(self.group, key)
data = node[start:stop]
kind = node._v_attrs.kind
Expand Down Expand Up @@ -3237,7 +3270,7 @@ def __init__(self, *args, **kwargs):
self.selection = None

@property
def table_type_short(self):
def table_type_short(self) -> str:
return self.table_type.split("_")[0]

@property
Expand Down Expand Up @@ -3311,7 +3344,7 @@ def validate(self, other):
)

@property
def is_multi_index(self):
def is_multi_index(self) -> bool:
"""the levels attribute is 1 or a list in the case of a multi-index"""
return isinstance(self.levels, list)

Expand All @@ -3335,7 +3368,7 @@ def validate_multiindex(self, obj):
)

@property
def nrows_expected(self):
def nrows_expected(self) -> int:
""" based on our axes, compute the expected nrows """
return np.prod([i.cvalues.shape[0] for i in self.index_axes])

Expand Down Expand Up @@ -3691,7 +3724,7 @@ def create_axes(
self,
axes,
obj,
validate=True,
validate: bool = True,
nan_rep=None,
data_columns=None,
min_itemsize=None,
Expand Down Expand Up @@ -4000,7 +4033,13 @@ def create_description(

return d

def read_coordinates(self, where=None, start=None, stop=None, **kwargs):
def read_coordinates(
self,
where=None,
start: Optional[int] = None,
stop: Optional[int] = None,
**kwargs,
):
"""select coordinates (row numbers) from a table; return the
coordinates object
"""
Expand All @@ -4013,7 +4052,7 @@ def read_coordinates(self, where=None, start=None, stop=None, **kwargs):
return False

# create the selection
self.selection = Selection(self, where=where, start=start, stop=stop, **kwargs)
self.selection = Selection(self, where=where, start=start, stop=stop)
coords = self.selection.select_coords()
if self.selection.filter is not None:
for field, op, filt in self.selection.filter.format():
Expand All @@ -4024,7 +4063,13 @@ def read_coordinates(self, where=None, start=None, stop=None, **kwargs):

return Index(coords)

def read_column(self, column: str, where=None, start=None, stop=None):
def read_column(
self,
column: str,
where=None,
start: Optional[int] = None,
stop: Optional[int] = None,
):
"""return a single column from the table, generally only indexables
are interesting
"""
Expand Down Expand Up @@ -4302,7 +4347,13 @@ def write_data_chunk(self, rows, indexes, mask, values):
"tables cannot write this data -> {detail}".format(detail=detail)
)

def delete(self, where=None, start=None, stop=None, **kwargs):
def delete(
self,
where=None,
start: Optional[int] = None,
stop: Optional[int] = None,
**kwargs,
):

# delete all rows (and return the nrows)
if where is None or not len(where):
Expand All @@ -4323,7 +4374,7 @@ def delete(self, where=None, start=None, stop=None, **kwargs):

# create the selection
table = self.table
self.selection = Selection(self, where, start=start, stop=stop, **kwargs)
self.selection = Selection(self, where, start=start, stop=stop)
values = self.selection.select_coords()

# delete the rows in reverse order
Expand Down Expand Up @@ -4913,7 +4964,13 @@ class Selection:
"""

def __init__(self, table, where=None, start=None, stop=None):
def __init__(
self,
table: Table,
where=None,
start: Optional[int] = None,
stop: Optional[int] = None,
):
self.table = table
self.where = where
self.start = start
Expand Down

0 comments on commit 0379fcd

Please sign in to comment.