Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: more annotations for io.pytables #29703

Merged
merged 18 commits into from
Nov 20, 2019
131 changes: 94 additions & 37 deletions pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,16 +520,16 @@ def root(self):
def filename(self):
return self._path

def __getitem__(self, key):
def __getitem__(self, key: str):
return self.get(key)

def __setitem__(self, key, value):
def __setitem__(self, key: str, value):
self.put(key, value)

def __delitem__(self, key):
def __delitem__(self, key: str):
return self.remove(key)

def __getattr__(self, name):
def __getattr__(self, name: str):
""" allow attribute access to get stores """
try:
return self.get(name)
Expand Down Expand Up @@ -791,7 +791,12 @@ def func(_start, _stop, _where):
return it.get_result()

def select_as_coordinates(
self, key: str, where=None, start=None, stop=None, **kwargs
self,
key: str,
where=None,
start: Optional[int] = None,
stop: Optional[int] = None,
**kwargs,
):
"""
return the selection as an Index
Expand Down Expand Up @@ -943,13 +948,13 @@ def func(_start, _stop, _where):

return it.get_result(coordinates=True)

def put(self, key, value, format=None, append=False, **kwargs):
def put(self, key: str, value, format=None, append=False, **kwargs):
"""
Store object in HDFStore.
Parameters
----------
key : object
key : str
value : {Series, DataFrame}
format : 'fixed(f)|table(t)', default is 'fixed'
fixed(f) : Fixed format
Expand Down Expand Up @@ -1028,15 +1033,22 @@ def remove(self, key: str, where=None, start=None, stop=None):
return s.delete(where=where, start=start, stop=stop)

def append(
self, key, value, format=None, append=True, columns=None, dropna=None, **kwargs
self,
key: str,
value,
format=None,
append=True,
columns=None,
dropna=None,
**kwargs,
):
"""
Append to Table in file. Node must already exist and be Table
format.
Parameters
----------
key : object
key : str
value : {Series, DataFrame}
format : 'table' is the default
table(t) : table format
Expand Down Expand Up @@ -1077,7 +1089,14 @@ def append(
self._write_to_group(key, value, append=append, dropna=dropna, **kwargs)

def append_to_multiple(
self, d, value, selector, data_columns=None, axes=None, dropna=False, **kwargs
self,
d: Dict,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment

value,
selector,
data_columns=None,
axes=None,
dropna=False,
**kwargs,
):
"""
Append to multiple tables
Expand Down Expand Up @@ -1123,7 +1142,7 @@ def append_to_multiple(

# figure out how to split the value
remain_key = None
remain_values = []
remain_values: List = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance of adding a subtype here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here and below: i'm pretty sure we'll be able to be more specific in these after #29692

for k, v in d.items():
if v is None:
if remain_key is not None:
Expand Down Expand Up @@ -1871,7 +1890,7 @@ def validate(self, handler, append):
def validate_names(self):
pass

def validate_and_set(self, handler, append):
def validate_and_set(self, handler: "AppendableTable", append: bool):
self.set_table(handler.table)
self.validate_col()
self.validate_attr(append)
Expand Down Expand Up @@ -1901,7 +1920,7 @@ def validate_col(self, itemsize=None):

return None

def validate_attr(self, append):
def validate_attr(self, append: bool):
# check for backwards incompatibility
if append:
existing_kind = getattr(self.attrs, self.kind_attr, None)
Expand Down Expand Up @@ -1967,7 +1986,7 @@ def read_metadata(self, handler):
""" retrieve the metadata for this columns """
self.metadata = handler.read_metadata(self.cname)

def validate_metadata(self, handler):
def validate_metadata(self, handler: "AppendableTable"):
""" validate that kind=category does not change the categories """
if self.meta == "category":
new_metadata = self.metadata
Expand All @@ -1982,7 +2001,7 @@ def validate_metadata(self, handler):
"different categories to the existing"
)

