Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docstrings for last notebooks and miscellaneous docstrings #40

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/cedalion/dataclasses/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ class Recording:
It maps to the NirsElement in the snirf format but it also holds additional
attributes (masks, headmodel, aux_obj) for which there is no corresponding
entity in the snirf format.

Attributes:
timeseries (OrderedDict[str, NDTimeSeries]): A dictionary of timeseries objects.
The keys are the names of the timeseries.
masks (OrderedDict[str, xr.DataArray]): A dictionary of masks. The keys are the
names of the masks.
geo3d (LabeledPointCloud): A labeled point cloud representing the 3D geometry of
the recording.
geo2d (LabeledPointCloud): A labeled point cloud representing the 2D geometry of
the recording.
stim (pd.DataFrame): A dataframe containing the stimulus information.
aux_ts (OrderedDict[str, NDTimeSeries]): A dictionary of auxiliary timeseries
objects.
aux_obj (OrderedDict[str, Any]): A dictionary of auxiliary objects.
head_model (Optional[Any]): A head model object.
meta_data (OrderedDict[str, Any]): A dictionary of meta data.
"""

timeseries: OrderedDict[str, NDTimeSeries] = field(default_factory=OrderedDict)
Expand All @@ -37,6 +53,7 @@ class Recording:
)

def __repr__(self):
"""Return a string representation of the Recording object."""
return (
f"<Recording | "
f" timeseries: {list(self.timeseries.keys())}, "
Expand All @@ -47,6 +64,15 @@ def __repr__(self):
)

def get_timeseries(self, key: Optional[str] = None) -> NDTimeSeries:
"""Get a timeseries object by key.

Args:
key (Optional[str]): The key of the timeseries to retrieve. If None, the
last timeseries is returned.

Returns:
NDTimeSeries: The requested timeseries object.
"""
if not self.timeseries:
raise ValueError("timeseries dict is empty.")

Expand All @@ -72,6 +98,15 @@ def set_timeseries(self, key: str, value: NDTimeSeries, overwrite: bool = False)
self.timeseries[key] = value

def get_mask(self, key: Optional[str] = None) -> xr.DataArray:
"""Get a mask by key.

Args:
key (Optional[str]): The key of the mask to retrieve. If None, the last
mask is returned.

Returns:
xr.DataArray: The requested mask.
"""
if not self.masks:
raise ValueError("masks dict is empty.")

Expand All @@ -83,12 +118,28 @@ def get_mask(self, key: Optional[str] = None) -> xr.DataArray:
return self.masks[last_key]

def set_mask(self, key: str, value: xr.DataArray, overwrite: bool = False):
"""Set a mask.

Args:
key (str): The key of the mask to set.
value (xr.DataArray): The mask to set.
overwrite (bool): Whether to overwrite an existing mask with the same key.
Defaults to False.
"""
if (overwrite is False) and (key in self.masks):
raise ValueError(f"a mask with key '{key}' already exists!")

self.masks[key] = value

def get_timeseries_type(self, key):
"""Get the type of a timeseries.

Args:
key (str): The key of the timeseries.

Returns:
str: The type of the timeseries.
"""
if key not in self.timeseries:
raise KeyError(f"unknown timeseries '{key}'")

Expand All @@ -109,13 +160,23 @@ def get_timeseries_type(self, key):

@property
def source_labels(self):
"""Get the unique source labels from the timeseries.

Returns:
list: A list of unique source labels.
"""
labels = [
ts.source.values for ts in self.timeseries.values() if "source" in ts.coords
]
return list(np.unique(np.hstack(labels)))

@property
def detector_labels(self):
"""Get the unique detector labels from the timeseries.

Returns:
list: A list of unique detector labels.
"""
labels = [
ts.detector.values
for ts in self.timeseries.values()
Expand All @@ -125,6 +186,11 @@ def detector_labels(self):

@property
def wavelengths(self):
"""Get the unique wavelengths from the timeseries.

Returns:
list: A list of unique wavelengths.
"""
wl = [
ts.wavelength.values
for ts in self.timeseries.values()
Expand All @@ -134,4 +200,9 @@ def wavelengths(self):

@property
def trial_types(self):
"""Get the unique trial types from the stimulus dataframe.

