Skip to content

Commit

Permalink
Merge pull request #10 from YosefLab/unique-names
Browse files Browse the repository at this point in the history
ensure names unique
  • Loading branch information
colganwi authored May 9, 2024
2 parents df51692 + 10c590c commit 30166a4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
13 changes: 12 additions & 1 deletion src/treedata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def _update_tree_labels(self):
mapping
)

def _check_uniqueness(self):
names = "Observation" if self.dim == "obs" else "Variable"
if not getattr(self.parent, self.dim).index.is_unique:
warnings.warn(
f"{names} names must be unique to store a tree. Calling `.{self.dim}_names_make_unique` to make them unique.",
stacklevel=2,
)
getattr(self.parent, self.dim).index = ad.utils.make_index_unique(getattr(self.parent, self.dim).index)

def copy(self):
d = AxisTrees(self.parent, self._axis)
for k, v in self.items():
Expand Down Expand Up @@ -122,7 +131,7 @@ def dim_names(self) -> pd.Index:
class AxisTrees(AxisTreesBase):
def __init__(
self,
parent: ad.AnnData,
parent: TreeData,
axis: int,
vals: Mapping | None = None,
):
Expand All @@ -141,6 +150,7 @@ def __getitem__(self, key: str) -> nx.DiGraph:
return nx.graphviews.generic_graph_view(self._data[key])

def __setitem__(self, key: str, value: nx.DiGraph):
self._check_uniqueness()
value, leaves = self._validate_tree(value, key)

for leaf in leaves:
Expand Down Expand Up @@ -194,6 +204,7 @@ def __getitem__(self, key: str) -> nx.DiGraph:
return subset_tree(self.parent_mapping[key], subset_leaves, asview=True)

def __setitem__(self, key: str, value: nx.DiGraph):
self._check_uniqueness()
value, _ = self._validate_tree(value, key) # Validate before mutating
warnings.warn(
f"Setting element `.{self.attrname}['{key}']` of view, initializing view as actual.", stacklevel=2
Expand Down
4 changes: 2 additions & 2 deletions src/treedata/_core/treedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def obst(self) -> AxisTrees:
return self._obst

@property
def vart(self):
def vart(self) -> AxisTrees:
"""Tree annotation of variables
Stores for each key a :class:`~networkx.DiGraph` with leaf nodes in
Expand Down Expand Up @@ -239,7 +239,7 @@ def obst(self, value):

@vart.setter
def vart(self, value):
vart = AxisTrees(self, 0, vals=dict(value))
vart = AxisTrees(self, 1, vals=dict(value))
self._vart = vart

def _gen_repr(self, n_obs, n_vars) -> str:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,13 @@ def test_transpose(adata, tree):
assert treedata.obst["tree"].nodes == treedata_transpose.vart["tree"].nodes
assert treedata_transpose.obst_keys() == []
assert np.array_equal(treedata.obs_names, treedata.T.obs_names)


@pytest.mark.parametrize("dim", ["obs", "var"])
def test_not_unique(X, tree, dim):
with pytest.warns(UserWarning):
tdata = td.TreeData(pd.DataFrame(X, index=["0", "1", "1"], columns=["0", "1", "1"]))
assert not getattr(tdata, f"{dim}_names").is_unique
with pytest.warns(UserWarning):
setattr(tdata, f"{dim}t", {"tree": tree})
assert getattr(tdata, f"{dim}_names").is_unique

0 comments on commit 30166a4

Please sign in to comment.