def write_metadata(self, handler):
def write_metadata(self, handler: "AppendableTable"):
""" set the meta data """
if self.metadata is not None:
handler.write_metadata(self.cname, self.metadata)
Expand All @@ -1995,7 +2014,15 @@ class GenericIndexCol(IndexCol):
def is_indexed(self) -> bool:
return False

def convert(self, values, nan_rep, encoding, errors, start=None, stop=None):
def convert(
self,
values,
nan_rep,
encoding,
errors,
start: Optional[int] = None,
stop: Optional[int] = None,
):
""" set the values from this selection: take = take ownership
Parameters
Expand All @@ -2012,9 +2039,9 @@ def convert(self, values, nan_rep, encoding, errors, start=None, stop=None):
the underlying table's row count are normalized to that.
"""

start = start if start is not None else 0
stop = min(stop, self.table.nrows) if stop is not None else self.table.nrows
self.values = Int64Index(np.arange(stop - start))
_start = start if start is not None else 0
_stop = min(stop, self.table.nrows) if stop is not None else self.table.nrows
self.values = Int64Index(np.arange(_stop - _start))

return self

Expand Down Expand Up @@ -2749,7 +2776,9 @@ def get_attrs(self):
def write(self, obj, **kwargs):
self.set_attrs()

def read_array(self, key: str, start=None, stop=None):
def read_array(
self, key: str, start: Optional[int] = None, stop: Optional[int] = None
):
""" read an array for the specified node (off of group """
import tables

Expand Down Expand Up @@ -2836,7 +2865,7 @@ def write_block_index(self, key, index):
self.write_array("{key}_blengths".format(key=key), index.blengths)
setattr(self.attrs, "{key}_length".format(key=key), index.length)

def read_block_index(self, key, **kwargs):
def read_block_index(self, key, **kwargs) -> BlockIndex:
length = getattr(self.attrs, "{key}_length".format(key=key))
blocs = self.read_array("{key}_blocs".format(key=key), **kwargs)
blengths = self.read_array("{key}_blengths".format(key=key), **kwargs)
Expand All @@ -2846,7 +2875,7 @@ def write_sparse_intindex(self, key, index):
self.write_array("{key}_indices".format(key=key), index.indices)
setattr(self.attrs, "{key}_length".format(key=key), index.length)

def read_sparse_intindex(self, key, **kwargs):
def read_sparse_intindex(self, key, **kwargs) -> IntIndex:
length = getattr(self.attrs, "{key}_length".format(key=key))
indices = self.read_array("{key}_indices".format(key=key), **kwargs)
return IntIndex(length, indices)
Expand Down Expand Up @@ -2878,7 +2907,7 @@ def write_multi_index(self, key, index):
label_key = "{key}_label{idx}".format(key=key, idx=i)
self.write_array(label_key, level_codes)

def read_multi_index(self, key, **kwargs):
def read_multi_index(self, key, **kwargs) -> MultiIndex:
nlevels = getattr(self.attrs, "{key}_nlevels".format(key=key))

levels = []
Expand All @@ -2898,7 +2927,9 @@ def read_multi_index(self, key, **kwargs):
levels=levels, codes=codes, names=names, verify_integrity=True
)

def read_index_node(self, node, start=None, stop=None):
def read_index_node(
self, node, start: Optional[int] = None, stop: Optional[int] = None
):
data = node[start:stop]
# If the index was an empty array write_array_empty() will
# have written a sentinel. Here we relace it with the original.
Expand Down Expand Up @@ -2953,7 +2984,7 @@ def read_index_node(self, node, start=None, stop=None):

return name, index

def write_array_empty(self, key, value):
def write_array_empty(self, key: str, value):
""" write a 0-len array """

# ugly hack for length 0 axes
Expand All @@ -2966,7 +2997,7 @@ def _is_empty_array(self, shape) -> bool:
"""Returns true if any axis is zero length."""
return any(x == 0 for x in shape)