Returns:
list: A list of unique trial types.
"""
return list(self.stim["trial_type"].drop_duplicates())
15 changes: 15 additions & 0 deletions src/cedalion/dataclasses/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,21 @@ def build_labeled_points(
labels: Optional[list[str]] = None,
types: Optional[list[str]] = None,
):
"""Build a labeled point cloud data array.

Args:
coordinates (ArrayLike, optional): The coordinates of the points. Defaults to None.
crs (str, optional): The coordinate system. Defaults to "pos".
units (Optional[pint.Unit | str], optional): The units of the coordinates.
Defaults to "1".
labels (Optional[list[str]], optional): The labels of the points. Defaults to
None.
types (Optional[list[str]], optional): The types of the points. Defaults to
None.

Returns:
xr.DataArray: The labeled point cloud data array.
"""
if coordinates is None:
coordinates = np.zeros((0, 3), dtype=float)
else:
Expand Down
31 changes: 24 additions & 7 deletions src/cedalion/geometry/landmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,21 @@ def _intersect_mesh_with_triangle(
class LandmarksBuilder1010:
"""Construct the 10-10-system on scalp surface based on :cite:t:`Oostenveld2001`.

Args:
scalp_surface: a triangle-mesh representing the scalp
landmarks: positions of "Nz", "Iz", "LPA", "RPA"

Attributes:
scalp_surface (Surface): a triangle-mesh representing the scalp
landmarks_mm (LabeledPointCloud): positions of all 10-10 landmarks in mm
vtk_mesh (vtk.vtkPolyData): the scalp surface as a VTK mesh
lines (List[np.ndarray]): points along the lines connecting the landmarks
"""

@validate_schemas
def __init__(self, scalp_surface: Surface, landmarks: LabeledPointCloud):
"""Initialize the LandmarksBuilder1010.

Args:
scalp_surface (Surface): a triangle-mesh representing the scalp
landmarks (LabeledPointCloud): positions of "Nz", "Iz", "LPA", "RPA"
"""
if isinstance(scalp_surface, TrimeshSurface):
scalp_surface = VTKSurface.from_trimeshsurface(scalp_surface)

Expand All @@ -144,6 +151,7 @@ def _estimate_cranial_vertex_by_height(self):
return highest_vertices.mean(axis=0)

def _estimate_cranial_vertex_from_lines(self):
"""Estimate the cranial vertex by intersecting lines through the head."""
if "Cz" in self.landmarks_mm.label:
cz1 = self.landmarks_mm.loc["Cz"].values
# FIXME remove Cz from landmarks
Expand Down Expand Up @@ -184,6 +192,14 @@ def _estimate_cranial_vertex_from_lines(self):
def _add_landmarks_along_line(
self, triangle_labels: List[str], labels: List[str], dists: List[float]
):
"""Add landmarks along a line defined by three landmarks.

Args:
triangle_labels (List[str]): Labels of the three landmarks defining the line
labels (List[str]): Labels for the new landmarks
dists (List[float]): Distances along the line where the new landmarks should
be placed.
"""
assert len(triangle_labels) == 3
assert len(labels) == len(dists)
assert all([label in self.landmarks_mm.label for label in triangle_labels])
Expand Down Expand Up @@ -213,6 +229,7 @@ def _add_landmarks_along_line(
self.lines.append(points)

def build(self):
"""Construct the 10-10-system on the scalp surface."""
warnings.warn("WIP: distance calculation around ears")

cz = self._estimate_cranial_vertex_from_lines()
Expand Down Expand Up @@ -300,11 +317,11 @@ def order_ref_points_6(landmarks: xr.DataArray, twoPoints: str) -> xr.DataArray:
"""Reorder a set of six landmarks based on spatial relationships and give labels.

Args:
landmarks: coordinates for six landmark points
twoPoints: two reference points ('Nz' or 'Iz') for orientation.
landmarks (xr.DataArray): coordinates for six landmark points
twoPoints (str): two reference points ('Nz' or 'Iz') for orientation.

