From bc9168362c5d012521f5ba1862ee68cae5b37796 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 20 Nov 2019 05:12:12 -0800 Subject: [PATCH] TYP: more annotations for io.pytables (#29703) --- pandas/io/pytables.py | 131 ++++++++++++++++++++++++++++++------------ 1 file changed, 94 insertions(+), 37 deletions(-) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 193b8f5053d653..95898320954747 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -520,16 +520,16 @@ def root(self): def filename(self): return self._path - def __getitem__(self, key): + def __getitem__(self, key: str): return self.get(key) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value): self.put(key, value) - def __delitem__(self, key): + def __delitem__(self, key: str): return self.remove(key) - def __getattr__(self, name): + def __getattr__(self, name: str): """ allow attribute access to get stores """ try: return self.get(name) @@ -791,7 +791,12 @@ def func(_start, _stop, _where): return it.get_result() def select_as_coordinates( - self, key: str, where=None, start=None, stop=None, **kwargs + self, + key: str, + where=None, + start: Optional[int] = None, + stop: Optional[int] = None, + **kwargs, ): """ return the selection as an Index @@ -943,13 +948,13 @@ def func(_start, _stop, _where): return it.get_result(coordinates=True) - def put(self, key, value, format=None, append=False, **kwargs): + def put(self, key: str, value, format=None, append=False, **kwargs): """ Store object in HDFStore. Parameters ---------- - key : object + key : str value : {Series, DataFrame} format : 'fixed(f)|table(t)', default is 'fixed' fixed(f) : Fixed format @@ -1028,7 +1033,14 @@ def remove(self, key: str, where=None, start=None, stop=None): return s.delete(where=where, start=start, stop=stop) def append( - self, key, value, format=None, append=True, columns=None, dropna=None, **kwargs + self, + key: str, + value, + format=None, + append=True, + columns=None, + dropna=None, + **kwargs, ): """ Append to Table in file. Node must already exist and be Table @@ -1036,7 +1048,7 @@ def append( Parameters ---------- - key : object + key : str value : {Series, DataFrame} format : 'table' is the default table(t) : table format @@ -1077,7 +1089,14 @@ def append( self._write_to_group(key, value, append=append, dropna=dropna, **kwargs) def append_to_multiple( - self, d, value, selector, data_columns=None, axes=None, dropna=False, **kwargs + self, + d: Dict, + value, + selector, + data_columns=None, + axes=None, + dropna=False, + **kwargs, ): """ Append to multiple tables @@ -1123,7 +1142,7 @@ def append_to_multiple( # figure out how to split the value remain_key = None - remain_values = [] + remain_values: List = [] for k, v in d.items(): if v is None: if remain_key is not None: @@ -1871,7 +1890,7 @@ def validate(self, handler, append): def validate_names(self): pass - def validate_and_set(self, handler, append): + def validate_and_set(self, handler: "AppendableTable", append: bool): self.set_table(handler.table) self.validate_col() self.validate_attr(append) @@ -1901,7 +1920,7 @@ def validate_col(self, itemsize=None): return None - def validate_attr(self, append): + def validate_attr(self, append: bool): # check for backwards incompatibility if append: existing_kind = getattr(self.attrs, self.kind_attr, None) @@ -1967,7 +1986,7 @@ def read_metadata(self, handler): """ retrieve the metadata for this columns """ self.metadata = handler.read_metadata(self.cname) - def validate_metadata(self, handler): + def validate_metadata(self, handler: "AppendableTable"): """ validate that kind=category does not change the categories """ if self.meta == "category": new_metadata = self.metadata @@ -1982,7 +2001,7 @@ def validate_metadata(self, handler): "different categories to the existing" ) - def write_metadata(self, handler): + def write_metadata(self, handler: "AppendableTable"): """ set the meta data """ if self.metadata is not None: handler.write_metadata(self.cname, self.metadata) @@ -1995,7 +2014,15 @@ class GenericIndexCol(IndexCol): def is_indexed(self) -> bool: return False - def convert(self, values, nan_rep, encoding, errors, start=None, stop=None): + def convert( + self, + values, + nan_rep, + encoding, + errors, + start: Optional[int] = None, + stop: Optional[int] = None, + ): """ set the values from this selection: take = take ownership Parameters @@ -2012,9 +2039,9 @@ def convert(self, values, nan_rep, encoding, errors, start=None, stop=None): the underlying table's row count are normalized to that. """ - start = start if start is not None else 0 - stop = min(stop, self.table.nrows) if stop is not None else self.table.nrows - self.values = Int64Index(np.arange(stop - start)) + _start = start if start is not None else 0 + _stop = min(stop, self.table.nrows) if stop is not None else self.table.nrows + self.values = Int64Index(np.arange(_stop - _start)) return self @@ -2749,7 +2776,9 @@ def get_attrs(self): def write(self, obj, **kwargs): self.set_attrs() - def read_array(self, key: str, start=None, stop=None): + def read_array( + self, key: str, start: Optional[int] = None, stop: Optional[int] = None + ): """ read an array for the specified node (off of group """ import tables @@ -2836,7 +2865,7 @@ def write_block_index(self, key, index): self.write_array("{key}_blengths".format(key=key), index.blengths) setattr(self.attrs, "{key}_length".format(key=key), index.length) - def read_block_index(self, key, **kwargs): + def read_block_index(self, key, **kwargs) -> BlockIndex: length = getattr(self.attrs, "{key}_length".format(key=key)) blocs = self.read_array("{key}_blocs".format(key=key), **kwargs) blengths = self.read_array("{key}_blengths".format(key=key), **kwargs) @@ -2846,7 +2875,7 @@ def write_sparse_intindex(self, key, index): self.write_array("{key}_indices".format(key=key), index.indices) setattr(self.attrs, "{key}_length".format(key=key), index.length) - def read_sparse_intindex(self, key, **kwargs): + def read_sparse_intindex(self, key, **kwargs) -> IntIndex: length = getattr(self.attrs, "{key}_length".format(key=key)) indices = self.read_array("{key}_indices".format(key=key), **kwargs) return IntIndex(length, indices) @@ -2878,7 +2907,7 @@ def write_multi_index(self, key, index): label_key = "{key}_label{idx}".format(key=key, idx=i) self.write_array(label_key, level_codes) - def read_multi_index(self, key, **kwargs): + def read_multi_index(self, key, **kwargs) -> MultiIndex: nlevels = getattr(self.attrs, "{key}_nlevels".format(key=key)) levels = [] @@ -2898,7 +2927,9 @@ def read_multi_index(self, key, **kwargs): levels=levels, codes=codes, names=names, verify_integrity=True ) - def read_index_node(self, node, start=None, stop=None): + def read_index_node( + self, node, start: Optional[int] = None, stop: Optional[int] = None + ): data = node[start:stop] # If the index was an empty array write_array_empty() will # have written a sentinel. Here we relace it with the original. @@ -2953,7 +2984,7 @@ def read_index_node(self, node, start=None, stop=None): return name, index - def write_array_empty(self, key, value): + def write_array_empty(self, key: str, value): """ write a 0-len array """ # ugly hack for length 0 axes @@ -2966,7 +2997,7 @@ def _is_empty_array(self, shape) -> bool: """Returns true if any axis is zero length.""" return any(x == 0 for x in shape) - def write_array(self, key, value, items=None): + def write_array(self, key: str, value, items=None): if key in self.group: self._handle.remove_node(self.group, key) @@ -3052,7 +3083,9 @@ def write_array(self, key, value, items=None): class LegacyFixed(GenericFixed): - def read_index_legacy(self, key, start=None, stop=None): + def read_index_legacy( + self, key: str, start: Optional[int] = None, stop: Optional[int] = None + ): node = getattr(self.group, key) data = node[start:stop] kind = node._v_attrs.kind @@ -3237,7 +3270,7 @@ def __init__(self, *args, **kwargs): self.selection = None @property - def table_type_short(self): + def table_type_short(self) -> str: return self.table_type.split("_")[0] @property @@ -3311,7 +3344,7 @@ def validate(self, other): ) @property - def is_multi_index(self): + def is_multi_index(self) -> bool: """the levels attribute is 1 or a list in the case of a multi-index""" return isinstance(self.levels, list) @@ -3335,7 +3368,7 @@ def validate_multiindex(self, obj): ) @property - def nrows_expected(self): + def nrows_expected(self) -> int: """ based on our axes, compute the expected nrows """ return np.prod([i.cvalues.shape[0] for i in self.index_axes]) @@ -3691,7 +3724,7 @@ def create_axes( self, axes, obj, - validate=True, + validate: bool = True, nan_rep=None, data_columns=None, min_itemsize=None, @@ -4000,7 +4033,13 @@ def create_description( return d - def read_coordinates(self, where=None, start=None, stop=None, **kwargs): + def read_coordinates( + self, + where=None, + start: Optional[int] = None, + stop: Optional[int] = None, + **kwargs, + ): """select coordinates (row numbers) from a table; return the coordinates object """ @@ -4013,7 +4052,7 @@ def read_coordinates(self, where=None, start=None, stop=None, **kwargs): return False # create the selection - self.selection = Selection(self, where=where, start=start, stop=stop, **kwargs) + self.selection = Selection(self, where=where, start=start, stop=stop) coords = self.selection.select_coords() if self.selection.filter is not None: for field, op, filt in self.selection.filter.format(): @@ -4024,7 +4063,13 @@ def read_coordinates(self, where=None, start=None, stop=None, **kwargs): return Index(coords) - def read_column(self, column: str, where=None, start=None, stop=None): + def read_column( + self, + column: str, + where=None, + start: Optional[int] = None, + stop: Optional[int] = None, + ): """return a single column from the table, generally only indexables are interesting """ @@ -4302,7 +4347,13 @@ def write_data_chunk(self, rows, indexes, mask, values): "tables cannot write this data -> {detail}".format(detail=detail) ) - def delete(self, where=None, start=None, stop=None, **kwargs): + def delete( + self, + where=None, + start: Optional[int] = None, + stop: Optional[int] = None, + **kwargs, + ): # delete all rows (and return the nrows) if where is None or not len(where): @@ -4323,7 +4374,7 @@ def delete(self, where=None, start=None, stop=None, **kwargs): # create the selection table = self.table - self.selection = Selection(self, where, start=start, stop=stop, **kwargs) + self.selection = Selection(self, where, start=start, stop=stop) values = self.selection.select_coords() # delete the rows in reverse order @@ -4913,7 +4964,13 @@ class Selection: """ - def __init__(self, table, where=None, start=None, stop=None): + def __init__( + self, + table: Table, + where=None, + start: Optional[int] = None, + stop: Optional[int] = None, + ): self.table = table self.where = where self.start = start