From cb8b5938b2dc8f606c8d7c73a7db129d35f54c80 Mon Sep 17 00:00:00 2001 From: ljgray Date: Wed, 17 Aug 2022 11:06:59 -0700 Subject: [PATCH] feat(rfi): add frequency bounds to mad_cut_rolling. --- ch_util/rfi.py | 85 ++++++++++++++++++++++++++------------------------ 1 file changed, 45 insertions(+), 40 deletions(-) diff --git a/ch_util/rfi.py b/ch_util/rfi.py index 6446aaee..fab75c94 100644 --- a/ch_util/rfi.py +++ b/ch_util/rfi.py @@ -195,68 +195,57 @@ def number_deviations( auto_ii, auto_vis, auto_flag = get_autocorrelations(data, stack, normalize) # Create static flag of frequencies that are known to be bad - if apply_static_mask: - static_flag = ~frequency_mask(data.freq) - else: - static_flag = np.ones(data.nfreq, dtype=np.bool) + static_flag = ( + ~frequency_mask(data.freq) + if apply_static_mask + else np.ones(data.nfreq, dtype=np.bool) + )[:, np.newaxis] - static_flag = static_flag[:, np.newaxis] if parallel: - static_flag = mpiarray.MPIArray.wrap(static_flag, axis=0) - # _sf = mpiarray.MPIArray(static_flag.shape, axis=0, dtype=bool) - # _sf[:] = static_flag[:] - # static_flag = _sf + # Ensure these are distributed across frequency + auto_vis = auto_vis.redistribute(0) + auto_flag = auto_flag.redistribute(0) + static_flag = mpiarray.MPIArray.wrap(static_flag[auto_vis.local_bounds], axis=0) # Calculate frequency interval in bins fwidth = ( int(freq_width / np.median(np.abs(np.diff(data.freq)))) + 1 if not flag1d else 1 ) - # Calculate time interval in samples twidth = int(time_width / np.median(np.abs(np.diff(data.time)))) + 1 - # Redistribute back across frequencies - if parallel: - auto_vis = auto_vis.redistribute(0) - auto_flag = auto_flag.redistribute(0) - static_flag = static_flag.redistribute(0) - # Create an empty array for number of median absolute deviations ndev = np.zeros_like(auto_vis, dtype=np.float32) - auto_flag_view = auto_flag.local_array if parallel else auto_flag - auto_vis_view = auto_vis.local_array if parallel else auto_vis + auto_flag_view = auto_flag.allgather() if parallel else auto_flag + static_flag_view = static_flag.allgather() if parallel else static_flag ndev_view = ndev.local_array if parallel else ndev - # mfd_view = mfd.local_array if parallel else mfd - static_flag_view = static_flag.local_array if parallel else static_flag # Loop over extracted autos and create a mask for each for ind in range(auto_vis.shape[1]): - # Create a quick copy flg = static_flag_view & auto_flag_view[:, ind] - arr = auto_vis_view[:, ind].copy() - + # Gather enire array onto each rank + arr = auto_vis[:, ind].allgather() if parallel else auto_vis[:, ind] # Use NaNs to ignore previously flagged data when computing the MAD arr = np.where(flg, arr.real, np.nan) + local_bounds = auto_vis.local_bounds if parallel else None # Apply RFI flagger if rolling: - ndev_i = mad_cut_rolling(arr, twidth=twidth, fwidth=fwidth, mask=False) + # Limit bounds to the local portion of the array + ndev_i = mad_cut_rolling( + arr, twidth=twidth, fwidth=fwidth, mask=False, limit_range=local_bounds + ) elif flag1d: - ndev_i = mad_cut_1d(arr, twidth=twidth, mask=False) + ndev_i = mad_cut_1d(arr[local_bounds, :], twidth=twidth, mask=False) else: - ndev_i = mad_cut_2d(arr, twidth=twidth, fwidth=fwidth, mask=False) + ndev_i = mad_cut_2d(arr[local_bounds, :], twidth=twidth, fwidth=fwidth, mask=False) ndev_view[:, ind, :] = ndev_i # Fill any values equal to NaN with the user specified fill value ndev_view[~np.isfinite(ndev_view)] = fill_value - # Convert back to an MPIArray and redistribute over freq axis - if parallel: - ndev = ndev.redistribute(0) - auto_vis = auto_vis.redistribute(0) - return auto_ii, auto_vis, ndev @@ -571,7 +560,13 @@ def _rolling_window(a, window): def mad_cut_rolling( - data, fwidth=64, twidth=42, threshold=5.0, freq_flat=True, mask=True + data, + fwidth=64, + twidth=42, + threshold=5.0, + freq_flat=True, + mask=True, + limit_range: slice = None, ): """Mask out RFI by placing a cut on the absolute deviation. Compared to `mad_cut_2d`, this function calculates @@ -598,12 +593,20 @@ def mad_cut_rolling( mask : boolean, optional If True return the mask, if False return the number of median absolute deviations. + limit_range : slice, optional + Data is limited to this range in the freqeuncy axis. Defaults to None. Returns ------- mask : np.ndarray[freq, time] Mask or number of median absolute deviations for each sample. """ + # Make sure we have an odd number of samples + fwidth += int(not (fwidth % 2)) + twidth += int(not (twidth % 2)) + + foff = fwidth // 2 + toff = twidth // 2 nfreq, ntime = data.shape @@ -612,19 +615,21 @@ def mad_cut_rolling( mfd = tools.invert_no_zero(nanmedian(data, axis=1)) data *= mfd[:, np.newaxis] - # Make sure we have an odd number of samples - fwidth += int(not (fwidth % 2)) - twidth += int(not (twidth % 2)) - - foff = fwidth // 2 - toff = twidth // 2 - # Add NaNs around the edges of the array so that we don't have to treat them separately eshp = [nfreq + fwidth - 1, ntime + twidth - 1] - exp_data = np.full(eshp, np.nan, dtype=data.dtype) exp_data[foff : foff + nfreq, toff : toff + ntime] = data + if limit_range is not None: + # Get only desired slice + expsl = slice( + max(limit_range.start, 0), + min(limit_range.stop + 2 * foff, exp_data.shape[0]), + ) + dsl = slice(max(limit_range.start, 0), min(limit_range.stop, data.shape[0])) + exp_data = exp_data[expsl, :] + data = data[dsl, :] + # Use numpy slices to construct the rolling windowed data win_data = _rolling_window(exp_data, (fwidth, twidth))