diff --git a/src/cedalion/dataclasses/recording.py b/src/cedalion/dataclasses/recording.py index 8b07fa9..7105088 100644 --- a/src/cedalion/dataclasses/recording.py +++ b/src/cedalion/dataclasses/recording.py @@ -19,6 +19,22 @@ class Recording: It maps to the NirsElement in the snirf format but it also holds additional attributes (masks, headmodel, aux_obj) for which there is no corresponding entity in the snirf format. + + Attributes: + timeseries (OrderedDict[str, NDTimeSeries]): A dictionary of timeseries objects. + The keys are the names of the timeseries. + masks (OrderedDict[str, xr.DataArray]): A dictionary of masks. The keys are the + names of the masks. + geo3d (LabeledPointCloud): A labeled point cloud representing the 3D geometry of + the recording. + geo2d (LabeledPointCloud): A labeled point cloud representing the 2D geometry of + the recording. + stim (pd.DataFrame): A dataframe containing the stimulus information. + aux_ts (OrderedDict[str, NDTimeSeries]): A dictionary of auxiliary timeseries + objects. + aux_obj (OrderedDict[str, Any]): A dictionary of auxiliary objects. + head_model (Optional[Any]): A head model object. + meta_data (OrderedDict[str, Any]): A dictionary of meta data. """ timeseries: OrderedDict[str, NDTimeSeries] = field(default_factory=OrderedDict) @@ -37,6 +53,7 @@ class Recording: ) def __repr__(self): + """Return a string representation of the Recording object.""" return ( f" NDTimeSeries: + """Get a timeseries object by key. + + Args: + key (Optional[str]): The key of the timeseries to retrieve. If None, the + last timeseries is returned. + + Returns: + NDTimeSeries: The requested timeseries object. + """ if not self.timeseries: raise ValueError("timeseries dict is empty.") @@ -72,6 +98,15 @@ def set_timeseries(self, key: str, value: NDTimeSeries, overwrite: bool = False) self.timeseries[key] = value def get_mask(self, key: Optional[str] = None) -> xr.DataArray: + """Get a mask by key. + + Args: + key (Optional[str]): The key of the mask to retrieve. If None, the last + mask is returned. + + Returns: + xr.DataArray: The requested mask. + """ if not self.masks: raise ValueError("masks dict is empty.") @@ -83,12 +118,28 @@ def get_mask(self, key: Optional[str] = None) -> xr.DataArray: return self.masks[last_key] def set_mask(self, key: str, value: xr.DataArray, overwrite: bool = False): + """Set a mask. + + Args: + key (str): The key of the mask to set. + value (xr.DataArray): The mask to set. + overwrite (bool): Whether to overwrite an existing mask with the same key. + Defaults to False. + """ if (overwrite is False) and (key in self.masks): raise ValueError(f"a mask with key '{key}' already exists!") self.masks[key] = value def get_timeseries_type(self, key): + """Get the type of a timeseries. + + Args: + key (str): The key of the timeseries. + + Returns: + str: The type of the timeseries. + """ if key not in self.timeseries: raise KeyError(f"unknown timeseries '{key}'") @@ -109,6 +160,11 @@ def get_timeseries_type(self, key): @property def source_labels(self): + """Get the unique source labels from the timeseries. + + Returns: + list: A list of unique source labels. + """ labels = [ ts.source.values for ts in self.timeseries.values() if "source" in ts.coords ] @@ -116,6 +172,11 @@ def source_labels(self): @property def detector_labels(self): + """Get the unique detector labels from the timeseries. + + Returns: + list: A list of unique detector labels. + """ labels = [ ts.detector.values for ts in self.timeseries.values() @@ -125,6 +186,11 @@ def detector_labels(self): @property def wavelengths(self): + """Get the unique wavelengths from the timeseries. + + Returns: + list: A list of unique wavelengths. + """ wl = [ ts.wavelength.values for ts in self.timeseries.values() @@ -134,4 +200,9 @@ def wavelengths(self): @property def trial_types(self): + """Get the unique trial types from the stimulus dataframe. + + Returns: + list: A list of unique trial types. + """ return list(self.stim["trial_type"].drop_duplicates()) diff --git a/src/cedalion/dataclasses/schemas.py b/src/cedalion/dataclasses/schemas.py index ef58844..db6dd83 100644 --- a/src/cedalion/dataclasses/schemas.py +++ b/src/cedalion/dataclasses/schemas.py @@ -145,6 +145,21 @@ def build_labeled_points( labels: Optional[list[str]] = None, types: Optional[list[str]] = None, ): + """Build a labeled point cloud data array. + + Args: + coordinates (ArrayLike, optional): The coordinates of the points. Defaults to None. + crs (str, optional): The coordinate system. Defaults to "pos". + units (Optional[pint.Unit | str], optional): The units of the coordinates. + Defaults to "1". + labels (Optional[list[str]], optional): The labels of the points. Defaults to + None. + types (Optional[list[str]], optional): The types of the points. Defaults to + None. + + Returns: + xr.DataArray: The labeled point cloud data array. + """ if coordinates is None: coordinates = np.zeros((0, 3), dtype=float) else: diff --git a/src/cedalion/geometry/landmarks.py b/src/cedalion/geometry/landmarks.py index e7366d2..b20d7cb 100644 --- a/src/cedalion/geometry/landmarks.py +++ b/src/cedalion/geometry/landmarks.py @@ -111,14 +111,21 @@ def _intersect_mesh_with_triangle( class LandmarksBuilder1010: """Construct the 10-10-system on scalp surface based on :cite:t:`Oostenveld2001`. - Args: - scalp_surface: a triangle-mesh representing the scalp - landmarks: positions of "Nz", "Iz", "LPA", "RPA" - + Attributes: + scalp_surface (Surface): a triangle-mesh representing the scalp + landmarks_mm (LabeledPointCloud): positions of all 10-10 landmarks in mm + vtk_mesh (vtk.vtkPolyData): the scalp surface as a VTK mesh + lines (List[np.ndarray]): points along the lines connecting the landmarks """ @validate_schemas def __init__(self, scalp_surface: Surface, landmarks: LabeledPointCloud): + """Initialize the LandmarksBuilder1010. + + Args: + scalp_surface (Surface): a triangle-mesh representing the scalp + landmarks (LabeledPointCloud): positions of "Nz", "Iz", "LPA", "RPA" + """ if isinstance(scalp_surface, TrimeshSurface): scalp_surface = VTKSurface.from_trimeshsurface(scalp_surface) @@ -144,6 +151,7 @@ def _estimate_cranial_vertex_by_height(self): return highest_vertices.mean(axis=0) def _estimate_cranial_vertex_from_lines(self): + """Estimate the cranial vertex by intersecting lines through the head.""" if "Cz" in self.landmarks_mm.label: cz1 = self.landmarks_mm.loc["Cz"].values # FIXME remove Cz from landmarks @@ -184,6 +192,14 @@ def _estimate_cranial_vertex_from_lines(self): def _add_landmarks_along_line( self, triangle_labels: List[str], labels: List[str], dists: List[float] ): + """Add landmarks along a line defined by three landmarks. + + Args: + triangle_labels (List[str]): Labels of the three landmarks defining the line + labels (List[str]): Labels for the new landmarks + dists (List[float]): Distances along the line where the new landmarks should + be placed. + """ assert len(triangle_labels) == 3 assert len(labels) == len(dists) assert all([label in self.landmarks_mm.label for label in triangle_labels]) @@ -213,6 +229,7 @@ def _add_landmarks_along_line( self.lines.append(points) def build(self): + """Construct the 10-10-system on the scalp surface.""" warnings.warn("WIP: distance calculation around ears") cz = self._estimate_cranial_vertex_from_lines() @@ -300,11 +317,11 @@ def order_ref_points_6(landmarks: xr.DataArray, twoPoints: str) -> xr.DataArray: """Reorder a set of six landmarks based on spatial relationships and give labels. Args: - landmarks: coordinates for six landmark points - twoPoints: two reference points ('Nz' or 'Iz') for orientation. + landmarks (xr.DataArray): coordinates for six landmark points + twoPoints (str): two reference points ('Nz' or 'Iz') for orientation. Returns: - the landmarks ordered as "Nz", "Iz", "RPA", "LPA", "Cz" + xr.DataArray: the landmarks ordered as "Nz", "Iz", "RPA", "LPA", "Cz" """ # Validate input parameters diff --git a/src/cedalion/geometry/registration.py b/src/cedalion/geometry/registration.py index 01ad9a8..03d6648 100644 --- a/src/cedalion/geometry/registration.py +++ b/src/cedalion/geometry/registration.py @@ -38,8 +38,8 @@ def register_trans_rot( between the two point clouds. Args: - coords_target: Target point cloud. - coords_trafo: Source point cloud. + coords_target (LabeledPointCloud): Target point cloud. + coords_trafo (LabeledPointCloud): Source point cloud. Returns: cdt.AffineTransform: Affine transformation between the two point clouds. @@ -109,6 +109,15 @@ def loss(params, coords_target, coords_trafo): def _std_distance_to_cog(points: cdt.LabeledPointCloud): + """Calculate the standard deviation of the distances to the center of gravity. + + Args: + points: Point cloud for which to calculate the standard deviation of the + distances to the center of gravity. + + Returns: + float: Standard deviation of the distances to the center of gravity. + """ dists = xrutils.norm(points - points.mean("label"), points.points.crs) return dists.std("label").item() @@ -124,8 +133,8 @@ def register_trans_rot_isoscale( between the two point clouds. Args: - coords_target: Target point cloud. - coords_trafo: Source point cloud. + coords_target (LabeledPointCloud): Target point cloud. + coords_trafo (LabeledPointCloud): Source point cloud. Returns: cdt.AffineTransform: Affine transformation between the two point clouds. @@ -186,10 +195,10 @@ def gen_xform_from_pts(p1: np.ndarray, p2: np.ndarray) -> np.ndarray: """Calculate the affine transformation matrix T that transforms p1 to p2. Args: - p1: Source points (p x m) where p is the number of points and m is the number of - dimensions. - p2: Target points (p x m) where p is the number of points and m is the number of - dimensions. + p1 (np.ndarray): Source points (p x m) where p is the number of points and m is + the number of dimensions. + p2 (np.ndarray): Target points (p x m) where p is the number of points and m is + the number of dimensions. Returns: Affine transformation matrix T. @@ -231,6 +240,19 @@ def register_icp( niterations=1000, random_sample_fraction=0.5, ): + """Iterative Closest Point algorithm for registration. + + Args: + surface (Surface): Surface mesh to which to register the points. + landmarks (LabeledPointCloud): Landmarks to use for registration. + geo3d (LabeledPointCloud): Points to register to the surface. + niterations (int): Number of iterations for the ICP algorithm (default 1000). + random_sample_fraction (float): Fraction of points to use in each iteration + (default 0.5). + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple containing the losses and transformations + """ units = "mm" landmarks_mm = landmarks.pint.to(units).points.to_homogeneous().pint.dequantify() geo3d_mm = geo3d.pint.to(units).points.to_homogeneous().pint.dequantify() @@ -470,6 +492,15 @@ def find_spread_points(points_xr : xr.DataArray) -> np.ndarray: def simple_scalp_projection(geo3d : cdt.LabeledPointCloud) -> cdt.LabeledPointCloud: + """Projects 3D coordinates onto a 2D plane using a simple scalp projection. + + Args: + geo3d (LabeledPointCloud): 3D coordinates of points to project. Requires the + landmarks Nz, LPA, and RPA. + + Returns: + A LabeledPointCloud containing the 2D coordinates of the projected points. + """ for label in ["LPA", "RPA", "Nz"]: if label not in geo3d.label: raise ValueError("this projection needs the landmarks Nz, LPA and RPA.") diff --git a/src/cedalion/geometry/segmentation.py b/src/cedalion/geometry/segmentation.py index a9d9ac9..ef91962 100644 --- a/src/cedalion/geometry/segmentation.py +++ b/src/cedalion/geometry/segmentation.py @@ -50,6 +50,15 @@ def surface_from_segmentation( def cell_coordinates(volume, flat: bool = False): + """Create a DataArray with the coordinates of the cells in a volume. + + Args: + volume (xr.DataArray): The volume to get the cell coordinates from. + flat (bool): Whether to flatten the coordinates. + + Returns: + xr.DataArray: A DataArray with the coordinates of the cells in the volume. + """ # coordinates in voxel space i = np.arange(volume.shape[0]) j = np.arange(volume.shape[1]) diff --git a/src/cedalion/imagereco/forward_model.py b/src/cedalion/imagereco/forward_model.py index 3daaba9..006417d 100644 --- a/src/cedalion/imagereco/forward_model.py +++ b/src/cedalion/imagereco/forward_model.py @@ -110,6 +110,8 @@ def from_segmentation( the scalp surface. smoothing(float): Smoothing factor for the brain and scalp surfaces. brain_face_count (Optional[int]): Number of faces for the brain surface. + scalp_face_count (Optional[int]): Number of faces for the scalp surface. + fill_holes (bool): Whether to fill holes in the segmentation masks. """ # load segmentation mask @@ -372,7 +374,7 @@ def save(self, foldername: str): """Save the head model to a folder. Args: - foldername : Folder to save the head model into. + foldername (str): Folder to save the head model into. Returns: None @@ -407,10 +409,10 @@ def load(cls, foldername: str): """Load the head model from a folder. Args: - foldername : Folder to load the head model from. + foldername (str): Folder to load the head model from. Returns: - Loaded head model. + TwoSurfaceHeadModel: Loaded head model. """ # Check if all files exist @@ -531,7 +533,7 @@ def snap_to_scalp_voxels( vec = np.zeros(self.scalp.nvertices) vec[idx[0,0]] = 1 voxel_idx = np.argwhere(self.voxel_to_vertex_scalp @ vec == 1)[:,0] - + if len(voxel_idx) > 0: # Get voxel coordinates from voxel indices try: @@ -572,26 +574,20 @@ class ForwardModel: ... Args: - head_model : TwoSurfaceHeadModel - Head model containing voxel projections to brain and scalp surfaces. - optode_pos : cdt.LabeledPointCloud - Optode positions. - optode_dir : xr.DataArray - Optode orientations (directions of light beams). - tissue_properties : xr.DataArray - Tissue properties for each tissue type. - volume : xr.DataArray - Voxelated head volume from segmentation masks. - unitinmm : float - Unit of head model, optodes expressed in mm. - measurement_list : pd.DataFrame - List of measurements of experiment with source, detector, channel and - wavelength. + head_model (TwoSurfaceHeadModel): Head model containing voxel projections to brain + and scalp surfaces. + optode_pos (cdt.LabeledPointCloud): Optode positions. + optode_dir (xr.DataArray): Optode orientations (directions of light beams). + tissue_properties (xr.DataArray): Tissue properties for each tissue type. + volume (xr.DataArray): Voxelated head volume from segmentation masks. + unitinmm (float): Unit of head model, optodes expressed in mm. + measurement_list (pd.DataFrame): List of measurements of experiment with source, + detector, channel, and wavelength. Methods: - compute_fluence(nphoton) + compute_fluence(nphoton): Compute fluence for each channel and wavelength from photon simulation. - compute_sensitivity(fluence_all, fluence_at_optodes) + compute_sensitivity(fluence_all, fluence_at_optodes): Compute sensitivity matrix from fluence. """ @@ -655,14 +651,11 @@ def _get_fluence_from_mcx(self, i_optode: int, nphoton: int): """Run MCX simulation to get fluence for one optode. Args: - i_optode : int - Index of the optode. - nphoton : int - Number of photons to simulate. + i_optode (int): Index of the optode. + nphoton (int): Number of photons to simulate. Returns: - np.ndarray - Fluence in each voxel. + np.ndarray: Fluence in each voxel. """ cfg = { @@ -694,14 +687,11 @@ def _fluence_at_optodes(self, fluence, emitting_opt): """Fluence caused by one optode at the positions of all other optodes. Args: - fluence : np.ndarray - Fluence in each voxel. - emitting_opt : int - Index of the emitting optode. + fluence (np.ndarray): Fluence in each voxel. + emitting_opt (int): Index of the emitting optode. Returns: - np.ndarray - Fluence at all optode positions. + np.ndarray: Fluence at all optode positions. """ n_optodes = len(self.optode_pos) @@ -736,12 +726,10 @@ def compute_fluence_mcx(self, nphoton: int = 1e8): """Compute fluence for each channel and wavelength using MCX package. Args: - nphoton : int - Number of photons to simulate. + nphoton (int): Number of photons to simulate. Returns: - xr.DataArray - Fluence in each voxel for each channel and wavelength. + xr.DataArray: Fluence in each voxel for each channel and wavelength. References: (:cite:t:`Fang2009`) Qianqian Fang and David A. Boas, "Monte Carlo @@ -823,12 +811,11 @@ def compute_fluence_nirfaster( """Compute fluence for each channel and wavelength using NIRFASTer package. Args: - meshingparam : ff.utils.MeshingParam - Parameters to be used by the CGAL mesher. Note:they should all be double + meshingparam (ff.utils.MeshingParam) Parameters to be used by the CGAL + mesher. Note: they should all be double Returns: - xr.DataArray - Fluence in each voxel for each channel and wavelength. + xr.DataArray: Fluence in each voxel for each channel and wavelength. References: (:cite:t:`Dehghani2009`) Dehghani, Hamid, et al. "Near infrared optical @@ -1015,14 +1002,12 @@ def compute_sensitivity(self, fluence_all, fluence_at_optodes): """Compute sensitivity matrix from fluence. Args: - fluence_all : xr.DataArray - Fluence in each voxel for each wavelength. - fluence_at_optodes : xr.DataArray - Fluence at all optode positions for each wavelength. + fluence_all (xr.DataArray): Fluence in each voxel for each wavelength. + fluence_at_optodes (xr.DataArray): Fluence at all optode positions for each + wavelength. Returns: - xr.DataArray - Sensitivity matrix for each channel, vertex and wavelength. + xr.DataArray: Sensitivity matrix for each channel, vertex and wavelength. """ channels = self.measurement_list.channel.unique().tolist() diff --git a/src/cedalion/imagereco/utils.py b/src/cedalion/imagereco/utils.py index f4f8433..a31ba60 100644 --- a/src/cedalion/imagereco/utils.py +++ b/src/cedalion/imagereco/utils.py @@ -17,7 +17,19 @@ def map_segmentation_mask_to_surface( transform_vox2ras: cdt.AffineTransform, # FIXME surface: cdc.Surface, ): - """Find for each voxel the closest vertex on the surface.""" + """Find for each voxel the closest vertex on the surface. + + Args: + segmentation_mask (xr.DataArray): A binary mask of shape (segmentation_type, i, + j, k). + transform_vox2ras (xr.DataArray): The affine transformation from voxel to RAS + space. + surface (cedalion.dataclasses.Surface): The surface to map the voxels to. + + Returns: + coo_array: A sparse matrix of shape (ncells, nvertices) that maps voxels to + cells. + """ assert surface.crs == transform_vox2ras.dims[0] @@ -48,6 +60,17 @@ def map_segmentation_mask_to_surface( def normal_hrf(t, t_peak, t_std, vmax): + """Create a normal hrf. + + Args: + t (np.ndarray): The time points. + t_peak (float): The peak time. + t_std (float): The standard deviation. + vmax (float): The maximum value of the HRF. + + Returns: + np.ndarray: The HRF. + """ hrf = scipy.stats.norm.pdf(t, loc=t_peak, scale=t_std) hrf *= vmax / hrf.max() return hrf @@ -61,6 +84,20 @@ def create_mock_activation_below_point( spatial_size: units.Quantity, vmax: units.Quantity, ): + """Create a mock activation below a point. + + Args: + head_model (cedalion.imagereco.forward_model.TwoSurfaceHeadModel): The head + model. + point (cdt.LabeledPointCloud): The point below which to create the activation. + time_length (units.Quantity): The length of the activation. + sampling_rate (units.Quantity): The sampling rate. + spatial_size (units.Quantity): The spatial size of the activation. + vmax (units.Quantity): The maximum value of the activation. + + Returns: + xr.DataArray: The activation. + """ # assert head_model.crs == point.points.crs _, vidx = head_model.brain.kdtree.query(point) diff --git a/src/cedalion/io/anatomy.py b/src/cedalion/io/anatomy.py index 4a5548a..47e3d9e 100644 --- a/src/cedalion/io/anatomy.py +++ b/src/cedalion/io/anatomy.py @@ -115,6 +115,16 @@ def read_segmentation_masks( def cell_coordinates(mask, affine, units="mm"): + """Get the coordinates of each voxel in the transformed mask. + + Args: + mask (xr.DataArray): A binary mask of shape (i, j, k). + affine (np.ndarray): Affine transformation matrix. + units (str): Units of the output coordinates. + + Returns: + xr.DataArray: Coordinates of the center of each voxel in the mask. + """ # coordinates in voxel space i = np.arange(mask.shape[0]) j = np.arange(mask.shape[1]) diff --git a/src/cedalion/io/forward_model.py b/src/cedalion/io/forward_model.py index 3f28f88..f33fa57 100644 --- a/src/cedalion/io/forward_model.py +++ b/src/cedalion/io/forward_model.py @@ -7,15 +7,11 @@ def save_Adot(fn: str, Adot: xr.DataArray): """Save Adot to a netCDF file. - Parameters: - ---------- - fn: str - File name to save the data to. - Adot: xr.DataArray - Data to save. + Args: + fn (str): File name to save the data to. + Adot (xr.DataArray): Data to save. Returns: - ------- None """ @@ -26,12 +22,10 @@ def load_Adot(fn: str): """Load Adot from a netCDF file. Args: - fn: str - File name to load the data from. + fn (str): File name to load the data from. Returns: - Adot: xr.DataArray - Data loaded from the file. + xr.DataArray: Data loaded from the file. """ Adot = xr.open_dataset(fn) @@ -88,7 +82,14 @@ def save_fluence(fn : str, fluence_all, fluence_at_optodes): def load_fluence(fn : str): - """Load forward model computation results.""" + """Load forward model computation results. + + Args: + fn (str): File name to load the data from. + + Returns: + Tuple[xr.DataArray, xr.DataArray]: Fluence data loaded from the file. + """ with h5py.File(fn, "r") as f: diff --git a/src/cedalion/io/photogrammetry.py b/src/cedalion/io/photogrammetry.py index c793cc0..97b8928 100644 --- a/src/cedalion/io/photogrammetry.py +++ b/src/cedalion/io/photogrammetry.py @@ -10,13 +10,14 @@ def read_photogrammetry_einstar(fn): photogrammetry pipeline using an einstar device. Args: - fn: the filename of the einstar photogrammatry output file + fn (str): The filename of the einstar photogrammetry output file. Returns: - fiducials : cedalion.LabeledPoints - The fiducials as a cedalion LabeledPoints object. - optodes : cedalion.LabeledPoints - The optodes as a cedalion LabeledPoints object. + tuple: A tuple containing: + - fiducials (cedalion.LabeledPoints): The fiducials as a cedalion + LabeledPoints object. + - optodes (cedalion.LabeledPoints): The optodes as a cedalion LabeledPoints + object. """ fiducials, optodes = read_einstar(fn) @@ -28,13 +29,12 @@ def read_einstar(fn): """Read optodes and fiducials from einstar devices. Args: - fn: The filename of the einstar photogrammatry output file. + fn (str): The filename of the einstar photogrammetry output file. Returns: - fiducials : OrderedDict - The fiducials as an OrderedDict. - optodes : OrderedDict - The optodes as an OrderedDict. + tuple: A tuple containing: + - fiducials (OrderedDict): The fiducials as an OrderedDict. + - optodes (OrderedDict): The optodes as an OrderedDict. """ with open(fn, "r") as f: @@ -54,16 +54,15 @@ def opt_fid_to_xr(fiducials, optodes): """Convert OrderedDicts fiducials and optodes to cedalion LabeledPoints objects. Args: - fiducials : OrderedDict - The fiducials as an OrderedDict. - optodes : OrderedDict - The optodes as an OrderedDict. + fiducials (OrderedDict): The fiducials as an OrderedDict. + optodes (OrderedDict): The optodes as an OrderedDict. Returns: - fiducials : cedalion.LabeledPoints - The fiducials as a cedalion LabeledPoints object. - optodes : cedalion.LabeledPoints - The optodes as a cedalion LabeledPoints object. + tuple: A tuple containing: + - fiducials (cedalion.LabeledPoints): The fiducials as a cedalion + LabeledPoints object. + - optodes (cedalion.LabeledPoints): The optodes as a cedalion LabeledPoints + object. """ # FIXME: this should get a different CRS diff --git a/src/cedalion/io/snirf.py b/src/cedalion/io/snirf.py index a0f001d..68af0a6 100644 --- a/src/cedalion/io/snirf.py +++ b/src/cedalion/io/snirf.py @@ -150,7 +150,13 @@ def reduce_ndim_sourceLabels(sourceLabels: np.ndarray) -> list: snirf supports multidimensional source labels but we don't. This function tries to reduce n-dimensional source labels - to a unique common prefix to obtain only one label per source + to a unique common prefix to obtain only one label per source. + + Args: + sourceLabels (np.ndarray): The source labels to reduce. + + Returns: + list: The reduced source labels. """ labels = [] for i_src in range(sourceLabels.shape[0]): @@ -181,11 +187,10 @@ def labels_and_positions(probe, dim: int = 3): Args: probe: Nirs probe geometry variable, see snirf documentation (:cite:t:`Tucker2022`). - dim: must be either 2 or 3. + dim (int): Must be either 2 or 3. Returns: - Tuple(np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray): - Tuple containing the source, detector and landmark labels/positions. + tuple: A tuple containing the source, detector, and landmark labels/positions. """ def convert_none(probe, attrname, default): attr = getattr(probe, attrname) @@ -247,14 +252,13 @@ def geometry_from_probe(nirs_element: NirsElement, dim: int = 3): """Extract 3D coordinates of optodes and landmarks from probe information. Args: - nirs_element: Nirs data element as specified in the snirf + nirs_element (NirsElement): Nirs data element as specified in the snirf documentation (:cite:t:`Tucker2022`). - dim: must be either 2 or 3. + dim (int): Must be either 2 or 3. Returns: - result (xr.DataArray, (label, pos)): A DataArray containing the 3D coordinates - of optodes and landmarks, with dimensions 'label' and - 'pos' and coordinates 'label' and 'type'. + xr.DataArray: A DataArray containing the 3D coordinates of optodes and landmarks, + with dimensions 'label' and 'pos' and coordinates 'label' and 'type'. """ probe = nirs_element.probe @@ -633,6 +637,18 @@ def _get_time_coords( data_element: DataElement, df_measurement_list: pd.DataFrame, ) -> dict[str, ArrayLike]: + """Get time coordinates for the NIRS data element. + + Args: + nirs_element (NirsElement): NIRS data element containing metadata. + data_element (DataElement): Data element containing time and dataTimeSeries. + df_measurement_list (pd.DataFrame): DataFrame containing the measurement list. + + Returns: + tuple: A tuple containing: + - indices (None): Placeholder for indices. + - coordinates (dict[str, ArrayLike]): Dictionary with time coordinates. + """ time = data_element.time time_unit = nirs_element.metaDataTags.TimeUnit @@ -651,6 +667,17 @@ def _get_channel_coords( nirs_element: NirsElement, df_measurement_list: pd.DataFrame, ) -> tuple[ArrayLike, dict[str, ArrayLike]]: + """Get channel coordinates for the NIRS data element. + + Args: + nirs_element (NirsElement): NIRS data element containing probe information. + df_measurement_list (pd.DataFrame): DataFrame containing the measurement list. + + Returns: + tuple: A tuple containing: + - indices (None): Placeholder for indices. + - coordinates (dict[str, ArrayLike]): Dictionary with channel coordinates. + """ sourceLabels, detectorLabels, landmarkLabels, _, _, _ = labels_and_positions( nirs_element.probe ) @@ -813,6 +840,21 @@ def measurement_list_from_stacked( detector_labels=None, wavelengths=None, ): + """Create a measurement list from a stacked array. + + Args: + stacked_array (xr.DataArray): Stacked array containing the data. + data_type (str): Data type of the data. + trial_types (list[str]): List of trial types. + stacked_channel (str): Name of the channel dimension in the stacked array. + source_labels (list[str]): List of source labels. + detector_labels (list[str]): List of detector labels. + wavelengths (list[float]): List of wavelengths. + + Returns: + tuple: A tuple containing the source labels, detector labels, wavelengths, and + the measurement list. + """ if source_labels is None: source_labels = list(np.unique(stacked_array.source.values)) if detector_labels is None: diff --git a/src/cedalion/models/glm/design_matrix.py b/src/cedalion/models/glm/design_matrix.py index e94b113..c463388 100644 --- a/src/cedalion/models/glm/design_matrix.py +++ b/src/cedalion/models/glm/design_matrix.py @@ -320,11 +320,11 @@ def max_corr_short_channel(ts_long: cdt.NDTimeSeries, ts_short: cdt.NDTimeSeries correleation coefficient in any wavelength or chromophore. Args: - ts_long: time series of long channels - ts_short: time series of short channels + ts_long (NDTimeSeries): time series of long channels + ts_short (NDTimeSeries): time series of short channels Returns: - channel-wise regressors + xr.DataArray: channel-wise regressors """ dim3 = xrutils.other_dim(ts_long, "channel", "time") diff --git a/src/cedalion/sigproc/quality.py b/src/cedalion/sigproc/quality.py index 9a39722..2df758c 100644 --- a/src/cedalion/sigproc/quality.py +++ b/src/cedalion/sigproc/quality.py @@ -81,6 +81,20 @@ def psp( window_length: Annotated[Quantity, "[time]"], psp_thresh: float, ): + """Calculate the phase slope index (PSP) metric. + + Args: + amplitudes (NDTimeSeries): input time + series + window_length (Quantity): size of the computation window + psp_thresh (float): if the calculated PSP metric falls below this threshold then the + corresponding time window should be excluded. + + Returns: + A tuple (psp, psp_mask), where psp is a DataArray with coords from the input + NDTimeseries containing the PSP metric. psp_mask is a boolean mask DataArray + with coords from psp, true where psp_thresh is met. + """ # FIXME make these configurable cardiac_fmin = 0.5 * units.Hz cardiac_fmax = 2.5 * units.Hz @@ -153,7 +167,14 @@ def psp( @cdc.validate_schemas def gvtd(amplitudes: NDTimeSeries): - """Calculate GVTD metric.""" + """Calculate GVTD metric. + + Args: + amplitudes (:class:`NDTimeSeries`, (channel, wavelength, time)): input time + series + Returns: + A DataArray with coords from the input NDTimeseries containing the GVTD metric. + """ fcut_min = 0.01 fcut_max = 0.5 @@ -517,7 +538,16 @@ def id_motion_refine(ma_mask: cdt.NDTimeSeries, operator: str): def detect_outliers_std( ts: cdt.NDTimeSeries, t_window: Annotated[Quantity, "[time]"], iqr_threshold=2 ): - """Detect outliers in fNIRSdata based on standard deviation of signal.""" + """Detect outliers in fNIRSdata based on standard deviation of signal. + + Args: + ts (NDTimeSeries): input time series + t_window (Quantity): window size for standard deviation + iqr_threshold: threshold for IQR based outlier detection. Default is 2. + + Returns: + A DataArray with coords from the input NDTimeseries containing the outlier mask. + """ ts = ts.pint.dequantify() fs = freq.sampling_rate(ts) @@ -551,7 +581,15 @@ def detect_outliers_std( @cdc.validate_schemas def detect_outliers_grad(ts: cdt.NDTimeSeries, iqr_threshold=1.5): - """Detect outliers in fNIRSdata based on gradient of signal.""" + """Detect outliers in fNIRSdata based on gradient of signal. + + Args: + ts (NDTimeSeries): input time series + iqr_threshold: threshold for IQR based outlier detection. Default is 1.5. + + Returns: + A DataArray with coords from the input NDTimeseries containing the outlier mask. + """ ts = ts.pint.dequantify() @@ -594,6 +632,17 @@ def detect_outliers( iqr_threshold_std : float =2, iqr_threshold_grad : float =1.5, ): + """Detect outliers based on standard deviation and gradient of signal. + + Args: + ts (NDTimeSeries): input time series + t_window_std (Quantity): window size for standard deviation + iqr_threshold_std: threshold for IQR based outlier detection in standard deviation + iqr_threshold_grad: threshold for IQR based outlier detection in gradient + + Returns: + A DataArray with coords from the input NDTimeseries containing the outlier mask. + """ mask_std = detect_outliers_std(ts, t_window_std, iqr_threshold_std) mask_grad = detect_outliers_grad(ts, iqr_threshold_grad) @@ -603,11 +652,15 @@ def detect_outliers( def _mask1D_to_segments(mask: ArrayLike): """Find consecutive segments for a boolean mask. - Given a boolean mask, this function returns an integer array `segements` of - shape (nsegments,3) in which - - segments[:,0] is the first index of the segment - - segments[:,1]-1 is the last index of the segment and - - segments[:,2] is the integer-converted mask value in that segment + Args: + mask (ArrayLike): boolean mask + + Returns: + Given a boolean mask, this function returns an integer array `segments` of + shape (nsegments,3) in which + - segments[:,0] is the first index of the segment + - segments[:,1]-1 is the last index of the segment and + - segments[:,2] is the integer-converted mask value in that segment """ # FIXME decide how to index: @@ -625,6 +678,16 @@ def _mask1D_to_segments(mask: ArrayLike): def _calculate_snr(ts, fs, segments): + """Calculate signal to noise ratio for a time series. + + Args: + ts (ArrayLike): Time series + fs (float): Sampling rate + segments (ArrayLike): Segments of the time series + + Returns: + float: Signal to noise ratio + """ # Calculate signal to noise ratio by considering only longer segments. # Only segments longer than 3s are used. Segments may be clean or tainted. long_seg_snr = [ @@ -643,6 +706,16 @@ def _calculate_snr(ts, fs, segments): return snr def _calculate_delta_threshold(ts, segments, threshold_samples): + """Calculate delta threshold for a time series. + + Args: + ts (ArrayLike): Time series + segments (ArrayLike): Segments of the time series + threshold_samples (int): Threshold samples + + Returns: + float: Delta threshold + """ # for long segments (>threshold_samples (0.5s)) that are not marked as artifacts # calculate the absolute differences of samples that are threshold_samples away # from each other @@ -663,6 +736,16 @@ def _calculate_delta_threshold(ts, segments, threshold_samples): def detect_baselineshift(ts: cdt.NDTimeSeries, outlier_mask: cdt.NDTimeSeries): + """Detect baseline shifts in fNIRS data. + + Args: + ts (NDTimeSeries): input time series + outlier_mask (NDTimeSeries): boolean mask indicating outliers + + Returns: + A DataArray with coords from the input NDTimeseries containing the baseline + shift mask. + """ ts = ts.pint.dequantify() #ts = ts.stack(measurement=["channel", "wavelength"]).sortby("wavelength") diff --git a/src/cedalion/sigproc/tasks.py b/src/cedalion/sigproc/tasks.py index 7a699bf..2574c1d 100644 --- a/src/cedalion/sigproc/tasks.py +++ b/src/cedalion/sigproc/tasks.py @@ -38,12 +38,12 @@ def od2conc( """Calculate hemoglobin concentrations from optical density data. Args: - rec: container of timeseries data - dpf: differential path length factors - spectrum: label of the extinction coefficients to use - ts_input: name of intensity timeseries. If None, this tasks operates on - the last timeseries in rec.timeseries. - ts_output: name of optical density timeseries. + rec (Recording): container of timeseries data + dpf (dict[float, float]): differential path length factors + spectrum (str): label of the extinction coefficients to use (default: "prahl") + ts_input (str | None): name of intensity timeseries. If None, this tasks operates + on the last timeseries in rec.timeseries. + ts_output (str): name of optical density timeseries (default: "conc"). """ ts = rec.get_timeseries(ts_input) @@ -67,6 +67,17 @@ def snr( aux_obj_output: str = "snr", mask_output: str = "snr", ): + """Calculate signal-to-noise ratio (SNR) of timeseries data. + + Args: + rec (cdc.Recording): The recording object containing the data. + snr_thresh (float): The SNR threshold. + ts_input (str | None, optional): The input time series. Defaults to None. + aux_obj_output (str, optional): The key for storing the SNR in the auxiliary + object. Defaults to "snr". + mask_output (str, optional): The key for storing the mask in the recording + object. Defaults to "snr". + """ ts = rec.get_timeseries(ts_input) snr, snr_mask = cedalion.sigproc.quality.snr(ts, snr_thresh) @@ -84,6 +95,18 @@ def sd_dist( aux_obj_output: str = "sd_dist", mask_output: str = "sd_dist", ): + """Calculate source-detector separations and mask channels outside a range. + + Args: + rec (cdc.Recording): The recording object containing the data. + sd_min (Annotated[Quantity, "[length]"]): The minimum source-detector separation. + sd_max (Annotated[Quantity, "[length]"]): The maximum source-detector separation. + ts_input (str | None, optional): The input time series. Defaults to None. + aux_obj_output (str, optional): The key for storing the source-detector distances + in the auxiliary object. Defaults to "sd_dist". + mask_output (str, optional): The key for storing the mask in the recording object. + Defaults to "sd_dist". + """ ts = rec.get_timeseries(ts_input) sd_dist, mask = cedalion.sigproc.quality.sd_dist(ts, rec.geo3d, (sd_min, sd_max)) diff --git a/src/cedalion/vtktutils.py b/src/cedalion/vtktutils.py index deb3701..0d47e53 100644 --- a/src/cedalion/vtktutils.py +++ b/src/cedalion/vtktutils.py @@ -6,6 +6,14 @@ def trimesh_to_vtk_polydata(mesh: trimesh.Trimesh): + """Convert a Trimesh object to a VTK PolyData object. + + Args: + mesh (trimesh.Trimesh): The input trimesh object. + + Returns: + vtk.vtkPolyData: The converted VTK PolyData object. + """ ntris, ndim_cells = mesh.faces.shape nvertices, ndim_vertices = mesh.vertices.shape @@ -45,6 +53,14 @@ def trimesh_to_vtk_polydata(mesh: trimesh.Trimesh): def pyvista_polydata_to_trimesh(polydata: pv.PolyData) -> trimesh.Trimesh: + """Convert a PyVista PolyData object to a Trimesh object. + + Args: + polydata (pv.PolyData): The input PyVista PolyData object. + + Returns: + trimesh.Trimesh: The converted Trimesh object. + """ vertices = polydata.points faces = polydata.regular_faces diff --git a/src/cedalion/xrutils.py b/src/cedalion/xrutils.py index 70c426f..2b39e7f 100644 --- a/src/cedalion/xrutils.py +++ b/src/cedalion/xrutils.py @@ -168,7 +168,7 @@ def other_dim(data_array: xr.DataArray, *dims: str) -> str: its name. Args: - data_array: a xr.DataArray + data_array: an xr.DataArray *dims: names of dimensions Returns: