diff --git a/pyproject.toml b/pyproject.toml index ec5c6e0541..355c4ff483 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,6 +237,7 @@ select = [ "PLR0917", # Ban APIs with too many positional parameters "FBT", # No positional boolean parameters "PT", # Pytest style + "SIM", # Simplify control flow ] ignore = [ # line too long -> we accept long comment lines; black gets rid of long code lines diff --git a/src/scanpy/_settings.py b/src/scanpy/_settings.py index b090261d1f..63f91d2279 100644 --- a/src/scanpy/_settings.py +++ b/src/scanpy/_settings.py @@ -371,7 +371,7 @@ def logpath(self) -> Path | None: def logpath(self, logpath: Path | str | None): _type_check(logpath, "logfile", (str, Path)) # set via “file object” branch of logfile.setter - self.logfile = Path(logpath).open("a") + self.logfile = Path(logpath).open("a") # noqa: SIM115 self._logpath = Path(logpath) @property @@ -519,7 +519,7 @@ def __str__(self) -> str: return "\n".join( f"{k} = {v!r}" for k, v in inspect.getmembers(self) - if not k.startswith("_") and not k == "getdoc" + if not k.startswith("_") and k != "getdoc" ) diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index 46d62bcde6..b8513d87ba 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -13,7 +13,7 @@ import sys import warnings from collections import namedtuple -from contextlib import contextmanager +from contextlib import contextmanager, suppress from enum import Enum from functools import partial, singledispatch, wraps from operator import mul, truediv @@ -281,10 +281,8 @@ def get_igraph_from_adjacency(adjacency, directed=None): g = ig.Graph(directed=directed) g.add_vertices(adjacency.shape[0]) # this adds adjacency.shape[0] vertices g.add_edges(list(zip(sources, targets))) - try: + with suppress(KeyError): g.es["weight"] = weights - except KeyError: - pass if g.vcount() != adjacency.shape[0]: logg.warning( f"The constructed graph has only {g.vcount()} nodes. " @@ -613,11 +611,10 @@ def _( out: sparse.csr_matrix | sparse.csc_matrix | None = None, ) -> sparse.csr_matrix | sparse.csc_matrix: check_op(op) - if out is not None: - if X.data is not out.data: - raise ValueError( - "`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling." - ) + if out is not None and X.data is not out.data: + raise ValueError( + "`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling." + ) if not allow_divide_by_zero and op is truediv: scaling_array = scaling_array.copy() + (scaling_array == 0) @@ -684,7 +681,7 @@ def _( column_scale = axis == 1 if isinstance(scaling_array, DaskArray): - if (row_scale and not X.chunksize[0] == scaling_array.chunksize[0]) or ( + if (row_scale and X.chunksize[0] != scaling_array.chunksize[0]) or ( column_scale and ( ( diff --git a/src/scanpy/_utils/compute/is_constant.py b/src/scanpy/_utils/compute/is_constant.py index 7dac03b40a..80f6581980 100644 --- a/src/scanpy/_utils/compute/is_constant.py +++ b/src/scanpy/_utils/compute/is_constant.py @@ -121,10 +121,7 @@ def _is_constant_csr_rows( for i in range(n): start = indptr[i] stop = indptr[i + 1] - if stop - start == shape[1]: - val = data[start] - else: - val = 0 + val = data[start] if stop - start == shape[1] else 0 for j in range(start, stop): if data[j] != val: result[i] = False diff --git a/src/scanpy/datasets/_datasets.py b/src/scanpy/datasets/_datasets.py index ccbc9a3bb3..41b23160d6 100644 --- a/src/scanpy/datasets/_datasets.py +++ b/src/scanpy/datasets/_datasets.py @@ -219,7 +219,7 @@ def moignard15() -> AnnData: } # annotate each observation/cell adata.obs["exp_groups"] = [ - next(gname for gname in groups.keys() if sname.startswith(gname)) + next(gname for gname in groups if sname.startswith(gname)) for sname in adata.obs_names ] # fix the order and colors of names in "groups" diff --git a/src/scanpy/datasets/_ebi_expression_atlas.py b/src/scanpy/datasets/_ebi_expression_atlas.py index 9f3bcb81ad..b7e1886e71 100644 --- a/src/scanpy/datasets/_ebi_expression_atlas.py +++ b/src/scanpy/datasets/_ebi_expression_atlas.py @@ -65,10 +65,7 @@ def read_mtx_from_stream(stream: BinaryIO) -> sparse.csr_matrix: n, m, _ = (int(x) for x in curline[:-1].split(b" ")) max_int32 = np.iinfo(np.int32).max - if n > max_int32 or m > max_int32: - coord_dtype = np.int64 - else: - coord_dtype = np.int32 + coord_dtype = np.int64 if n > max_int32 or m > max_int32 else np.int32 data = pd.read_csv( stream, diff --git a/src/scanpy/external/pl.py b/src/scanpy/external/pl.py index 662bc88eb3..a3d1767cee 100644 --- a/src/scanpy/external/pl.py +++ b/src/scanpy/external/pl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib from typing import TYPE_CHECKING import matplotlib.pyplot as plt @@ -214,10 +215,8 @@ def sam( return axes if isinstance(c, str): - try: + with contextlib.suppress(KeyError): c = np.array(list(adata.obs[c])) - except KeyError: - pass if isinstance(c[0], (str, np.str_)) and isinstance(c, (np.ndarray, list)): import samalg.utilities as ut diff --git a/src/scanpy/get/get.py b/src/scanpy/get/get.py index c5e95de1ab..0c1272ae62 100644 --- a/src/scanpy/get/get.py +++ b/src/scanpy/get/get.py @@ -121,10 +121,7 @@ def _check_indices( use_raw: bool = False, ) -> tuple[list[str], list[str], list[str]]: """Common logic for checking indices for obs_df and var_df.""" - if use_raw: - alt_repr = "adata.raw" - else: - alt_repr = "adata" + alt_repr = "adata.raw" if use_raw else "adata" alt_dim = ("obs", "var")[dim == "obs"] @@ -288,10 +285,7 @@ def obs_df( var = adata.raw.var else: var = adata.var - if gene_symbols is not None: - alias_index = pd.Index(var[gene_symbols]) - else: - alias_index = None + alias_index = pd.Index(var[gene_symbols]) if gene_symbols is not None else None obs_cols, var_idx_keys, var_symbols = _check_indices( adata.obs, diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 0666a8b3ed..64b14cf112 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib from collections.abc import Mapping from textwrap import indent from types import MappingProxyType @@ -469,10 +470,7 @@ def transitions(self) -> np.ndarray | csr_matrix: ----- This has not been tested, in contrast to `transitions_sym`. """ - if issparse(self.Z): - Zinv = self.Z.power(-1) - else: - Zinv = np.diag(1.0 / np.diag(self.Z)) + Zinv = self.Z.power(-1) if issparse(self.Z) else np.diag(1.0 / np.diag(self.Z)) return self.Z @ self.transitions_sym @ Zinv @property @@ -591,10 +589,9 @@ def compute_neighbors( if isinstance(index, NNDescent): # very cautious here - try: + # TODO catch the correct exception + with contextlib.suppress(Exception): self._rp_forest = _make_forest_dict(index) - except Exception: # TODO catch the correct exception - pass start_connect = logg.debug("computed neighbors", time=start_neighbors) if method == "umap": diff --git a/src/scanpy/plotting/_anndata.py b/src/scanpy/plotting/_anndata.py index 5bd3de8188..5dfea0ae29 100755 --- a/src/scanpy/plotting/_anndata.py +++ b/src/scanpy/plotting/_anndata.py @@ -157,15 +157,15 @@ def scatter( if x is None or y is None: raise ValueError("Either provide a `basis` or `x` and `y`.") if ( - (x in adata.obs.keys() or x in var_index) - and (y in adata.obs.keys() or y in var_index) - and (color is None or color in adata.obs.keys() or color in var_index) + (x in adata.obs.columns or x in var_index) + and (y in adata.obs.columns or y in var_index) + and (color is None or color in adata.obs.columns or color in var_index) ): return _scatter_obs(**args) if ( - (x in adata.var.keys() or x in adata.obs.index) - and (y in adata.var.keys() or y in adata.obs.index) - and (color is None or color in adata.var.keys() or color in adata.obs.index) + (x in adata.var.columns or x in adata.obs.index) + and (y in adata.var.columns or y in adata.obs.index) + and (color is None or color in adata.var.columns or color in adata.obs.index) ): adata_T = adata.T axs = _scatter_obs( @@ -217,14 +217,12 @@ def _scatter_obs( use_raw = _check_use_raw(adata, use_raw) # Process layers - if layers in ["X", None] or ( - isinstance(layers, str) and layers in adata.layers.keys() - ): + if layers in ["X", None] or (isinstance(layers, str) and layers in adata.layers): layers = (layers, layers, layers) elif isinstance(layers, Collection) and len(layers) == 3: layers = tuple(layers) for layer in layers: - if layer not in adata.layers.keys() and layer not in ["X", None]: + if layer not in adata.layers and layer not in ["X", None]: raise ValueError( "`layers` should have elements that are " "either None or in adata.layers.keys()." @@ -256,7 +254,7 @@ def _scatter_obs( ) if title is not None and isinstance(title, str): title = [title] - highlights = adata.uns["highlights"] if "highlights" in adata.uns else [] + highlights = adata.uns.get("highlights", []) if basis is not None: try: # ignore the '0th' diffusion component @@ -292,19 +290,14 @@ def _scatter_obs( n = Y.shape[0] size = 120000 / n - if legend_loc.startswith("on data") and legend_fontsize is None: - legend_fontsize = rcParams["legend.fontsize"] - elif legend_fontsize is None: + if legend_fontsize is None: legend_fontsize = rcParams["legend.fontsize"] palette_was_none = False if palette is None: palette_was_none = True if isinstance(palette, Sequence) and not isinstance(palette, str): - if not is_color_like(palette[0]): - palettes = palette - else: - palettes = [palette] + palettes = palette if not is_color_like(palette[0]) else [palette] else: palettes = [palette for _ in range(len(keys))] palettes = [_utils.default_palette(palette) for palette in palettes] @@ -328,7 +321,7 @@ def _scatter_obs( else: component_name = None axis_labels = (x, y) if component_name is None else None - show_ticks = True if component_name is None else False + show_ticks = component_name is None # generate the colors color_ids: list[np.ndarray | ColorLike] = [] @@ -364,9 +357,8 @@ def _scatter_obs( categoricals.append(ikey) color_ids.append(c) - if right_margin is None and len(categoricals) > 0: - if legend_loc == "right margin": - right_margin = 0.5 + if right_margin is None and len(categoricals) > 0 and legend_loc == "right margin": + right_margin = 0.5 if title is None and keys[0] is not None: title = [ key.replace("_", " ") if not is_color_like(key) else "" for key in keys @@ -488,10 +480,7 @@ def add_centroid(centroids, name, Y, mask): all_pos = np.zeros((len(adata.obs[key].cat.categories), 2)) for iname, name in enumerate(adata.obs[key].cat.categories): - if name in centroids: - all_pos[iname] = centroids[name] - else: - all_pos[iname] = [np.nan, np.nan] + all_pos[iname] = centroids.get(name, [np.nan, np.nan]) if legend_loc == "on data export": filename = settings.writedir / "pos.csv" logg.warning(f"exporting label positions to {filename}") @@ -1245,10 +1234,7 @@ def heatmap( groupby_width = 0.2 if categorical else 0 if figsize is None: height = 6 - if show_gene_labels: - heatmap_width = len(var_names) * 0.3 - else: - heatmap_width = 8 + heatmap_width = len(var_names) * 0.3 if show_gene_labels else 8 width = heatmap_width + dendro_width + groupby_width else: width, height = figsize @@ -1352,10 +1338,7 @@ def heatmap( dendro_height = 0.8 if dendrogram else 0 groupby_height = 0.13 if categorical else 0 if figsize is None: - if show_gene_labels: - heatmap_height = len(var_names) * 0.18 - else: - heatmap_height = 4 + heatmap_height = len(var_names) * 0.18 if show_gene_labels else 4 width = 10 height = heatmap_height + dendro_height + groupby_height else: @@ -1440,10 +1423,7 @@ def heatmap( for idx, (label, pos) in enumerate( zip(var_group_labels, var_group_positions) ): - if var_groups_subset_of_groupby: - label_code = label2code[label] - else: - label_code = idx + label_code = label2code[label] if var_groups_subset_of_groupby else idx arr += [label_code] * (pos[1] + 1 - pos[0]) gene_groups_ax.imshow( np.array([arr]).T, aspect="auto", cmap=groupby_cmap, norm=norm @@ -1892,10 +1872,7 @@ def correlation_matrix( labels = adata.obs[groupby].cat.categories num_rows = corr_matrix.shape[0] colorbar_height = 0.2 - if dendrogram: - dendrogram_width = 1.8 - else: - dendrogram_width = 0 + dendrogram_width = 1.8 if dendrogram else 0 if figsize is None: corr_matrix_height = num_rows * 0.6 height = corr_matrix_height + colorbar_height @@ -2052,7 +2029,7 @@ def _prepare_dataframe( "groupby has to be a valid observation. " f"Given {group}, is not in observations: {adata.obs_keys()}" + msg ) - if group in adata.obs.keys() and group == adata.obs.index.name: + if group in adata.obs.columns and group == adata.obs.index.name: raise ValueError( f"Given group {group} is both and index and a column level, " "which is ambiguous." @@ -2171,10 +2148,7 @@ def _plot_gene_groups_brackets( if orientation == "top": # rotate labels if any of them is longer than 4 characters if rotation is None and group_labels: - if max([len(x) for x in group_labels]) > 4: - rotation = 90 - else: - rotation = 0 + rotation = 90 if max([len(x) for x in group_labels]) > 4 else 0 for idx in range(len(left)): verts.append((left[idx], 0)) # lower-left verts.append((left[idx], 0.6)) # upper-left @@ -2600,11 +2574,8 @@ def _plot_categories_as_colorblocks( ) if len(labels) > 1: groupby_ax.set_xticks(ticks) - if max([len(str(x)) for x in labels]) < 3: - # if the labels are small do not rotate them - rotation = 0 - else: - rotation = 90 + # if the labels are small do not rotate them + rotation = 0 if max(len(str(x)) for x in labels) < 3 else 90 groupby_ax.set_xticklabels(labels, rotation=rotation) # remove x ticks diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index 6e5c8cd2c5..e68cc07727 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -137,9 +137,7 @@ def __init__( self.width, self.height = figsize if figsize is not None else (None, None) self.has_var_groups = ( - True - if var_group_positions is not None and len(var_group_positions) > 0 - else False + var_group_positions is not None and len(var_group_positions) > 0 ) self._update_var_groups() @@ -160,18 +158,19 @@ def __init__( "Plot would be very large." ) - if categories_order is not None: - if set(self.obs_tidy.index.categories) != set(categories_order): - logg.error( - "Please check that the categories given by " - "the `order` parameter match the categories that " - "want to be reordered.\n\n" - "Mismatch: " - f"{set(self.obs_tidy.index.categories).difference(categories_order)}\n\n" - f"Given order categories: {categories_order}\n\n" - f"{groupby} categories: {list(self.obs_tidy.index.categories)}\n" - ) - return + if categories_order is not None and ( + set(self.obs_tidy.index.categories) != set(categories_order) + ): + logg.error( + "Please check that the categories given by " + "the `order` parameter match the categories that " + "want to be reordered.\n\n" + "Mismatch: " + f"{set(self.obs_tidy.index.categories).difference(categories_order)}\n\n" + f"Given order categories: {categories_order}\n\n" + f"{groupby} categories: {list(self.obs_tidy.index.categories)}\n" + ) + return self.adata = adata self.groupby = [groupby] if isinstance(groupby, str) else groupby @@ -388,8 +387,8 @@ def add_totals( self.group_extra_size = 0 return self - _sort = True if sort is not None else False - _ascending = True if sort == "ascending" else False + _sort = sort is not None + _ascending = sort == "ascending" counts_df = self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending) if _sort: @@ -489,10 +488,7 @@ def _plot_totals( if self.categories_order is not None: counts_df = counts_df.loc[self.categories_order] if params["color"] is None: - if f"{self.groupby}_colors" in self.adata.uns: - color = self.adata.uns[f"{self.groupby}_colors"] - else: - color = "salmon" + color = self.adata.uns.get(f"{self.groupby}_colors", "salmon") else: color = params["color"] @@ -1027,10 +1023,7 @@ def _plot_var_groups_brackets( if orientation == "top": # rotate labels if any of them is longer than 4 characters if rotation is None and group_labels: - if max([len(x) for x in group_labels]) > 4: - rotation = 90 - else: - rotation = 0 + rotation = 90 if max([len(x) for x in group_labels]) > 4 else 0 for idx, (left_coor, right_coor) in enumerate(zip(left, right)): verts.append((left_coor, 0)) # lower-left verts.append((left_coor, 0.6)) # upper-left diff --git a/src/scanpy/plotting/_tools/__init__.py b/src/scanpy/plotting/_tools/__init__.py index eec202d0a5..837d3791e8 100644 --- a/src/scanpy/plotting/_tools/__init__.py +++ b/src/scanpy/plotting/_tools/__init__.py @@ -93,9 +93,7 @@ def pca_overview(adata: AnnData, **params): -------- pp.pca """ - show = params["show"] if "show" in params else None - if "show" in params: - del params["show"] + show = params.pop("show", None) pca(adata, **params, show=False) pca_loadings(adata, show=False) pca_variance_ratio(adata, show=show) @@ -398,10 +396,7 @@ def rank_genes_groups( tl.rank_genes_groups """ - if "n_panels_per_row" in kwds: - n_panels_per_row = kwds["n_panels_per_row"] - else: - n_panels_per_row = ncols + n_panels_per_row = kwds.get("n_panels_per_row", ncols) if n_genes < 1: raise NotImplementedError( "Specifying a negative number for n_genes has not been implemented for " @@ -567,10 +562,7 @@ def _rank_genes_groups_plot( if len(genes_list) == 0: logg.warning(f"No genes found for group {group}") continue - if n_genes < 0: - genes_list = genes_list[n_genes:] - else: - genes_list = genes_list[:n_genes] + genes_list = genes_list[n_genes:] if n_genes < 0 else genes_list[:n_genes] var_names[group] = genes_list var_names_list.extend(genes_list) @@ -1566,10 +1558,7 @@ def embedding_density( # turn group into a list if needed if group == "all": - if groupby is None: - group = None - else: - group = list(adata.obs[groupby].cat.categories) + group = None if groupby is None else list(adata.obs[groupby].cat.categories) elif isinstance(group, str): group = [group] @@ -1633,10 +1622,7 @@ def embedding_density( adata.obs[density_col_name] = dens_values dot_sizes[group_mask] = np.ones(sum(group_mask)) * fg_dotsize - if title is None: - _title = group_name - else: - _title = title + _title = group_name if title is None else title ax = embedding( adata, @@ -1773,16 +1759,15 @@ def _get_values_to_plot( df["names"] = df[gene_symbols] # check that all genes are present in the df as sc.tl.rank_genes_groups # can be called with only top genes - if not check_done: - if df.shape[0] < adata.shape[1]: - message = ( - "Please run `sc.tl.rank_genes_groups` with " - "'n_genes=adata.shape[1]' to save all gene " - f"scores. Currently, only {df.shape[0]} " - "are found" - ) - logg.error(message) - raise ValueError(message) + if not check_done and df.shape[0] < adata.shape[1]: + message = ( + "Please run `sc.tl.rank_genes_groups` with " + "'n_genes=adata.shape[1]' to save all gene " + f"scores. Currently, only {df.shape[0]} " + "are found" + ) + logg.error(message) + raise ValueError(message) df["group"] = group df_list.append(df) diff --git a/src/scanpy/plotting/_tools/paga.py b/src/scanpy/plotting/_tools/paga.py index f0d45e9a80..159be79913 100644 --- a/src/scanpy/plotting/_tools/paga.py +++ b/src/scanpy/plotting/_tools/paga.py @@ -1131,7 +1131,7 @@ def paga_path( groups_key = adata.uns["paga"]["groups"] groups_names = adata.obs[groups_key].cat.categories - if "dpt_pseudotime" not in adata.obs.keys(): + if "dpt_pseudotime" not in adata.obs.columns: raise ValueError( "`pl.paga_path` requires computation of a pseudotime `tl.dpt` " "for ordering at single-cell resolution" diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index 769514a69e..5c15fa8df4 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -886,7 +886,7 @@ def pca( return embedding( adata, "pca", show=show, return_fig=return_fig, save=save, **kwargs ) - if "pca" not in adata.obsm.keys() and "X_pca" not in adata.obsm.keys(): + if "pca" not in adata.obsm and "X_pca" not in adata.obsm: raise KeyError( f"Could not find entry in `obsm` for 'pca'.\n" f"Available keys are: {list(adata.obsm.keys())}." @@ -1011,10 +1011,7 @@ def spatial( crop_coord = _check_crop_coord(crop_coord, scale_factor) na_color = _check_na_color(na_color, img=img) - if bw: - cmap_img = "gray" - else: - cmap_img = None + cmap_img = "gray" if bw else None circle_radius = size * scale_factor * spot_size * 0.5 axs = embedding( @@ -1342,10 +1339,7 @@ def _check_spatial_data( library_id = list(spatial_mapping.keys())[0] else: library_id = None - if library_id is not None: - spatial_data = spatial_mapping[library_id] - else: - spatial_data = None + spatial_data = spatial_mapping[library_id] if library_id is not None else None return library_id, spatial_data @@ -1387,10 +1381,7 @@ def _check_na_color( na_color: ColorLike | None, *, img: np.ndarray | None = None ) -> ColorLike: if na_color is None: - if img is not None: - na_color = (0.0, 0.0, 0.0, 0.0) - else: - na_color = "lightgray" + na_color = (0.0, 0.0, 0.0, 0.0) if img is not None else "lightgray" return na_color diff --git a/src/scanpy/plotting/_utils.py b/src/scanpy/plotting/_utils.py index ea6aa0cb10..13832658f5 100644 --- a/src/scanpy/plotting/_utils.py +++ b/src/scanpy/plotting/_utils.py @@ -1280,10 +1280,10 @@ def fix_kwds(kwds_dict, **kwargs): def _get_basis(adata: AnnData, basis: str): - if basis in adata.obsm.keys(): + if basis in adata.obsm: basis_key = basis - elif f"X_{basis}" in adata.obsm.keys(): + elif f"X_{basis}" in adata.obsm: basis_key = f"X_{basis}" return basis_key diff --git a/src/scanpy/preprocessing/_combat.py b/src/scanpy/preprocessing/_combat.py index b8487e4e7e..ef193d38b4 100644 --- a/src/scanpy/preprocessing/_combat.py +++ b/src/scanpy/preprocessing/_combat.py @@ -196,10 +196,7 @@ def combat( raise ValueError("Covariates must be unique") # only works on dense matrices so far - if issparse(adata.X): - X = adata.X.toarray().T - else: - X = adata.X.T + X = adata.X.toarray().T if issparse(adata.X) else adata.X.T data = pd.DataFrame(data=X, index=adata.var_names, columns=adata.obs_names) sanitize_anndata(adata) diff --git a/src/scanpy/preprocessing/_deprecated/__init__.py b/src/scanpy/preprocessing/_deprecated/__init__.py index bb944b874a..7cd7520171 100644 --- a/src/scanpy/preprocessing/_deprecated/__init__.py +++ b/src/scanpy/preprocessing/_deprecated/__init__.py @@ -7,7 +7,7 @@ @legacy_api("max_fraction", "mult_with_mean") def normalize_per_cell_weinreb16_deprecated( - X: np.ndarray, + x: np.ndarray, *, max_fraction: float = 1, mult_with_mean: bool = False, @@ -37,23 +37,23 @@ def normalize_per_cell_weinreb16_deprecated( if max_fraction < 0 or max_fraction > 1: raise ValueError("Choose max_fraction between 0 and 1.") - counts_per_cell = X.sum(1).A1 if issparse(X) else X.sum(1) - gene_subset = np.all(X <= counts_per_cell[:, None] * max_fraction, axis=0) - if issparse(X): + counts_per_cell = x.sum(1).A1 if issparse(x) else x.sum(1) + gene_subset = np.all(x <= counts_per_cell[:, None] * max_fraction, axis=0) + if issparse(x): gene_subset = gene_subset.A1 tc_include = ( - X[:, gene_subset].sum(1).A1 if issparse(X) else X[:, gene_subset].sum(1) + x[:, gene_subset].sum(1).A1 if issparse(x) else x[:, gene_subset].sum(1) ) - X_norm = ( - X.multiply(csr_matrix(1 / tc_include[:, None])) - if issparse(X) - else X / tc_include[:, None] + x_norm = ( + x.multiply(csr_matrix(1 / tc_include[:, None])) + if issparse(x) + else x / tc_include[:, None] ) if mult_with_mean: - X_norm *= np.mean(counts_per_cell) + x_norm *= np.mean(counts_per_cell) - return X_norm + return x_norm def zscore_deprecated(X: np.ndarray) -> np.ndarray: diff --git a/src/scanpy/preprocessing/_normalization.py b/src/scanpy/preprocessing/_normalization.py index c6fccfb70a..686c69b224 100644 --- a/src/scanpy/preprocessing/_normalization.py +++ b/src/scanpy/preprocessing/_normalization.py @@ -206,25 +206,25 @@ def normalize_total( view_to_actual(adata) - X = _get_obs_rep(adata, layer=layer) + x = _get_obs_rep(adata, layer=layer) gene_subset = None msg = "normalizing counts per cell" - counts_per_cell = axis_sum(X, axis=1) + counts_per_cell = axis_sum(x, axis=1) if exclude_highly_expressed: counts_per_cell = np.ravel(counts_per_cell) # at least one cell as more than max_fraction of counts per cell - gene_subset = axis_sum((X > counts_per_cell[:, None] * max_fraction), axis=0) + gene_subset = axis_sum((x > counts_per_cell[:, None] * max_fraction), axis=0) gene_subset = np.asarray(np.ravel(gene_subset) == 0) msg += ( ". The following highly-expressed genes are not considered during " f"normalization factor computation:\n{adata.var_names[~gene_subset].tolist()}" ) - counts_per_cell = axis_sum(X[:, gene_subset], axis=1) + counts_per_cell = axis_sum(x[:, gene_subset], axis=1) start = logg.info(msg) counts_per_cell = np.ravel(counts_per_cell) @@ -237,12 +237,12 @@ def normalize_total( if key_added is not None: adata.obs[key_added] = counts_per_cell _set_obs_rep( - adata, _normalize_data(X, counts_per_cell, target_sum), layer=layer + adata, _normalize_data(x, counts_per_cell, target_sum), layer=layer ) else: # not recarray because need to support sparse dat = dict( - X=_normalize_data(X, counts_per_cell, target_sum, copy=True), + X=_normalize_data(x, counts_per_cell, target_sum, copy=True), norm_factor=counts_per_cell, ) diff --git a/src/scanpy/preprocessing/_pca.py b/src/scanpy/preprocessing/_pca.py index 5b5706c123..93432841f8 100644 --- a/src/scanpy/preprocessing/_pca.py +++ b/src/scanpy/preprocessing/_pca.py @@ -202,10 +202,7 @@ def pca( if n_comps is None: min_dim = min(adata_comp.n_vars, adata_comp.n_obs) - if settings.N_PCS >= min_dim: - n_comps = min_dim - 1 - else: - n_comps = settings.N_PCS + n_comps = min_dim - 1 if min_dim <= settings.N_PCS else settings.N_PCS logg.info(f" with n_comps={n_comps}") @@ -395,7 +392,7 @@ def _handle_mask_var( if use_highly_variable or ( use_highly_variable is None and mask_var is _empty - and "highly_variable" in adata.var.keys() + and "highly_variable" in adata.var.columns ): mask_var = "highly_variable" diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index 976dafe89f..d57eb81750 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -245,7 +245,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): return {"obs": ad_obs.obs, "uns": ad_obs.uns["scrublet"]} if batch_key is not None: - if batch_key not in adata.obs.keys(): + if batch_key not in adata.obs.columns: msg = ( "`batch_key` must be a column of .obs in the input AnnData object," f"but {batch_key!r} is not in {adata.obs.keys()!r}." diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 6cf42cccb7..073af955e9 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -376,10 +376,7 @@ def log1p_array(X: np.ndarray, *, base: Number | None = None, copy: bool = False # Can force arrays to be np.ndarrays, but would be useful to not # X = check_array(X, dtype=(np.float64, np.float32), ensure_2d=False, copy=copy) if copy: - if not np.issubdtype(X.dtype, np.floating): - X = X.astype(float) - else: - X = X.copy() + X = X.astype(float) if not np.issubdtype(X.dtype, np.floating) else X.copy() elif not (np.issubdtype(X.dtype, np.floating) or np.issubdtype(X.dtype, complex)): X = X.astype(float) np.log1p(X, out=X) @@ -700,10 +697,7 @@ def regress_out( # regress on one or several ordinal variables else: # create data frame with selected keys (if given) - if keys: - regressors = adata.obs[keys] - else: - regressors = adata.obs.copy() + regressors = adata.obs[keys] if keys else adata.obs.copy() # add column of ones at index 0 (first column) regressors.insert(0, "ones", 1.0) @@ -720,10 +714,7 @@ def regress_out( for idx, data_chunk in enumerate(chunk_list): # each task is a tuple of a data_chunk eg. (adata.X[:,0:100]) and # the regressors. This data will be passed to each of the jobs. - if variable_is_categorical: - regres = regressors_chunk[idx] - else: - regres = regressors + regres = regressors_chunk[idx] if variable_is_categorical else regressors tasks.append(tuple((data_chunk, regres, variable_is_categorical))) from joblib import Parallel, delayed diff --git a/src/scanpy/readwrite.py b/src/scanpy/readwrite.py index 90179daf46..9e0a298e18 100644 --- a/src/scanpy/readwrite.py +++ b/src/scanpy/readwrite.py @@ -703,13 +703,12 @@ def read_params( params = OrderedDict([]) for line in filename.open(): - if "=" in line: - if not as_header or line.startswith("#"): - line = line[1:] if line.startswith("#") else line - key, val = line.split("=") - key = key.strip() - val = val.strip() - params[key] = convert_string(val) + if "=" in line and (not as_header or line.startswith("#")): + line = line[1:] if line.startswith("#") else line + key, val = line.split("=") + key = key.strip() + val = val.strip() + params[key] = convert_string(val) return params diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 8c5f7d857b..c0fa59262f 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -137,7 +137,7 @@ def dpt( " adata.uns['iroot'] = root_cell_index\n" " adata.var['xroot'] = adata[root_cell_name, :].X" ) - if "X_diffmap" not in adata.obsm.keys(): + if "X_diffmap" not in adata.obsm: logg.warning( "Trying to run `tl.dpt` without prior call of `tl.diffmap`. " "Falling back to `tl.diffmap` with default parameters." diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index b36638b5b5..4e8c91fb1f 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -130,7 +130,7 @@ def draw_graph( if adjacency is None: adjacency = _choose_graph(adata, obsp, neighbors_key) # init coordinates - if init_pos in adata.obsm.keys(): + if init_pos in adata.obsm: init_coords = adata.obsm[init_pos] elif init_pos == "paga" or init_pos: init_coords = get_init_pos_from_paga( diff --git a/src/scanpy/tools/_ingest.py b/src/scanpy/tools/_ingest.py index 136f58af46..9ced418888 100644 --- a/src/scanpy/tools/_ingest.py +++ b/src/scanpy/tools/_ingest.py @@ -297,16 +297,12 @@ def _init_neighbors(self, adata, neighbors_key): self._use_rep = "X_pca" self._n_pcs = neighbors["params"]["n_pcs"] self._rep = adata.obsm["X_pca"][:, : self._n_pcs] - elif adata.n_vars > settings.N_PCS and "X_pca" in adata.obsm.keys(): + elif adata.n_vars > settings.N_PCS and "X_pca" in adata.obsm: self._use_rep = "X_pca" self._rep = adata.obsm["X_pca"][:, : settings.N_PCS] self._n_pcs = self._rep.shape[1] - if "metric_kwds" in neighbors["params"]: - self._metric_kwds = neighbors["params"]["metric_kwds"] - else: - self._metric_kwds = {} - + self._metric_kwds = neighbors["params"].get("metric_kwds", {}) self._metric = neighbors["params"]["metric"] self._neigh_random_state = neighbors["params"].get("random_state", 0) @@ -317,7 +313,7 @@ def _init_pca(self, adata): self._pca_use_hvg = adata.uns["pca"]["params"]["use_highly_variable"] mask = "highly_variable" - if self._pca_use_hvg and mask not in adata.var.keys(): + if self._pca_use_hvg and mask not in adata.var.columns: msg = f"Did not find `adata.var[{mask!r}']`." raise ValueError(msg) @@ -376,7 +372,7 @@ def _same_rep(self): return self._pca(self._n_pcs) if self._use_rep == "X": return adata.X - if self._use_rep in adata.obsm.keys(): + if self._use_rep in adata.obsm: return adata.obsm[self._use_rep] return adata.X diff --git a/src/scanpy/tools/_louvain.py b/src/scanpy/tools/_louvain.py index e6259f035d..d3e616a850 100644 --- a/src/scanpy/tools/_louvain.py +++ b/src/scanpy/tools/_louvain.py @@ -163,10 +163,7 @@ def louvain( if not directed: logg.debug(" using the undirected graph") g = _utils.get_igraph_from_adjacency(adjacency, directed=directed) - if use_weights: - weights = np.array(g.es["weight"]).astype(np.float64) - else: - weights = None + weights = np.array(g.es["weight"]).astype(np.float64) if use_weights else None if flavor == "vtraag": import louvain diff --git a/src/scanpy/tools/_marker_gene_overlap.py b/src/scanpy/tools/_marker_gene_overlap.py index 1c286e9333..a1d71cf993 100644 --- a/src/scanpy/tools/_marker_gene_overlap.py +++ b/src/scanpy/tools/_marker_gene_overlap.py @@ -30,10 +30,7 @@ def _calc_overlap_count(markers1: dict, markers2: dict): overlaps = np.zeros((len(markers1), len(markers2))) for j, marker_group in enumerate(markers1): - tmp = [ - len(markers2[i].intersection(markers1[marker_group])) - for i in markers2.keys() - ] + tmp = [len(markers2[i].intersection(markers1[marker_group])) for i in markers2] overlaps[j, :] = tmp return overlaps @@ -51,7 +48,7 @@ def _calc_overlap_coef(markers1: dict, markers2: dict): tmp = [ len(markers2[i].intersection(markers1[marker_group])) / max(min(len(markers2[i]), len(markers1[marker_group])), 1) - for i in markers2.keys() + for i in markers2 ] overlap_coef[j, :] = tmp @@ -70,7 +67,7 @@ def _calc_jaccard(markers1: dict, markers2: dict): tmp = [ len(markers2[i].intersection(markers1[marker_group])) / len(markers2[i].union(markers1[marker_group])) - for i in markers2.keys() + for i in markers2 ] jacc_results[j, :] = tmp diff --git a/src/scanpy/tools/_paga.py b/src/scanpy/tools/_paga.py index 98f0ac622b..98146b83e2 100644 --- a/src/scanpy/tools/_paga.py +++ b/src/scanpy/tools/_paga.py @@ -196,10 +196,7 @@ def _compute_connectivities_v1_2(self): inter_es = inter_es.tocoo() for i, j, v in zip(inter_es.row, inter_es.col, inter_es.data): expected_random_null = (es[i] * ns[j] + es[j] * ns[i]) / (n - 1) - if expected_random_null != 0: - scaled_value = v / expected_random_null - else: - scaled_value = 1 + scaled_value = v / expected_random_null if expected_random_null != 0 else 1 if scaled_value > 1: scaled_value = 1 connectivities[i, j] = scaled_value @@ -229,10 +226,7 @@ def _compute_connectivities_v1_0(self): for i, j, v in zip(inter_es.row, inter_es.col, inter_es.data): # have n_neighbors**2 inside sqrt for backwards compat geom_mean_approx_knn = np.sqrt(n_neighbors_sq * ns[i] * ns[j]) - if geom_mean_approx_knn != 0: - scaled_value = v / geom_mean_approx_knn - else: - scaled_value = 1 + scaled_value = v / geom_mean_approx_knn if geom_mean_approx_knn != 0 else 1 connectivities[i, j] = scaled_value # set attributes self.ns = ns diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 3327fe7501..56b71d55eb 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -287,10 +287,7 @@ def wilcoxon( # initialize space for z-scores scores = np.zeros(n_genes) # initialize space for tie correction coefficients - if tie_correct: - T = np.zeros(n_genes) - else: - T = 1 + T = np.zeros(n_genes) if tie_correct else 1 for group_index, mask_obs in enumerate(self.groups_masks_obs): if group_index == self.ireference: @@ -346,10 +343,7 @@ def wilcoxon( for group_index, mask_obs in enumerate(self.groups_masks_obs): n_active = np.count_nonzero(mask_obs) - if tie_correct: - T_i = T[group_index] - else: - T_i = 1 + T_i = T[group_index] if tie_correct else 1 std_dev = np.sqrt( T_i * n_active * (n_cells - n_active) * (n_cells + 1) / 12.0 @@ -733,10 +727,7 @@ def rank_genes_groups( def _calc_frac(X): - if issparse(X): - n_nonzero = X.getnnz(axis=0) - else: - n_nonzero = np.count_nonzero(X, axis=0) + n_nonzero = X.getnnz(axis=0) if issparse(X) else np.count_nonzero(X, axis=0) return n_nonzero / X.shape[0] diff --git a/src/scanpy/tools/_sim.py b/src/scanpy/tools/_sim.py index bf562b8408..7410442952 100644 --- a/src/scanpy/tools/_sim.py +++ b/src/scanpy/tools/_sim.py @@ -198,7 +198,7 @@ def sample_dynamic_data(**params): X[::step], dir=writedir, noiseObs=noiseObs, - append=(False if restart == 0 else True), + append=restart != 0, branching=branching, nrRealizations=nrRealizations, ) @@ -208,7 +208,7 @@ def sample_dynamic_data(**params): noiseDyn * np.random.randn(500, 3), dir=writedir, noiseObs=noiseObs, - append=(False if restart == 0 else True), + append=restart != 0, branching=branching, nrRealizations=nrRealizations, ) @@ -270,7 +270,7 @@ def sample_dynamic_data(**params): X[::step], dir=writedir, noiseObs=noiseObs, - append=(False if restart == 0 else True), + append=restart != 0, branching=branching, nrRealizations=nrRealizations, ) @@ -367,7 +367,7 @@ def write_data( # variable names if varNames: header += f'{"it":>2} ' - for v in varNames.keys(): + for v in varNames: header += f"{v:>7} " with (dir / f"sim_{id}.txt").open("ab" if append else "wb") as f: np.savetxt( @@ -430,7 +430,7 @@ def __init__( # checks if initType not in ["branch", "random"]: raise RuntimeError("initType must be either: branch, random") - if model not in self.availModels.keys(): + if model not in self.availModels: message = "model not among predefined models \n" # noqa: F841 # TODO FIX # read from file from .. import sim_models @@ -605,12 +605,12 @@ def set_coupl(self, Coupl=None): or via sampling. """ self.varNames = {str(i): i for i in range(self.dim)} - if self.model not in self.availModels.keys() and Coupl is None: + if self.model not in self.availModels and Coupl is None: self.read_model() elif "var" in self.model.name: # vector auto regressive process self.Coupl = Coupl - self.boolRules = {s: "" for s in self.varNames.keys()} + self.boolRules = {s: "" for s in self.varNames} names = list(self.varNames.keys()) for gp in range(self.dim): pas = [] @@ -819,7 +819,7 @@ def parents_from_boolRule(self, rule): pa_old = [] pa_delete = [] for pa in rule_pa: - if pa not in self.varNames.keys(): + if pa not in self.varNames: settings.m(0, "list of available variables:") settings.m(0, list(self.varNames.keys())) message = ( @@ -842,12 +842,11 @@ def parents_from_boolRule(self, rule): def build_boolCoeff(self): """Compute coefficients for tuple space.""" # coefficients for hill functions from boolean update rules - self.boolCoeff = {s: [] for s in self.varNames.keys()} + self.boolCoeff = {s: [] for s in self.varNames} # parents - self.pas = {s: [] for s in self.varNames.keys()} + self.pas = {s: [] for s in self.varNames} # - for key in self.boolRules.keys(): - rule = self.boolRules[key] + for key, rule in self.boolRules.items(): self.pas[key] = self.parents_from_boolRule(rule) pasIndices = [self.varNames[pa] for pa in self.pas[key]] # check whether there are coupling matrix entries for each parent @@ -1150,11 +1149,10 @@ def sim_givenAdj(self, Adj: np.ndarray, model="line"): # if there is more than a child with a single parent # order these children (there are two in three dim) # by distance to the source/parent - if nrchildren_par[1] > 1: - if Adj[children_sorted[0], parents[0]] == 0: - help = children_sorted[0] - children_sorted[0] = children_sorted[1] - children_sorted[1] = help + if nrchildren_par[1] > 1 and Adj[children_sorted[0], parents[0]] == 0: + help = children_sorted[0] + children_sorted[0] = children_sorted[1] + children_sorted[1] = help for gp in children_sorted: for g in range(dim): diff --git a/src/scanpy/tools/_top_genes.py b/src/scanpy/tools/_top_genes.py index 00dc764d82..d66e9232f0 100644 --- a/src/scanpy/tools/_top_genes.py +++ b/src/scanpy/tools/_top_genes.py @@ -181,10 +181,7 @@ def ROC_AUC_analysis( y_true = mask for i, j in enumerate(name_list): vec = adata[:, [j]].X - if issparse(vec): - y_score = vec.todense() - else: - y_score = vec + y_score = vec.todense() if issparse(vec) else vec ( fpr[name_list[i]], diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index c23b5551fa..4f225da2a1 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -187,7 +187,7 @@ def umap( if a is None or b is None: a, b = find_ab_params(spread, min_dist) adata.uns[key_uns] = dict(params=dict(a=a, b=b)) - if isinstance(init_pos, str) and init_pos in adata.obsm.keys(): + if isinstance(init_pos, str) and init_pos in adata.obsm: init_coords = adata.obsm[init_pos] elif isinstance(init_pos, str) and init_pos == "paga": init_coords = get_init_pos_from_paga( diff --git a/src/scanpy/tools/_utils.py b/src/scanpy/tools/_utils.py index b3ffa7d324..97e2de0df1 100644 --- a/src/scanpy/tools/_utils.py +++ b/src/scanpy/tools/_utils.py @@ -30,7 +30,7 @@ def _choose_representation( use_rep = "X" if use_rep is None: if adata.n_vars > settings.N_PCS: - if "X_pca" in adata.obsm.keys(): + if "X_pca" in adata.obsm: if n_pcs is not None and n_pcs > adata.obsm["X_pca"].shape[1]: raise ValueError( "`X_pca` does not have enough PCs. Rerun `sc.pp.pca` with adjusted `n_comps`." @@ -50,7 +50,7 @@ def _choose_representation( logg.info(" using data matrix X directly") X = adata.X else: - if use_rep in adata.obsm.keys() and n_pcs is not None: + if use_rep in adata.obsm and n_pcs is not None: if n_pcs > adata.obsm[use_rep].shape[1]: raise ValueError( f"{use_rep} does not have enough Dimensions. Provide a " @@ -58,7 +58,7 @@ def _choose_representation( "`n_pcs` or lower `n_pcs` " ) X = adata.obsm[use_rep][:, :n_pcs] - elif use_rep in adata.obsm.keys() and n_pcs is None: + elif use_rep in adata.obsm and n_pcs is None: X = adata.obsm[use_rep] elif use_rep == "X": X = adata.X diff --git a/tests/notebooks/test_paga_paul15_subsampled.py b/tests/notebooks/test_paga_paul15_subsampled.py index 6d4ee886ba..9ce6ea8319 100644 --- a/tests/notebooks/test_paga_paul15_subsampled.py +++ b/tests/notebooks/test_paga_paul15_subsampled.py @@ -129,7 +129,7 @@ def test_paga_paul15_subsampled(image_comparer, plt): left_margin=0.15, n_avg=50, annotations=["distance"], - show_yticks=True if ipath == 0 else False, + show_yticks=ipath == 0, show_colorbar=False, color_map="Greys", color_maps_annotations={"distance": "viridis"}, diff --git a/tests/test_highly_variable_genes.py b/tests/test_highly_variable_genes.py index f3b9298505..cdd5238c70 100644 --- a/tests/test_highly_variable_genes.py +++ b/tests/test_highly_variable_genes.py @@ -255,7 +255,7 @@ def test_pearson_residuals_general( "residual_variances", "highly_variable_rank", ]: - assert key in output_df.keys() + assert key in output_df.columns # check consistency with normalization method if subset: @@ -324,7 +324,7 @@ def test_pearson_residuals_batch(pbmc3k_parametrized_small, subset, n_top_genes) "highly_variable_nbatches", "highly_variable_intersection", ]: - assert key in output_df.keys() + assert key in output_df.columns # general checks on ranks, hvg flag and residual variance _check_pearson_hvg_columns(output_df, n_top_genes) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 92eb61c252..b34db78780 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -1590,11 +1590,13 @@ def test_color_cycler(caplog): colors = sns.color_palette("deep") cyl = sns.rcmod.cycler("color", sns.color_palette("deep")) - with caplog.at_level(logging.WARNING): - with plt.rc_context({"axes.prop_cycle": cyl, "patch.facecolor": colors[0]}): - sc.pl.umap(pbmc, color="phase") - plt.show() - plt.close() + with ( + caplog.at_level(logging.WARNING), + plt.rc_context({"axes.prop_cycle": cyl, "patch.facecolor": colors[0]}), + ): + sc.pl.umap(pbmc, color="phase") + plt.show() + plt.close() assert caplog.text == "" diff --git a/tests/test_rank_genes_groups_logreg.py b/tests/test_rank_genes_groups_logreg.py index 3cc294487e..618de375f7 100644 --- a/tests/test_rank_genes_groups_logreg.py +++ b/tests/test_rank_genes_groups_logreg.py @@ -40,7 +40,7 @@ def test_rank_genes_groups_with_renamed_categories_use_rep(): assert adata.uns["rank_genes_groups"]["names"][0].tolist() == ("1", "3", "0") sc.tl.rank_genes_groups(adata, "blobs", method="logreg") - assert not adata.uns["rank_genes_groups"]["names"][0].tolist() == ("3", "1", "0") + assert adata.uns["rank_genes_groups"]["names"][0].tolist() != ("3", "1", "0") def test_rank_genes_groups_with_unsorted_groups(): diff --git a/tests/test_scaling.py b/tests/test_scaling.py index fad2443dc1..0ad62bbc7d 100644 --- a/tests/test_scaling.py +++ b/tests/test_scaling.py @@ -120,7 +120,7 @@ def test_mask_string(): adata.obs["some cells"] = np.array((0, 0, 1, 1, 1, 0, 0), dtype=bool) sc.pp.scale(adata, mask_obs="some cells") assert np.array_equal(adata.X, X_centered_for_mask) - assert "mean of some cells" in adata.var.keys() + assert "mean of some cells" in adata.var.columns @pytest.mark.parametrize("zero_center", [True, False])