Skip to content

Commit

Permalink
add example notebook for PCA motion correction
Browse files Browse the repository at this point in the history
  • Loading branch information
lauracarlton committed Mar 22, 2024
1 parent d97f04e commit da1652c
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 0 deletions.
167 changes: 167 additions & 0 deletions examples/PCArecurse_motion_correct.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "773ef4aa",
"metadata": {},
"outputs": [],
"source": [
"import cedalion\n",
"import cedalion.nirs\n",
"import cedalion.sigproc.quality as quality\n",
"from cedalion.sigproc.artifact import detect_baselineshift, detect_outliers\n",
"from cedalion.sigproc.motion_correct import motionCorrectSpline, motionCorrectSplineSG\n",
"import cedalion.xrutils as xrutils\n",
"import cedalion.datasets as datasets\n",
"import xarray as xr\n",
"import matplotlib.pyplot as p\n",
"from functools import reduce\n",
"\n",
"\n",
"from cedalion import Quantity, units"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2d29061",
"metadata": {},
"outputs": [],
"source": [
"# get example finger tapping dataset\n",
"snirf_element = datasets.get_fingertapping()\n",
"amp = snirf_element[0].data[0]\n",
"geo = snirf_element[0].geo3d\n",
"od = cedalion.nirs.int2od(amp)\n",
"\n",
"data = xr.Dataset(\n",
" data_vars = {\n",
" \"amp\" : amp,\n",
" \"od\" : od,\n",
" \"geo3d\": geo\n",
" })\n",
"\n",
"\n",
"# Plot some data for visual validation\n",
"f,ax = p.subplots(1,1, figsize=(12,4))\n",
"ax.plot( data.amp.time, data.amp.sel(channel=\"S1D1\", wavelength=\"850\"), \"r-\", label=\"850nm\")\n",
"ax.plot( data.amp.time, data.amp.sel(channel=\"S1D1\", wavelength=\"760\"), \"r-\", label=\"760nm\")\n",
"p.legend()\n",
"ax.set_xlabel(\"time / s\")\n",
"ax.set_ylabel(\"Signal intensity / a.u.\")"
]
},
{
"cell_type": "markdown",
"id": "33ee847e",
"metadata": {},
"source": [
"# Detect motion and perform PCA filtering \n",
"\n",
"The motion_correct_PCA_recurse algortithm first detects motion in the the OD data. It then iteratively calls motion_correct_PCA which performs PCA filtering on all time points labelled as motion. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "973dcf0c",
"metadata": {},
"outputs": [],
"source": [
"# typical motion id parameters\n",
"t_motion = 0.5\n",
"t_mask = 1\n",
"stdev_thresh = 20\n",
"amp_thresh = 5\n",
"\n",
"# motion identification \n",
"tIncCh = id_motion(fNIRSdata=data.od, t_motion=t_motion, t_mask=t_mask, \n",
" stedv_thresh=stdev_thresh, amp_thresh=amp_thresh) \n",
"tInc = id_motion_refine(tIncCh, 'all')[0]\n",
"tInc.values = np.hstack([False, tInc.values[:-1]]) # manual shift to account for indexing differences\n",
"\n",
"# call motion_correct_PCA\n",
"nSV=0.97 # discard n components up to 97% of variance \n",
"od_cleaned = motion_correct_PCA(fNIRSdata=data.od, tInc=tInc, nSV=nSV)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9031cd39",
"metadata": {},
"outputs": [],
"source": [
"# plot difference between uncorrected OD and after PCA filter correction\n",
"f,ax = p.subplots(1,1, figsize=(12,4))\n",
"ax.plot( data.od.time, data.od.sel(channel=\"S1D1\", wavelength=\"760\"), \"b-\", label=\"850nm OD\")\n",
"ax.plot( od_cleaned.time, od_cleaned.sel(channel=\"S1D1\", wavelength=\"760\"), \"g-\", label=\"850nm OD post PCA filtering\")\n",
"p.legend()\n",
"ax.set_xlabel(\"time / s\")\n",
"ax.set_ylabel(\"Optical density / a.u.\")"
]
},
{
"cell_type": "markdown",
"id": "a53f438c",
"metadata": {},
"source": [
"# Iterative PCA filtering \n",
"\n",
"Above, the PCA filtering was performed once. motion_correct_PCA_recurse iteratively calls the motion detection and motion_correct_PCA until either it reaches the maximum number of iterations specified or until there is no longer any motion detected. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb893ffb",
"metadata": {},
"outputs": [],
"source": [
"maxIter = 5\n",
"\n",
"od_cleaned_fully = motion_correct_PCA_recurse(fNIRSdata=data.od, t_motion=t_motion, t_mask=t_mask, \n",
" stedv_thresh=stdev_thresh, amp_thresh=amp_thresh, nSV=nSV, maxIter=maxIter)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4c154425",
"metadata": {},
"outputs": [],
"source": [
"# plot difference between uncorrected OD and after iterative PCA filter correction\n",
"f,ax = p.subplots(1,1, figsize=(12,4))\n",
"ax.plot( data.od.time, data.od.sel(channel=\"S1D1\", wavelength=\"760\"), \"b-\", label=\"850nm OD\")\n",
"ax.plot( od_cleaned_fully.time, od_cleaned_fully.sel(channel=\"S1D1\", wavelength=\"760\"), \"g-\", label=\"850nm OD post PCA filtering\")\n",
"p.legend()\n",
"ax.set_xlabel(\"time / s\")\n",
"ax.set_ylabel(\"Optical density / a.u.\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions src/cedalion/sigproc/motion_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def motion_correct_PCA_recurse(fNIRSdata:cdt.NDTimeSeries, t_motion:Quantity = 0
tInc = id_motion_refine(tIncCh, 'all')[0]
tInc.values = np.hstack([False, tInc.values[:-1]])


return fNIRSdata_cleaned, svs, nSV, tInc


Expand Down

0 comments on commit da1652c

Please sign in to comment.