Skip to content

Commit

Permalink
feat(rfi): add frequency bounds to mad_cut_rolling.
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Sep 6, 2022
1 parent facd85a commit e8d8605
Showing 1 changed file with 47 additions and 40 deletions.
87 changes: 47 additions & 40 deletions ch_util/rfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,68 +195,59 @@ def number_deviations(
auto_ii, auto_vis, auto_flag = get_autocorrelations(data, stack, normalize)

# Create static flag of frequencies that are known to be bad
if apply_static_mask:
static_flag = ~frequency_mask(data.freq)
else:
static_flag = np.ones(data.nfreq, dtype=np.bool)
static_flag = (
~frequency_mask(data.freq)
if apply_static_mask
else np.ones(data.nfreq, dtype=np.bool)
)[:, np.newaxis]

static_flag = static_flag[:, np.newaxis]
if parallel:
static_flag = mpiarray.MPIArray.wrap(static_flag, axis=0)
# _sf = mpiarray.MPIArray(static_flag.shape, axis=0, dtype=bool)
# _sf[:] = static_flag[:]
# static_flag = _sf
# Ensure these are distributed across frequency
auto_vis = auto_vis.redistribute(0)
auto_flag = auto_flag.redistribute(0)
static_flag = mpiarray.MPIArray.wrap(static_flag[auto_vis.local_bounds], axis=0)

# Calculate frequency interval in bins
fwidth = (
int(freq_width / np.median(np.abs(np.diff(data.freq)))) + 1 if not flag1d else 1
)

# Calculate time interval in samples
twidth = int(time_width / np.median(np.abs(np.diff(data.time)))) + 1

# Redistribute back across frequencies
if parallel:
auto_vis = auto_vis.redistribute(0)
auto_flag = auto_flag.redistribute(0)
static_flag = static_flag.redistribute(0)

# Create an empty array for number of median absolute deviations
ndev = np.zeros_like(auto_vis, dtype=np.float32)

auto_flag_view = auto_flag.local_array if parallel else auto_flag
auto_vis_view = auto_vis.local_array if parallel else auto_vis
auto_flag_view = auto_flag.allgather() if parallel else auto_flag
static_flag_view = static_flag.allgather() if parallel else static_flag
ndev_view = ndev.local_array if parallel else ndev
# mfd_view = mfd.local_array if parallel else mfd
static_flag_view = static_flag.local_array if parallel else static_flag

# Loop over extracted autos and create a mask for each
for ind in range(auto_vis.shape[1]):

# Create a quick copy
flg = static_flag_view & auto_flag_view[:, ind]
arr = auto_vis_view[:, ind].copy()

# Gather enire array onto each rank
arr = auto_vis[:, ind].allgather() if parallel else auto_vis[:, ind]
# Use NaNs to ignore previously flagged data when computing the MAD
arr = np.where(flg, arr.real, np.nan)
local_bounds = auto_vis.local_bounds if parallel else slice(None)
# Apply RFI flagger
if rolling:
ndev_i = mad_cut_rolling(arr, twidth=twidth, fwidth=fwidth, mask=False)
# Limit bounds to the local portion of the array
ndev_i = mad_cut_rolling(
arr, twidth=twidth, fwidth=fwidth, mask=False, limit_range=local_bounds
)
elif flag1d:
ndev_i = mad_cut_1d(arr, twidth=twidth, mask=False)
ndev_i = mad_cut_1d(arr[local_bounds, :], twidth=twidth, mask=False)
else:
ndev_i = mad_cut_2d(arr, twidth=twidth, fwidth=fwidth, mask=False)
ndev_i = mad_cut_2d(
arr[local_bounds, :], twidth=twidth, fwidth=fwidth, mask=False
)

ndev_view[:, ind, :] = ndev_i

# Fill any values equal to NaN with the user specified fill value
ndev_view[~np.isfinite(ndev_view)] = fill_value

# Convert back to an MPIArray and redistribute over freq axis
if parallel:
ndev = ndev.redistribute(0)
auto_vis = auto_vis.redistribute(0)

return auto_ii, auto_vis, ndev


Expand Down Expand Up @@ -571,7 +562,13 @@ def _rolling_window(a, window):


def mad_cut_rolling(
data, fwidth=64, twidth=42, threshold=5.0, freq_flat=True, mask=True
data,
fwidth=64,
twidth=42,
threshold=5.0,
freq_flat=True,
mask=True,
limit_range: slice = slice(None),
):
"""Mask out RFI by placing a cut on the absolute deviation.
Compared to `mad_cut_2d`, this function calculates
Expand All @@ -598,12 +595,20 @@ def mad_cut_rolling(
mask : boolean, optional
If True return the mask, if False return the number of
median absolute deviations.
limit_range : slice, optional
Data is limited to this range in the freqeuncy axis. Defaults to slice(None).
Returns
-------
mask : np.ndarray[freq, time]
Mask or number of median absolute deviations for each sample.
"""
# Make sure we have an odd number of samples
fwidth += int(not (fwidth % 2))
twidth += int(not (twidth % 2))

foff = fwidth // 2
toff = twidth // 2

nfreq, ntime = data.shape

Expand All @@ -612,19 +617,21 @@ def mad_cut_rolling(
mfd = tools.invert_no_zero(nanmedian(data, axis=1))
data *= mfd[:, np.newaxis]

# Make sure we have an odd number of samples
fwidth += int(not (fwidth % 2))
twidth += int(not (twidth % 2))

foff = fwidth // 2
toff = twidth // 2

# Add NaNs around the edges of the array so that we don't have to treat them separately
eshp = [nfreq + fwidth - 1, ntime + twidth - 1]

exp_data = np.full(eshp, np.nan, dtype=data.dtype)
exp_data[foff : foff + nfreq, toff : toff + ntime] = data

if limit_range != slice(None):
# Get only desired slice
expsl = slice(
max(limit_range.start, 0),
min(limit_range.stop + 2 * foff, exp_data.shape[0]),
)
dsl = slice(max(limit_range.start, 0), min(limit_range.stop, data.shape[0]))
exp_data = exp_data[expsl, :]
data = data[dsl, :]

# Use numpy slices to construct the rolling windowed data
win_data = _rolling_window(exp_data, (fwidth, twidth))

Expand Down

0 comments on commit e8d8605

Please sign in to comment.