Skip to content

Commit

Permalink
Refactor(rfi): split number_deviations into two functions.
Browse files Browse the repository at this point in the history
A significant part of `number_deviations` involved extracting the
autocorrelations from input data. It makes sense to make this
functionality into its own function.
  • Loading branch information
ljgray committed Aug 17, 2022
1 parent 3a0a46a commit e4a9e40
Showing 1 changed file with 81 additions and 46 deletions.
127 changes: 81 additions & 46 deletions ch_util/rfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import warnings
import logging
from typing import Tuple

import numpy as np
import scipy.signal as sig
Expand Down Expand Up @@ -180,7 +181,6 @@ def number_deviations(
Number of median absolute deviations of the autocorrelations
from the local median.
"""

from caput import memh5, mpiarray

if fill_value is None:
Expand All @@ -192,51 +192,7 @@ def number_deviations(
data.redistribute("freq")

# Extract the auto correlations
prod = data.index_map["prod"][data.index_map["stack"]["prod"]]
auto_ii, auto_pi = np.array(
list(zip(*[(pp[0], ind) for ind, pp in enumerate(prod) if pp[0] == pp[1]]))
)

auto_vis = data.vis[:, auto_pi, :].real.copy()

# If requested, average over all inputs to construct the stacked autocorrelations
# for the instrument (also known as the incoherent beam)
if stack:
weight = (data.weight[:, auto_pi, :] > 0.0).astype(np.float32)

# Do not include bad inputs in the average
partial_stack = data.index_map["stack"].size < data.index_map["prod"].size

if not partial_stack and hasattr(data, "input_flags"):
input_flags = data.input_flags[:]
logger.info(
"There are on average %d good inputs."
% np.mean(np.sum(input_flags, axis=0), axis=-1)
)

if np.any(input_flags) and not np.all(input_flags):
logger.info("Applying input_flags to weight.")
weight *= input_flags[np.newaxis, auto_ii, :].astype(weight.dtype)

if normalize:
logger.info("Normalizing autocorrelations prior to stacking.")
med_auto = nanmedian(
np.where(weight, auto_vis, np.nan), axis=-1, keepdims=True
)
med_auto = np.where(np.isfinite(med_auto), med_auto, 0.0)
auto_vis *= tools.invert_no_zero(med_auto)

norm = np.sum(weight, axis=1, keepdims=True)

auto_vis = np.sum(
weight * auto_vis, axis=1, keepdims=True
) * tools.invert_no_zero(norm)

auto_flag = norm > 0.0
auto_ii = np.zeros(1, dtype=int)

else:
auto_flag = data.weight[:, auto_pi, :] > 0.0
auto_ii, auto_vis, auto_flag = get_autocorrelations(data, stack, normalize)

# Create static flag of frequencies that are known to be bad
if apply_static_mask:
Expand Down Expand Up @@ -304,6 +260,85 @@ def number_deviations(
return auto_ii, auto_vis, ndev


def get_autocorrelations(
data, stack: bool = False, normalize: bool = False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Extract autocorrelations from a data stack.
Parameters
----------
data : `andata.CorrData`
Must contain vis and weight attributes that are both
`np.ndarray[nfreq, nprod, ntime]`.
stack: bool, optional
Average over all autocorrelations.
normalize : bool, optional
Normalize by the median value over time prior to averaging over
autocorrelations. Only relevant if `stack` is True.
Returns
-------
auto_ii: np.ndarray[ninput,]
Index of the inputs that have been processed.
If stack is True, then [0] will be returned.
auto_vis: np.ndarray[nfreq, ninput, ntime]
The autocorrelations that were used to calculate
the number of deviations.
auto_flag: np.ndarray[nfreq, ninput, ntime]
Indices where data weights are positive
"""
# Extract the auto correlations
prod = data.index_map["prod"][data.index_map["stack"]["prod"]]
auto_ii, auto_pi = np.array(
list(zip(*[(pp[0], ind) for ind, pp in enumerate(prod) if pp[0] == pp[1]]))
)

auto_vis = data.vis[:, auto_pi, :].real.copy()

# If requested, average over all inputs to construct the stacked autocorrelations
# for the instrument (also known as the incoherent beam)
if stack:
weight = (data.weight[:, auto_pi, :] > 0.0).astype(np.float32)

# Do not include bad inputs in the average
partial_stack = data.index_map["stack"].size < data.index_map["prod"].size

if not partial_stack and hasattr(data, "input_flags"):
input_flags = data.input_flags[:]
logger.info(
"There are on average %d good inputs."
% np.mean(np.sum(input_flags, axis=0), axis=-1)
)

if np.any(input_flags) and not np.all(input_flags):
logger.info("Applying input_flags to weight.")
weight *= input_flags[np.newaxis, auto_ii, :].astype(weight.dtype)

if normalize:
logger.info("Normalizing autocorrelations prior to stacking.")
med_auto = nanmedian(
np.where(weight, auto_vis, np.nan), axis=-1, keepdims=True
)
med_auto = np.where(np.isfinite(med_auto), med_auto, 0.0)
auto_vis *= tools.invert_no_zero(med_auto)

norm = np.sum(weight, axis=1, keepdims=True)

auto_vis = np.sum(
weight * auto_vis, axis=1, keepdims=True
) * tools.invert_no_zero(norm)

auto_flag = norm > 0.0
auto_ii = np.zeros(1, dtype=int)

else:
auto_flag = data.weight[:, auto_pi, :] > 0.0

return auto_ii, auto_vis, auto_flag


def spectral_cut(data, fil_window=15, only_autos=False):
"""Flag out the TV bands, or other constant spectral RFI.
Expand Down

0 comments on commit e4a9e40

Please sign in to comment.