Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated h5ad and zarr format #31

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [Unrealeased]

### Added

### Changed

- Encoding of `treedata` attributes in h5ad and zarr files. `label`, `allow_overlap`, `obst`, and `vart` are now separate fields in the file. (#31)

### Fixed

- `TreeData` objects with `.raw` specified can now be read (#31)

## [0.0.4] - 2024-09-02

### Added
Expand Down
125 changes: 71 additions & 54 deletions src/treedata/_core/read.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,92 @@
from __future__ import annotations

import json
from collections.abc import MutableMapping, Sequence
from collections.abc import MutableMapping
from pathlib import Path
from typing import (
Literal,
)

import anndata as ad
import h5py
import networkx as nx
import zarr
from scipy import sparse

from treedata._core.aligned_mapping import AxisTrees
from treedata._core.treedata import TreeData
from treedata._utils import dict_to_digraph


def _tdata_from_adata(tdata, treedata_attrs=None) -> TreeData:
"""Create a TreeData object parsing attribute from AnnData uns field."""
tdata.__class__ = TreeData
def _dict_to_digraph(graph_dict: dict) -> nx.DiGraph:
"""Convert a dictionary to a networkx.DiGraph."""
G = nx.DiGraph()
# Add nodes and their attributes
for node, attrs in graph_dict["nodes"].items():
G.add_node(node, **attrs)
# Add edges and their attributes
for source, targets in graph_dict["edges"].items():
for target, attrs in targets.items():
G.add_edge(source, target, **attrs)
return G


def _parse_axis_trees(data: str) -> dict:
"""Parse AxisTrees from a string."""
return {k: _dict_to_digraph(v) for k, v in json.loads(data).items()}


def _parse_legacy(treedata_attrs: dict) -> dict:
"""Parse tree attributes from AnnData uns field."""
if treedata_attrs is not None:
tdata._tree_label = treedata_attrs["label"] if "label" in treedata_attrs.keys() else None
tdata._allow_overlap = bool(treedata_attrs["allow_overlap"])
tdata._obst = AxisTrees(tdata, 0, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["obst"].items()})
tdata._vart = AxisTrees(tdata, 1, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["vart"].items()})
for j in ["obst", "vart"]:
if j in treedata_attrs:
treedata_attrs[j] = {k: _dict_to_digraph(v) for k, v in treedata_attrs[j].items()}
treedata_attrs["allow_overlap"] = bool(treedata_attrs["allow_overlap"])
treedata_attrs["label"] = treedata_attrs["label"] if "label" in treedata_attrs.keys() else None
return treedata_attrs


def _read_raw(f, backed):
"""Read raw from file."""
d = {}
for k in ["obs", "var"]:
if f"raw/{k}" in f:
d[k] = ad.experimental.read_elem(f[f"raw/{k}"])
if not backed:
d["X"] = ad.experimental.read_elem(f["raw/X"])
return d


def _read_tdata(f, filename, backed) -> dict:
"""Read TreeData from file."""
d = {}
if backed is None:
backed = False
elif backed is True:
backed = "r"
# Read X if not backed
if not backed:
d["X"] = ad.experimental.read_elem(f["X"])
else:
tdata._tree_label = None
tdata._allow_overlap = False
tdata._obst = AxisTrees(tdata, 0)
tdata._vart = AxisTrees(tdata, 1)
return tdata
d.update({"filename": filename, "filemode": backed})
# Read standard elements
for k in ["obs", "var", "obsm", "varm", "obsp", "varp", "layers", "uns", "label", "allow_overlap"]:
if k in f:
d[k] = ad.experimental.read_elem(f[k])
# Read raw
if "raw" in f:
d["raw"] = _read_raw(f, backed)
# Read axis tree elements
for k in ["obst", "vart"]:
if k in f:
d[k] = _parse_axis_trees(ad.experimental.read_elem(f[k]))
# Read legacy treedata format
if "raw.treedata" in f:
d.update(_parse_legacy(json.loads(ad.experimental.read_elem(f["raw.treedata"]))))
return d


def read_h5ad(
filename: str | Path = None,
backed: Literal["r", "r+"] | bool | None = None,
*,
as_sparse: Sequence[str] = (),
as_sparse_fmt: type[sparse.spmatrix] = sparse.csr_matrix,
chunk_size: int = 6000,
) -> TreeData:
"""Read `.h5ad`-formatted hdf5 file.

Expand All @@ -52,33 +99,10 @@ def read_h5ad(
instead of fully loading it into memory (`memory` mode).
If you want to modify backed attributes of the TreeData object,
you need to choose `'r+'`.
as_sparse
If an array was saved as dense, passing its name here will read it as
a sparse_matrix, by chunk of size `chunk_size`.
as_sparse_fmt
Sparse format class to read elements from `as_sparse` in as.
chunk_size
Used only when loading sparse dataset that is stored as dense.
Loading iterates through chunks of the dataset of this row size
until it reads the whole dataset.
Higher size means higher memory consumption and higher (to a point)
loading speed.
"""
adata = ad.read_h5ad(
filename,
backed=backed,
as_sparse=as_sparse,
as_sparse_fmt=as_sparse_fmt,
chunk_size=chunk_size,
)
with h5py.File(filename, "r") as f:
if "raw.treedata" in f:
treedata_attrs = json.loads(f["raw.treedata"][()])
else:
treedata_attrs = None
tdata = _tdata_from_adata(adata, treedata_attrs)

