Skip to content

Commit

Permalink
Add SIM checks (#3258)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Sep 24, 2024
1 parent 8b2088d commit d998742
Show file tree
Hide file tree
Showing 38 changed files with 160 additions and 282 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ select = [
"PLR0917", # Ban APIs with too many positional parameters
"FBT", # No positional boolean parameters
"PT", # Pytest style
"SIM", # Simplify control flow
]
ignore = [
# line too long -> we accept long comment lines; black gets rid of long code lines
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def logpath(self) -> Path | None:
def logpath(self, logpath: Path | str | None):
_type_check(logpath, "logfile", (str, Path))
# set via “file object” branch of logfile.setter
self.logfile = Path(logpath).open("a")
self.logfile = Path(logpath).open("a") # noqa: SIM115
self._logpath = Path(logpath)

@property
Expand Down Expand Up @@ -519,7 +519,7 @@ def __str__(self) -> str:
return "\n".join(
f"{k} = {v!r}"
for k, v in inspect.getmembers(self)
if not k.startswith("_") and not k == "getdoc"
if not k.startswith("_") and k != "getdoc"
)


Expand Down
17 changes: 7 additions & 10 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import sys
import warnings
from collections import namedtuple
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from enum import Enum
from functools import partial, singledispatch, wraps
from operator import mul, truediv
Expand Down Expand Up @@ -281,10 +281,8 @@ def get_igraph_from_adjacency(adjacency, directed=None):
g = ig.Graph(directed=directed)
g.add_vertices(adjacency.shape[0]) # this adds adjacency.shape[0] vertices
g.add_edges(list(zip(sources, targets)))
try:
with suppress(KeyError):
g.es["weight"] = weights
except KeyError:
pass
if g.vcount() != adjacency.shape[0]:
logg.warning(
f"The constructed graph has only {g.vcount()} nodes. "
Expand Down Expand Up @@ -613,11 +611,10 @@ def _(
out: sparse.csr_matrix | sparse.csc_matrix | None = None,
) -> sparse.csr_matrix | sparse.csc_matrix:
check_op(op)
if out is not None:
if X.data is not out.data:
raise ValueError(
"`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling."
)
if out is not None and X.data is not out.data:
raise ValueError(
"`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling."
)
if not allow_divide_by_zero and op is truediv:
scaling_array = scaling_array.copy() + (scaling_array == 0)

Expand Down Expand Up @@ -684,7 +681,7 @@ def _(
column_scale = axis == 1

if isinstance(scaling_array, DaskArray):
if (row_scale and not X.chunksize[0] == scaling_array.chunksize[0]) or (
if (row_scale and X.chunksize[0] != scaling_array.chunksize[0]) or (
column_scale
and (
(
Expand Down
5 changes: 1 addition & 4 deletions src/scanpy/_utils/compute/is_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,7 @@ def _is_constant_csr_rows(
for i in range(n):
start = indptr[i]
stop = indptr[i + 1]
if stop - start == shape[1]:
val = data[start]
else:
val = 0
val = data[start] if stop - start == shape[1] else 0
for j in range(start, stop):
if data[j] != val:
result[i] = False
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/datasets/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def moignard15() -> AnnData:
}
# annotate each observation/cell
adata.obs["exp_groups"] = [
next(gname for gname in groups.keys() if sname.startswith(gname))
next(gname for gname in groups if sname.startswith(gname))
for sname in adata.obs_names
]
# fix the order and colors of names in "groups"
Expand Down
5 changes: 1 addition & 4 deletions src/scanpy/datasets/_ebi_expression_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ def read_mtx_from_stream(stream: BinaryIO) -> sparse.csr_matrix:
n, m, _ = (int(x) for x in curline[:-1].split(b" "))

max_int32 = np.iinfo(np.int32).max
if n > max_int32 or m > max_int32:
coord_dtype = np.int64
else:
coord_dtype = np.int32
coord_dtype = np.int64 if n > max_int32 or m > max_int32 else np.int32

data = pd.read_csv(
stream,
Expand Down
5 changes: 2 additions & 3 deletions src/scanpy/external/pl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -214,10 +215,8 @@ def sam(
return axes

if isinstance(c, str):
try:
with contextlib.suppress(KeyError):
c = np.array(list(adata.obs[c]))
except KeyError:
pass

if isinstance(c[0], (str, np.str_)) and isinstance(c, (np.ndarray, list)):
import samalg.utilities as ut
Expand Down
10 changes: 2 additions & 8 deletions src/scanpy/get/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,7 @@ def _check_indices(
use_raw: bool = False,
) -> tuple[list[str], list[str], list[str]]:
"""Common logic for checking indices for obs_df and var_df."""
if use_raw:
alt_repr = "adata.raw"
else:
alt_repr = "adata"
alt_repr = "adata.raw" if use_raw else "adata"

alt_dim = ("obs", "var")[dim == "obs"]

Expand Down Expand Up @@ -288,10 +285,7 @@ def obs_df(
var = adata.raw.var
else:
var = adata.var
if gene_symbols is not None:
alias_index = pd.Index(var[gene_symbols])
else:
alias_index = None
alias_index = pd.Index(var[gene_symbols]) if gene_symbols is not None else None

obs_cols, var_idx_keys, var_symbols = _check_indices(
adata.obs,
Expand Down
11 changes: 4 additions & 7 deletions src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
from collections.abc import Mapping
from textwrap import indent
from types import MappingProxyType
Expand Down Expand Up @@ -469,10 +470,7 @@ def transitions(self) -> np.ndarray | csr_matrix:
-----
This has not been tested, in contrast to `transitions_sym`.
"""
if issparse(self.Z):
Zinv = self.Z.power(-1)
else:
Zinv = np.diag(1.0 / np.diag(self.Z))
Zinv = self.Z.power(-1) if issparse(self.Z) else np.diag(1.0 / np.diag(self.Z))
return self.Z @ self.transitions_sym @ Zinv

@property
Expand Down Expand Up @@ -591,10 +589,9 @@ def compute_neighbors(

if isinstance(index, NNDescent):
# very cautious here
try:
# TODO catch the correct exception
with contextlib.suppress(Exception):
self._rp_forest = _make_forest_dict(index)
except Exception: # TODO catch the correct exception
pass
start_connect = logg.debug("computed neighbors", time=start_neighbors)

if method == "umap":
Expand Down
75 changes: 23 additions & 52 deletions src/scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ def scatter(
if x is None or y is None:
raise ValueError("Either provide a `basis` or `x` and `y`.")
if (
(x in adata.obs.keys() or x in var_index)
and (y in adata.obs.keys() or y in var_index)
and (color is None or color in adata.obs.keys() or color in var_index)
(x in adata.obs.columns or x in var_index)
and (y in adata.obs.columns or y in var_index)
and (color is None or color in adata.obs.columns or color in var_index)
):
return _scatter_obs(**args)
if (
(x in adata.var.keys() or x in adata.obs.index)
and (y in adata.var.keys() or y in adata.obs.index)
and (color is None or color in adata.var.keys() or color in adata.obs.index)
(x in adata.var.columns or x in adata.obs.index)
and (y in adata.var.columns or y in adata.obs.index)
and (color is None or color in adata.var.columns or color in adata.obs.index)
):
adata_T = adata.T
axs = _scatter_obs(
Expand Down Expand Up @@ -217,14 +217,12 @@ def _scatter_obs(
use_raw = _check_use_raw(adata, use_raw)

# Process layers
if layers in ["X", None] or (
isinstance(layers, str) and layers in adata.layers.keys()
):
if layers in ["X", None] or (isinstance(layers, str) and layers in adata.layers):
layers = (layers, layers, layers)
elif isinstance(layers, Collection) and len(layers) == 3:
layers = tuple(layers)
for layer in layers:
if layer not in adata.layers.keys() and layer not in ["X", None]:
if layer not in adata.layers and layer not in ["X", None]:
raise ValueError(
"`layers` should have elements that are "
"either None or in adata.layers.keys()."
Expand Down Expand Up @@ -256,7 +254,7 @@ def _scatter_obs(
)
if title is not None and isinstance(title, str):
title = [title]
highlights = adata.uns["highlights"] if "highlights" in adata.uns else []
highlights = adata.uns.get("highlights", [])
if basis is not None:
try:
# ignore the '0th' diffusion component
Expand Down Expand Up @@ -292,19 +290,14 @@ def _scatter_obs(
n = Y.shape[0]
size = 120000 / n

if legend_loc.startswith("on data") and legend_fontsize is None:
legend_fontsize = rcParams["legend.fontsize"]
elif legend_fontsize is None:
if legend_fontsize is None:
legend_fontsize = rcParams["legend.fontsize"]

palette_was_none = False
if palette is None:
palette_was_none = True
if isinstance(palette, Sequence) and not isinstance(palette, str):
if not is_color_like(palette[0]):
palettes = palette
else:
palettes = [palette]
palettes = palette if not is_color_like(palette[0]) else [palette]
else:
palettes = [palette for _ in range(len(keys))]
palettes = [_utils.default_palette(palette) for palette in palettes]
Expand All @@ -328,7 +321,7 @@ def _scatter_obs(
else:
component_name = None
axis_labels = (x, y) if component_name is None else None
show_ticks = True if component_name is None else False
show_ticks = component_name is None

# generate the colors
color_ids: list[np.ndarray | ColorLike] = []
Expand Down Expand Up @@ -364,9 +357,8 @@ def _scatter_obs(
categoricals.append(ikey)
color_ids.append(c)

if right_margin is None and len(categoricals) > 0:
if legend_loc == "right margin":
right_margin = 0.5
if right_margin is None and len(categoricals) > 0 and legend_loc == "right margin":
right_margin = 0.5
if title is None and keys[0] is not None:
title = [
key.replace("_", " ") if not is_color_like(key) else "" for key in keys
Expand Down Expand Up @@ -488,10 +480,7 @@ def add_centroid(centroids, name, Y, mask):

all_pos = np.zeros((len(adata.obs[key].cat.categories), 2))
for iname, name in enumerate(adata.obs[key].cat.categories):
if name in centroids:
all_pos[iname] = centroids[name]
else:
all_pos[iname] = [np.nan, np.nan]
all_pos[iname] = centroids.get(name, [np.nan, np.nan])
if legend_loc == "on data export":
filename = settings.writedir / "pos.csv"
logg.warning(f"exporting label positions to {filename}")
Expand Down Expand Up @@ -1245,10 +1234,7 @@ def heatmap(
groupby_width = 0.2 if categorical else 0
if figsize is None:
height = 6
if show_gene_labels:
heatmap_width = len(var_names) * 0.3
else:
heatmap_width = 8
heatmap_width = len(var_names) * 0.3 if show_gene_labels else 8
width = heatmap_width + dendro_width + groupby_width
else:
width, height = figsize
Expand Down Expand Up @@ -1352,10 +1338,7 @@ def heatmap(
dendro_height = 0.8 if dendrogram else 0
groupby_height = 0.13 if categorical else 0
if figsize is None:
if show_gene_labels:
heatmap_height = len(var_names) * 0.18
else:
heatmap_height = 4
heatmap_height = len(var_names) * 0.18 if show_gene_labels else 4
width = 10
height = heatmap_height + dendro_height + groupby_height
else:
Expand Down Expand Up @@ -1440,10 +1423,7 @@ def heatmap(
for idx, (label, pos) in enumerate(
zip(var_group_labels, var_group_positions)
):
if var_groups_subset_of_groupby:
label_code = label2code[label]
else:
label_code = idx
label_code = label2code[label] if var_groups_subset_of_groupby else idx
arr += [label_code] * (pos[1] + 1 - pos[0])
gene_groups_ax.imshow(
np.array([arr]).T, aspect="auto", cmap=groupby_cmap, norm=norm
Expand Down Expand Up @@ -1892,10 +1872,7 @@ def correlation_matrix(
labels = adata.obs[groupby].cat.categories
num_rows = corr_matrix.shape[0]
colorbar_height = 0.2
if dendrogram:
dendrogram_width = 1.8
else:
dendrogram_width = 0
dendrogram_width = 1.8 if dendrogram else 0
if figsize is None:
corr_matrix_height = num_rows * 0.6
height = corr_matrix_height + colorbar_height
Expand Down Expand Up @@ -2052,7 +2029,7 @@ def _prepare_dataframe(
"groupby has to be a valid observation. "
f"Given {group}, is not in observations: {adata.obs_keys()}" + msg
)
if group in adata.obs.keys() and group == adata.obs.index.name:
if group in adata.obs.columns and group == adata.obs.index.name:
raise ValueError(
f"Given group {group} is both and index and a column level, "
"which is ambiguous."
Expand Down Expand Up @@ -2171,10 +2148,7 @@ def _plot_gene_groups_brackets(
if orientation == "top":
# rotate labels if any of them is longer than 4 characters
if rotation is None and group_labels:
if max([len(x) for x in group_labels]) > 4:
rotation = 90
else:
rotation = 0
rotation = 90 if max([len(x) for x in group_labels]) > 4 else 0
for idx in range(len(left)):
verts.append((left[idx], 0)) # lower-left
verts.append((left[idx], 0.6)) # upper-left
Expand Down Expand Up @@ -2600,11 +2574,8 @@ def _plot_categories_as_colorblocks(
)
if len(labels) > 1:
groupby_ax.set_xticks(ticks)
if max([len(str(x)) for x in labels]) < 3:
# if the labels are small do not rotate them
rotation = 0
else:
rotation = 90
# if the labels are small do not rotate them
rotation = 0 if max(len(str(x)) for x in labels) < 3 else 90
groupby_ax.set_xticklabels(labels, rotation=rotation)

# remove x ticks
Expand Down
Loading

0 comments on commit d998742

Please sign in to comment.