diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dbf943272a..d1cec5f545 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,10 @@ repos: - id: ruff args: ["--fix"] - id: ruff-format + # The following can be removed once PLR0917 is out of preview + - name: ruff preview rules + id: ruff + args: ["--preview", "--select=PLR0917"] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: diff --git a/docs/conf.py b/docs/conf.py index 2e164b9c4c..b98a8d6012 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -183,8 +183,6 @@ def setup(app: Sphinx): qualname_overrides = { "sklearn.neighbors._dist_metrics.DistanceMetric": "sklearn.metrics.DistanceMetric", - # If the docs are built with an old version of numpy, this will make it work: - "numpy.random.RandomState": "numpy.random.mtrand.RandomState", "scanpy.plotting._matrixplot.MatrixPlot": "scanpy.pl.MatrixPlot", "scanpy.plotting._dotplot.DotPlot": "scanpy.pl.DotPlot", "scanpy.plotting._stacked_violin.StackedViolin": "scanpy.pl.StackedViolin", @@ -192,12 +190,14 @@ def setup(app: Sphinx): } nitpick_ignore = [ + # Technical issues + ("py:class", "numpy.int64"), # documented as “attribute” # Will probably be documented ("py:class", "scanpy._settings.Verbosity"), ("py:class", "scanpy.neighbors.OnFlySymMatrix"), # Currently undocumented # https://github.com/mwaskom/seaborn/issues/1810 - ("py:class", "seaborn.ClusterGrid"), + ("py:class", "seaborn.matrix.ClusterGrid"), ("py:class", "samalg.SAM"), # Won’t be documented ("py:class", "scanpy.plotting._utils._AxesSubplot"), diff --git a/docs/extensions/cite.py b/docs/extensions/cite.py index 5db46edc47..3b8afd34a7 100644 --- a/docs/extensions/cite.py +++ b/docs/extensions/cite.py @@ -15,7 +15,7 @@ from sphinx.application import Sphinx -def cite_role( +def cite_role( # noqa: PLR0917 name: str, rawsource: str, text: str, diff --git a/docs/extensions/debug_docstrings.py b/docs/extensions/debug_docstrings.py index 87bc210cef..208a8793b0 100644 --- a/docs/extensions/debug_docstrings.py +++ b/docs/extensions/debug_docstrings.py @@ -13,7 +13,7 @@ _pd_orig = sphinx.ext.napoleon._process_docstring -def pd_new(app, what, name, obj, options, lines): +def pd_new(app, what, name, obj, options, lines): # noqa: PLR0917 _pd_orig(app, what, name, obj, options, lines) print(*lines, sep="\n") diff --git a/docs/extensions/function_images.py b/docs/extensions/function_images.py index 7042daf1ee..0503be39ce 100644 --- a/docs/extensions/function_images.py +++ b/docs/extensions/function_images.py @@ -9,7 +9,7 @@ from sphinx.ext.autodoc import Options -def insert_function_images( +def insert_function_images( # noqa: PLR0917 app: Sphinx, what: str, name: str, obj: Any, options: Options, lines: list[str] ): path = app.config.api_dir / f"{name}.png" diff --git a/docs/release-notes/1.10.0.md b/docs/release-notes/1.10.0.md index 856a66b233..013f69d9c2 100644 --- a/docs/release-notes/1.10.0.md +++ b/docs/release-notes/1.10.0.md @@ -34,3 +34,5 @@ ``` * Dropped support for Python 3.8. [More details here](https://numpy.org/neps/nep-0029-deprecation_policy.html). {pr}`2695` {smaller}`P Angerer` +* Deprecated specifying large numbers of function parameters by position as opposed to by name/keyword in all public APIs. + e.g. prefer `sc.tl.umap(adata, min_dist=0.1, spread=0.8)` over `sc.tl.umap(adata, 0.1, 0.8)` {pr}`2702` {smaller}`P Angerer` diff --git a/pyproject.toml b/pyproject.toml index b585e85ad5..ba0999437a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ dependencies = [ "umap-learn>=0.3.10", "packaging", "session-info", + "legacy-api-wrap>=1.4", # for positional API deprecations "get-annotations; python_version < '3.10'", ] dynamic = ["version"] @@ -168,6 +169,8 @@ markers = [ "gpu: tests that use a GPU (currently unused, but needs to be specified here as we import anndata.tests.helpers, which uses it)", ] filterwarnings = [ + # legacy-api-wrap: internal use of positional API + "error:The specified parameters:FutureWarning", # When calling `.show()` in tests, this is raised "ignore:FigureCanvasAgg is non-interactive:UserWarning", # We explicitly handle these errors in tests @@ -202,6 +205,7 @@ select = [ "TID251", # Banned imports "ICN", # Follow import conventions "PTH", # Pathlib instead of os.path + "PLR0917", # Ban APIs with too many positional parameters ] ignore = [ # line too long -> we accept long comment lines; black gets rid of long code lines diff --git a/scanpy/_compat.py b/scanpy/_compat.py index ab254e2b28..244f8588fa 100644 --- a/scanpy/_compat.py +++ b/scanpy/_compat.py @@ -1,8 +1,10 @@ from __future__ import annotations from dataclasses import dataclass, field +from functools import partial from pathlib import Path +from legacy_api_wrap import legacy_api from packaging import version try: @@ -61,3 +63,6 @@ def pkg_version(package): from importlib.metadata import version as v return version.parse(v(package)) + + +old_positionals = partial(legacy_api, category=FutureWarning) diff --git a/scanpy/_settings.py b/scanpy/_settings.py index 9bcf30760a..04c94b6204 100644 --- a/scanpy/_settings.py +++ b/scanpy/_settings.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, TextIO, Union from . import logging +from ._compat import old_positionals from .logging import _RootLogger, _set_log_file, _set_log_level if TYPE_CHECKING: @@ -36,6 +37,8 @@ class Verbosity(IntEnum): + """Logging verbosity levels.""" + error = 0 warning = 1 info = 2 @@ -102,14 +105,14 @@ def __init__( file_format_figs: str = "pdf", autosave: bool = False, autoshow: bool = True, - writedir: str | Path = "./write/", - cachedir: str | Path = "./cache/", - datasetdir: str | Path = "./data/", - figdir: str | Path = "./figures/", + writedir: Path | str = "./write/", + cachedir: Path | str = "./cache/", + datasetdir: Path | str = "./data/", + figdir: Path | str = "./figures/", cache_compression: str | None = "lzf", max_memory=15, n_jobs=1, - logfile: str | Path | None = None, + logfile: Path | str | None = None, categories_to_ignore: Iterable[str] = ("N/A", "dontknow", "no_gate", "?"), _frameon: bool = True, _vector_friendly: bool = False, @@ -269,7 +272,7 @@ def writedir(self) -> Path: return self._writedir @writedir.setter - def writedir(self, writedir: str | Path): + def writedir(self, writedir: Path | str): _type_check(writedir, "writedir", (str, Path)) self._writedir = Path(writedir) @@ -281,7 +284,7 @@ def cachedir(self) -> Path: return self._cachedir @cachedir.setter - def cachedir(self, cachedir: str | Path): + def cachedir(self, cachedir: Path | str): _type_check(cachedir, "cachedir", (str, Path)) self._cachedir = Path(cachedir) @@ -293,7 +296,7 @@ def datasetdir(self) -> Path: return self._datasetdir @datasetdir.setter - def datasetdir(self, datasetdir: str | Path): + def datasetdir(self, datasetdir: Path | str): _type_check(datasetdir, "datasetdir", (str, Path)) self._datasetdir = Path(datasetdir).resolve() @@ -305,7 +308,7 @@ def figdir(self) -> Path: return self._figdir @figdir.setter - def figdir(self, figdir: str | Path): + def figdir(self, figdir: Path | str): _type_check(figdir, "figdir", (str, Path)) self._figdir = Path(figdir) @@ -365,7 +368,7 @@ def logpath(self) -> Path | None: return self._logpath @logpath.setter - def logpath(self, logpath: str | Path | None): + def logpath(self, logpath: Path | str | None): _type_check(logpath, "logfile", (str, Path)) # set via “file object” branch of logfile.setter self.logfile = Path(logpath).open("a") @@ -385,7 +388,7 @@ def logfile(self) -> TextIO: return self._logfile @logfile.setter - def logfile(self, logfile: str | Path | TextIO | None): + def logfile(self, logfile: Path | str | TextIO | None): if not hasattr(logfile, "write") and logfile: self.logpath = logfile else: # file object @@ -413,8 +416,23 @@ def categories_to_ignore(self, categories_to_ignore: Iterable[str]): # Functions # -------------------------------------------------------------------------------- + @old_positionals( + "scanpy", + "dpi", + "dpi_save", + "frameon", + "vector_friendly", + "fontsize", + "figsize", + "color_map", + "format", + "facecolor", + "transparent", + "ipython_format", + ) def set_figure_params( self, + *, scanpy: bool = True, dpi: int = 80, dpi_save: int = 150, @@ -427,7 +445,7 @@ def set_figure_params( facecolor: str | None = None, transparent: bool = False, ipython_format: str = "png2x", - ): + ) -> None: """\ Set resolution/size, styling and format of figures. diff --git a/scanpy/_utils/__init__.py b/scanpy/_utils/__init__.py index 20fac13363..81e5de4c9f 100644 --- a/scanpy/_utils/__init__.py +++ b/scanpy/_utils/__init__.py @@ -45,7 +45,7 @@ def __repr__(self) -> str: _empty = Empty.token # e.g. https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html -AnyRandom = Union[None, int, random.RandomState] # maybe in the future random.Generator +AnyRandom = Union[int, random.RandomState, None] # maybe in the future random.Generator EPS = 1e-15 @@ -96,42 +96,56 @@ def type_doc(name: str): ) -def deprecated_arg_names(arg_mapping: Mapping[str, str]): - """ - Decorator which marks a functions keyword arguments as deprecated. It will - result in a warning being emitted when the deprecated keyword argument is - used, and the function being called with the new argument. - - Parameters - ---------- - arg_mapping - Mapping from deprecated argument name to current argument name. - """ - +def renamed_arg(old_name, new_name, *, pos_0: bool = False): def decorator(func): @wraps(func) - def func_wrapper(*args, **kwargs): - warnings.simplefilter("always", DeprecationWarning) # turn off filter - for old, new in arg_mapping.items(): - if old in kwargs: - warnings.warn( - f"Keyword argument '{old}' has been " - f"deprecated in favour of '{new}'. " - f"'{old}' will be removed in a future version.", - category=DeprecationWarning, - stacklevel=2, + def wrapper(*args, **kwargs): + if old_name in kwargs: + f_name = func.__name__ + pos_str = ( + ( + f" at first position. Call it as `{f_name}(val, ...)` " + f"instead of `{f_name}({old_name}=val, ...)`" ) - val = kwargs.pop(old) - kwargs[new] = val - # reset filter - warnings.simplefilter("default", DeprecationWarning) + if pos_0 + else "" + ) + msg = ( + f"In function `{f_name}`, argument `{old_name}` " + f"was renamed to `{new_name}`{pos_str}." + ) + warnings.warn(msg, FutureWarning, stacklevel=3) + if pos_0: + args = (kwargs.pop(old_name), *args) + else: + kwargs[new_name] = kwargs.pop(old_name) return func(*args, **kwargs) - return func_wrapper + return wrapper return decorator +def _import_name(name: str) -> Any: + from importlib import import_module + + parts = name.split(".") + obj = import_module(parts[0]) + for i, name in enumerate(parts[1:]): + try: + obj = import_module(f"{obj.__name__}.{name}") + except ModuleNotFoundError: + break + else: + i = len(parts) + for name in parts[i + 1 :]: + try: + obj = getattr(obj, name) + except AttributeError: + raise RuntimeError(f"{parts[:i]}, {parts[i + 1:]}, {obj} {name}") + return obj + + def _one_of_ours(obj, root: str): return ( hasattr(obj, "__name__") @@ -146,19 +160,19 @@ def descend_classes_and_funcs(mod: ModuleType, root: str, encountered=None): if encountered is None: encountered = WeakSet() for obj in vars(mod).values(): - if not _one_of_ours(obj, root): + if not _one_of_ours(obj, root) or obj in encountered: continue + encountered.add(obj) if callable(obj) and not isinstance(obj, MethodType): yield obj if isinstance(obj, type): for m in vars(obj).values(): if callable(m) and _one_of_ours(m, root): yield m - elif isinstance(obj, ModuleType) and obj not in encountered: + elif isinstance(obj, ModuleType): if obj.__name__.startswith("scanpy.tests"): # Python’s import mechanism seems to add this to `scanpy`’s attributes continue - encountered.add(obj) yield from descend_classes_and_funcs(obj, root, encountered) @@ -264,6 +278,7 @@ def compute_association_matrix_of_groups( adata: AnnData, prediction: str, reference: str, + *, normalization: Literal["prediction", "reference"] = "prediction", threshold: float = 0.01, max_n_names: int | None = 2, @@ -595,7 +610,7 @@ def select_groups( return groups_order_subset, groups_masks -def warn_with_traceback(message, category, filename, lineno, file=None, line=None): +def warn_with_traceback(message, category, filename, lineno, file=None, line=None): # noqa: PLR0917 """Get full tracebacks when warning is raised by setting warnings.showwarning = warn_with_traceback diff --git a/scanpy/datasets/_datasets.py b/scanpy/datasets/_datasets.py index f190a0eb7f..f040eb7a52 100644 --- a/scanpy/datasets/_datasets.py +++ b/scanpy/datasets/_datasets.py @@ -10,6 +10,7 @@ from .. import _utils from .. import logging as logg +from .._compat import old_positionals from .._settings import settings from ..readwrite import read, read_visium from ._utils import check_datasetdir_exists, filter_oldformatwarning @@ -20,7 +21,11 @@ HERE = Path(__file__).parent +@old_positionals( + "n_variables", "n_centers", "cluster_std", "n_observations", "random_state" +) def blobs( + *, n_variables: int = 11, n_centers: int = 5, cluster_std: float = 1.0, @@ -192,7 +197,7 @@ def paul15() -> ad.AnnData: # names reflecting the cell type identifications from the paper cell_type = 6 * ["Ery"] cell_type += "MEP Mk GMP GMP DC Baso Baso Mo Mo Neu Neu Eos Lymph".split() - adata.obs["paul15_clusters"] = [f"{i}{cell_type[i-1]}" for i in clusters] + adata.obs["paul15_clusters"] = [f"{i}{cell_type[i - 1]}" for i in clusters] # make string annotations categorical (optional) _utils.sanitize_anndata(adata) # just keep the first of the two equivalent names per gene diff --git a/scanpy/experimental/pp/_highly_variable_genes.py b/scanpy/experimental/pp/_highly_variable_genes.py index 69854c965d..c8192c9381 100644 --- a/scanpy/experimental/pp/_highly_variable_genes.py +++ b/scanpy/experimental/pp/_highly_variable_genes.py @@ -131,6 +131,7 @@ def clac_clipped_res_dense(gene: int, cell: int) -> np.float64: def _highly_variable_pearson_residuals( adata: AnnData, + *, theta: float = 100, clip: float | None = None, n_top_genes: int = 1000, diff --git a/scanpy/experimental/pp/_normalization.py b/scanpy/experimental/pp/_normalization.py index 0287c3d8f6..614ef3b704 100644 --- a/scanpy/experimental/pp/_normalization.py +++ b/scanpy/experimental/pp/_normalization.py @@ -33,7 +33,7 @@ from collections.abc import Mapping -def _pearson_residuals(X, theta, clip, check_values, copy=False): +def _pearson_residuals(X, theta, clip, check_values, copy: bool = False): X = X.copy() if copy else X # check theta @@ -90,7 +90,7 @@ def normalize_pearson_residuals( layer: str | None = None, inplace: bool = True, copy: bool = False, -) -> dict[str, np.ndarray] | None: +) -> AnnData | dict[str, np.ndarray] | None: """\ Applies analytic Pearson residual normalization, based on [Lause21]_. diff --git a/scanpy/external/exporting.py b/scanpy/external/exporting.py index 7f7158bd4f..80d091e13d 100644 --- a/scanpy/external/exporting.py +++ b/scanpy/external/exporting.py @@ -14,6 +14,7 @@ import scipy.sparse from pandas.api.types import CategoricalDtype +from .._compat import old_positionals from .._utils import NeighborsView from ..preprocessing._utils import _get_mean_var @@ -22,18 +23,29 @@ from anndata import AnnData +__all__ = ["spring_project", "cellbrowser"] + +@old_positionals( + "subplot_name", + "cell_groupings", + "custom_color_tracks", + "total_counts_key", + "neighbors_key", + "overwrite", +) def spring_project( adata: AnnData, project_dir: Path | str, embedding_method: str, + *, subplot_name: str | None = None, cell_groupings: str | Iterable[str] | None = None, custom_color_tracks: str | Iterable[str] | None = None, total_counts_key: str = "n_counts", neighbors_key: str | None = None, overwrite: bool = False, -): +) -> None: """\ Exports to a SPRING project directory [Weinreb17]_. @@ -96,8 +108,8 @@ def spring_project( # Make project directory and subplot directory (subplot has same name as project) # For now, the subplot is just all cells in adata - project_dir: Path = Path(project_dir) - subplot_dir: Path = ( + project_dir = Path(project_dir) + subplot_dir = ( project_dir.parent if subplot_name is None else project_dir / subplot_name ) subplot_dir.mkdir(parents=True, exist_ok=True) @@ -472,10 +484,21 @@ def _export_PAGA_to_SPRING(adata, paga_coords, outpath): return None +@old_positionals( + "embedding_keys", + "annot_keys", + "cluster_field", + "nb_marker", + "skip_matrix", + "html_dir", + "port", + "do_debug", +) def cellbrowser( adata: AnnData, data_dir: Path | str, data_name: str, + *, embedding_keys: Iterable[str] | Mapping[str, str] | str | None = None, annot_keys: Iterable[str] | Mapping[str, str] | None = ( "louvain", diff --git a/scanpy/external/pl.py b/scanpy/external/pl.py index 021d9d96c9..9e65f9a305 100644 --- a/scanpy/external/pl.py +++ b/scanpy/external/pl.py @@ -5,8 +5,10 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +from anndata import AnnData # noqa: TCH002 from matplotlib.axes import Axes # noqa: TCH002 +from .._compat import old_positionals from .._utils import _doc_params from ..plotting import _utils, embedding from ..plotting._docs import ( @@ -22,7 +24,15 @@ if TYPE_CHECKING: from collections.abc import Collection - from anndata import AnnData + +__all__ = [ + "phate", + "trimap", + "harmony_timeseries", + "sam", + "wishbone_marker_trajectory", + "scrublet_score_distribution", +] @doctest_needs("phate") @@ -33,7 +43,7 @@ scatter_bulk=doc_scatter_embedding, show_save_ax=doc_show_save_ax, ) -def phate(adata, **kwargs) -> list[Axes] | None: +def phate(adata: AnnData, **kwargs) -> list[Axes] | None: """\ Scatter plot in PHATE basis. @@ -83,7 +93,7 @@ def phate(adata, **kwargs) -> list[Axes] | None: scatter_bulk=doc_scatter_embedding, show_save_ax=doc_show_save_ax, ) -def trimap(adata, **kwargs) -> Axes | list[Axes] | None: +def trimap(adata: AnnData, **kwargs) -> Axes | list[Axes] | None: """\ Scatter plot in TriMap basis. @@ -109,7 +119,7 @@ def trimap(adata, **kwargs) -> Axes | list[Axes] | None: show_save_ax=doc_show_save_ax, ) def harmony_timeseries( - adata, *, show: bool = True, return_fig: bool = False, **kwargs + adata: AnnData, *, show: bool = True, return_fig: bool = False, **kwargs ) -> Axes | list[Axes] | None: """\ Scatter plot in Harmony force-directed layout basis. @@ -145,13 +155,16 @@ def harmony_timeseries( p.set_axis_off() if return_fig: return fig - elif not show: - return axes + if show: + return None + return axes +@old_positionals("c", "cmap", "linewidth", "edgecolor", "axes", "colorbar", "s") def sam( adata: AnnData, projection: str | np.ndarray = "X_umap", + *, c: str | np.ndarray | None = None, cmap: str = "Spectral_r", linewidth: float = 0.0, @@ -246,10 +259,22 @@ def sam( return axes +@old_positionals( + "no_bins", + "smoothing_factor", + "min_delta", + "show_variance", + "figsize", + "return_fig", + "show", + "save", + "ax", +) @_doc_params(show_save_ax=doc_show_save_ax) def wishbone_marker_trajectory( adata: AnnData, markers: Collection[str], + *, no_bins: int = 150, smoothing_factor: int = 1, min_delta: float = 0.1, @@ -329,12 +354,17 @@ def wishbone_marker_trajectory( if return_fig: return fig - elif not show: - return ax + if show: + return None + return ax +@old_positionals( + "scale_hist_obs", "scale_hist_sim", "figsize", "return_fig", "show", "save" +) def scrublet_score_distribution( - adata, + adata: AnnData, + *, scale_hist_obs: str = "log", scale_hist_sim: str = "linear", figsize: tuple[float, float] | None = (8, 3), @@ -464,5 +494,6 @@ def _plot_scores( _utils.savefig_or_show("scrublet_score_distribution", show=show, save=save) if return_fig: return fig - elif not show: - return axs + if show: + return None + return axs diff --git a/scanpy/external/pp/_bbknn.py b/scanpy/external/pp/_bbknn.py index 7c3d6e6d15..f0dce56585 100644 --- a/scanpy/external/pp/_bbknn.py +++ b/scanpy/external/pp/_bbknn.py @@ -2,23 +2,25 @@ from typing import TYPE_CHECKING, Callable +from ..._compat import old_positionals +from ...testing._doctests import doctest_needs + if TYPE_CHECKING: from anndata import AnnData from sklearn.metrics import DistanceMetric -from ...testing._doctests import doctest_needs - +@old_positionals("batch_key", "use_rep", "approx", "use_annoy", "metric", "copy") @doctest_needs("bbknn") def bbknn( adata: AnnData, + *, batch_key: str = "batch", use_rep: str = "X_pca", approx: bool = True, use_annoy: bool = True, metric: str | Callable | DistanceMetric = "euclidean", copy: bool = False, - *, neighbors_within_batch: int = 3, n_pcs: int = 50, trim: int | None = None, @@ -29,7 +31,7 @@ def bbknn( set_op_mix_ratio: float = 1.0, local_connectivity: int = 1, **kwargs, -) -> AnnData: +) -> AnnData | None: """\ Batch balanced kNN [Polanski19]_. diff --git a/scanpy/external/pp/_dca.py b/scanpy/external/pp/_dca.py index 30e2735a4a..105b25dbd1 100644 --- a/scanpy/external/pp/_dca.py +++ b/scanpy/external/pp/_dca.py @@ -3,6 +3,8 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, Literal +from ..._compat import old_positionals + if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -13,9 +15,35 @@ _AEType = Literal["zinb-conddisp", "zinb", "nb-conddisp", "nb"] +@old_positionals( + "ae_type", + "normalize_per_cell", + "scale", + "log1p", + "hidden_size", + "hidden_dropout", + "batchnorm", + "activation", + "init", + "network_kwds", + "epochs", + "reduce_lr", + "early_stop", + "batch_size", + "optimizer", + "random_state", + "threads", + "learning_rate", + "verbose", + "training_kwds", + "return_model", + "return_info", + "copy", +) def dca( adata: AnnData, mode: Literal["denoise", "latent"] = "denoise", + *, ae_type: _AEType = "nb-conddisp", normalize_per_cell: bool = True, scale: bool = True, diff --git a/scanpy/external/pp/_harmony_integrate.py b/scanpy/external/pp/_harmony_integrate.py index 63847f9ee1..5ebd79b77c 100644 --- a/scanpy/external/pp/_harmony_integrate.py +++ b/scanpy/external/pp/_harmony_integrate.py @@ -7,16 +7,19 @@ import numpy as np +from ..._compat import old_positionals from ...testing._doctests import doctest_needs if TYPE_CHECKING: from anndata import AnnData +@old_positionals("basis", "adjusted_basis") @doctest_needs("harmonypy") def harmony_integrate( adata: AnnData, key: str, + *, basis: str = "X_pca", adjusted_basis: str = "X_pca_harmony", **kwargs, diff --git a/scanpy/external/pp/_hashsolo.py b/scanpy/external/pp/_hashsolo.py index 1adb176564..9d18206448 100644 --- a/scanpy/external/pp/_hashsolo.py +++ b/scanpy/external/pp/_hashsolo.py @@ -23,11 +23,14 @@ import pandas as pd from scipy.stats import norm +from ..._compat import old_positionals from ..._utils import check_nonnegative_integers from ...testing._doctests import doctest_skip if TYPE_CHECKING: - import anndata + from collections.abc import Sequence + + from anndata import AnnData def _calculate_log_likelihoods(data, number_of_noise_barcodes): @@ -261,15 +264,19 @@ def _calculate_bayes_rule(data, priors, number_of_noise_barcodes): } +@old_positionals( + "priors", "pre_existing_clusters", "number_of_noise_barcodes", "inplace" +) @doctest_skip("Illustrative but not runnable doctest code") def hashsolo( - adata: anndata.AnnData, - cell_hashing_columns: list, - priors: list = [0.01, 0.8, 0.19], - pre_existing_clusters: str = None, - number_of_noise_barcodes: int = None, + adata: AnnData, + cell_hashing_columns: Sequence[str], + *, + priors: tuple[float, float, float] = (0.01, 0.8, 0.19), + pre_existing_clusters: str | None = None, + number_of_noise_barcodes: int | None = None, inplace: bool = True, -): +) -> AnnData | None: """Probabilistic demultiplexing of cell hashing data using HashSolo [Bernstein20]_. .. note:: @@ -281,9 +288,9 @@ def hashsolo( The (annotated) data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. cell_hashing_columns - A list specifying `.obs` columns that contain cell hashing counts. + `.obs` columns that contain cell hashing counts. priors - A list specifying the prior probability of each hypothesis, in + Prior probabilities of each hypothesis, in the order `[negative, singlet, doublet]`. The default is set to `[0.01, 0.8, 0.19]` assuming barcode counts are from cells that have passed QC in the transcriptome space, e.g. UMI counts, pct diff --git a/scanpy/external/pp/_scanorama_integrate.py b/scanpy/external/pp/_scanorama_integrate.py index 4f57db0934..7718e9ded5 100644 --- a/scanpy/external/pp/_scanorama_integrate.py +++ b/scanpy/external/pp/_scanorama_integrate.py @@ -7,16 +7,21 @@ import numpy as np +from ..._compat import old_positionals from ...testing._doctests import doctest_needs if TYPE_CHECKING: from anndata import AnnData +@old_positionals( + "basis", "adjusted_basis", "knn", "sigma", "approx", "alpha", "batch_size" +) @doctest_needs("scanorama") def scanorama_integrate( adata: AnnData, key: str, + *, basis: str = "X_pca", adjusted_basis: str = "X_scanorama", knn: int = 20, @@ -25,7 +30,7 @@ def scanorama_integrate( alpha: float = 0.10, batch_size: int = 5000, **kwargs, -): +) -> None: """\ Use Scanorama [Hie19]_ to integrate different experiments. diff --git a/scanpy/external/pp/_scrublet.py b/scanpy/external/pp/_scrublet.py index f24950a5b0..0e21352dc5 100644 --- a/scanpy/external/pp/_scrublet.py +++ b/scanpy/external/pp/_scrublet.py @@ -7,13 +7,35 @@ from ... import logging as logg from ... import preprocessing as pp +from ..._compat import old_positionals from ...get import _get_obs_rep +@old_positionals( + "adata_sim", + "batch_key", + "sim_doublet_ratio", + "expected_doublet_rate", + "stdev_doublet_rate", + "synthetic_doublet_umi_subsampling", + "knn_dist_metric", + "normalize_variance", + "log_transform", + "mean_center", + "n_prin_comps", + "use_approx_neighbors", + "get_doublet_neighbor_parents", + "n_neighbors", + "threshold", + "verbose", + "copy", + "random_state", +) def scrublet( adata: AnnData, + *, adata_sim: AnnData | None = None, - batch_key: str = None, + batch_key: str | None = None, sim_doublet_ratio: float = 2.0, expected_doublet_rate: float = 0.05, stdev_doublet_rate: float = 0.02, @@ -278,6 +300,7 @@ def _run_scrublet(ad_obs, ad_sim=None): def _scrublet_call_doublets( + *, adata_obs: AnnData, adata_sim: AnnData, n_neighbors: int | None = None, @@ -496,9 +519,13 @@ def _scrublet_call_doublets( return adata_obs +@old_positionals( + "layer", "sim_doublet_ratio", "synthetic_doublet_umi_subsampling", "random_seed" +) def scrublet_simulate_doublets( adata: AnnData, - layer=None, + *, + layer: str | None = None, sim_doublet_ratio: float = 2.0, synthetic_doublet_umi_subsampling: float = 1.0, random_seed: int = 0, diff --git a/scanpy/external/tl/_harmony_timeseries.py b/scanpy/external/tl/_harmony_timeseries.py index 9c873868b7..ce904fd0cb 100644 --- a/scanpy/external/tl/_harmony_timeseries.py +++ b/scanpy/external/tl/_harmony_timeseries.py @@ -10,16 +10,19 @@ import pandas as pd from ... import logging as logg +from ..._compat import old_positionals from ...testing._doctests import doctest_needs if TYPE_CHECKING: from anndata import AnnData +@old_positionals("n_neighbors", "n_components", "n_jobs", "copy") @doctest_needs("harmony") def harmony_timeseries( adata: AnnData, tp: str, + *, n_neighbors: int = 30, n_components: int | None = 1000, n_jobs: int = -2, diff --git a/scanpy/external/tl/_palantir.py b/scanpy/external/tl/_palantir.py index 5510c0c287..0cf613b605 100644 --- a/scanpy/external/tl/_palantir.py +++ b/scanpy/external/tl/_palantir.py @@ -8,21 +8,34 @@ import pandas as pd from ... import logging as logg +from ..._compat import old_positionals from ...testing._doctests import doctest_needs if TYPE_CHECKING: from anndata import AnnData +@old_positionals( + "n_components", + "knn", + "alpha", + "use_adjacency_matrix", + "distances_key", + "n_eigs", + "impute_data", + "n_steps", + "copy", +) @doctest_needs("palantir") def palantir( adata: AnnData, + *, n_components: int = 10, knn: int = 30, alpha: float = 0, use_adjacency_matrix: bool = False, distances_key: str | None = None, - n_eigs: int = None, + n_eigs: int | None = None, impute_data: bool = True, n_steps: int = 3, copy: bool = False, @@ -243,11 +256,22 @@ def palantir( return adata if copy else None +@old_positionals( + "ms_data", + "terminal_states", + "knn", + "num_waypoints", + "n_jobs", + "scale_components", + "use_early_cell_as_start", + "max_iterations", +) def palantir_results( adata: AnnData, early_cell: str, + *, ms_data: str = "X_palantir_multiscale", - terminal_states: list = None, + terminal_states: list | None = None, knn: int = 30, num_waypoints: int = 1200, n_jobs: int = -1, diff --git a/scanpy/external/tl/_phate.py b/scanpy/external/tl/_phate.py index a4731cae69..e926b7ad11 100644 --- a/scanpy/external/tl/_phate.py +++ b/scanpy/external/tl/_phate.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Literal from ... import logging as logg +from ..._compat import old_positionals from ..._settings import settings from ...testing._doctests import doctest_needs @@ -15,10 +16,26 @@ from ..._utils import AnyRandom +@old_positionals( + "k", + "a", + "n_landmark", + "t", + "gamma", + "n_pca", + "knn_dist", + "mds_dist", + "mds", + "n_jobs", + "random_state", + "verbose", + "copy", +) @doctest_needs("phate") def phate( adata: AnnData, n_components: int = 2, + *, k: int = 5, a: int = 15, n_landmark: int = 2000, diff --git a/scanpy/external/tl/_phenograph.py b/scanpy/external/tl/_phenograph.py index 809a7dcab3..753e84b785 100644 --- a/scanpy/external/tl/_phenograph.py +++ b/scanpy/external/tl/_phenograph.py @@ -9,6 +9,8 @@ from anndata import AnnData from ... import logging as logg +from ..._compat import old_positionals +from ..._utils import renamed_arg from ...testing._doctests import doctest_needs if TYPE_CHECKING: @@ -18,10 +20,30 @@ from ...tools._leiden import MutableVertexPartition +@renamed_arg("adata", "data", pos_0=True) +@old_positionals( + "k", + "directed", + "prune", + "min_cluster_size", + "jaccard", + "primary_metric", + "n_jobs", + "q_tol", + "louvain_time_limit", + "nn_method", + "partition_type", + "resolution_parameter", + "n_iterations", + "use_weights", + "seed", + "copy", +) @doctest_needs("phenograph") def phenograph( - adata: AnnData | np.ndarray | spmatrix, + data: AnnData | np.ndarray | spmatrix, clustering_algo: Literal["louvain", "leiden"] | None = "louvain", + *, k: int = 30, directed: bool = False, prune: bool = False, @@ -44,7 +66,7 @@ def phenograph( seed: int | None = None, copy: bool = False, **kargs: Any, -) -> tuple[np.ndarray | None, spmatrix, float | None]: +) -> tuple[np.ndarray | None, spmatrix, float | None] | None: """\ PhenoGraph clustering [Levine15]_. @@ -63,7 +85,7 @@ def phenograph( Parameters ---------- - adata + data AnnData, or Array of data to cluster, or sparse matrix of k-nearest neighbor graph. If ndarray, n-by-d array of n cells in d dimensions. if sparse matrix, n-by-n adjacency matrix. @@ -206,13 +228,14 @@ def phenograph( "pip install -U PhenoGraph" ) - if isinstance(adata, AnnData): + if isinstance(data, AnnData): + adata = data try: - data = adata.obsm["X_pca"] + data = data.obsm["X_pca"] except KeyError: - raise KeyError("Please run `sc.pp.pca` on `adata` and try again!") + raise KeyError("Please run `sc.pp.pca` on `data` and try again!") else: - data = adata + adata = None copy = True comm_key = ( @@ -246,7 +269,8 @@ def phenograph( if copy: return communities, graph, Q - else: + + if adata is not None: adata.obsp[ig_key] = graph.tocsr() if comm_key: adata.obs[comm_key] = pd.Categorical(communities) diff --git a/scanpy/external/tl/_sam.py b/scanpy/external/tl/_sam.py index f74da3af9a..086a49b2d1 100644 --- a/scanpy/external/tl/_sam.py +++ b/scanpy/external/tl/_sam.py @@ -5,17 +5,33 @@ from typing import TYPE_CHECKING, Literal +from ... import logging as logg +from ..._compat import old_positionals +from ...testing._doctests import doctest_needs + if TYPE_CHECKING: from anndata import AnnData from samalg import SAM -from ... import logging as logg -from ...testing._doctests import doctest_needs - +@old_positionals( + "max_iter", + "num_norm_avg", + "k", + "distance", + "standardization", + "weight_pcs", + "sparse_pca", + "n_pcs", + "n_genes", + "projection", + "inplace", + "verbose", +) @doctest_needs("samalg") def sam( adata: AnnData, + *, max_iter: int = 10, num_norm_avg: int = 50, k: int = 20, diff --git a/scanpy/external/tl/_trimap.py b/scanpy/external/tl/_trimap.py index d9a01a7921..84a6d9e5f6 100644 --- a/scanpy/external/tl/_trimap.py +++ b/scanpy/external/tl/_trimap.py @@ -8,6 +8,7 @@ import scipy.sparse as scp from ... import logging as logg +from ..._compat import old_positionals from ..._settings import settings from ...testing._doctests import doctest_needs @@ -15,10 +16,22 @@ from anndata import AnnData +@old_positionals( + "n_inliers", + "n_outliers", + "n_random", + "metric", + "weight_adj", + "lr", + "n_iters", + "verbose", + "copy", +) @doctest_needs("trimap") def trimap( adata: AnnData, n_components: int = 2, + *, n_inliers: int = 10, n_outliers: int = 5, n_random: int = 5, diff --git a/scanpy/external/tl/_wishbone.py b/scanpy/external/tl/_wishbone.py index 4f9f7d7f61..dec93d458c 100644 --- a/scanpy/external/tl/_wishbone.py +++ b/scanpy/external/tl/_wishbone.py @@ -7,6 +7,7 @@ import pandas as pd from ... import logging +from ..._compat import old_positionals from ...testing._doctests import doctest_needs if TYPE_CHECKING: @@ -15,10 +16,12 @@ from anndata import AnnData +@old_positionals("branch", "k", "components", "num_waypoints") @doctest_needs("wishbone") def wishbone( adata: AnnData, start_cell: str, + *, branch: bool = True, k: int = 15, components: Iterable[int] = (1, 2, 3), diff --git a/scanpy/get/get.py b/scanpy/get/get.py index cb8b262dfb..f6842d67ab 100644 --- a/scanpy/get/get.py +++ b/scanpy/get/get.py @@ -1,7 +1,7 @@ """This module contains helper functions for accessing data.""" from __future__ import annotations -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal import numpy as np import pandas as pd @@ -106,6 +106,7 @@ def rank_genes_groups_df( def _check_indices( dim_df: pd.DataFrame, alt_index: pd.Index, + *, dim: Literal["obs", "var"], keys: list[str], alias_index: pd.Index | None = None, @@ -217,8 +218,8 @@ def obs_df( keys: Iterable[str] = (), obsm_keys: Iterable[tuple[str, int]] = (), *, - layer: str = None, - gene_symbols: str = None, + layer: str | None = None, + gene_symbols: str | None = None, use_raw: bool = False, ) -> pd.DataFrame: """\ @@ -286,8 +287,8 @@ def obs_df( obs_cols, var_idx_keys, var_symbols = _check_indices( adata.obs, var.index, - "obs", - keys, + dim="obs", + keys=keys, alias_index=alias_index, use_raw=use_raw, ) @@ -335,7 +336,7 @@ def var_df( keys: Iterable[str] = (), varm_keys: Iterable[tuple[str, int]] = (), *, - layer: str = None, + layer: str | None = None, ) -> pd.DataFrame: """\ Return values for observations in adata. @@ -357,7 +358,9 @@ def var_df( and `varm_keys`. """ # Argument handling - var_cols, obs_idx_keys, _ = _check_indices(adata.var, adata.obs_names, "var", keys) + var_cols, obs_idx_keys, _ = _check_indices( + adata.var, adata.obs_names, dim="var", keys=keys + ) # initialize df df = pd.DataFrame(index=adata.var.index) @@ -395,7 +398,14 @@ def var_df( return df -def _get_obs_rep(adata, *, use_raw=False, layer=None, obsm=None, obsp=None): +def _get_obs_rep( + adata: AnnData, + *, + use_raw: bool = False, + layer: str | None = None, + obsm: str | None = None, + obsp: str | None = None, +): """ Choose array aligned with obs annotation. """ @@ -426,7 +436,15 @@ def _get_obs_rep(adata, *, use_raw=False, layer=None, obsm=None, obsp=None): ) -def _set_obs_rep(adata, val, *, use_raw=False, layer=None, obsm=None, obsp=None): +def _set_obs_rep( + adata: AnnData, + val: Any, + *, + use_raw: bool = False, + layer: str | None = None, + obsm: str | None = None, + obsp: str | None = None, +): """ Set value for observation rep. """ diff --git a/scanpy/logging.py b/scanpy/logging.py index 9f836ec977..f59d2c6831 100644 --- a/scanpy/logging.py +++ b/scanpy/logging.py @@ -16,6 +16,9 @@ from ._settings import ScanpyConfig +# This is currently the only documented API +__all__ = ["print_versions"] + HINT = (INFO + DEBUG) // 2 logging.addLevelName(HINT, "HINT") diff --git a/scanpy/metrics/_gearys_c.py b/scanpy/metrics/_gearys_c.py index 7027754f3a..4efb9a14c7 100644 --- a/scanpy/metrics/_gearys_c.py +++ b/scanpy/metrics/_gearys_c.py @@ -188,7 +188,7 @@ def _gearys_c_inner_sparse_x_densevec(g_data, g_indices, g_indptr, x, W): @numba.njit(cache=True) -def _gearys_c_inner_sparse_x_sparsevec( +def _gearys_c_inner_sparse_x_sparsevec( # noqa: PLR0917 g_data, g_indices, g_indptr, x_data, x_indices, N, W ): x = np.zeros(N, dtype=np.float_) @@ -231,7 +231,7 @@ def _gearys_c_mtx(g_data, g_indices, g_indptr, X): @numba.njit(cache=True, parallel=True) -def _gearys_c_mtx_csr( +def _gearys_c_mtx_csr( # noqa: PLR0917 g_data, g_indices, g_indptr, x_data, x_indices, x_indptr, x_shape ): M, N = x_shape diff --git a/scanpy/metrics/_morans_i.py b/scanpy/metrics/_morans_i.py index a069b44bb3..ed98e75afa 100644 --- a/scanpy/metrics/_morans_i.py +++ b/scanpy/metrics/_morans_i.py @@ -126,7 +126,7 @@ def morans_i( @njit(cache=True) -def _morans_i_vec_W_sparse( +def _morans_i_vec_W_sparse( # noqa: PLR0917 g_data: np.ndarray, g_indices: np.ndarray, g_indptr: np.ndarray, @@ -191,7 +191,7 @@ def _morans_i_mtx( @njit(cache=True, parallel=True) -def _morans_i_mtx_csr( +def _morans_i_mtx_csr( # noqa: PLR0917 g_data: np.ndarray, g_indices: np.ndarray, g_indptr: np.ndarray, diff --git a/scanpy/neighbors/__init__.py b/scanpy/neighbors/__init__.py index fbda62718e..65595f0b87 100644 --- a/scanpy/neighbors/__init__.py +++ b/scanpy/neighbors/__init__.py @@ -11,6 +11,8 @@ from scipy.sparse import csr_matrix, issparse from sklearn.utils import check_random_state +from .._compat import old_positionals + if TYPE_CHECKING: from anndata import AnnData from igraph import Graph @@ -276,6 +278,7 @@ def __init__( self, get_row: Callable[[Any], np.ndarray], shape: tuple[int, int], + *, DC_start: int = 0, DC_end: int = -1, rows: MutableMapping[Any, np.ndarray] | None = None, @@ -342,9 +345,11 @@ class Neighbors: Where to look in `.uns` and `.obsp` for neighbors data """ + @old_positionals("n_dcs", "neighbors_key") def __init__( self, adata: AnnData, + *, n_dcs: int | None = None, neighbors_key: str | None = None, ): diff --git a/scanpy/plotting/__init__.py b/scanpy/plotting/__init__.py index 0116771407..16ef150bfc 100644 --- a/scanpy/plotting/__init__.py +++ b/scanpy/plotting/__init__.py @@ -34,7 +34,12 @@ rank_genes_groups_violin, sim, ) -from ._tools.paga import paga, paga_adjacency, paga_compare, paga_path +from ._tools.paga import ( + paga, + paga_adjacency, # noqa: F401 + paga_compare, + paga_path, +) from ._tools.scatterplots import ( diffmap, draw_graph, @@ -83,7 +88,6 @@ "rank_genes_groups_violin", "sim", "paga", - "paga_adjacency", "paga_compare", "paga_path", "diffmap", diff --git a/scanpy/plotting/_anndata.py b/scanpy/plotting/_anndata.py index bdbf111d47..1ff8d94e81 100755 --- a/scanpy/plotting/_anndata.py +++ b/scanpy/plotting/_anndata.py @@ -18,6 +18,7 @@ from .. import get from .. import logging as logg +from .._compat import old_positionals from .._settings import settings from .._utils import _check_use_raw, _doc_params, sanitize_anndata from . import _utils @@ -41,6 +42,8 @@ from anndata import AnnData from cycler import Cycler from matplotlib.axes import Axes + from seaborn import FacetGrid + from seaborn.matrix import ClusterGrid VALID_LEGENDLOCS = { "none", @@ -65,26 +68,44 @@ _VarNames = Union[str, Sequence[str]] +@old_positionals( + "color", + "use_raw", + "layers", + "sort_order", + "alpha", + "basis", + "groups", + "components", + "projection", + "legend_loc", + "legend_fontsize", + "legend_fontweight", + "legend_fontoutline", + "color_map", + # 17 positionals are enough for backwards compatibility +) @_doc_params(scatter_temp=doc_scatter_basic, show_save_ax=doc_show_save_ax) def scatter( adata: AnnData, x: str | None = None, y: str | None = None, - color: str | Collection[str] = None, + *, + color: str | Collection[str] | None = None, use_raw: bool | None = None, - layers: str | Collection[str] = None, + layers: str | Collection[str] | None = None, sort_order: bool = True, alpha: float | None = None, basis: _Basis | None = None, - groups: str | Iterable[str] = None, - components: str | Collection[str] = None, + groups: str | Iterable[str] | None = None, + components: str | Collection[str] | None = None, projection: Literal["2d", "3d"] = "2d", legend_loc: str = "right margin", legend_fontsize: int | float | _FontSize | None = None, legend_fontweight: int | _FontWeight | None = None, - legend_fontoutline: float = None, - color_map: str | Colormap = None, - palette: Cycler | ListedColormap | ColorLike | Sequence[ColorLike] = None, + legend_fontoutline: float | None = None, + color_map: str | Colormap | None = None, + palette: Cycler | ListedColormap | ColorLike | Sequence[ColorLike] | None = None, frameon: bool | None = None, right_margin: float | None = None, left_margin: float | None = None, @@ -94,7 +115,7 @@ def scatter( show: bool | None = None, save: str | bool | None = None, ax: Axes | None = None, -): +) -> Axes | list[Axes] | None: """\ Scatter plot along observations or variables axes. @@ -163,6 +184,7 @@ def scatter( def _scatter_obs( + *, adata: AnnData, x=None, y=None, @@ -190,7 +212,7 @@ def _scatter_obs( show=None, save=None, ax=None, -): +) -> Axes | list[Axes] | None: """See docstring of scatter.""" sanitize_anndata(adata) @@ -398,7 +420,7 @@ def add_centroid(centroids, name, Y, mask): iname, adata, Y, - projection, + projection=projection, size=size, alpha=alpha, marker=marker, @@ -424,7 +446,7 @@ def add_centroid(centroids, name, Y, mask): iname, adata, Y, - projection, + projection=projection, size=size, alpha=alpha, marker=marker, @@ -509,23 +531,37 @@ def add_centroid(centroids, name, Y, mask): show = settings.autoshow if show is None else show _utils.savefig_or_show("scatter" if basis is None else basis, show=show, save=save) - if not show: - return axs if len(keys) > 1 else axs[0] - - + if show: + return None + if len(keys) > 1: + return axs + return axs[0] + + +@old_positionals( + "dictionary", + "indices", + "labels", + "color", + "n_points", + "log", + "include_lowest", + "show", +) def ranking( adata: AnnData, attr: Literal["var", "obs", "uns", "varm", "obsm"], keys: str | Sequence[str], - dictionary=None, - indices=None, - labels=None, - color="black", - n_points=30, - log=False, - include_lowest=False, - show=None, -): + *, + dictionary: str | None = None, + indices: Sequence[int] | None = None, + labels: str | Sequence[str] | None = None, + color: ColorLike = "black", + n_points: int = 30, + log: bool = False, + include_lowest: bool = False, + show: bool | None = None, +) -> gridspec.GridSpec | None: """\ Plot rankings. @@ -625,15 +661,34 @@ def ranking( (1.05 if score_max > 0 else 0.95) * score_max, ) show = settings.autoshow if show is None else show - if not show: - return gs - - + if show: + return None + return gs + + +@old_positionals( + "log", + "use_raw", + "stripplot", + "jitter", + "size", + "layer", + "scale", + "order", + "multi_panel", + "xlabel", + "ylabel", + "rotation", + "show", + "save", + "ax", +) @_doc_params(show_save_ax=doc_show_save_ax) def violin( adata: AnnData, keys: str | Sequence[str], groupby: str | None = None, + *, log: bool = False, use_raw: bool | None = None, stripplot: bool = True, @@ -650,7 +705,7 @@ def violin( save: bool | str | None = None, ax: Axes | None = None, **kwds, -): +) -> Axes | FacetGrid | None: """\ Violin plot. @@ -802,7 +857,7 @@ def violin( # keys if groupby is None. y = ys[0] - g = sns.catplot( + g: sns.axisgrid.FacetGrid = sns.catplot( y=y, data=obs_tidy, kind="violin", @@ -840,7 +895,7 @@ def violin( if ax is None: axs, _, _, _ = setup_axes( - ax=ax, + ax, panels=["x"] if groupby is None else keys, show_ticks=True, right_margin=0.3, @@ -881,24 +936,26 @@ def violin( ax.tick_params(axis="x", labelrotation=rotation) show = settings.autoshow if show is None else show _utils.savefig_or_show("violin", show=show, save=save) - if not show: - if multi_panel and groupby is None and len(ys) == 1: - return g - elif len(axs) == 1: - return axs[0] - else: - return axs + if show: + return None + if multi_panel and groupby is None and len(ys) == 1: + return g + if len(axs) == 1: + return axs[0] + return axs +@old_positionals("use_raw", "show", "save") @_doc_params(show_save_ax=doc_show_save_ax) def clustermap( adata: AnnData, - obs_keys: str = None, + obs_keys: str | None = None, + *, use_raw: bool | None = None, show: bool | None = None, save: bool | str | None = None, **kwds, -): +) -> ClusterGrid | None: """\ Hierarchically-clustered heatmap. @@ -919,7 +976,7 @@ def clustermap( Returns ------- - If `show` is `False`, a :class:`~seaborn.ClusterGrid` object + If `show` is `False`, a :class:`~seaborn.matrix.ClusterGrid` object (see :func:`~seaborn.clustermap`). Examples @@ -960,10 +1017,31 @@ def clustermap( _utils.savefig_or_show("clustermap", show=show, save=save) if show: plt.show() - else: - return g - - + return None + return g + + +@old_positionals( + "use_raw", + "log", + "num_categories", + "dendrogram", + "gene_symbols", + "var_group_positions", + "var_group_labels", + "var_group_rotation", + "layer", + "standard_scale", + "swap_axes", + "show_gene_labels", + "show", + "save", + "figsize", + "vmin", + "vmax", + "vcenter", + "norm", +) @_doc_params( vminmax=doc_vboundnorm, show_save_ax=doc_show_save_ax, @@ -973,6 +1051,7 @@ def heatmap( adata: AnnData, var_names: _VarNames | Mapping[str, _VarNames], groupby: str | Sequence[str], + *, use_raw: bool | None = None, log: bool = False, num_categories: int = 7, @@ -993,7 +1072,7 @@ def heatmap( vcenter: float | None = None, norm: Normalize | None = None, **kwds, -): +) -> dict[str, Axes] | None: """\ Heatmap of the expression values of genes. @@ -1021,7 +1100,7 @@ def heatmap( Returns ------- - List of :class:`~matplotlib.axes.Axes` + Dict of :class:`~matplotlib.axes.Axes` Examples ------- @@ -1048,9 +1127,9 @@ def heatmap( adata, var_names, groupby, - use_raw, - log, - num_categories, + use_raw=use_raw, + log=log, + num_categories=num_categories, gene_symbols=gene_symbols, layer=layer, ) @@ -1377,15 +1456,29 @@ def heatmap( _utils.savefig_or_show("heatmap", show=show, save=save) show = settings.autoshow if show is None else show - if not show: - return return_ax_dict - - + if show: + return None + return return_ax_dict + + +@old_positionals( + "use_raw", + "log", + "dendrogram", + "gene_symbols", + "var_group_positions", + "var_group_labels", + "layer", + "show", + "save", + "figsize", +) @_doc_params(show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args) def tracksplot( adata: AnnData, var_names: _VarNames | Mapping[str, _VarNames], groupby: str | Sequence[str], + *, use_raw: bool | None = None, log: bool = False, dendrogram: bool | str = False, @@ -1397,7 +1490,7 @@ def tracksplot( save: str | bool | None = None, figsize: tuple[float, float] | None = None, **kwds, -): +) -> dict[str, Axes] | None: """\ In this type of plot each var_name is plotted as a filled line plot where the y values correspond to the var_name values and x is each of the cells. Best results @@ -1460,9 +1553,9 @@ def tracksplot( adata, var_names, groupby, - use_raw, - log, - None, + use_raw=use_raw, + log=log, + num_categories=None, # TODO: fix this line gene_symbols=gene_symbols, layer=layer, ) @@ -1636,8 +1729,9 @@ def tracksplot( _utils.savefig_or_show("tracksplot", show=show, save=save) show = settings.autoshow if show is None else show - if not show: - return return_ax_dict + if show: + return None + return return_ax_dict @_doc_params(show_save_ax=doc_show_save_ax) @@ -1651,7 +1745,7 @@ def dendrogram( show: bool | None = None, save: str | bool | None = None, ax: Axes | None = None, -): +) -> Axes: """\ Plots a dendrogram of the categories defined in `groupby`. @@ -1705,10 +1799,23 @@ def dendrogram( return ax +@old_positionals( + "show_correlation_numbers", + "dendrogram", + "figsize", + "show", + "save", + "ax", + "vmin", + "vmax", + "vcenter", + "norm", +) @_doc_params(show_save_ax=doc_show_save_ax, vminmax=doc_vboundnorm) def correlation_matrix( adata: AnnData, groupby: str, + *, show_correlation_numbers: bool = False, dendrogram: bool | str | None = None, figsize: tuple[float, float] | None = None, @@ -1720,7 +1827,7 @@ def correlation_matrix( vcenter: float | None = None, norm: Normalize | None = None, **kwds, -) -> Axes | list[Axes]: +) -> list[Axes] | None: """\ Plots the correlation matrix computed as part of `sc.tl.dendrogram`. @@ -1749,6 +1856,7 @@ def correlation_matrix( Returns ------- + If `show=False`, returns a list of :class:`matplotlib.axes.Axes` objects. Examples -------- @@ -1869,20 +1977,22 @@ def correlation_matrix( show = settings.autoshow if show is None else show _utils.savefig_or_show("correlation_matrix", show=show, save=save) - if ax is None and not show: - return axs + if ax is not None or show: + return None + return axs def _prepare_dataframe( adata: AnnData, var_names: _VarNames | Mapping[str, _VarNames], groupby: str | Sequence[str] | None = None, + *, use_raw: bool | None = None, log: bool = False, num_categories: int = 7, - layer=None, + layer: str | None = None, gene_symbols: str | None = None, -): +) -> tuple[Sequence[str], pd.DataFrame]: """ Given the anndata object, prepares a data frame in which the row index are the categories defined by group by and the columns correspond to var_names. @@ -1997,6 +2107,7 @@ def _prepare_dataframe( def _plot_gene_groups_brackets( gene_groups_ax: Axes, + *, group_positions: Iterable[tuple[int, int]], group_labels: Sequence[str], left_adjustment: float = -0.3, @@ -2134,6 +2245,7 @@ def _reorder_categories_after_dendrogram( adata: AnnData, groupby, dendrogram, + *, var_names=None, var_group_labels=None, var_group_positions=None, @@ -2274,6 +2386,7 @@ def _plot_dendrogram( dendro_ax: Axes, adata: AnnData, groupby: str, + *, dendrogram_key: str | None = None, orientation: Literal["top", "bottom", "left", "right"] = "right", remove_labels: bool = True, diff --git a/scanpy/plotting/_baseplot_class.py b/scanpy/plotting/_baseplot_class.py index 65ebfa64b5..a3cc34e5fa 100644 --- a/scanpy/plotting/_baseplot_class.py +++ b/scanpy/plotting/_baseplot_class.py @@ -13,6 +13,7 @@ from matplotlib import pyplot as plt from .. import logging as logg +from .._compat import old_positionals from ._anndata import _get_dendrogram_key, _plot_dendrogram, _prepare_dataframe from ._utils import ColorLike, _AxesSubplot, check_colornorm, make_grid_spec @@ -72,11 +73,30 @@ class BasePlot: MAX_NUM_CATEGORIES = 500 # maximum number of categories allowed to be plotted + @old_positionals( + "use_raw", + "log", + "num_categories", + "categories_order", + "title", + "figsize", + "gene_symbols", + "var_group_positions", + "var_group_labels", + "var_group_rotation", + "layer", + "ax", + "vmin", + "vmax", + "vcenter", + "norm", + ) def __init__( self, adata: AnnData, var_names: _VarNames | Mapping[str, _VarNames], groupby: str | Sequence[str], + *, use_raw: bool | None = None, log: bool = False, num_categories: int = 7, @@ -113,9 +133,9 @@ def __init__( adata, self.var_names, groupby, - use_raw, - log, - num_categories, + use_raw=use_raw, + log=log, + num_categories=num_categories, layer=layer, gene_symbols=gene_symbols, ) @@ -365,7 +385,8 @@ def add_totals( } return self - def style(self, cmap: str | None = DEFAULT_COLORMAP): + @old_positionals("cmap") + def style(self, *, cmap: str | None = DEFAULT_COLORMAP): """\ Set visual style parameters @@ -381,8 +402,10 @@ def style(self, cmap: str | None = DEFAULT_COLORMAP): self.cmap = cmap + @old_positionals("show", "title", "width") def legend( self, + *, show: bool | None = True, title: str | None = DEFAULT_COLOR_LEGEND_TITLE, width: float | None = DEFAULT_LEGENDS_WIDTH, @@ -428,7 +451,7 @@ def legend( return self - def get_axes(self): + def get_axes(self) -> dict[str, Axes]: if self.ax_dict is None: self.make_figure() return self.ax_dict @@ -466,7 +489,7 @@ def _plot_totals( for p in total_barplot_ax.patches: p.set_x(p.get_x() + 0.5) if p.get_height() >= 1000: - display_number = f"{np.round(p.get_height()/1000, decimals=1)}k" + display_number = f"{np.round(p.get_height() / 1000, decimals=1)}k" else: display_number = np.round(p.get_height(), decimals=1) total_barplot_ax.annotate( @@ -496,7 +519,7 @@ def _plot_totals( max_x = max([p.get_width() for p in total_barplot_ax.patches]) for p in total_barplot_ax.patches: if p.get_width() >= 1000: - display_number = f"{np.round(p.get_width()/1000, decimals=1)}k" + display_number = f"{np.round(p.get_width() / 1000, decimals=1)}k" else: display_number = np.round(p.get_width(), decimals=1) total_barplot_ax.annotate( @@ -926,6 +949,7 @@ def _format_first_three_categories(_categories): @staticmethod def _plot_var_groups_brackets( gene_groups_ax: Axes, + *, group_positions: Iterable[tuple[int, int]], group_labels: Sequence[str], left_adjustment: float = -0.3, diff --git a/scanpy/plotting/_dotplot.py b/scanpy/plotting/_dotplot.py index 8d21cf384b..c074c864cd 100644 --- a/scanpy/plotting/_dotplot.py +++ b/scanpy/plotting/_dotplot.py @@ -6,6 +6,7 @@ from matplotlib import pyplot as plt from .. import logging as logg +from .._compat import old_positionals from .._settings import settings from .._utils import _doc_params from ._baseplot_class import BasePlot, _VarNames, doc_common_groupby_plot_args @@ -111,11 +112,35 @@ class DotPlot(BasePlot): DEFAULT_PLOT_X_PADDING = 0.8 # a unit is the distance between two x-axis ticks DEFAULT_PLOT_Y_PADDING = 1.0 # a unit is the distance between two y-axis ticks + @old_positionals( + "use_raw", + "log", + "num_categories", + "categories_order", + "title", + "figsize", + "gene_symbols", + "var_group_positions", + "var_group_labels", + "var_group_rotation", + "layer", + "expression_cutoff", + "mean_only_expressed", + "standard_scale", + "dot_color_df", + "dot_size_df", + "ax", + "vmin", + "vmax", + "vcenter", + "norm", + ) def __init__( self, adata: AnnData, var_names: _VarNames | Mapping[str, _VarNames], groupby: str | Sequence[str], + *, use_raw: bool | None = None, log: bool = False, num_categories: int = 7, @@ -129,7 +154,7 @@ def __init__( layer: str | None = None, expression_cutoff: float = 0.0, mean_only_expressed: bool = False, - standard_scale: Literal["var", "group"] = None, + standard_scale: Literal["var", "group"] | None = None, dot_color_df: pd.DataFrame | None = None, dot_size_df: pd.DataFrame | None = None, ax: _AxesSubplot | None = None, @@ -251,8 +276,23 @@ def __init__( self.show_size_legend = True self.show_colorbar = True + @old_positionals( + "cmap", + "color_on", + "dot_max", + "dot_min", + "smallest_dot", + "largest_dot", + "dot_edge_color", + "dot_edge_lw", + "size_exponent", + "grid", + "x_padding", + "y_padding", + ) def style( self, + *, cmap: str = DEFAULT_COLORMAP, color_on: Literal["dot", "square"] | None = DEFAULT_COLOR_ON, dot_max: float | None = DEFAULT_DOT_MAX, @@ -367,8 +407,17 @@ def style( return self + @old_positionals( + "show", + "show_size_legend", + "show_colorbar", + "size_title", + "colorbar_title", + "width", + ) def legend( self, + *, show: bool | None = True, show_size_legend: bool | None = True, show_colorbar: bool | None = True, @@ -568,6 +617,7 @@ def _dotplot( dot_size, dot_color, dot_ax, + *, cmap: str = "Reds", color_on: str | None = "dot", y_label: str | None = None, @@ -798,6 +848,22 @@ def _dotplot( return normalize, dot_min, dot_max +@old_positionals( + "use_raw", + "log", + "num_categories", + "expression_cutoff", + "mean_only_expressed", + "cmap", + "dot_max", + "dot_min", + "standard_scale", + "smallest_dot", + "title", + "colorbar_title", + "size_title", + # No need to have backwards compat for > 16 positional parameters +) @_doc_params( show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args, @@ -808,6 +874,7 @@ def dotplot( adata: AnnData, var_names: _VarNames | Mapping[str, _VarNames], groupby: str | Sequence[str], + *, use_raw: bool | None = None, log: bool = False, num_categories: int = 7, diff --git a/scanpy/plotting/_matrixplot.py b/scanpy/plotting/_matrixplot.py index fe1b7844da..c25960e4fd 100644 --- a/scanpy/plotting/_matrixplot.py +++ b/scanpy/plotting/_matrixplot.py @@ -7,6 +7,7 @@ from matplotlib import rcParams from .. import logging as logg +from .._compat import old_positionals from .._settings import settings from .._utils import _doc_params from ._baseplot_class import BasePlot, _VarNames, doc_common_groupby_plot_args @@ -25,6 +26,7 @@ import pandas as pd from anndata import AnnData + from matplotlib.axes import Axes from matplotlib.colors import Normalize @@ -94,11 +96,32 @@ class MatrixPlot(BasePlot): DEFAULT_EDGE_COLOR = "gray" DEFAULT_EDGE_LW = 0.1 + @old_positionals( + "use_raw", + "log", + "num_categories", + "categories_order", + "title", + "figsize", + "gene_symbols", + "var_group_positions", + "var_group_labels", + "var_group_rotation", + "layer", + "standard_scale", + "ax", + "values_df", + "vmin", + "vmax", + "vcenter", + "norm", + ) def __init__( self, adata: AnnData, var_names: _VarNames | Mapping[str, _VarNames], groupby: str | Sequence[str], + *, use_raw: bool | None = None, log: bool = False, num_categories: int = 7, @@ -280,6 +303,23 @@ def _mainplot(self, ax): return normalize +@old_positionals( + "use_raw", + "log", + "num_categories", + "figsize", + "dendrogram", + "title", + "cmap", + "colorbar_title", + "gene_symbols", + "var_group_positions", + "var_group_labels", + "var_group_rotation", + "layer", + "standard_scale", + # 17 positionals are enough for backwards compatibility +) @_doc_params( show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args, @@ -290,6 +330,7 @@ def matrixplot( adata: AnnData, var_names: _VarNames | Mapping[str, _VarNames], groupby: str | Sequence[str], + *, use_raw: bool | None = None, log: bool = False, num_categories: int = 7, @@ -303,7 +344,7 @@ def matrixplot( var_group_labels: Sequence[str] | None = None, var_group_rotation: float | None = None, layer: str | None = None, - standard_scale: Literal["var", "group"] = None, + standard_scale: Literal["var", "group"] | None = None, values_df: pd.DataFrame | None = None, swap_axes: bool = False, show: bool | None = None, @@ -315,7 +356,7 @@ def matrixplot( vcenter: float | None = None, norm: Normalize | None = None, **kwds, -) -> MatrixPlot | dict | None: +) -> MatrixPlot | dict[str, Axes] | None: """\ Creates a heatmap of the mean expression values per group of each var_names. @@ -411,9 +452,9 @@ def matrixplot( mp = mp.style(cmap=cmap).legend(title=colorbar_title) if return_fig: return mp - else: - mp.make_figure() - savefig_or_show(MatrixPlot.DEFAULT_SAVE_PREFIX, show=show, save=save) - show = settings.autoshow if show is None else show - if not show: - return mp.get_axes() + mp.make_figure() + savefig_or_show(MatrixPlot.DEFAULT_SAVE_PREFIX, show=show, save=save) + show = settings.autoshow if show is None else show + if show: + return None + return mp.get_axes() diff --git a/scanpy/plotting/_preprocessing.py b/scanpy/plotting/_preprocessing.py index f470d73187..9dce5dca3f 100644 --- a/scanpy/plotting/_preprocessing.py +++ b/scanpy/plotting/_preprocessing.py @@ -6,6 +6,8 @@ from matplotlib import pyplot as plt from matplotlib import rcParams +from .._compat import old_positionals +from .._settings import settings from . import _utils # -------------------------------------------------------------------------------- @@ -13,13 +15,15 @@ # -------------------------------------------------------------------------------- +@old_positionals("log", "show", "save", "highly_variable_genes") def highly_variable_genes( adata_or_result: AnnData | pd.DataFrame | np.recarray, + *, log: bool = False, show: bool | None = None, save: bool | str | None = None, highly_variable_genes: bool = True, -): +) -> None: """Plot dispersions or normalized variance versus means for genes. Produces Supp. Fig. 5c of Zheng et al. (2017) and MeanVarPlot() and @@ -91,9 +95,11 @@ def highly_variable_genes( + (" (normalized)" if idx == 0 else " (not normalized)") ) + show = settings.autoshow if show is None else show _utils.savefig_or_show("filter_genes_dispersion", show=show, save=save) - if show is False: - return plt.gca() + if show: + return None + return plt.gca() # backwards compat @@ -102,7 +108,7 @@ def filter_genes_dispersion( log: bool = False, show: bool | None = None, save: bool | str | None = None, -): +) -> None: """\ Plot dispersions versus means for genes. diff --git a/scanpy/plotting/_qc.py b/scanpy/plotting/_qc.py index 35f3b63c53..2aec45cbf9 100644 --- a/scanpy/plotting/_qc.py +++ b/scanpy/plotting/_qc.py @@ -6,6 +6,8 @@ import pandas as pd from matplotlib import pyplot as plt +from .._compat import old_positionals +from .._settings import settings from .._utils import _doc_params from ..preprocessing._normalization import normalize_total from . import _utils @@ -16,10 +18,12 @@ from matplotlib.axes import Axes +@old_positionals("show", "save", "ax", "gene_symbols", "log") @_doc_params(show_save_ax=doc_show_save_ax) def highest_expr_genes( adata: AnnData, n_top: int = 30, + *, show: bool | None = None, save: str | bool | None = None, ax: Axes | None = None, @@ -97,6 +101,8 @@ def highest_expr_genes( ax.set_xlabel("% of total counts") if log: ax.set_xscale("log") + show = settings.autoshow if show is None else show _utils.savefig_or_show("highest_expr_genes", show=show, save=save) - if show is False: - return ax + if show: + return None + return ax diff --git a/scanpy/plotting/_stacked_violin.py b/scanpy/plotting/_stacked_violin.py index e058358400..6e264ef652 100644 --- a/scanpy/plotting/_stacked_violin.py +++ b/scanpy/plotting/_stacked_violin.py @@ -8,6 +8,7 @@ from matplotlib.colors import Normalize, is_color_like from .. import logging as logg +from .._compat import old_positionals from .._settings import settings from .._utils import _doc_params from ._baseplot_class import BasePlot, _VarNames, doc_common_groupby_plot_args @@ -130,11 +131,31 @@ class StackedViolin(BasePlot): # None will draw unadorned violins. DEFAULT_INNER = None + @old_positionals( + "use_raw", + "log", + "num_categories", + "categories_order", + "title", + "figsize", + "gene_symbols", + "var_group_positions", + "var_group_labels", + "var_group_rotation", + "layer", + "standard_scale", + "ax", + "vmin", + "vmax", + "vcenter", + "norm", + ) def __init__( self, adata: AnnData, var_names: _VarNames | Mapping[str, _VarNames], groupby: str | Sequence[str], + *, use_raw: bool | None = None, log: bool = False, num_categories: int = 7, @@ -146,7 +167,7 @@ def __init__( var_group_labels: Sequence[str] | None = None, var_group_rotation: float | None = None, layer: str | None = None, - standard_scale: Literal["var", "group"] = None, + standard_scale: Literal["var", "group"] | None = None, ax: _AxesSubplot | None = None, vmin: float | None = None, vmax: float | None = None, @@ -205,8 +226,22 @@ def __init__( self.kwds.setdefault("linewidth", self.DEFAULT_LINE_WIDTH) self.kwds.setdefault("scale", self.DEFAULT_SCALE) + @old_positionals( + "cmap", + "stripplot", + "jitter", + "jitter_size", + "linewidth", + "row_palette", + "scale", + "yticklabels", + "ylim", + "x_padding", + "y_padding", + ) def style( self, + *, cmap: str | None = DEFAULT_COLORMAP, stripplot: bool | None = DEFAULT_STRIPPLOT, jitter: float | bool | None = DEFAULT_JITTER, @@ -551,6 +586,23 @@ def _setup_violin_axes_ticks(self, row_ax, num_cols): ) +@old_positionals( + "log", + "use_raw", + "num_categories", + "title", + "colorbar_title", + "figsize", + "dendrogram", + "gene_symbols", + "var_group_positions", + "var_group_labels", + "standard_scale", + "var_group_rotation", + "layer", + "stripplot", + # 17 positionals are enough for backwards compatibility +) @_doc_params( show_save_ax=doc_show_save_ax, common_plot_args=doc_common_plot_args, @@ -561,6 +613,7 @@ def stacked_violin( adata: AnnData, var_names: _VarNames | Mapping[str, _VarNames], groupby: str | Sequence[str], + *, log: bool = False, use_raw: bool | None = None, num_categories: int = 7, @@ -732,9 +785,9 @@ def stacked_violin( ).legend(title=colorbar_title) if return_fig: return vp - else: - vp.make_figure() - savefig_or_show(StackedViolin.DEFAULT_SAVE_PREFIX, show=show, save=save) - show = settings.autoshow if show is None else show - if not show: - return vp.get_axes() + vp.make_figure() + savefig_or_show(StackedViolin.DEFAULT_SAVE_PREFIX, show=show, save=save) + show = settings.autoshow if show is None else show + if show: + return None + return vp.get_axes() diff --git a/scanpy/plotting/_tools/__init__.py b/scanpy/plotting/_tools/__init__.py index 0d22821020..d55c1fed22 100644 --- a/scanpy/plotting/_tools/__init__.py +++ b/scanpy/plotting/_tools/__init__.py @@ -13,6 +13,7 @@ from scanpy.get import obs_df from ... import logging as logg +from ..._compat import old_positionals from ..._settings import settings from ..._utils import _doc_params, sanitize_anndata, subsample from ...get import rank_genes_groups_df @@ -96,9 +97,11 @@ def pca_overview(adata: AnnData, **params): pca_scatter = pca +@old_positionals("include_lowest", "n_points", "show", "save") def pca_loadings( adata: AnnData, components: str | Sequence[int] | None = None, + *, include_lowest: bool = True, n_points: int | None = None, show: bool | None = None, @@ -169,9 +172,11 @@ def pca_loadings( savefig_or_show("pca_loadings", show=show, save=save) +@old_positionals("log", "show", "save") def pca_variance_ratio( adata: AnnData, n_pcs: int = 30, + *, log: bool = False, show: bool | None = None, save: bool | str | None = None, @@ -210,9 +215,11 @@ def pca_variance_ratio( # ------------------------------------------------------------------------------ +@old_positionals("color_map", "show", "save", "as_heatmap", "marker") def dpt_timeseries( adata: AnnData, - color_map: str | Colormap = None, + *, + color_map: str | Colormap | None = None, show: bool | None = None, save: bool | None = None, as_heatmap: bool = True, @@ -253,8 +260,10 @@ def dpt_timeseries( savefig_or_show("dpt_timeseries", save=save, show=show) +@old_positionals("color_map", "palette", "show", "save", "marker") def dpt_groups_pseudotime( adata: AnnData, + *, color_map: str | Colormap | None = None, palette: Sequence[str] | Cycler | None = None, show: bool | None = None, @@ -293,10 +302,22 @@ def dpt_groups_pseudotime( savefig_or_show("dpt_groups_pseudotime", save=save, show=show) +@old_positionals( + "n_genes", + "gene_symbols", + "key", + "fontsize", + "ncols", + "sharey", + "show", + "save", + "ax", +) @_doc_params(show_save_ax=doc_show_save_ax) def rank_genes_groups( adata: AnnData, groups: str | Sequence[str] | None = None, + *, n_genes: int = 20, gene_symbols: str | None = None, key: str | None = "rank_genes_groups", @@ -452,17 +473,18 @@ def _fig_show_save_or_axes(plot_obj, return_fig, show, save): """ if return_fig: return plot_obj - else: - plot_obj.make_figure() - savefig_or_show(plot_obj.DEFAULT_SAVE_PREFIX, show=show, save=save) - show = settings.autoshow if show is None else show - if show is False: - return plot_obj.get_axes() + plot_obj.make_figure() + savefig_or_show(plot_obj.DEFAULT_SAVE_PREFIX, show=show, save=save) + show = settings.autoshow if show is None else show + if show: + return None + return plot_obj.get_axes() def _rank_genes_groups_plot( adata: AnnData, plot_type: str = "heatmap", + *, groups: str | Sequence[str] | None = None, n_genes: int | None = None, groupby: str | None = None, @@ -626,10 +648,21 @@ def _rank_genes_groups_plot( ) +@old_positionals( + "n_genes", + "groupby", + "gene_symbols", + "var_names", + "min_logfoldchange", + "key", + "show", + "save", +) @_doc_params(params=doc_rank_genes_groups_plot_args, show_save_ax=doc_show_save_ax) def rank_genes_groups_heatmap( adata: AnnData, groups: str | Sequence[str] | None = None, + *, n_genes: int | None = None, groupby: str | None = None, gene_symbols: str | None = None, @@ -699,10 +732,21 @@ def rank_genes_groups_heatmap( ) +@old_positionals( + "n_genes", + "groupby", + "var_names", + "gene_symbols", + "min_logfoldchange", + "key", + "show", + "save", +) @_doc_params(params=doc_rank_genes_groups_plot_args, show_save_ax=doc_show_save_ax) def rank_genes_groups_tracksplot( adata: AnnData, groups: str | Sequence[str] | None = None, + *, n_genes: int | None = None, groupby: str | None = None, var_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, @@ -752,6 +796,18 @@ def rank_genes_groups_tracksplot( ) +@old_positionals( + "n_genes", + "groupby", + "values_to_plot", + "var_names", + "gene_symbols", + "min_logfoldchange", + "key", + "show", + "save", + "return_fig", +) @_doc_params( params=doc_rank_genes_groups_plot_args, vals_to_plot=doc_rank_genes_groups_values_to_plot, @@ -760,6 +816,7 @@ def rank_genes_groups_tracksplot( def rank_genes_groups_dotplot( adata: AnnData, groups: str | Sequence[str] | None = None, + *, n_genes: int | None = None, groupby: str | None = None, values_to_plot: Literal[ @@ -902,14 +959,15 @@ def rank_genes_groups_dotplot( ) +@old_positionals("n_genes", "groupby", "gene_symbols") @_doc_params(params=doc_rank_genes_groups_plot_args, show_save_ax=doc_show_save_ax) def rank_genes_groups_stacked_violin( adata: AnnData, groups: str | Sequence[str] | None = None, + *, n_genes: int | None = None, groupby: str | None = None, gene_symbols: str | None = None, - *, var_names: Sequence[str] | Mapping[str, Sequence[str]] | None = None, min_logfoldchange: float | None = None, key: str | None = None, @@ -965,6 +1023,18 @@ def rank_genes_groups_stacked_violin( ) +@old_positionals( + "n_genes", + "groupby", + "values_to_plot", + "var_names", + "gene_symbols", + "min_logfoldchange", + "key", + "show", + "save", + "return_fig", +) @_doc_params( params=doc_rank_genes_groups_plot_args, vals_to_plot=doc_rank_genes_groups_values_to_plot, @@ -973,6 +1043,7 @@ def rank_genes_groups_stacked_violin( def rank_genes_groups_matrixplot( adata: AnnData, groups: str | Sequence[str] | None = None, + *, n_genes: int | None = None, groupby: str | None = None, values_to_plot: Literal[ @@ -1098,10 +1169,26 @@ def rank_genes_groups_matrixplot( ) +@old_positionals( + "n_genes", + "gene_names", + "gene_symbols", + "use_raw", + "key", + "split", + "scale", + "strip", + "jitter", + "size", + "ax", + "show", + "save", +) @_doc_params(show_save_ax=doc_show_save_ax) def rank_genes_groups_violin( adata: AnnData, groups: Sequence[str] | None = None, + *, n_genes: int = 20, gene_names: Iterable[str] | None = None, gene_symbols: str | None = None, @@ -1216,19 +1303,23 @@ def rank_genes_groups_violin( ) savefig_or_show(writekey, show=show, save=save) axs.append(_ax) - if show is False: - return axs + show = settings.autoshow if show is None else show + if show: + return None + return axs +@old_positionals("tmax_realization", "as_heatmap", "shuffle", "show", "save", "marker") def sim( - adata, + adata: AnnData, + *, tmax_realization: int | None = None, as_heatmap: bool = False, shuffle: bool = False, show: bool | None = None, save: bool | str | None = None, marker: str | Sequence[str] = ".", -): +) -> None: """\ Plot results of simulation. @@ -1292,14 +1383,34 @@ def sim( savefig_or_show("sim_shuffled", save=save, show=show) +@old_positionals( + "key", + "groupby", + "group", + "color_map", + "bg_dotsize", + "fg_dotsize", + "vmax", + "vmin", + "vcenter", + "norm", + "ncols", + "hspace", + "wspace", + "title", + "show", + "save", + "ax", + "return_fig", +) @_doc_params( vminmax=doc_vbound_percentile, panels=doc_panels, show_save_ax=doc_show_save_ax ) def embedding_density( adata: AnnData, - # on purpose, there is no asterisk here (for backward compat) - basis: str = "umap", # was positional before 1.4.5 - key: str | None = None, # was positional before 1.4.5 + basis: str = "umap", + *, + key: str | None = None, groupby: str | None = None, group: str | Sequence[str] | None | None = "all", color_map: Colormap | str = "YlOrRd", @@ -1567,8 +1678,10 @@ def embedding_density( if return_fig: return fig savefig_or_show(f"{key}_", show=show, save=save) - if show is False: - return ax + show = settings.autoshow if show is None else show + if show: + return None + return ax def _get_values_to_plot( @@ -1582,6 +1695,7 @@ def _get_values_to_plot( "log10_pvals_adj", ], gene_names: Sequence[str], + *, groups: Sequence[str] | None = None, key: str | None = "rank_genes_groups", gene_symbols: str | None = None, diff --git a/scanpy/plotting/_tools/paga.py b/scanpy/plotting/_tools/paga.py index 5fab4eae02..ec8656026d 100644 --- a/scanpy/plotting/_tools/paga.py +++ b/scanpy/plotting/_tools/paga.py @@ -18,6 +18,7 @@ from ... import _utils as _sc_utils from ... import logging as logg +from ..._compat import old_positionals from ..._settings import settings from .. import _utils from .._utils import _FontSize, _FontWeight, _IGraphLayout, matrix @@ -29,9 +30,33 @@ from matplotlib.axes import Axes +@old_positionals( + "edges", + "color", + "alpha", + "groups", + "components", + "projection", + "legend_loc", + "legend_fontsize", + "legend_fontweight", + "legend_fontoutline", + "color_map", + "palette", + "frameon", + "size", + "title", + "right_margin", + "left_margin", + "show", + "save", + "title_graph", + "groups_graph", +) def paga_compare( adata: AnnData, basis=None, + *, edges=False, color=None, alpha=None, @@ -53,7 +78,6 @@ def paga_compare( save=None, title_graph=None, groups_graph=None, - *, pos=None, **paga_graph_params, ): @@ -81,10 +105,7 @@ def paga_compare( ------- A list of :class:`~matplotlib.axes.Axes` if `show` is `False`. """ - axs, _, _, _ = _utils.setup_axes( - panels=[0, 1], - right_margin=right_margin, - ) + axs, _, _, _ = _utils.setup_axes(panels=[0, 1], right_margin=right_margin) if color is None: color = adata.uns["paga"]["groups"] suptitle = None # common title for entire figure @@ -171,12 +192,14 @@ def paga_compare( if suptitle is not None: plt.suptitle(suptitle) _utils.savefig_or_show("paga_compare", show=show, save=save) - if show is False: - return axs + if show: + return None + return axs def _compute_pos( adjacency_solid, + *, layout=None, random_state=0, init_pos=None, @@ -282,8 +305,28 @@ def _compute_pos( return pos_array +@old_positionals( + "threshold", + "color", + "layout", + "layout_kwds", + "init_pos", + "root", + "labels", + "single_component", + "solid_edges", + "dashed_edges", + "transitions", + "fontsize", + "fontweight", + "fontoutline", + "text_kwds", + "node_size_scale", + # 17 positionals are enough for backwards compat +) def paga( adata: AnnData, + *, threshold: float | None = None, color: str | Mapping[str | int, Mapping[Any, float]] | None = None, layout: _IGraphLayout | None = None, @@ -308,9 +351,9 @@ def paga( title: str | None = None, left_margin: float = 0.01, random_state: int | None = 0, - pos: np.ndarray | str | Path | None = None, + pos: np.ndarray | Path | str | None = None, normalize_to_color: bool = False, - cmap: str | Colormap = None, + cmap: str | Colormap | None = None, cax: Axes | None = None, colorbar=None, # TODO: this seems to be unused cb_kwds: Mapping[str, Any] = MappingProxyType({}), @@ -570,9 +613,7 @@ def is_flat(x): if plot: axs, panel_pos, draw_region_width, figure_width = _utils.setup_axes( - ax=ax, - panels=colors, - colorbars=colorbars, + ax, panels=colors, colorbars=colorbars ) if len(colors) == 1 and not isinstance(axs, list): @@ -634,17 +675,22 @@ def is_flat(x): if add_pos: adata.uns["paga"]["pos"] = pos logg.hint("added 'pos', the PAGA positions (adata.uns['paga'])") - if plot: - _utils.savefig_or_show("paga", show=show, save=save) - if len(colors) == 1 and isinstance(axs, list): - axs = axs[0] - if show is False: - return axs + + if not plot: + return None + _utils.savefig_or_show("paga", show=show, save=save) + if len(colors) == 1 and isinstance(axs, list): + axs = axs[0] + show = settings.autoshow if show is None else show + if show: + return None + return axs def _paga_graph( adata, ax, + *, solid_edges=None, dashed_edges=None, adjacency_solid=None, @@ -988,10 +1034,36 @@ def _paga_graph( return sct +@old_positionals( + "use_raw", + "annotations", + "color_map", + "color_maps_annotations", + "palette_groups", + "n_avg", + "groups_key", + "xlim", + "title", + "left_margin", + "ytick_fontsize", + "title_fontsize", + "show_node_names", + "show_yticks", + "show_colorbar", + "legend_fontsize", + "legend_fontweight", + "normalize_to_zero_one", + "as_heatmap", + "return_data", + "show", + "save", + "ax", +) def paga_path( adata: AnnData, nodes: Sequence[str | int], keys: Sequence[str], + *, use_raw: bool = True, annotations: Sequence[str] = ("dpt_pseudotime",), color_map: str | Colormap | None = None, @@ -1017,7 +1089,7 @@ def paga_path( show: bool | None = None, save: bool | str | None = None, ax: Axes | None = None, -) -> Axes | None: +) -> tuple[Axes, pd.DataFrame] | Axes | pd.DataFrame | None: """\ Gene expression and annotation changes along paths in the abstracted graph. @@ -1301,20 +1373,21 @@ def moving_average(a): df["groups"] = moving_average(groups) # groups is without moving average, yet if "dpt_pseudotime" in anno_dict: df["distance"] = anno_dict["dpt_pseudotime"].T - return ax, df if ax_was_none and not show else df - else: - return ax if ax_was_none and not show else None + if not ax_was_none or show: + return df if return_data else None + return (ax, df) if return_data else ax def paga_adjacency( - adata, - adjacency="connectivities", - adjacency_tree="connectivities_tree", - as_heatmap=True, - color_map=None, - show=None, - save=None, -): + adata: AnnData, + *, + adjacency: str = "connectivities", + adjacency_tree: str = "connectivities_tree", + as_heatmap: bool = True, + color_map: str | Colormap | None = None, + show: bool | None = None, + save: bool | str | None = None, +) -> None: """Connectivity of paga groups.""" connectivity = adata.uns[adjacency].toarray() connectivity_select = adata.uns[adjacency_tree] diff --git a/scanpy/plotting/_tools/scatterplots.py b/scanpy/plotting/_tools/scatterplots.py index b74524d259..694a5a9981 100644 --- a/scanpy/plotting/_tools/scatterplots.py +++ b/scanpy/plotting/_tools/scatterplots.py @@ -105,7 +105,7 @@ def embedding( return_fig: bool | None = None, marker: str | Sequence[str] = ".", **kwargs, -) -> Figure | Axes | None: +) -> Figure | Axes | list[Axes] | None: """\ Scatter plot for user specified embedding basis (e.g. umap, pca, etc) @@ -319,7 +319,7 @@ def embedding( if not categorical: vmin_float, vmax_float, vcenter_float, norm_obj = _get_vboundnorm( - vmin, vmax, vcenter, norm, count, color_vector + vmin, vmax, vcenter, norm=norm, index=count, colors=color_vector ) normalize = check_colornorm( vmin_float, @@ -427,7 +427,9 @@ def embedding( ax.autoscale_view() if edges: - _utils.plot_edges(ax, adata, basis, edges_width, edges_color, neighbors_key) + _utils.plot_edges( + ax, adata, basis, edges_width, edges_color, neighbors_key=neighbors_key + ) if arrows: _utils.plot_arrows(ax, adata, basis, arrows_kwds) @@ -467,8 +469,10 @@ def embedding( return fig axs = axs if grid else ax _utils.savefig_or_show(basis, show=show, save=save) - if show is False: - return axs + show = settings.autoshow if show is None else show + if show: + return None + return axs def _panel_grid(hspace, wspace, ncols, num_panels): @@ -502,9 +506,10 @@ def _get_vboundnorm( vmin: Sequence[VBound], vmax: Sequence[VBound], vcenter: Sequence[VBound], + *, norm: Sequence[Normalize], index: int, - color_vector: Sequence[float], + colors: Sequence[float], ) -> tuple[float | None, float | None]: """ Evaluates the value of vmin, vmax and vcenter, which could be a @@ -514,17 +519,17 @@ def _get_vboundnorm( Floats are accepted as p99.9 Alternatively, vmin/vmax could be a function that is applied to - the list of color values (`color_vector`). E.g. + the list of color values (`colors`). E.g. - def my_vmax(color_vector): np.percentile(color_vector, p=80) + def my_vmax(colors): np.percentile(colors, p=80) Parameters ---------- index This index of the plot - color_vector - List or values for the plot + colors + Values for the plot Returns ------- @@ -561,10 +566,10 @@ def my_vmax(color_vector): np.percentile(color_vector, p=80) f"Please check the correct format for percentiles." ) # interpret value of vmin/vmax as quantile with the following syntax 'p99.9' - v_value = np.nanpercentile(color_vector, q=float(v_value[1:])) + v_value = np.nanpercentile(colors, q=float(v_value[1:])) elif callable(v_value): # interpret vmin/vmax as function - v_value = v_value(color_vector) + v_value = v_value(colors) if not isinstance(v_value, float): logg.error( f"The return of the function given for {v_name} is not valid. " @@ -627,7 +632,7 @@ def _wraps_plot_scatter(wrapper): scatter_bulk=doc_scatter_embedding, show_save_ax=doc_show_save_ax, ) -def umap(adata, **kwargs) -> Axes | list[Axes] | None: +def umap(adata: AnnData, **kwargs) -> Figure | Axes | list[Axes] | None: """\ Scatter plot in UMAP basis. @@ -689,7 +694,7 @@ def umap(adata, **kwargs) -> Axes | list[Axes] | None: scatter_bulk=doc_scatter_embedding, show_save_ax=doc_show_save_ax, ) -def tsne(adata, **kwargs) -> Axes | list[Axes] | None: +def tsne(adata: AnnData, **kwargs) -> Figure | Axes | list[Axes] | None: """\ Scatter plot in tSNE basis. @@ -729,7 +734,7 @@ def tsne(adata, **kwargs) -> Axes | list[Axes] | None: scatter_bulk=doc_scatter_embedding, show_save_ax=doc_show_save_ax, ) -def diffmap(adata, **kwargs) -> Axes | list[Axes] | None: +def diffmap(adata: AnnData, **kwargs) -> Figure | Axes | list[Axes] | None: """\ Scatter plot in Diffusion Map basis. @@ -771,7 +776,7 @@ def diffmap(adata, **kwargs) -> Axes | list[Axes] | None: ) def draw_graph( adata: AnnData, *, layout: _IGraphLayout | None = None, **kwargs -) -> Axes | list[Axes] | None: +) -> Figure | Axes | list[Axes] | None: """\ Scatter plot in graph-drawing basis. @@ -823,14 +828,14 @@ def draw_graph( show_save_ax=doc_show_save_ax, ) def pca( - adata, + adata: AnnData, *, annotate_var_explained: bool = False, show: bool | None = None, return_fig: bool | None = None, save: bool | str | None = None, **kwargs, -) -> Axes | list[Axes] | None: +) -> Figure | Axes | list[Axes] | None: """\ Scatter plot in PCA coordinates. @@ -881,41 +886,41 @@ def pca( return embedding( adata, "pca", show=show, return_fig=return_fig, save=save, **kwargs ) - else: - if "pca" not in adata.obsm.keys() and "X_pca" not in adata.obsm.keys(): - raise KeyError( - f"Could not find entry in `obsm` for 'pca'.\n" - f"Available keys are: {list(adata.obsm.keys())}." - ) + if "pca" not in adata.obsm.keys() and "X_pca" not in adata.obsm.keys(): + raise KeyError( + f"Could not find entry in `obsm` for 'pca'.\n" + f"Available keys are: {list(adata.obsm.keys())}." + ) - label_dict = { - f"PC{i + 1}": f"PC{i + 1} ({round(v * 100, 2)}%)" - for i, v in enumerate(adata.uns["pca"]["variance_ratio"]) - } - - if return_fig is True: - # edit axis labels in returned figure - fig = embedding(adata, "pca", return_fig=return_fig, **kwargs) - for ax in fig.axes: - if xlabel := label_dict.get(ax.xaxis.get_label().get_text()): - ax.set_xlabel(xlabel) - if ylabel := label_dict.get(ax.yaxis.get_label().get_text()): - ax.set_ylabel(ylabel) - return fig + label_dict = { + f"PC{i + 1}": f"PC{i + 1} ({round(v * 100, 2)}%)" + for i, v in enumerate(adata.uns["pca"]["variance_ratio"]) + } - else: - # get the axs, edit the labels and apply show and save from user - axs = embedding(adata, "pca", show=False, save=False, **kwargs) - if isinstance(axs, list): - for ax in axs: - ax.set_xlabel(label_dict[ax.xaxis.get_label().get_text()]) - ax.set_ylabel(label_dict[ax.yaxis.get_label().get_text()]) - else: - axs.set_xlabel(label_dict[axs.xaxis.get_label().get_text()]) - axs.set_ylabel(label_dict[axs.yaxis.get_label().get_text()]) - _utils.savefig_or_show("pca", show=show, save=save) - if show is False: - return axs + if return_fig is True: + # edit axis labels in returned figure + fig = embedding(adata, "pca", return_fig=return_fig, **kwargs) + for ax in fig.axes: + if xlabel := label_dict.get(ax.xaxis.get_label().get_text()): + ax.set_xlabel(xlabel) + if ylabel := label_dict.get(ax.yaxis.get_label().get_text()): + ax.set_ylabel(ylabel) + return fig + + # get the axs, edit the labels and apply show and save from user + axs = embedding(adata, "pca", show=False, save=False, **kwargs) + if isinstance(axs, list): + for ax in axs: + ax.set_xlabel(label_dict[ax.xaxis.get_label().get_text()]) + ax.set_ylabel(label_dict[ax.yaxis.get_label().get_text()]) + else: + axs.set_xlabel(label_dict[axs.xaxis.get_label().get_text()]) + axs.set_ylabel(label_dict[axs.yaxis.get_label().get_text()]) + _utils.savefig_or_show("pca", show=show, save=save) + show = settings.autoshow if show is None else show + if show: + return None + return axs @_wraps_plot_scatter @@ -926,7 +931,7 @@ def pca( show_save_ax=doc_show_save_ax, ) def spatial( - adata, + adata: AnnData, *, basis: str = "spatial", img: np.ndarray | None = None, @@ -943,7 +948,7 @@ def spatial( return_fig: bool | None = None, save: bool | str | None = None, **kwargs, -) -> Axes | list[Axes] | None: +) -> Figure | Axes | list[Axes] | None: """\ Scatter plot in spatial coordinates. @@ -1038,8 +1043,12 @@ def spatial( ax.set_xlim(cur_coords[0], cur_coords[1]) ax.set_ylim(cur_coords[3], cur_coords[2]) _utils.savefig_or_show("show", show=show, save=save) - if show is False or return_fig is True: - return axs + if return_fig: + return axs[0].figure + show = settings.autoshow if show is None else show + if show: + return None + return axs # Helpers @@ -1081,6 +1090,7 @@ def _components_to_dimensions( def _add_categorical_legend( ax, color_source_vector, + *, palette: dict, legend_loc: str, legend_fontweight, diff --git a/scanpy/plotting/_utils.py b/scanpy/plotting/_utils.py index 3661f36159..7fe9307918 100644 --- a/scanpy/plotting/_utils.py +++ b/scanpy/plotting/_utils.py @@ -13,18 +13,21 @@ from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.collections import PatchCollection -from matplotlib.colors import is_color_like +from matplotlib.colors import Colormap, is_color_like from matplotlib.figure import Figure from matplotlib.figure import SubplotParams as sppars from matplotlib.patches import Circle from .. import logging as logg +from .._compat import old_positionals from .._settings import settings from .._utils import NeighborsView from . import palettes if TYPE_CHECKING: - import anndata + from anndata import AnnData + from numpy.typing import ArrayLike + from PIL.Image import Image ColorLike = _U[str, tuple[float, ...]] _IGraphLayout = Literal["fa", "fr", "rt", "rt_circular", "drl", "eq_tree", ...] @@ -44,19 +47,32 @@ class _AxesSubplot(Axes, axes.SubplotBase): # ------------------------------------------------------------------------------- +@old_positionals( + "xlabel", + "ylabel", + "xticks", + "yticks", + "title", + "colorbar_shrink", + "color_map", + "show", + "save", + "ax", +) def matrix( - matrix, - xlabel=None, - ylabel=None, - xticks=None, - yticks=None, - title=None, - colorbar_shrink=0.5, - color_map=None, - show=None, - save=None, - ax=None, -): + matrix: ArrayLike | Image, + *, + xlabel: str | None = None, + ylabel: str | None = None, + xticks: Collection[str] | None = None, + yticks: Collection[str] | None = None, + title: str | None = None, + colorbar_shrink: float = 0.5, + color_map: str | Colormap | None = None, + show: bool | None = None, + save: bool | str | None = None, + ax: Axes | None = None, +) -> None: """Plot a matrix.""" if ax is None: ax = plt.gca() @@ -88,6 +104,7 @@ def timeseries(X, **kwargs): def timeseries_subplot( X: np.ndarray, + *, time=None, color=None, var_names=(), @@ -165,7 +182,7 @@ def timeseries_subplot( def timeseries_as_heatmap( - X: np.ndarray, var_names: Collection[str] = (), highlights_x=(), color_map=None + X: np.ndarray, *, var_names: Collection[str] = (), highlights_x=(), color_map=None ): """\ Plot timeseries as heatmap. @@ -301,7 +318,7 @@ def savefig_or_show( writekey: str, show: bool | None = None, dpi: int | None = None, - ext: str = None, + ext: str | None = None, save: bool | str | None = None, ): if isinstance(save, str): @@ -336,7 +353,7 @@ def default_palette( return palette -def _validate_palette(adata: anndata.AnnData, key: str) -> None: +def _validate_palette(adata: AnnData, key: str) -> None: """ checks if the list of colors in adata.uns[f'{key}_colors'] is valid and updates the color list in adata.uns[f'{key}_colors'] if needed. @@ -506,7 +523,7 @@ def add_colors_for_categorical_sample_annotation( _set_default_colors_for_categorical_obs(adata, key) -def plot_edges(axs, adata, basis, edges_width, edges_color, neighbors_key=None): +def plot_edges(axs, adata, basis, edges_width, edges_color, *, neighbors_key=None): import networkx as nx if not isinstance(axs, cabc.Sequence): @@ -568,7 +585,7 @@ def plot_arrows(axs, adata, basis, arrows_kwds=None): def scatter_group( - ax, key, imask, adata, Y, projection="2d", size=3, alpha=None, marker="." + ax, key, imask, adata, Y, *, projection="2d", size=3, alpha=None, marker="." ): """Scatter of group using representation of data Y.""" mask = adata.obs[key].cat.categories[imask] == adata.obs[key].values @@ -596,7 +613,8 @@ def scatter_group( def setup_axes( - ax: Axes | Sequence[Axes] = None, + ax: Axes | Sequence[Axes] | None = None, + *, panels="blue", colorbars=(False,), right_margin=None, @@ -679,6 +697,7 @@ def setup_axes( def scatter_base( Y: np.ndarray, + *, colors="blue", sort_order=True, alpha=None, @@ -726,7 +745,7 @@ def scatter_base( if len(markers) != len(colors) and len(markers) == 1: markers = [markers[0] for _ in range(len(colors))] axs, panel_pos, draw_region_width, figure_width = setup_axes( - ax=ax, + ax, panels=colors, colorbars=colorbars, projection=projection, @@ -1091,7 +1110,7 @@ def check_projection(projection): def circles( - x, y, s, ax, marker=None, c="b", vmin=None, vmax=None, scale_factor=1.0, **kwargs + x, y, *, s, ax, marker=None, c="b", vmin=None, vmax=None, scale_factor=1.0, **kwargs ): """ Taken from here: https://gist.github.com/syrte/592a062c562cd2a98a83 @@ -1153,6 +1172,7 @@ def circles( def make_grid_spec( ax_or_figsize: tuple[int, int] | _AxesSubplot, + *, nrows: int, ncols: int, wspace: float | None = None, @@ -1209,7 +1229,7 @@ def fix_kwds(kwds_dict, **kwargs): return kwargs -def _get_basis(adata: anndata.AnnData, basis: str): +def _get_basis(adata: AnnData, basis: str): if basis in adata.obsm.keys(): basis_key = basis diff --git a/scanpy/preprocessing/_combat.py b/scanpy/preprocessing/_combat.py index 550e35829e..f5feda94fe 100644 --- a/scanpy/preprocessing/_combat.py +++ b/scanpy/preprocessing/_combat.py @@ -8,6 +8,7 @@ from scipy.sparse import issparse from .. import logging as logg +from .._compat import old_positionals from .._utils import sanitize_anndata if TYPE_CHECKING: @@ -134,9 +135,11 @@ def _standardize_data( return s_data, design, var_pooled, stand_mean +@old_positionals("covariates", "inplace") def combat( adata: AnnData, key: str = "batch", + *, covariates: Collection[str] | None = None, inplace: bool = True, ) -> np.ndarray | None: @@ -245,10 +248,10 @@ def combat( s_data.iloc[:, batch_idxs].values, gamma_hat[i], delta_hat[i].values, - gamma_bar[i], - t2[i], - a_prior[i], - b_prior[i], + g_bar=gamma_bar[i], + t2=t2[i], + a=a_prior[i], + b=b_prior[i], ) gamma_star.append(gamma) @@ -288,6 +291,7 @@ def _it_sol( s_data: np.ndarray, g_hat: np.ndarray, d_hat: np.ndarray, + *, g_bar: float, t2: float, a: float, @@ -310,7 +314,7 @@ def _it_sol( Initial guess for gamma d_hat Initial guess for delta - g_bar, t_2, a, b + g_bar, t2, a, b Hyperparameters conv: float, optional (default: `0.0001`) convergence criterium diff --git a/scanpy/preprocessing/_deprecated/highly_variable_genes.py b/scanpy/preprocessing/_deprecated/highly_variable_genes.py index c7c89c45e1..6cc1f4c8c8 100644 --- a/scanpy/preprocessing/_deprecated/highly_variable_genes.py +++ b/scanpy/preprocessing/_deprecated/highly_variable_genes.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Literal +from typing import TYPE_CHECKING, Literal import numpy as np import pandas as pd @@ -12,9 +12,12 @@ from .._distributed import materialize_as_ndarray from .._utils import _get_mean_var +if TYPE_CHECKING: + from scipy.sparse import spmatrix -def filter_genes_dispersion( - data: AnnData, + +def filter_genes_dispersion( # noqa: PLR0917 + data: AnnData | spmatrix | np.ndarray, flavor: Literal["seurat", "cell_ranger"] = "seurat", min_disp: float | None = None, max_disp: float | None = None, @@ -25,7 +28,7 @@ def filter_genes_dispersion( log: bool = True, subset: bool = True, copy: bool = False, -): +) -> AnnData | np.recarray | None: """\ Extract highly variable genes [Satija15]_ [Zheng17]_. diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 2e5c14f82e..823e364e4e 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -9,6 +9,7 @@ from anndata import AnnData from .. import logging as logg +from .._compat import old_positionals from .._settings import Verbosity, settings from .._utils import check_nonnegative_integers, sanitize_anndata from ._distributed import materialize_as_ndarray @@ -18,6 +19,7 @@ def _highly_variable_genes_seurat_v3( adata: AnnData, + *, layer: str | None = None, n_top_genes: int = 2000, batch_key: str | None = None, @@ -180,6 +182,7 @@ def _highly_variable_genes_seurat_v3( def _highly_variable_genes_single_batch( adata: AnnData, + *, layer: str | None = None, min_disp: float | None = 0.5, max_disp: float | None = np.inf, @@ -301,8 +304,24 @@ def _highly_variable_genes_single_batch( return df +@old_positionals( + "layer", + "n_top_genes", + "min_disp", + "max_disp", + "min_mean", + "max_mean", + "span", + "n_bins", + "flavor", + "subset", + "inplace", + "batch_key", + "check_values", +) def highly_variable_genes( adata: AnnData, + *, layer: str | None = None, n_top_genes: int | None = None, min_disp: float | None = 0.5, diff --git a/scanpy/preprocessing/_normalization.py b/scanpy/preprocessing/_normalization.py index 74ac636a93..ea50686e60 100644 --- a/scanpy/preprocessing/_normalization.py +++ b/scanpy/preprocessing/_normalization.py @@ -8,7 +8,7 @@ from sklearn.utils import sparsefuncs from .. import logging as logg -from .._compat import DaskArray +from .._compat import DaskArray, old_positionals from .._utils import view_to_actual from ..get import _get_obs_rep, _set_obs_rep @@ -18,7 +18,7 @@ from anndata import AnnData -def _normalize_data(X, counts, after=None, copy=False): +def _normalize_data(X, counts, after=None, copy: bool = False): X = X.copy() if copy else X if issubclass(X.dtype.type, (int, np.integer)): X = X.astype(np.float32) # TODO: Check if float64 should be used @@ -39,18 +39,30 @@ def _normalize_data(X, counts, after=None, copy=False): return X +@old_positionals( + "target_sum", + "exclude_highly_expressed", + "max_fraction", + "key_added", + "layer", + "layers", + "layer_norm", + "inplace", + "copy", +) def normalize_total( adata: AnnData, + *, target_sum: float | None = None, exclude_highly_expressed: bool = False, max_fraction: float = 0.05, key_added: str | None = None, layer: str | None = None, - layers: Literal["all"] | Iterable[str] = None, + layers: Literal["all"] | Iterable[str] | None = None, layer_norm: str | None = None, inplace: bool = True, copy: bool = False, -) -> dict[str, np.ndarray] | None: +) -> AnnData | dict[str, np.ndarray] | None: """\ Normalize counts per cell. diff --git a/scanpy/preprocessing/_pca.py b/scanpy/preprocessing/_pca.py index f5d7a34249..53377f1321 100644 --- a/scanpy/preprocessing/_pca.py +++ b/scanpy/preprocessing/_pca.py @@ -38,7 +38,7 @@ def pca( copy: bool = False, chunked: bool = False, chunk_size: int | None = None, -) -> AnnData | np.ndarray | spmatrix: +) -> AnnData | np.ndarray | spmatrix | None: """\ Principal component analysis [Pedregosa11]_. diff --git a/scanpy/preprocessing/_qc.py b/scanpy/preprocessing/_qc.py index 2c0a61e87b..241afb32c3 100644 --- a/scanpy/preprocessing/_qc.py +++ b/scanpy/preprocessing/_qc.py @@ -25,7 +25,7 @@ from anndata import AnnData -def _choose_mtx_rep(adata, use_raw=False, layer=None): +def _choose_mtx_rep(adata, use_raw: bool = False, layer: str | None = None): is_layer = layer is not None if use_raw and is_layer: raise ValueError( @@ -56,7 +56,7 @@ def describe_obs( percent_top: Collection[int] | None = (50, 100, 200, 500), layer: str | None = None, use_raw: bool = False, - log1p: str | None = True, + log1p: bool | None = True, inplace: bool = False, X=None, parallel=None, @@ -155,9 +155,9 @@ def describe_var( var_type: str = "genes", layer: str | None = None, use_raw: bool = False, - inplace=False, - log1p=True, - X=None, + inplace: bool = False, + log1p: bool = True, + X: spmatrix | np.ndarray | None = None, ) -> pd.DataFrame | None: """\ Describe variables of anndata. @@ -331,7 +331,7 @@ def calculate_qc_metrics( return obs_metrics, var_metrics -def top_proportions(mtx: np.array | spmatrix, n: int): +def top_proportions(mtx: np.ndarray | spmatrix, n: int): """\ Calculates cumulative proportions of top expressed genes @@ -383,7 +383,7 @@ def top_proportions_sparse_csr(data, indptr, n): def top_segment_proportions( - mtx: np.array | spmatrix, ns: Collection[int] + mtx: np.ndarray | spmatrix, ns: Collection[int] ) -> np.ndarray: """ Calculates total percentage of counts in top ns genes. @@ -409,7 +409,7 @@ def top_segment_proportions( def top_segment_proportions_dense( - mtx: np.array | spmatrix, ns: Collection[int] + mtx: np.ndarray | spmatrix, ns: Collection[int] ) -> np.ndarray: # Currently ns is considered to be 1 indexed ns = np.sort(ns) diff --git a/scanpy/preprocessing/_recipes.py b/scanpy/preprocessing/_recipes.py index 8a70b3094e..10556bc993 100644 --- a/scanpy/preprocessing/_recipes.py +++ b/scanpy/preprocessing/_recipes.py @@ -5,6 +5,7 @@ from .. import logging as logg from .. import preprocessing as pp +from .._compat import old_positionals from ._deprecated.highly_variable_genes import ( filter_genes_cv_deprecated, filter_genes_dispersion, @@ -17,8 +18,18 @@ from .._utils import AnyRandom +@old_positionals( + "log", + "mean_threshold", + "cv_threshold", + "n_pcs", + "svd_solver", + "random_state", + "copy", +) def recipe_weinreb17( adata: AnnData, + *, log: bool = True, mean_threshold: float = 0.01, cv_threshold: int = 2, @@ -68,8 +79,9 @@ def recipe_weinreb17( return adata if copy else None +@old_positionals("log", "plot", "copy") def recipe_seurat( - adata: AnnData, log: bool = True, plot: bool = False, copy: bool = False + adata: AnnData, *, log: bool = True, plot: bool = False, copy: bool = False ) -> AnnData | None: """\ Normalization and filtering as of Seurat [Satija15]_. @@ -100,8 +112,10 @@ def recipe_seurat( return adata if copy else None +@old_positionals("n_top_genes", "log", "plot", "copy") def recipe_zheng17( adata: AnnData, + *, n_top_genes: int = 1000, log: bool = True, plot: bool = False, diff --git a/scanpy/preprocessing/_simple.py b/scanpy/preprocessing/_simple.py index 05b875a9a9..0da7f9e961 100644 --- a/scanpy/preprocessing/_simple.py +++ b/scanpy/preprocessing/_simple.py @@ -17,11 +17,12 @@ from sklearn.utils import check_array, sparsefuncs from .. import logging as logg +from .._compat import old_positionals from .._settings import settings as sett from .._utils import ( AnyRandom, _check_array_function_arguments, - deprecated_arg_names, + renamed_arg, sanitize_anndata, view_to_actual, ) @@ -45,15 +46,19 @@ from numpy.typing import NDArray +@old_positionals( + "min_counts", "min_genes", "max_counts", "max_genes", "inplace", "copy" +) def filter_cells( - data: AnnData, + data: AnnData | spmatrix | np.ndarray, + *, min_counts: int | None = None, min_genes: int | None = None, max_counts: int | None = None, max_genes: int | None = None, inplace: bool = True, copy: bool = False, -) -> tuple[np.ndarray, np.ndarray] | None: +) -> AnnData | tuple[np.ndarray, np.ndarray] | None: """\ Filter cell outliers based on counts and numbers of genes expressed. @@ -138,7 +143,13 @@ def filter_cells( if isinstance(data, AnnData): adata = data.copy() if copy else data cell_subset, number = materialize_as_ndarray( - filter_cells(adata.X, min_counts, min_genes, max_counts, max_genes) + filter_cells( + adata.X, + min_counts=min_counts, + min_genes=min_genes, + max_counts=max_counts, + max_genes=max_genes, + ), ) if not inplace: return cell_subset, number @@ -182,15 +193,19 @@ def filter_cells( return cell_subset, number_per_cell +@old_positionals( + "min_counts", "min_cells", "max_counts", "max_cells", "inplace", "copy" +) def filter_genes( - data: AnnData, + data: AnnData | spmatrix | np.ndarray, + *, min_counts: int | None = None, min_cells: int | None = None, max_counts: int | None = None, max_cells: int | None = None, inplace: bool = True, copy: bool = False, -) -> AnnData | None | tuple[np.ndarray, np.ndarray]: +) -> AnnData | tuple[np.ndarray, np.ndarray] | None: """\ Filter genes based on number of cells or counts. @@ -290,17 +305,18 @@ def filter_genes( return gene_subset, number_per_gene +@renamed_arg("X", "data", pos_0=True) @singledispatch def log1p( - X: AnnData | np.ndarray | spmatrix, + data: AnnData | np.ndarray | spmatrix, *, base: Number | None = None, copy: bool = False, - chunked: bool = None, + chunked: bool | None = None, chunk_size: int | None = None, layer: str | None = None, obsm: str | None = None, -): +) -> AnnData | np.ndarray | spmatrix | None: """\ Logarithmize the data matrix. @@ -309,7 +325,7 @@ def log1p( Parameters ---------- - X + data The (annotated) data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. base @@ -334,7 +350,7 @@ def log1p( _check_array_function_arguments( chunked=chunked, chunk_size=chunk_size, layer=layer, obsm=obsm ) - return log1p_array(X, copy=copy, base=base) + return log1p_array(data, copy=copy, base=base) @log1p.register(spmatrix) @@ -397,12 +413,14 @@ def log1p_anndata( return adata +@old_positionals("copy", "chunked", "chunk_size") def sqrt( - data: AnnData, + data: AnnData | spmatrix | np.ndarray, + *, copy: bool = False, chunked: bool = False, chunk_size: int | None = None, -) -> AnnData | None: +) -> AnnData | spmatrix | np.ndarray | None: """\ Square root the data matrix. @@ -441,7 +459,7 @@ def sqrt( return X.sqrt() -def normalize_per_cell( +def normalize_per_cell( # noqa: PLR0917 data: AnnData | np.ndarray | spmatrix, counts_per_cell_after: float | None = None, counts_per_cell: np.ndarray | None = None, @@ -450,7 +468,7 @@ def normalize_per_cell( layers: Literal["all"] | Iterable[str] = (), use_rep: Literal["after", "X"] | None = None, min_counts: int = 1, -) -> AnnData | None: +) -> AnnData | np.ndarray | spmatrix | None: """\ Normalize total counts per cell. @@ -577,9 +595,11 @@ def normalize_per_cell( return X if copy else None +@old_positionals("layer", "n_jobs", "copy") def regress_out( adata: AnnData, keys: str | Sequence[str], + *, layer: str | None = None, n_jobs: int | None = None, copy: bool = False, @@ -725,16 +745,19 @@ def _regress_out_chunk(data): return np.vstack(responses_chunk_list) +@renamed_arg("X", "data", pos_0=True) +@old_positionals("zero_center", "max_value", "copy", "layer", "obsm", "mask") @singledispatch def scale( - X: AnnData | spmatrix | np.ndarray, + data: AnnData | spmatrix | np.ndarray, + *, zero_center: bool = True, max_value: float | None = None, copy: bool = False, layer: str | None = None, obsm: str | None = None, mask: NDArray[np.bool_] | str | None = None, -): +) -> AnnData | spmatrix | np.ndarray | None: """\ Scale data to unit variance and zero mean. @@ -745,7 +768,7 @@ def scale( Parameters ---------- - X + data The (annotated) data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. zero_center @@ -776,11 +799,15 @@ def scale( """ _check_array_function_arguments(layer=layer, obsm=obsm) if layer is not None: - raise ValueError(f"`layer` argument inappropriate for value of type {type(X)}") + raise ValueError( + f"`layer` argument inappropriate for value of type {type(data)}" + ) if obsm is not None: - raise ValueError(f"`obsm` argument inappropriate for value of type {type(X)}") + raise ValueError( + f"`obsm` argument inappropriate for value of type {type(data)}" + ) return scale_array( - X, zero_center=zero_center, max_value=max_value, copy=copy, mask=mask + data, zero_center=zero_center, max_value=max_value, copy=copy, mask=mask ) @@ -905,17 +932,18 @@ def scale_anndata( mask=mask, ) _set_obs_rep(adata, X, layer=layer, obsm=obsm) - if copy: - return adata + return adata if copy else None +@old_positionals("n_obs", "random_state", "copy") def subsample( data: AnnData | np.ndarray | spmatrix, fraction: float | None = None, + *, n_obs: int | None = None, random_state: AnyRandom = 0, copy: bool = False, -) -> AnnData | None: +) -> AnnData | tuple[np.ndarray | spmatrix, NDArray[np.int64]] | None: """\ Subsample to a fraction of the number of observations. @@ -970,7 +998,7 @@ def subsample( return X[obs_indices], obs_indices -@deprecated_arg_names({"target_counts": "counts_per_cell"}) +@renamed_arg("target_counts", "counts_per_cell") def downsample_counts( adata: AnnData, counts_per_cell: int | Collection[int] | None = None, diff --git a/scanpy/preprocessing/_utils.py b/scanpy/preprocessing/_utils.py index 9df47cb981..d1f5cf7498 100644 --- a/scanpy/preprocessing/_utils.py +++ b/scanpy/preprocessing/_utils.py @@ -47,7 +47,12 @@ def sparse_mean_variance_axis(mtx: sparse.spmatrix, axis: int): raise ValueError("This function only works on sparse csr and csc matrices") if axis == ax_minor: return sparse_mean_var_major_axis( - mtx.data, mtx.indices, mtx.indptr, *shape, np.float64 + mtx.data, + mtx.indices, + mtx.indptr, + major_len=shape[0], + minor_len=shape[1], + dtype=np.float64, ) else: return sparse_mean_var_minor_axis(mtx.data, mtx.indices, *shape, np.float64) @@ -89,7 +94,7 @@ def sparse_mean_var_minor_axis(data, indices, major_len, minor_len, dtype): @numba.njit(cache=True) -def sparse_mean_var_major_axis(data, indices, indptr, major_len, minor_len, dtype): +def sparse_mean_var_major_axis(data, indices, indptr, *, major_len, minor_len, dtype): """ Computes mean and variance for a sparse array for the major axis. diff --git a/scanpy/readwrite.py b/scanpy/readwrite.py index 91fd629aac..14d1e9ee21 100644 --- a/scanpy/readwrite.py +++ b/scanpy/readwrite.py @@ -23,6 +23,7 @@ from matplotlib.image import imread from . import logging as logg +from ._compat import old_positionals from ._settings import settings from ._utils import Empty, _empty @@ -52,9 +53,19 @@ # -------------------------------------------------------------------------------- +@old_positionals( + "sheet", + "ext", + "delimiter", + "first_column_names", + "backup_url", + "cache", + "cache_compression", +) def read( filename: Path | str, backed: Literal["r", "r+"] | None = None, + *, sheet: str | None = None, ext: str | None = None, delimiter: str | None = None, @@ -136,8 +147,10 @@ def read( return read_h5ad(filename, backed=backed) +@old_positionals("genome", "gex_only", "backup_url") def read_10x_h5( - filename: str | Path, + filename: Path | str, + *, genome: str | None = None, gex_only: bool = True, backup_url: str | None = None, @@ -336,13 +349,13 @@ def _read_v3_10x_h5(filename, *, start=None): def read_visium( - path: str | Path, + path: Path | str, genome: str | None = None, *, count_file: str = "filtered_feature_bc_matrix.h5", library_id: str | None = None, load_images: bool | None = True, - source_image_path: str | Path | None = None, + source_image_path: Path | str | None = None, ) -> AnnData: """\ Read 10x-Genomics-formatted visum dataset. @@ -494,15 +507,16 @@ def read_visium( return adata +@old_positionals("var_names", "make_unique", "cache", "cache_compression", "gex_only") def read_10x_mtx( path: Path | str, + *, var_names: Literal["gene_symbols", "gene_ids"] = "gene_symbols", make_unique: bool = True, cache: bool = False, cache_compression: Literal["gzip", "lzf"] | None | Empty = _empty, gex_only: bool = True, - *, - prefix: str = None, + prefix: str | None = None, ) -> AnnData: """\ Read 10x-Genomics-formatted mtx directory. @@ -627,9 +641,11 @@ def _read_v3_10x_mtx( return adata +@old_positionals("ext", "compression", "compression_opts") def write( - filename: str | Path, + filename: Path | str, adata: AnnData, + *, ext: Literal["h5", "csv", "txt", "npz"] | None = None, compression: Literal["gzip", "lzf"] | None = "gzip", compression_opts: int | None = None, @@ -748,6 +764,7 @@ def write_params(path: Path | str, *args, **maps): def _read( filename: Path, + *, backed=None, sheet=None, ext=None, diff --git a/scanpy/testing/_doctests.py b/scanpy/testing/_doctests.py index ac711d4263..72ddc20dae 100644 --- a/scanpy/testing/_doctests.py +++ b/scanpy/testing/_doctests.py @@ -1,12 +1,9 @@ from __future__ import annotations -from types import FunctionType -from typing import TYPE_CHECKING, TypeVar +from collections.abc import Callable +from typing import TypeVar -if TYPE_CHECKING: - from collections.abc import Callable - -F = TypeVar("F", bound=FunctionType) +F = TypeVar("F", bound=Callable) def doctest_needs(mod: str) -> Callable[[F], F]: diff --git a/scanpy/testing/_pytest/__init__.py b/scanpy/testing/_pytest/__init__.py index a87b752daa..8933091344 100644 --- a/scanpy/testing/_pytest/__init__.py +++ b/scanpy/testing/_pytest/__init__.py @@ -2,10 +2,11 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pytest +from ..._utils import _import_name from .fixtures import * # noqa: F403 if TYPE_CHECKING: @@ -71,21 +72,3 @@ def pytest_itemcollected(item: pytest.Item) -> None: item.add_marker(marker) if skip_reason := getattr(func, "_doctest_skip_reason", False): item.add_marker(pytest.mark.skip(reason=skip_reason)) - - -def _import_name(name: str) -> Any: - from importlib import import_module - - parts = name.split(".") - obj = import_module(parts[0]) - for i, name in enumerate(parts[1:]): - try: - obj = import_module(f"{obj.__name__}.{name}") - except ModuleNotFoundError: - break - for name in parts[i + 1 :]: - try: - obj = getattr(obj, name) - except AttributeError: - raise RuntimeError(f"{parts[:i]}, {parts[i+1:]}, {obj} {name}") - return obj diff --git a/scanpy/testing/_pytest/fixtures/__init__.py b/scanpy/testing/_pytest/fixtures/__init__.py index 26bca50b3b..bd4ba90a27 100644 --- a/scanpy/testing/_pytest/fixtures/__init__.py +++ b/scanpy/testing/_pytest/fixtures/__init__.py @@ -41,7 +41,7 @@ def doctest_env(cache: pytest.Cache, tmp_path: Path) -> Generator[None, None, No showwarning_orig = warnings.showwarning - def showwarning(message, category, filename, lineno, file=None, line=None): + def showwarning(message, category, filename, lineno, file=None, line=None): # noqa: PLR0917 if file is None: if line is None: import linecache diff --git a/scanpy/tests/test_embedding_density.py b/scanpy/tests/test_embedding_density.py index bf40bf30da..187412b37b 100644 --- a/scanpy/tests/test_embedding_density.py +++ b/scanpy/tests/test_embedding_density.py @@ -27,4 +27,4 @@ def test_embedding_density_plot(): # Test that sc.pl.embedding_density() runs without error adata = pbmc68k_reduced() sc.tl.embedding_density(adata, "umap") - sc.pl.embedding_density(adata, "umap", "umap_density", show=False) + sc.pl.embedding_density(adata, "umap", key="umap_density", show=False) diff --git a/scanpy/tests/test_embedding_plots.py b/scanpy/tests/test_embedding_plots.py index b2fc458c76..54036c52a6 100644 --- a/scanpy/tests/test_embedding_plots.py +++ b/scanpy/tests/test_embedding_plots.py @@ -154,6 +154,7 @@ def vbounds(request): def test_missing_values_categorical( + *, fixture_request, image_comparer, adata, @@ -181,7 +182,7 @@ def test_missing_values_categorical( def test_missing_values_continuous( - fixture_request, image_comparer, adata, plotfunc, na_color, legend_loc, vbounds + *, fixture_request, image_comparer, adata, plotfunc, na_color, legend_loc, vbounds ): save_and_compare_images = partial(image_comparer, MISSING_VALUES_ROOT, tol=15) diff --git a/scanpy/tests/test_normalization.py b/scanpy/tests/test_normalization.py index c395b41c80..f1fa910b1e 100644 --- a/scanpy/tests/test_normalization.py +++ b/scanpy/tests/test_normalization.py @@ -193,6 +193,7 @@ def _check_pearson_pca_fields(ad, n_cells, n_comps): ], ) def test_normalize_pearson_residuals_pca( + *, pbmc3k_parametrized_small: Callable[[], AnnData], n_hvgs: int, n_comps: int, diff --git a/scanpy/tests/test_package_structure.py b/scanpy/tests/test_package_structure.py index 58afdcb06a..783e7c4eae 100644 --- a/scanpy/tests/test_package_structure.py +++ b/scanpy/tests/test_package_structure.py @@ -1,23 +1,54 @@ from __future__ import annotations -import inspect import os +from collections import defaultdict +from inspect import Parameter, signature from pathlib import Path -from types import FunctionType +from typing import TYPE_CHECKING, Any, TypedDict import pytest +from anndata import AnnData # CLI is locally not imported by default but on travis it is? import scanpy.cli -from scanpy._utils import descend_classes_and_funcs +from scanpy._utils import _import_name, descend_classes_and_funcs + +if TYPE_CHECKING: + from types import FunctionType mod_dir = Path(scanpy.__file__).parent proj_dir = mod_dir.parent -scanpy_functions = [ - c_or_f - for c_or_f in descend_classes_and_funcs(scanpy, "scanpy") - if isinstance(c_or_f, FunctionType) + +api_module_names = [ + "sc", + "sc.pp", + "sc.tl", + "sc.pl", + "sc.experimental.pp", + "sc.external.pp", + "sc.external.tl", + "sc.external.pl", + "sc.external.exporting", + "sc.get", + "sc.logging", + # "sc.neighbors", # Not documented + "sc.datasets", + "sc.queries", + "sc.metrics", +] +api_modules = { + mod_name: _import_name(f"scanpy{mod_name.removeprefix('sc')}") + for mod_name in api_module_names +} + + +# get all exported functions that aren’t re-exports from anndata +api_functions = [ + pytest.param(func, f"{mod_name}.{name}", id=f"{mod_name}.{name}") + for mod_name, mod in api_modules.items() + for name in sorted(mod.__all__) + if callable(func := getattr(mod, name)) and func.__module__.startswith("scanpy.") ] @@ -31,13 +62,18 @@ def in_project_dir(): os.chdir(wd_orig) -@pytest.mark.parametrize("f", scanpy_functions) -def test_function_headers(f): - name = f"{f.__module__}.{f.__qualname__}" - filename = inspect.getsourcefile(f) - lines, lineno = inspect.getsourcelines(f) +@pytest.mark.xfail(reason="TODO: unclear if we want this to totally match, let’s see") +def test_descend_classes_and_funcs(): + funcs = set(descend_classes_and_funcs(scanpy, "scanpy")) + assert {p.values[0] for p in api_functions} == funcs + + +@pytest.mark.parametrize(("f", "qualname"), api_functions) +def test_function_headers(f, qualname): + filename = getsourcefile(f) + lines, lineno = getsourcelines(f) if f.__doc__ is None: - msg = f"Function `{name}` has no docstring" + msg = f"Function `{qualname}` has no docstring" text = lines[0] else: lines = getattr(f, "__orig_doc__", f.__doc__).split("\n") @@ -47,7 +83,7 @@ def test_function_headers(f): if not any(broken): return msg = f'''\ -Header of function `{name}`’s docstring should start with one-line description +Header of function `{qualname}`’s docstring should start with one-line description and be consistently indented like this: ␣␣␣␣"""\\ @@ -60,3 +96,104 @@ def test_function_headers(f): ''' text = f">{lines[broken[0]]}<" raise SyntaxError(msg, (filename, lineno, 2, text)) + + +def param_is_pos(p: Parameter) -> bool: + return p.kind in { + Parameter.POSITIONAL_ONLY, + Parameter.POSITIONAL_OR_KEYWORD, + } + + +def is_deprecated(f: FunctionType) -> bool: + # TODO: use deprecated decorator instead + # https://github.com/scverse/scanpy/issues/2505 + return f.__name__ in { + "normalize_per_cell", + "filter_genes_dispersion", + } + + +class ExpectedSig(TypedDict): + first_name: str + copy_default: Any + return_ann: str | None + + +copy_sigs: defaultdict[str, ExpectedSig | None] = defaultdict( + lambda: ExpectedSig(first_name="adata", copy_default=False, return_ann=None) +) +# full exceptions +copy_sigs["sc.external.tl.phenograph"] = None # external +copy_sigs["sc.pp.filter_genes_dispersion"] = None # deprecated +copy_sigs["sc.pp.filter_cells"] = None # unclear `inplace` situation +copy_sigs["sc.pp.filter_genes"] = None # unclear `inplace` situation +copy_sigs["sc.pp.subsample"] = None # returns indices along matrix +# partial exceptions: “data” instead of “adata” +copy_sigs["sc.pp.log1p"]["first_name"] = "data" +copy_sigs["sc.pp.normalize_per_cell"]["first_name"] = "data" +copy_sigs["sc.pp.pca"]["first_name"] = "data" +copy_sigs["sc.pp.scale"]["first_name"] = "data" +copy_sigs["sc.pp.sqrt"]["first_name"] = "data" +# other partial exceptions +copy_sigs["sc.pp.normalize_total"]["return_ann"] = copy_sigs[ + "sc.experimental.pp.normalize_pearson_residuals" +]["return_ann"] = "AnnData | dict[str, np.ndarray] | None" +copy_sigs["sc.external.pp.magic"]["copy_default"] = None + + +@pytest.mark.parametrize(("f", "qualname"), api_functions) +def test_sig_conventions(f, qualname): + sig = signature(f) + + # TODO: replace the following check with lint rule for all funtions eventually + if not is_deprecated(f): + n_pos = sum(1 for p in sig.parameters.values() if param_is_pos(p)) + assert n_pos <= 3, "Public functions should have <= 3 positional parameters" + + first_param = next(iter(sig.parameters.values()), None) + if first_param is None: + return + + if first_param.name == "adata": + assert first_param.annotation in {"AnnData", AnnData} + elif first_param.name == "data": + assert first_param.annotation.startswith("AnnData |") + elif first_param.name in {"filename", "path"}: + assert first_param.annotation == "Path | str" + + # Test if functions with `copy` follow conventions + if (copy_param := sig.parameters.get("copy")) is not None and ( + expected_sig := copy_sigs[qualname] + ) is not None: + s = ExpectedSig( + first_name=first_param.name, + copy_default=copy_param.default, + return_ann=sig.return_annotation, + ) + expected_sig = expected_sig.copy() + if expected_sig["return_ann"] is None: + expected_sig["return_ann"] = f"{first_param.annotation} | None" + assert s == expected_sig + if not is_deprecated(f): + assert not param_is_pos(copy_param) + + +def getsourcefile(obj): + """inspect.getsourcefile, but supports singledispatch""" + from inspect import getsourcefile + + if wrapped := getattr(obj, "__wrapped__", None): + return getsourcefile(wrapped) + + return getsourcefile(obj) + + +def getsourcelines(obj): + """inspect.getsourcelines, but supports singledispatch""" + from inspect import getsourcelines + + if wrapped := getattr(obj, "__wrapped__", None): + return getsourcelines(wrapped) + + return getsourcelines(obj) diff --git a/scanpy/tests/test_preprocessing.py b/scanpy/tests/test_preprocessing.py index 7f4547b242..3029068776 100644 --- a/scanpy/tests/test_preprocessing.py +++ b/scanpy/tests/test_preprocessing.py @@ -36,6 +36,12 @@ def test_log1p(tmp_path): assert np.allclose(ad4.X, A_l / np.log(2)) +def test_log1p_deprecated_arg(): + A = np.random.rand(200, 10).astype(np.float32) + with pytest.warns(FutureWarning, match=r".*`X` was renamed to `data`"): + sc.pp.log1p(X=A) + + @pytest.fixture(params=[None, 2]) def base(request): return request.param diff --git a/scanpy/tests/test_scaling.py b/scanpy/tests/test_scaling.py index 69cefe4458..d34c91e861 100644 --- a/scanpy/tests/test_scaling.py +++ b/scanpy/tests/test_scaling.py @@ -67,7 +67,7 @@ ), ], ) -def test_scale(typ, dtype, mask, X, X_centered, X_scaled): +def test_scale(*, typ, dtype, mask, X, X_centered, X_scaled): # test AnnData arguments # test scaling with default zero_center == True adata0 = AnnData(typ(X).astype(dtype)) diff --git a/scanpy/tools/__init__.py b/scanpy/tools/__init__.py index b791a48a35..9ca74faf81 100644 --- a/scanpy/tools/__init__.py +++ b/scanpy/tools/__init__.py @@ -7,15 +7,18 @@ from ._dpt import dpt from ._draw_graph import draw_graph from ._embedding_density import embedding_density -from ._ingest import Ingest, ingest +from ._ingest import ( + Ingest, # noqa: F401 + ingest, +) from ._leiden import leiden from ._louvain import louvain from ._marker_gene_overlap import marker_gene_overlap from ._paga import ( paga, - paga_compare_paths, - paga_degrees, - paga_expression_entropies, + paga_compare_paths, # noqa: F401 + paga_degrees, # noqa: F401 + paga_expression_entropies, # noqa: F401 ) from ._rank_genes_groups import filter_rank_genes_groups, rank_genes_groups from ._score_genes import score_genes, score_genes_cell_cycle @@ -38,15 +41,11 @@ def __getattr__(name: str) -> Any: "dpt", "draw_graph", "embedding_density", - "Ingest", "ingest", "leiden", "louvain", "marker_gene_overlap", "paga", - "paga_compare_paths", - "paga_degrees", - "paga_expression_entropies", "filter_rank_genes_groups", "rank_genes_groups", "score_genes", diff --git a/scanpy/tools/_dendrogram.py b/scanpy/tools/_dendrogram.py index ccae494748..a6a6d2c625 100644 --- a/scanpy/tools/_dendrogram.py +++ b/scanpy/tools/_dendrogram.py @@ -10,6 +10,7 @@ from pandas.api.types import CategoricalDtype from .. import logging as logg +from .._compat import old_positionals from .._utils import _doc_params from ..neighbors._doc import doc_n_pcs, doc_use_rep from ._utils import _choose_representation @@ -20,10 +21,22 @@ from anndata import AnnData +@old_positionals( + "n_pcs", + "use_rep", + "var_names", + "use_raw", + "cor_method", + "linkage_method", + "optimal_ordering", + "key_added", + "inplace", +) @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) def dendrogram( adata: AnnData, groupby: str | Sequence[str], + *, n_pcs: int | None = None, use_rep: str | None = None, var_names: Sequence[str] | None = None, @@ -137,7 +150,9 @@ def dendrogram( gene_names = adata.raw.var_names if use_raw else adata.var_names from ..plotting._anndata import _prepare_dataframe - categories, rep_df = _prepare_dataframe(adata, gene_names, groupby, use_raw) + categories, rep_df = _prepare_dataframe( + adata, gene_names, groupby, use_raw=use_raw + ) # aggregate values within categories using 'mean' mean_df = rep_df.groupby(level=0).mean() diff --git a/scanpy/tools/_diffmap.py b/scanpy/tools/_diffmap.py index 8b1a09eb16..9a7667ae47 100644 --- a/scanpy/tools/_diffmap.py +++ b/scanpy/tools/_diffmap.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from .._compat import old_positionals from ._dpt import _diffmap if TYPE_CHECKING: @@ -10,13 +11,15 @@ from .._utils import AnyRandom +@old_positionals("neighbors_key", "random_state", "copy") def diffmap( adata: AnnData, n_comps: int = 15, + *, neighbors_key: str | None = None, random_state: AnyRandom = 0, copy: bool = False, -): +) -> AnnData | None: """\ Diffusion Maps [Coifman05]_ [Haghverdi15]_ [Wolf18]_. diff --git a/scanpy/tools/_dpt.py b/scanpy/tools/_dpt.py index 864211d9bc..ba5ca88b7e 100644 --- a/scanpy/tools/_dpt.py +++ b/scanpy/tools/_dpt.py @@ -8,6 +8,7 @@ from natsort import natsorted from .. import logging as logg +from .._compat import old_positionals from ..neighbors import Neighbors, OnFlySymMatrix if TYPE_CHECKING: @@ -34,9 +35,13 @@ def _diffmap(adata, n_comps=15, neighbors_key=None, random_state=0): ) +@old_positionals( + "n_branchings", "min_group_size", "allow_kendall_tau_shift", "neighbors_key", "copy" +) def dpt( adata: AnnData, n_dcs: int = 10, + *, n_branchings: int = 0, min_group_size: float = 0.01, allow_kendall_tau_shift: bool = True, @@ -201,12 +206,13 @@ class DPT(Neighbors): def __init__( self, - adata, - n_dcs=None, - min_group_size=0.01, - n_branchings=0, - allow_kendall_tau_shift=False, - neighbors_key=None, + adata: AnnData, + *, + n_dcs: int | None = None, + min_group_size: float = 0.01, + n_branchings: int = 0, + allow_kendall_tau_shift: bool = False, + neighbors_key: str | None = None, ): super().__init__(adata, n_dcs=n_dcs, neighbors_key=neighbors_key) self.flavor = "haghverdi16" @@ -316,13 +322,13 @@ def detect_branchings(self): ) # [third start end] # detect branching and update segs and segs_tips self.detect_branching( - segs, - segs_tips, - segs_connects, - segs_undecided, - segs_adjacency, - iseg, - tips3, + segs=segs, + segs_tips=segs_tips, + segs_connects=segs_connects, + segs_undecided=segs_undecided, + segs_adjacency=segs_adjacency, + iseg=iseg, + tips3=tips3, ) # store as class members self.segs = segs @@ -531,6 +537,7 @@ def order_pseudotime(self): def detect_branching( self, + *, segs: Sequence[np.ndarray], segs_tips: Sequence[np.ndarray], segs_connects, diff --git a/scanpy/tools/_draw_graph.py b/scanpy/tools/_draw_graph.py index 5a88597245..44a0393d78 100644 --- a/scanpy/tools/_draw_graph.py +++ b/scanpy/tools/_draw_graph.py @@ -7,6 +7,7 @@ from .. import _utils from .. import logging as logg +from .._compat import old_positionals from .._utils import AnyRandom, _choose_graph from ._utils import get_init_pos_from_paga @@ -18,9 +19,21 @@ _Layout = Literal[_LAYOUTS] +@old_positionals( + "init_pos", + "root", + "random_state", + "n_jobs", + "adjacency", + "key_added_ext", + "neighbors_key", + "obsp", + "copy", +) def draw_graph( adata: AnnData, layout: _Layout = "fa", + *, init_pos: str | bool | None = None, root: int | None = None, random_state: AnyRandom = 0, @@ -31,7 +44,7 @@ def draw_graph( obsp: str | None = None, copy: bool = False, **kwds, -): +) -> AnnData | None: """\ Force-directed graph drawing [Islam11]_ [Jacomy14]_ [Chippada18]_. diff --git a/scanpy/tools/_embedding_density.py b/scanpy/tools/_embedding_density.py index 44296556f0..00e97c0780 100644 --- a/scanpy/tools/_embedding_density.py +++ b/scanpy/tools/_embedding_density.py @@ -8,6 +8,7 @@ import numpy as np from .. import logging as logg +from .._compat import old_positionals from .._utils import sanitize_anndata if TYPE_CHECKING: @@ -35,10 +36,11 @@ def _calc_density(x: np.ndarray, y: np.ndarray): return scaled_z +@old_positionals("groupby", "key_added", "components") def embedding_density( adata: AnnData, - # there is no asterisk here for backward compat (previously, there was) - basis: str = "umap", # was positional before 1.4.5 + basis: str = "umap", + *, groupby: str | None = None, key_added: str | None = None, components: str | Sequence[str] | None = None, diff --git a/scanpy/tools/_ingest.py b/scanpy/tools/_ingest.py index e4d7b61ec0..14dc300f19 100644 --- a/scanpy/tools/_ingest.py +++ b/scanpy/tools/_ingest.py @@ -10,7 +10,7 @@ from sklearn.utils import check_random_state from .. import logging as logg -from .._compat import pkg_version +from .._compat import old_positionals, pkg_version from .._settings import settings from .._utils import NeighborsView from ..neighbors import FlatTree, RPForestDict @@ -22,10 +22,19 @@ ANNDATA_MIN_VERSION = version.parse("0.7rc1") +@old_positionals( + "obs", + "embedding_method", + "labeling_method", + "neighbors_key", + "neighbors_key", + "inplace", +) @doctest_skip("illustrative short example but not runnable") def ingest( adata: AnnData, adata_ref: AnnData, + *, obs: str | Iterable[str] | None = None, embedding_method: str | Iterable[str] = ("umap", "pca"), labeling_method: str = "knn", @@ -392,7 +401,7 @@ def _init_pca(self, adata): else: self._pca_basis = adata.varm["PCs"] - def __init__(self, adata, neighbors_key=None): + def __init__(self, adata: AnnData, neighbors_key: str | None = None): # assume rep is X if all initializations fail to identify it self._rep = adata.X self._use_rep = "X" diff --git a/scanpy/tools/_leiden.py b/scanpy/tools/_leiden.py index d0d903dc4d..aa8a9f4981 100644 --- a/scanpy/tools/_leiden.py +++ b/scanpy/tools/_leiden.py @@ -130,8 +130,8 @@ def leiden( adjacency, restrict_indices = restrict_adjacency( adata, restrict_key, - restrict_categories, - adjacency, + restrict_categories=restrict_categories, + adjacency=adjacency, ) # convert it to igraph g = _utils.get_igraph_from_adjacency(adjacency, directed=directed) @@ -157,11 +157,11 @@ def leiden( key_added += "_R" groups = rename_groups( adata, - key_added, - restrict_key, - restrict_categories, - restrict_indices, - groups, + key_added=key_added, + restrict_key=restrict_key, + restrict_categories=restrict_categories, + restrict_indices=restrict_indices, + groups=groups, ) adata.obs[key_added] = pd.Categorical( values=groups.astype("U"), diff --git a/scanpy/tools/_louvain.py b/scanpy/tools/_louvain.py index 5fc9a9ac9e..51f3cb5bb7 100644 --- a/scanpy/tools/_louvain.py +++ b/scanpy/tools/_louvain.py @@ -11,6 +11,7 @@ from .. import _utils from .. import logging as logg +from .._compat import old_positionals from .._utils import _choose_graph from ._utils_clustering import rename_groups, restrict_adjacency @@ -30,9 +31,24 @@ class MutableVertexPartition: MutableVertexPartition.__module__ = "louvain.VertexPartition" +@old_positionals( + "random_state", + "restrict_to", + "key_added", + "adjacency", + "flavor", + "directed", + "use_weights", + "partition_type", + "partition_kwargs", + "neighbors_key", + "obsp", + "copy", +) def louvain( adata: AnnData, resolution: float | None = None, + *, random_state: _utils.AnyRandom = 0, restrict_to: tuple[str, Sequence[str]] | None = None, key_added: str = "louvain", @@ -135,8 +151,8 @@ def louvain( adjacency, restrict_indices = restrict_adjacency( adata, restrict_key, - restrict_categories, - adjacency, + restrict_categories=restrict_categories, + adjacency=adjacency, ) if flavor in {"vtraag", "igraph"}: if flavor == "igraph" and resolution is not None: @@ -229,11 +245,11 @@ def louvain( key_added += "_R" groups = rename_groups( adata, - key_added, - restrict_key, - restrict_categories, - restrict_indices, - groups, + key_added=key_added, + restrict_key=restrict_key, + restrict_categories=restrict_categories, + restrict_indices=restrict_indices, + groups=groups, ) adata.obs[key_added] = pd.Categorical( values=groups.astype("U"), diff --git a/scanpy/tools/_paga.py b/scanpy/tools/_paga.py index c4f9ca334b..1481ede0f2 100644 --- a/scanpy/tools/_paga.py +++ b/scanpy/tools/_paga.py @@ -8,6 +8,7 @@ from .. import _utils from .. import logging as logg +from .._compat import old_positionals from ..neighbors import Neighbors if TYPE_CHECKING: @@ -16,14 +17,16 @@ _AVAIL_MODELS = {"v1.0", "v1.2"} +@old_positionals("use_rna_velocity", "model", "neighbors_key", "copy") def paga( adata: AnnData, groups: str | None = None, + *, use_rna_velocity: bool = False, model: Literal["v1.2", "v1.0"] = "v1.2", neighbors_key: str | None = None, copy: bool = False, -): +) -> AnnData | None: """\ Mapping out the coarse-grained connectivity structures of complex manifolds [Wolf19]_. @@ -405,7 +408,7 @@ def paga_degrees(adata: AnnData) -> list[int]: return degrees -def paga_expression_entropies(adata) -> list[float]: +def paga_expression_entropies(adata: AnnData) -> list[float]: """Compute the median expression entropy for each node-group. Parameters diff --git a/scanpy/tools/_rank_genes_groups.py b/scanpy/tools/_rank_genes_groups.py index 25d1f2cc54..de7e842b93 100644 --- a/scanpy/tools/_rank_genes_groups.py +++ b/scanpy/tools/_rank_genes_groups.py @@ -11,6 +11,7 @@ from .. import _utils from .. import logging as logg +from .._compat import old_positionals from .._utils import check_nonnegative_integers from ..get import _check_mask from ..preprocessing._simple import _get_mean_var @@ -449,10 +450,25 @@ def compute_statistics( self.stats.index = self.var_names -# TODO: Make arguments after groupby keyword only +@old_positionals( + "mask", + "use_raw", + "groups", + "reference", + "n_genes", + "rankby_abs", + "pts", + "key_added", + "copy", + "method", + "corr_method", + "tie_correct", + "layer", +) def rank_genes_groups( adata: AnnData, groupby: str, + *, mask: NDArray[np.bool_] | str | None = None, use_raw: bool | None = None, groups: Literal["all"] | Iterable[str] = "all", @@ -712,16 +728,27 @@ def _calc_frac(X): return n_nonzero / X.shape[0] +@old_positionals( + "key", + "groupby", + "use_raw", + "key_added", + "min_in_group_fraction", + "min_fold_change", + "max_out_group_fraction", + "compare_abs", +) def filter_rank_genes_groups( adata: AnnData, - key=None, - groupby=None, - use_raw=None, - key_added="rank_genes_groups_filtered", - min_in_group_fraction=0.25, - min_fold_change=1, - max_out_group_fraction=0.5, - compare_abs=False, + *, + key: str | None = None, + groupby: str | None = None, + use_raw: bool | None = None, + key_added: str = "rank_genes_groups_filtered", + min_in_group_fraction: float = 0.25, + min_fold_change: int | float = 1, + max_out_group_fraction: float = 0.5, + compare_abs: bool = False, ) -> None: """\ Filters out genes based on log fold change and fraction of genes expressing the diff --git a/scanpy/tools/_score_genes.py b/scanpy/tools/_score_genes.py index 303609e80a..3e31f0cc66 100644 --- a/scanpy/tools/_score_genes.py +++ b/scanpy/tools/_score_genes.py @@ -11,6 +11,7 @@ from scanpy._utils import _check_use_raw from .. import logging as logg +from .._compat import old_positionals if TYPE_CHECKING: from collections.abc import Sequence @@ -45,9 +46,13 @@ def _sparse_nanmean(X, axis): return m +@old_positionals( + "ctrl_size", "gene_pool", "n_bins", "score_name", "random_state", "copy", "use_raw" +) def score_genes( adata: AnnData, gene_list: Sequence[str], + *, ctrl_size: int = 50, gene_pool: Sequence[str] | None = None, n_bins: int = 25, @@ -198,8 +203,10 @@ def score_genes( return adata if copy else None +@old_positionals("s_genes", "g2m_genes", "copy") def score_genes_cell_cycle( adata: AnnData, + *, s_genes: Sequence[str], g2m_genes: Sequence[str], copy: bool = False, diff --git a/scanpy/tools/_sim.py b/scanpy/tools/_sim.py index 8d24172a61..d1ee6ac5b2 100644 --- a/scanpy/tools/_sim.py +++ b/scanpy/tools/_sim.py @@ -22,6 +22,7 @@ from .. import _utils, readwrite from .. import logging as logg +from .._compat import old_positionals from .._settings import settings if TYPE_CHECKING: @@ -30,8 +31,20 @@ from anndata import AnnData +@old_positionals( + "params_file", + "tmax", + "branching", + "nrRealizations", + "noiseObs", + "noiseDyn", + "step", + "seed", + "writedir", +) def sim( model: Literal["krumsiek11", "toggleswitch"], + *, params_file: bool = True, tmax: int | None = None, branching: bool | None = None, @@ -40,7 +53,7 @@ def sim( noiseDyn: float | None = None, step: int | None = None, seed: int | None = None, - writedir: str | Path | None = None, + writedir: Path | str | None = None, ) -> AnnData: """\ Simulate dynamic gene expression data [Wittmann09]_ [Wolf18]_. @@ -201,7 +214,7 @@ def sample_dynamic_data(**params): break logg.debug( f"mean nr of offdiagonal edges {nrOffEdges_list.mean()} " - f"compared to total nr {grnsim.dim*(grnsim.dim-1)/2.}" + f"compared to total nr {grnsim.dim * (grnsim.dim - 1) / 2.}" ) # more complex models @@ -275,6 +288,7 @@ def sample_dynamic_data(**params): def write_data( X, + *, dir=Path("sim/test"), append=False, header="", @@ -385,6 +399,7 @@ class GRNsim: def __init__( self, + *, dim=3, model="ex0", modelType="var", @@ -401,9 +416,8 @@ def __init__( either string for predefined model, or directory with a model file and a couple matrix files """ - self.dim = ( - dim if Coupl is None else Coupl.shape[0] - ) # number of nodes / dimension of system + # number of nodes / dimension of system + self.dim = dim if Coupl is None else Coupl.shape[0] self.maxnpar = 1 # maximal number of parents self.p_indep = 0.4 # fraction of independent genes self.model = model @@ -867,6 +881,7 @@ def process_rule(self, rule, pa, tuple): def write_data( self, X, + *, dir=Path("sim/test"), noiseObs=0.0, append=False, @@ -887,9 +902,9 @@ def write_data( # call helper function write_data( X, - dir, - append, - header, + dir=dir, + append=append, + header=header, varNames=self.varNames, Adj=self.Adj, Coupl=self.Coupl, diff --git a/scanpy/tools/_top_genes.py b/scanpy/tools/_top_genes.py index 96bee3efcf..316f9cb652 100644 --- a/scanpy/tools/_top_genes.py +++ b/scanpy/tools/_top_genes.py @@ -12,6 +12,7 @@ from sklearn import metrics from .. import logging as logg +from .._compat import old_positionals from .._utils import select_groups if TYPE_CHECKING: @@ -20,10 +21,12 @@ from anndata import AnnData +@old_positionals("group", "n_genes", "data", "method", "annotation_key") def correlation_matrix( adata: AnnData, name_list: Collection[str] | None = None, groupby: str | None = None, + *, group: int | None = None, n_genes: int = 20, data: Literal["Complete", "Group", "Rest"] = "Complete", diff --git a/scanpy/tools/_tsne.py b/scanpy/tools/_tsne.py index 9db096af65..4ed280ea0c 100644 --- a/scanpy/tools/_tsne.py +++ b/scanpy/tools/_tsne.py @@ -6,6 +6,7 @@ from packaging import version from .. import logging as logg +from .._compat import old_positionals from .._settings import settings from .._utils import AnyRandom, _doc_params from ..neighbors._doc import doc_n_pcs, doc_use_rep @@ -15,10 +16,21 @@ from anndata import AnnData +@old_positionals( + "use_rep", + "perplexity", + "early_exaggeration", + "learning_rate", + "random_state", + "use_fast_tsne", + "n_jobs", + "copy", +) @_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep) def tsne( adata: AnnData, n_pcs: int | None = None, + *, use_rep: str | None = None, perplexity: float | int = 30, early_exaggeration: float | int = 12, @@ -27,7 +39,6 @@ def tsne( use_fast_tsne: bool = False, n_jobs: int | None = None, copy: bool = False, - *, metric: str = "euclidean", ) -> AnnData | None: """\ diff --git a/scanpy/tools/_umap.py b/scanpy/tools/_umap.py index 5a97c4fc3c..6b18438e1b 100644 --- a/scanpy/tools/_umap.py +++ b/scanpy/tools/_umap.py @@ -8,6 +8,7 @@ from sklearn.utils import check_array, check_random_state from .. import logging as logg +from .._compat import old_positionals from .._settings import settings from .._utils import AnyRandom, NeighborsView from ._utils import _choose_representation, get_init_pos_from_paga @@ -18,8 +19,25 @@ _InitPos = Literal["paga", "spectral", "random"] +@old_positionals( + "min_dist", + "spread", + "n_components", + "maxiter", + "alpha", + "gamma", + "negative_sample_rate", + "init_pos", + "random_state", + "a", + "b", + "copy", + "method", + "neighbors_key", +) def umap( adata: AnnData, + *, min_dist: float = 0.5, spread: float = 1.0, n_components: int = 2, diff --git a/scanpy/tools/_utils_clustering.py b/scanpy/tools/_utils_clustering.py index d741d481d0..47f652fbdf 100644 --- a/scanpy/tools/_utils_clustering.py +++ b/scanpy/tools/_utils_clustering.py @@ -1,10 +1,27 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterable + + import numpy as np + import pandas as pd + from anndata import AnnData + from numpy.typing import NDArray + from scipy.sparse import spmatrix + def rename_groups( - adata, key_added, restrict_key, restrict_categories, restrict_indices, groups -): - key_added = restrict_key + "_R" if key_added is None else key_added + adata: AnnData, + restrict_key: str, + *, + key_added: str | None, + restrict_categories: Iterable[str], + restrict_indices: NDArray[np.bool_], + groups: NDArray, +) -> pd.Series[str]: + key_added = f"{restrict_key}_R" if key_added is None else key_added all_groups = adata.obs[restrict_key].astype("U") prefix = "-".join(restrict_categories) + "," new_groups = [prefix + g for g in groups.astype("U")] @@ -12,7 +29,13 @@ def rename_groups( return all_groups -def restrict_adjacency(adata, restrict_key, restrict_categories, adjacency): +def restrict_adjacency( + adata: AnnData, + restrict_key: str, + *, + restrict_categories: Iterable[str], + adjacency: spmatrix, +) -> tuple[spmatrix, NDArray[np.bool_]]: if not isinstance(restrict_categories[0], str): raise ValueError( "You need to use strings to label categories, " "e.g. '1' instead of 1."