Skip to content

Commit

Permalink
fix: resolve MPIArray warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
jrs65 committed Jul 11, 2022
1 parent e80df30 commit e534d72
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
27 changes: 12 additions & 15 deletions ch_util/rfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,12 @@ def number_deviations(
list(zip(*[(pp[0], ind) for ind, pp in enumerate(prod) if pp[0] == pp[1]]))
)

auto_vis = data.vis[:, auto_pi, :].view(np.ndarray).copy().real
auto_vis = data.vis[:, auto_pi, :].real.copy()

# If requested, average over all inputs to construct the stacked autocorrelations
# for the instrument (also known as the incoherent beam)
if stack:
weight = (data.weight[:, auto_pi, :].view(np.ndarray) > 0.0).astype(np.float32)
weight = (data.weight[:, auto_pi, :] > 0.0).astype(np.float32)

# Do not include bad inputs in the average
partial_stack = data.index_map["stack"].size < data.index_map["prod"].size
Expand Down Expand Up @@ -236,12 +236,7 @@ def number_deviations(
auto_ii = np.zeros(1, dtype=np.int)

else:
auto_flag = data.weight[:, auto_pi, :].view(np.ndarray) > 0.0

# Convert back to an MPIArray distributed over the freq axis
if parallel:
auto_flag = mpiarray.MPIArray.wrap(auto_flag, axis=0, comm=data.vis.comm)
auto_vis = mpiarray.MPIArray.wrap(auto_vis, axis=0, comm=data.vis.comm)
auto_flag = data.weight[:, auto_pi, :] > 0.0

# Now redistribute the array over inputs
if parallel:
Expand All @@ -257,7 +252,7 @@ def number_deviations(
static_flag = static_flag[:, np.newaxis]

# Create an empty array for number of median absolute deviations
ndev = np.zeros(auto_vis.shape, dtype=np.float32)
ndev = np.zeros_like(auto_vis, dtype=np.float32)

# Calculate frequency interval in bins
fwidth = (
Expand All @@ -267,12 +262,16 @@ def number_deviations(
# Calculate time interval in samples
twidth = int(time_width / np.median(np.abs(np.diff(data.time)))) + 1

auto_flag_view = auto_flag.local_array if parallel else auto_flag
auto_vis_view = auto_vis.local_array if parallel else auto_vis_view
ndev_view = ndev.local_array if parallel else ndev

# Loop over extracted autos and create a mask for each
for ind in range(auto_vis.shape[1]):

# Create a quick copy
flg = static_flag & auto_flag[:, ind].view(np.ndarray)
arr = auto_vis[:, ind].view(np.ndarray).copy()
flg = static_flag & auto_flag_view[:, ind]
arr = auto_vis_view[:, ind].copy()

# Use NaNs to ignore previously flagged data when computing the MAD
arr = np.where(flg, arr.real, np.nan)
Expand All @@ -285,16 +284,14 @@ def number_deviations(
else:
ndev_i = mad_cut_2d(arr, twidth=twidth, fwidth=fwidth, mask=False)

ndev[:, ind, :] = ndev_i
ndev_view[:, ind, :] = ndev_i

# Fill any values equal to NaN with the user specified fill value
ndev = np.where(np.isfinite(ndev), ndev, fill_value)
ndev_view[~np.isfinite(ndev_view)] = fill_value

# Convert back to an MPIArray and redistribute over freq axis
if parallel:
ndev = mpiarray.MPIArray.wrap(ndev, axis=1, comm=data.vis.comm)
ndev = ndev.redistribute(0)

auto_vis = auto_vis.redistribute(0)

return auto_ii, auto_vis, ndev
Expand Down
6 changes: 5 additions & 1 deletion ch_util/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,7 +1297,11 @@ def apply_timing_correction(self, timestream, copy=False, **kwargs):
else:
is_obj = True

vis = timestream.vis[:] if not copy else timestream.vis[:].copy()
# This works for both distributed and non-distributed datasets
vis = timestream.vis[:].view(np.ndarray)

if copy:
vis = vis.copy()

freq = kwargs.pop("freq") if "freq" in kwargs else timestream.freq[:]
prod = (
Expand Down

0 comments on commit e534d72

Please sign in to comment.