Returns:
the landmarks ordered as "Nz", "Iz", "RPA", "LPA", "Cz"
xr.DataArray: the landmarks ordered as "Nz", "Iz", "RPA", "LPA", "Cz"
"""

# Validate input parameters
Expand Down
47 changes: 39 additions & 8 deletions src/cedalion/geometry/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def register_trans_rot(
between the two point clouds.

Args:
coords_target: Target point cloud.
coords_trafo: Source point cloud.
coords_target (LabeledPointCloud): Target point cloud.
coords_trafo (LabeledPointCloud): Source point cloud.

Returns:
cdt.AffineTransform: Affine transformation between the two point clouds.
Expand Down Expand Up @@ -109,6 +109,15 @@ def loss(params, coords_target, coords_trafo):


def _std_distance_to_cog(points: cdt.LabeledPointCloud):
"""Calculate the standard deviation of the distances to the center of gravity.

Args:
points: Point cloud for which to calculate the standard deviation of the
distances to the center of gravity.

Returns:
float: Standard deviation of the distances to the center of gravity.
"""
dists = xrutils.norm(points - points.mean("label"), points.points.crs)
return dists.std("label").item()

Expand All @@ -124,8 +133,8 @@ def register_trans_rot_isoscale(
between the two point clouds.

Args:
coords_target: Target point cloud.
coords_trafo: Source point cloud.
coords_target (LabeledPointCloud): Target point cloud.
coords_trafo (LabeledPointCloud): Source point cloud.

Returns:
cdt.AffineTransform: Affine transformation between the two point clouds.
Expand Down Expand Up @@ -186,10 +195,10 @@ def gen_xform_from_pts(p1: np.ndarray, p2: np.ndarray) -> np.ndarray:
"""Calculate the affine transformation matrix T that transforms p1 to p2.

Args:
p1: Source points (p x m) where p is the number of points and m is the number of
dimensions.
p2: Target points (p x m) where p is the number of points and m is the number of
dimensions.
p1 (np.ndarray): Source points (p x m) where p is the number of points and m is
the number of dimensions.
p2 (np.ndarray): Target points (p x m) where p is the number of points and m is
the number of dimensions.

Returns:
Affine transformation matrix T.
Expand Down Expand Up @@ -231,6 +240,19 @@ def register_icp(
niterations=1000,
random_sample_fraction=0.5,
):
"""Iterative Closest Point algorithm for registration.

Args:
surface (Surface): Surface mesh to which to register the points.
landmarks (LabeledPointCloud): Landmarks to use for registration.
geo3d (LabeledPointCloud): Points to register to the surface.
niterations (int): Number of iterations for the ICP algorithm (default 1000).
random_sample_fraction (float): Fraction of points to use in each iteration
(default 0.5).

Returns:
Tuple[np.ndarray, np.ndarray]: Tuple containing the losses and transformations
"""
units = "mm"
landmarks_mm = landmarks.pint.to(units).points.to_homogeneous().pint.dequantify()
geo3d_mm = geo3d.pint.to(units).points.to_homogeneous().pint.dequantify()
Expand Down Expand Up @@ -470,6 +492,15 @@ def find_spread_points(points_xr : xr.DataArray) -> np.ndarray:


def simple_scalp_projection(geo3d : cdt.LabeledPointCloud) -> cdt.LabeledPointCloud:
"""Projects 3D coordinates onto a 2D plane using a simple scalp projection.

Args:
geo3d (LabeledPointCloud): 3D coordinates of points to project. Requires the
landmarks Nz, LPA, and RPA.

Returns:
A LabeledPointCloud containing the 2D coordinates of the projected points.
"""
for label in ["LPA", "RPA", "Nz"]:
if label not in geo3d.label:
raise ValueError("this projection needs the landmarks Nz, LPA and RPA.")
Expand Down
9 changes: 9 additions & 0 deletions src/cedalion/geometry/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ def surface_from_segmentation(


def cell_coordinates(volume, flat: bool = False):
"""Create a DataArray with the coordinates of the cells in a volume.

Args:
volume (xr.DataArray): The volume to get the cell coordinates from.
flat (bool): Whether to flatten the coordinates.

Returns:
xr.DataArray: A DataArray with the coordinates of the cells in the volume.
"""
# coordinates in voxel space
i = np.arange(volume.shape[0])
j = np.arange(volume.shape[1])
Expand Down
Loading