Skip to content

Commit

Permalink
Fix set_transform_to_* methods in the Scene class
Browse files Browse the repository at this point in the history
  • Loading branch information
jp-dark committed Oct 18, 2024
1 parent 58ca2d0 commit 9b57b2f
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 60 deletions.
117 changes: 64 additions & 53 deletions apis/python/src/tiledbsoma/_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from somacore import (
CoordinateSpace,
CoordinateTransform,
IdentityTransform,
)

from . import _funcs, _tdb_handles
Expand Down Expand Up @@ -65,6 +64,34 @@ def __init__(
else:
self._coord_space = coordinate_space_from_json(coord_space)

def _check_transform_to_asset(
self,
transform: CoordinateTransform,
asset_coord_space: Optional[CoordinateSpace],
) -> None:
"""Raises an error if the scene coordinate space is not set or the
scene axis names do not match the input transform axis names.
"""
if self.coordinate_space is None:
raise SOMAError(
"The scene coordinate space must be set before setting a transform."
)
if transform.input_axes != self.coordinate_space.axis_names:
raise ValueError(
f"The name of the transform input axes, {transform.input_axes}, do "
f"not match the name of the axes in the scene coordinate space, "
f"{self.coordinate_space.axis_names}."
)
if (
asset_coord_space is not None
and transform.output_axes != asset_coord_space.axis_names
):
raise ValueError(
f"The name of the transform output axes, {transform.output_axes}, do "
f"not match the name of the axes in the provided coordinate space, "
f"{asset_coord_space.axis_names}."
)

def _open_subcollection(
self, subcollection: Union[str, Sequence[str]]
) -> CollectionBase[AnySOMAObject]:
Expand Down Expand Up @@ -320,35 +347,11 @@ def set_transform_to_multiscale_image(
Lifecycle: experimental
"""
if not isinstance(subcollection, str):
raise NotImplementedError()

# Check the transform matches this
if self.coordinate_space is None:
raise SOMAError(
"The scene coordinate space must be set before registering an image."
)
if transform.output_axes != self.coordinate_space.axis_names:
raise ValueError(
f"The name of the transform output axes, {transform.output_axes}, do "
f"not match the name of the axes in the scene coordinate space, "
f"{self.coordinate_space.axis_names}."
)

# Create the coordinate space if it does not exist. Otherwise, check it is
# compatible with the provide transform.
if coordinate_space is None:
if isinstance(transform, IdentityTransform):
coordinate_space = self.coordinate_space
else:
coordinate_space = CoordinateSpace.from_axis_names(transform.input_axes)
else:
if transform.input_axes != coordinate_space.axis_names:
raise ValueError(
f"The name of the transform input axes, {transform.input_axes}, do "
f"not match the name of the axes in the provided coordinate space, "
f"{coordinate_space.axis_names}."
)
# Check the transform is compatible with the oordinate spaces.
self._check_transform_to_asset(transform, coordinate_space)
assert (
self.coordinate_space is not None
) # Assert for typing - verified in the above method.

# Check asset exists in the specified location.
coll = self._open_subcollection(subcollection)
Expand All @@ -361,7 +364,19 @@ def set_transform_to_multiscale_image(
if not isinstance(image, MultiscaleImage):
raise TypeError(f"'{key}' in '{subcollection}' is not an MultiscaleImage.")

image.coordinate_space = coordinate_space
# Either set the new coordinate space or check the axes of the current
# coordinate space the multiscale image is defined on.
if coordinate_space is None:
if image.coordinate_space.axis_names != transform.output_axes:
raise ValueError(
f"The name of transform output axes, {transform.output_axes}, do"
f"not match the name of the axes in the multiscale image coordinate"
f" space, {image.coordinate_space.axis_names}."
)
else:
image.coordinate_space = coordinate_space

# Set the transform metadata and return the multisclae image.
coll.metadata[f"soma_scene_registry_{key}"] = transform_to_json(transform)
return image

Expand Down Expand Up @@ -399,27 +414,11 @@ def set_transform_to_point_cloud_dataframe(
Lifecycle: experimental
"""
if not isinstance(subcollection, str):
raise NotImplementedError()
if self.coordinate_space is None:
raise SOMAError(
"The scene coordinate space must be set before registering a point "
"cloud dataframe."
)
# Create the coordinate space if it does not exist. Otherwise, check it is
# compatible with the provide transform.
if coordinate_space is None:
if isinstance(transform, IdentityTransform):
coordinate_space = self.coordinate_space
else:
coordinate_space = CoordinateSpace.from_axis_names(transform.input_axes)
else:
if transform.input_axes != coordinate_space.axis_names:
raise ValueError(
f"The name of the transform input axes, {transform.input_axes}, do "
f"not match the name of the axes in the provided coordinate space, "
f"{coordinate_space.axis_names}."
)
# Check the transform is compatible with the Scene coordinate spaces.
self._check_transform_to_asset(transform, coordinate_space)
assert (
self.coordinate_space is not None
) # Assert for typing - verified in the above method.

# Check asset exists in the specified location.
try:
Expand All @@ -435,7 +434,19 @@ def set_transform_to_point_cloud_dataframe(
f"'{key}' in '{subcollection}' is not an PointCloudDataFrame."
)

point_cloud.coordinate_space = coordinate_space
# Either set the new coordinate space or check the axes of the current point
# cloud coordinate space.
if coordinate_space is None:
if point_cloud.coordinate_space.axis_names != transform.output_axes:
raise ValueError(
f"The name of transform output axes, {transform.output_axes}, do"
f"not match the name of the axes in the point cloud coordinate "
f"space, {point_cloud.coordinate_space.axis_names}."
)
else:
point_cloud.coordinate_space = coordinate_space

# Set the transform metadata and return the point cloud.
coll.metadata[f"soma_scene_registry_{key}"] = transform_to_json(transform)
return point_cloud

Expand Down
60 changes: 53 additions & 7 deletions apis/python/tests/test_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def scene(self, tmp_path_factory):
varl = scene.add_new_collection("varl")
varl.metadata["name"] = "varl"

rna = varl.add_new_collection("RNA") # type: ignore[attr-defined]
rna = varl.add_new_collection("RNA")
rna.metadata["name"] = "varl/RNA"

# Add a collection that is not part of the set data model.
Expand Down Expand Up @@ -245,15 +245,19 @@ def test_scene_point_cloud(tmp_path, coord_transform, transform_kwargs):
scene["obsl"] = soma.Collection.create(obsl_uri)

asch = pa.schema([("x", pa.float64()), ("y", pa.float64())])
coord_space = soma.CoordinateSpace([soma.Axis(name="x"), soma.Axis(name="y")])
coord_space = soma.CoordinateSpace(
[soma.Axis(name="x_scene"), soma.Axis(name="y_scene")]
)

# TODO Add transform directly to add_new_point_cloud
scene.add_new_point_cloud_dataframe(
"ptc", subcollection="obsl", transform=None, schema=asch
)

transform = coord_transform(
input_axes=("x", "y"), output_axes=("x", "y"), **transform_kwargs
input_axes=("x_scene", "y_scene"),
output_axes=("x", "y"),
**transform_kwargs,
)

# The scene coordinate space must be set before registering
Expand All @@ -266,6 +270,24 @@ def test_scene_point_cloud(tmp_path, coord_transform, transform_kwargs):
with pytest.raises(KeyError):
scene.set_transform_to_point_cloud_dataframe("bad", transform)

# Mismatched input axes.
transform_bad = coord_transform(
input_axes=("x", "y"),
output_axes=("x", "y"),
**transform_kwargs,
)
with pytest.raises(ValueError):
scene.set_transform_to_point_cloud_dataframe("ptc", transform_bad)

# Mismatched output axes.
transform_bad = coord_transform(
input_axes=("x_scene", "y_scene"),
output_axes=("x_scene", "y_scene"),
**transform_kwargs,
)
with pytest.raises(ValueError):
scene.set_transform_to_point_cloud_dataframe("ptc", transform_bad)

# Not a PointCloudDataFrame
scene["obsl"]["col"] = soma.Collection.create(urljoin(obsl_uri, "col"))
with pytest.raises(typeguard.TypeCheckError):
Expand Down Expand Up @@ -313,7 +335,9 @@ def test_scene_multiscale_image(tmp_path, coord_transform, transform_kwargs):
img_uri = urljoin(baseuri, "img")
scene["img"] = soma.Collection.create(img_uri)

coord_space = soma.CoordinateSpace([soma.Axis(name="x"), soma.Axis(name="y")])
coord_space = soma.CoordinateSpace(
[soma.Axis(name="x_scene"), soma.Axis(name="y_scene")]
)

# TODO Add transform directly to add_new_multiscale_image
scene.add_new_multiscale_image(
Expand All @@ -325,7 +349,7 @@ def test_scene_multiscale_image(tmp_path, coord_transform, transform_kwargs):
)

transform = coord_transform(
input_axes=("x", "y"),
input_axes=("x_scene", "y_scene"),
output_axes=("x", "y"),
**transform_kwargs,
)
Expand All @@ -349,6 +373,24 @@ def test_scene_multiscale_image(tmp_path, coord_transform, transform_kwargs):
with pytest.raises(typeguard.TypeCheckError):
scene.set_transform_to_multiscale_image("col", transform)

# Mismatched input axes.
transform_bad = coord_transform(
input_axes=("x", "y"),
output_axes=("x", "y"),
**transform_kwargs,
)
with pytest.raises(ValueError):
scene.set_transform_to_multiscale_image("msi", transform_bad)

# Mismatched output axes.
transform_bad = coord_transform(
input_axes=("x_scene", "y_scene"),
output_axes=("x_scene", "y_scene"),
**transform_kwargs,
)
with pytest.raises(ValueError):
scene.set_transform_to_multiscale_image("msi", transform_bad)

scene.set_transform_to_multiscale_image("msi", transform)

msi_transform = scene.get_transform_to_multiscale_image("msi")
Expand Down Expand Up @@ -387,13 +429,17 @@ def test_scene_geometry_dataframe(tmp_path, coord_transform, transform_kwargs):

gdf_uri = urljoin(obsl_uri, "gdf")
asch = pa.schema([("x", pa.float64()), ("y", pa.float64())])
coord_space = soma.CoordinateSpace([soma.Axis(name="x"), soma.Axis(name="y")])
coord_space = soma.CoordinateSpace(
[soma.Axis(name="x_scene"), soma.Axis(name="y_scene")]
)

# TODO replace with Scene.add_new_geometry_dataframe when implemented
scene["obsl"]["gdf"] = soma.GeometryDataFrame.create(gdf_uri, schema=asch)

transform = coord_transform(
input_axes=("x", "y"), output_axes=("x", "y"), **transform_kwargs
input_axes=("x_scene", "y_scene"),
output_axes=("x", "y"),
**transform_kwargs,
)

# The scene coordinate space must be set before registering
Expand Down

0 comments on commit 9b57b2f

Please sign in to comment.