return tdata
d = _read_tdata(f, filename, backed)
return TreeData(**d)


def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> TreeData:
Expand All @@ -89,13 +113,6 @@ def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> TreeData:
store
The filename, a :class:`~typing.MutableMapping`, or a Zarr storage class.
"""
adata = ad.read_zarr(store)

with zarr.open(store, mode="r") as f:
if "raw.treedata" in f:
treedata_attrs = json.loads(f["raw.treedata"][()])
else:
treedata_attrs = None
tdata = _tdata_from_adata(adata, treedata_attrs)

return tdata
d = _read_tdata(f, store, backed=False)
return TreeData(**d)
55 changes: 17 additions & 38 deletions src/treedata/_core/treedata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import json
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from copy import deepcopy
from pathlib import Path
Expand All @@ -11,17 +10,12 @@
)

import anndata as ad
import h5py
import networkx as nx
import numpy as np
import pandas as pd
import zarr
from anndata._core.index import _subset
from anndata._io import write_h5ad, write_zarr
from scipy import sparse

from treedata._utils import digraph_to_dict, make_serializable

from .aligned_mapping import (
AxisTrees,
)
Expand Down Expand Up @@ -181,8 +175,14 @@ def _init_as_actual(

# init from scratch
else:
self._tree_label = label
self._allow_overlap = allow_overlap
if isinstance(label, str) or label is None:
self._tree_label = label
else:
raise ValueError("label has to be a string or None")
if isinstance(allow_overlap, bool) or isinstance(allow_overlap, np.bool_):
self._allow_overlap = bool(allow_overlap)
else:
raise ValueError("allow_overlap has to be a boolean")
self._obst = AxisTrees(self, 0, vals=obst)
self._vart = AxisTrees(self, 1, vals=vart)

Expand Down Expand Up @@ -281,16 +281,6 @@ def to_adata(self) -> ad.AnnData:
"""Convert this TreeData object to an AnnData object."""
return ad.AnnData(self)

def _treedata_attrs(self) -> dict:
"""Dictionary of TreeData attributes"""
attrs = {
"obst": {k: digraph_to_dict(v) for k, v in self.obst.items()},
"vart": {k: digraph_to_dict(v) for k, v in self.vart.items()},
"label": self.label,
"allow_overlap": self.allow_overlap,
}
return make_serializable(attrs)

def _mutated_copy(self, **kwargs):
"""Creating TreeData with attributes optionally specified via kwargs."""
if self.isbacked:
Expand Down Expand Up @@ -366,7 +356,7 @@ def write_h5ad(
filename: PathLike | None = None,
compression: Literal["gzip", "lzf"] | None = None,
compression_opts: int | Any = None,
as_dense: Sequence[str] = (),
**kwargs,
):
"""Write `.h5ad`-formatted hdf5 file.

Expand All @@ -378,34 +368,26 @@ def write_h5ad(
[`lzf`, `gzip`], see the h5py :ref:`dataset_compression`.
compression_opts
[`lzf`, `gzip`], see the h5py :ref:`dataset_compression`.
as_dense
Sparse arrays in TreeData object to write as dense. Currently only
supports `X` and `raw/X`.
"""
from .write import write_h5ad

if filename is None and not self.isbacked:
raise ValueError("Provide a filename!")
if filename is None:
filename = self.filename

write_h5ad(
Path(filename),
self,
compression=compression,
compression_opts=compression_opts,
as_dense=as_dense,
)
write_h5ad(Path(filename), self, compression=compression, compression_opts=compression_opts)

with h5py.File(filename, "a") as f:
if "raw.treedata" in f:
del f["raw.treedata"]
f.create_dataset("raw.treedata", data=json.dumps(self._treedata_attrs()))
if self.isbacked:
self.file.filename = filename

write = write_h5ad # a shortcut and backwards compat

def write_zarr(
self,
store: MutableMapping | PathLike,
chunks: bool | int | tuple[int, ...] | None = None,
**kwargs,
):
"""Write a hierarchical Zarr array store.

Expand All @@ -416,12 +398,9 @@ def write_zarr(
chunks
Chunk shape.
"""
write_zarr(store, self.to_adata(), chunks=chunks)
from .write import write_zarr

with zarr.open(store, mode="a") as f:
if "treedata" in f:
del f["raw.treedata"]
f.create_dataset("raw.treedata", data=json.dumps(self._treedata_attrs()))
write_zarr(Path(store), self, chunks=chunks)

def to_memory(self, copy=False) -> TreeData:
"""Return a new AnnData object with all backed arrays loaded into memory.
Expand Down
Loading
Loading