diff --git a/examples/PCArecurse_motion_correct.ipynb b/examples/PCArecurse_motion_correct.ipynb new file mode 100644 index 0000000..b0a632d --- /dev/null +++ b/examples/PCArecurse_motion_correct.ipynb @@ -0,0 +1,167 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "773ef4aa", + "metadata": {}, + "outputs": [], + "source": [ + "import cedalion\n", + "import cedalion.nirs\n", + "import cedalion.sigproc.quality as quality\n", + "from cedalion.sigproc.artifact import detect_baselineshift, detect_outliers\n", + "from cedalion.sigproc.motion_correct import motionCorrectSpline, motionCorrectSplineSG\n", + "import cedalion.xrutils as xrutils\n", + "import cedalion.datasets as datasets\n", + "import xarray as xr\n", + "import matplotlib.pyplot as p\n", + "from functools import reduce\n", + "\n", + "\n", + "from cedalion import Quantity, units" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2d29061", + "metadata": {}, + "outputs": [], + "source": [ + "# get example finger tapping dataset\n", + "snirf_element = datasets.get_fingertapping()\n", + "amp = snirf_element[0].data[0]\n", + "geo = snirf_element[0].geo3d\n", + "od = cedalion.nirs.int2od(amp)\n", + "\n", + "data = xr.Dataset(\n", + " data_vars = {\n", + " \"amp\" : amp,\n", + " \"od\" : od,\n", + " \"geo3d\": geo\n", + " })\n", + "\n", + "\n", + "# Plot some data for visual validation\n", + "f,ax = p.subplots(1,1, figsize=(12,4))\n", + "ax.plot( data.amp.time, data.amp.sel(channel=\"S1D1\", wavelength=\"850\"), \"r-\", label=\"850nm\")\n", + "ax.plot( data.amp.time, data.amp.sel(channel=\"S1D1\", wavelength=\"760\"), \"r-\", label=\"760nm\")\n", + "p.legend()\n", + "ax.set_xlabel(\"time / s\")\n", + "ax.set_ylabel(\"Signal intensity / a.u.\")" + ] + }, + { + "cell_type": "markdown", + "id": "33ee847e", + "metadata": {}, + "source": [ + "# Detect motion and perform PCA filtering \n", + "\n", + "The motion_correct_PCA_recurse algortithm first detects motion in the the OD data. It then iteratively calls motion_correct_PCA which performs PCA filtering on all time points labelled as motion. \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "973dcf0c", + "metadata": {}, + "outputs": [], + "source": [ + "# typical motion id parameters\n", + "t_motion = 0.5\n", + "t_mask = 1\n", + "stdev_thresh = 20\n", + "amp_thresh = 5\n", + "\n", + "# motion identification \n", + "tIncCh = id_motion(fNIRSdata=data.od, t_motion=t_motion, t_mask=t_mask, \n", + " stedv_thresh=stdev_thresh, amp_thresh=amp_thresh) \n", + "tInc = id_motion_refine(tIncCh, 'all')[0]\n", + "tInc.values = np.hstack([False, tInc.values[:-1]]) # manual shift to account for indexing differences\n", + "\n", + "# call motion_correct_PCA\n", + "nSV=0.97 # discard n components up to 97% of variance \n", + "od_cleaned = motion_correct_PCA(fNIRSdata=data.od, tInc=tInc, nSV=nSV)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9031cd39", + "metadata": {}, + "outputs": [], + "source": [ + "# plot difference between uncorrected OD and after PCA filter correction\n", + "f,ax = p.subplots(1,1, figsize=(12,4))\n", + "ax.plot( data.od.time, data.od.sel(channel=\"S1D1\", wavelength=\"760\"), \"b-\", label=\"850nm OD\")\n", + "ax.plot( od_cleaned.time, od_cleaned.sel(channel=\"S1D1\", wavelength=\"760\"), \"g-\", label=\"850nm OD post PCA filtering\")\n", + "p.legend()\n", + "ax.set_xlabel(\"time / s\")\n", + "ax.set_ylabel(\"Optical density / a.u.\")" + ] + }, + { + "cell_type": "markdown", + "id": "a53f438c", + "metadata": {}, + "source": [ + "# Iterative PCA filtering \n", + "\n", + "Above, the PCA filtering was performed once. motion_correct_PCA_recurse iteratively calls the motion detection and motion_correct_PCA until either it reaches the maximum number of iterations specified or until there is no longer any motion detected. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb893ffb", + "metadata": {}, + "outputs": [], + "source": [ + "maxIter = 5\n", + "\n", + "od_cleaned_fully = motion_correct_PCA_recurse(fNIRSdata=data.od, t_motion=t_motion, t_mask=t_mask, \n", + " stedv_thresh=stdev_thresh, amp_thresh=amp_thresh, nSV=nSV, maxIter=maxIter)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c154425", + "metadata": {}, + "outputs": [], + "source": [ + "# plot difference between uncorrected OD and after iterative PCA filter correction\n", + "f,ax = p.subplots(1,1, figsize=(12,4))\n", + "ax.plot( data.od.time, data.od.sel(channel=\"S1D1\", wavelength=\"760\"), \"b-\", label=\"850nm OD\")\n", + "ax.plot( od_cleaned_fully.time, od_cleaned_fully.sel(channel=\"S1D1\", wavelength=\"760\"), \"g-\", label=\"850nm OD post PCA filtering\")\n", + "p.legend()\n", + "ax.set_xlabel(\"time / s\")\n", + "ax.set_ylabel(\"Optical density / a.u.\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/cedalion/sigproc/motion_correct.py b/src/cedalion/sigproc/motion_correct.py index 5bc3797..753c8f5 100644 --- a/src/cedalion/sigproc/motion_correct.py +++ b/src/cedalion/sigproc/motion_correct.py @@ -367,6 +367,7 @@ def motion_correct_PCA_recurse(fNIRSdata:cdt.NDTimeSeries, t_motion:Quantity = 0 tInc = id_motion_refine(tIncCh, 'all')[0] tInc.values = np.hstack([False, tInc.values[:-1]]) + return fNIRSdata_cleaned, svs, nSV, tInc