def write_array(self, key, value, items=None):
def write_array(self, key: str, value, items=None):
if key in self.group:
self._handle.remove_node(self.group, key)

Expand Down Expand Up @@ -3052,7 +3083,9 @@ def write_array(self, key, value, items=None):


class LegacyFixed(GenericFixed):
def read_index_legacy(self, key, start=None, stop=None):
def read_index_legacy(
self, key: str, start: Optional[int] = None, stop: Optional[int] = None
):
node = getattr(self.group, key)
data = node[start:stop]
kind = node._v_attrs.kind
Expand Down Expand Up @@ -3237,7 +3270,7 @@ def __init__(self, *args, **kwargs):
self.selection = None

@property
def table_type_short(self):
def table_type_short(self) -> str:
return self.table_type.split("_")[0]

@property
Expand Down Expand Up @@ -3311,7 +3344,7 @@ def validate(self, other):
)

@property
def is_multi_index(self):
def is_multi_index(self) -> bool:
"""the levels attribute is 1 or a list in the case of a multi-index"""
return isinstance(self.levels, list)

Expand All @@ -3335,7 +3368,7 @@ def validate_multiindex(self, obj):
)

@property
def nrows_expected(self):
def nrows_expected(self) -> int:
""" based on our axes, compute the expected nrows """
return np.prod([i.cvalues.shape[0] for i in self.index_axes])

Expand Down Expand Up @@ -3691,7 +3724,7 @@ def create_axes(
self,
axes,
obj,
validate=True,
validate: bool = True,
nan_rep=None,
data_columns=None,
min_itemsize=None,
Expand Down Expand Up @@ -4000,7 +4033,13 @@ def create_description(

return d

def read_coordinates(self, where=None, start=None, stop=None, **kwargs):
def read_coordinates(
self,
where=None,
start: Optional[int] = None,
stop: Optional[int] = None,
**kwargs,
):
"""select coordinates (row numbers) from a table; return the
coordinates object
"""
Expand All @@ -4013,7 +4052,7 @@ def read_coordinates(self, where=None, start=None, stop=None, **kwargs):
return False

# create the selection
self.selection = Selection(self, where=where, start=start, stop=stop, **kwargs)
self.selection = Selection(self, where=where, start=start, stop=stop)
coords = self.selection.select_coords()
if self.selection.filter is not None:
for field, op, filt in self.selection.filter.format():
Expand All @@ -4024,7 +4063,13 @@ def read_coordinates(self, where=None, start=None, stop=None, **kwargs):

return Index(coords)

def read_column(self, column: str, where=None, start=None, stop=None):
def read_column(
self,
column: str,
where=None,
start: Optional[int] = None,
stop: Optional[int] = None,
):
"""return a single column from the table, generally only indexables
are interesting
"""
Expand Down Expand Up @@ -4302,7 +4347,13 @@ def write_data_chunk(self, rows, indexes, mask, values):
"tables cannot write this data -> {detail}".format(detail=detail)
)

def delete(self, where=None, start=None, stop=None, **kwargs):
def delete(
self,
where=None,
start: Optional[int] = None,
stop: Optional[int] = None,
**kwargs,
):

# delete all rows (and return the nrows)
if where is None or not len(where):
Expand All @@ -4323,7 +4374,7 @@ def delete(self, where=None, start=None, stop=None, **kwargs):

# create the selection
table = self.table
self.selection = Selection(self, where, start=start, stop=stop, **kwargs)
self.selection = Selection(self, where, start=start, stop=stop)
values = self.selection.select_coords()

# delete the rows in reverse order
Expand Down Expand Up @@ -4913,7 +4964,13 @@ class Selection:
"""

def __init__(self, table, where=None, start=None, stop=None):
def __init__(
self,
table: Table,
where=None,
start: Optional[int] = None,
stop: Optional[int] = None,
):
self.table = table
self.where = where
self.start = start
Expand Down