Skip to content

Commit

Permalink
Added PCArecurse
Browse files Browse the repository at this point in the history
Added two functions to motion_correct.py - motion_correct_PCA and motion_correct_PCA_recurse which are modelled after the Homer functions.
  • Loading branch information
lauracarlton committed Mar 22, 2024
1 parent 391c2b9 commit 086de44
Showing 1 changed file with 158 additions and 1 deletion.
159 changes: 158 additions & 1 deletion src/cedalion/sigproc/motion_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from scipy.interpolate import UnivariateSpline
from scipy.signal import savgol_filter
import xarray as xr
import cedalion.xrutils as xrutils
from scipy.linalg import svd
import cedalion.typing as cdt
from cedalion import Quantity, units
import cedalion.dataclasses as cdc
from .artifact import detect_outliers, detect_baselineshift
from .artifact import detect_outliers, detect_baselineshift, id_motion, id_motion_refine

@cdc.validate_schemas
def motionCorrectSpline(fNIRSdata:cdt.NDTimeSeries, tIncCh:cdt.NDTimeSeries): #, mlAct:cdt.NDTimeSeries):
Expand Down Expand Up @@ -212,5 +214,160 @@ def motionCorrectSplineSG(fNIRSdata:cdt.NDTimeSeries, framesize_sec:Quantity = 1

return dodSplineSG

#%%
@cdc.validate_schemas
def motion_correct_PCA(fNIRSdata:cdt.NDTimeSeries, tInc:cdt.NDTimeSeries, nSV:Quantity = 0.97):
'''
Apply motion correction using PCA filter on segments of data idenitfied as motion artefact.
Based on Homer3 [1] v1.80.2 "hmrR_MotionCorrectPCA.m"
Boston University Neurophotonics Center
https://github.com/BUNPC/Homer3
Inputs:
fNIRSdata (cdt.NDTimeSeries): The fNIRS data to be motion corrected.
tInc (cdt.NDTimeSeries): The time series indicating the presence of motion artifacts.
nSV (Quantity): Specifies the number of prinicpal components to remove from the data. If nSV < 1 then the filter removes the first
n components of the data that removes a fraction of the variance up to nSV.
Returns:
fNIRSdata_cleaned (cdt.NDTimeSeries): The motion-corrected fNIRS data.
svs (np.array): the singular values of the PCA.
nSV (Quantity): the number of principal components removed from the data.
"""
'''
# apply mask to get only points with motion
y, m = xrutils.apply_mask(fNIRSdata, tInc, 'drop', 'none')

# stack y and od
y = y.stack(measurement = ['channel', 'wavelength']).sortby('wavelength').pint.dequantify()

fNIRSdata = fNIRSdata.stack(measurement = ['channel', 'wavelength']).sortby('wavelength').pint.dequantify()

# PCA
yo = y.copy()
c = np.dot(y.T, y)

V, St, foo = xr.apply_ufunc(svd, c)

svs = St / np.sum(St)

svsc = svs.copy()
for idx in range(1,svs.shape[0]):
svsc[idx] = svsc[idx-1] + svs[idx]

if nSV < 1 and nSV > 0:
ev = svsc < nSV
nSV = np.where(ev == 0)[0][0]

ev = np.zeros((svs.shape[0], 1))
ev[:nSV] = 1
ev = np.diag(np.squeeze(ev))

# remove top PCs
yc = yo - np.dot(np.dot(y, V), np.dot(ev, V.T))

# insert cleaned signal back into od
lstMs = np.where(np.diff(tInc.values.astype(int)) == 1)[0]
lstMf = np.where(np.diff(tInc.values.astype(int)) == -1)[0]


if len(lstMs) == 0:
lstMs = [0]
if len(lstMf) == 0:
lstMf = len(tInc)-1
if lstMs[0] > lstMf[0]:
lstMs = np.insert(lstMs, 0, 0)
if lstMs[-1] > lstMf[-1]:
lstMf = np.append(lstMf, len(tInc)-1)


lstMb = lstMf - lstMs

for ii in range(1, len(lstMb)):
lstMb[ii] = lstMb[ii - 1] + lstMb[ii]

lstMb = lstMb-1

yc_ts = yc.values
fNIRSdata_cleaned_ts = fNIRSdata.copy().values
fNIRSdata_ts = fNIRSdata.copy().values

for jj in range(fNIRSdata_cleaned_ts.shape[1]):

lst = np.arange(lstMs[0], lstMf[0])

if lstMs[0] > 0:
fNIRSdata_cleaned_ts[lst, jj] = yc_ts[:lstMb[0]+1, jj] - yc_ts[0, jj] + fNIRSdata_cleaned_ts[lst[0], jj]
else:
fNIRSdata_cleaned_ts[lst, jj] = yc_ts[:lstMb[0]+1, jj] - yc_ts[lstMb[0], jj] + fNIRSdata_cleaned_ts[lst[-1], jj]

for kk in range(len(lstMf) - 1):
lst = np.arange(lstMf[kk] - 1, lstMs[kk + 1]+1 )
fNIRSdata_cleaned_ts[lst, jj] = fNIRSdata_ts[lst, jj] - fNIRSdata_ts[lst[0], jj] + fNIRSdata_cleaned_ts[lst[0], jj]

lst = np.arange(lstMs[kk + 1], lstMf[kk + 1])
fNIRSdata_cleaned_ts[lst, jj] = fNIRSdata_ts[lstMb[kk]+1 :lstMb[kk + 1]+1, jj] - fNIRSdata_ts[lstMb[kk]+1 , jj] + fNIRSdata_cleaned_ts[lst[0], jj]


if lstMf[-1] < len(fNIRSdata_ts)-1:
lst = np.arange(lstMf[-1] - 1, len(fNIRSdata_ts))
fNIRSdata_cleaned_ts[lst, jj] = fNIRSdata_ts[lst, jj] - fNIRSdata_ts[lst[0], jj] + fNIRSdata_cleaned_ts[lst[0], jj]



fNIRSdata_cleaned = fNIRSdata.copy()
fNIRSdata_cleaned.values = fNIRSdata_cleaned_ts

fNIRSdata_cleaned = fNIRSdata_cleaned.unstack('measurement').pint.quantify()
fNIRSdata_cleaned = fNIRSdata_cleaned.transpose("channel", "wavelength", "time")


return fNIRSdata_cleaned, nSV, svs


#%%
def motion_correct_PCA_recurse(fNIRSdata:cdt.NDTimeSeries, t_motion:Quantity = 0.5, t_mask:Quantity = 1, stdev_thresh:Quantity = 20, amp_thresh:Quantity = 5, nSV:Quantity = 0.97, maxIter:Quantity = 5):
'''
Identify motion artefacts in input fNIRSdata. If any active channel exhibits signal change greater than STDEVthresh or AMPthresh,
then that segment of data is marked as a motion artefact. motion_correct_PCA is applied to all segments of data identified as a motion
artefact. This is called until maxIter is reached or there are no motion artefacts identified.
Inputs:
fNIRSdata (cdt.NDTimeSeries): The fNIRS data to be motion corrected.
tMotion (Quantity): check for signal change indicative of a motion artefact over time range tMotion. (units of seconds)
tMask (Quantity): mark data +/- tMask seconds aroundthe identified motion artefact as a motion artefact.
stdev_thresh (Quantity): if the signal d for any given active channel changes by more than stdev_thresh * stdev(d) over the time interval tMotion
then this time point is marked as a motion artefact
amp_thresh (Quantity): if the signal d for any given active channel changes by more than amp_thresh over the time interval tMotion then this time point
is marked as a motion artefact.
Returns:
fNIRSdata_cleaned (cdt.NDTimeSeries): The motion-corrected fNIRS data.
svs (np.array): the singular values of the PCA.
nSV (int): the number of principal components removed from the data.
'''

tIncCh = id_motion(fNIRSdata, t_motion, t_mask, stdev_thresh, amp_thresh) # unit stripped error x2

tInc = id_motion_refine(tIncCh, 'all')[0]
tInc.values = np.hstack([False, tInc.values[:-1]])

nI = 0
fNIRSdata_cleaned = fNIRSdata.copy()

while sum(tInc.values) > 0 and nI < maxIter:

nI = nI+1

fNIRSdata_cleaned, nSV, svs = motion_correct_PCA(fNIRSdata_cleaned, tInc, nSV=nSV)


tIncCh = id_motion(fNIRSdata_cleaned, t_motion, t_mask, stdev_thresh, amp_thresh)
tInc = id_motion_refine(tIncCh, 'all')[0]
tInc.values = np.hstack([False, tInc.values[:-1]])

return fNIRSdata_cleaned, svs, nSV, tInc



0 comments on commit 086de44

Please sign in to comment.