Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wavelet Power Spectra #237

Merged
merged 7 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 77 additions & 33 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Delay space spectrum estimation and filtering."""

import typing
from typing import List, Optional, Tuple, TypeVar

import numpy as np
import scipy.linalg as la
from caput import config, mpiarray
from caput import config, memh5, mpiarray
from cora.util import units
from numpy.lib.recfunctions import structured_to_unstructured

Expand Down Expand Up @@ -138,7 +138,7 @@ def process(self, ss):


# A specific subclass of a FreqContainer
FreqContainerType = typing.TypeVar("FreqContainerType", bound=containers.FreqContainer)
FreqContainerType = TypeVar("FreqContainerType", bound=containers.FreqContainer)


class DelayFilterBase(task.SingleTask):
Expand Down Expand Up @@ -641,41 +641,19 @@ def _process_data(self, ss):
)

# Find the relevant axis positions
data_axes = ss.datasets[self.dataset].attrs["axis"]
freq_axis_pos = list(data_axes).index("freq")
average_axis_pos = list(data_axes).index(self.average_axis)

# Create a view of the dataset with the relevant axes at the back,
# and all other axes compressed. End result is packed as
# [baseline_axis, average_axis, freq_axis].
data_view = np.moveaxis(
ss.datasets[self.dataset][:].local_array,
[average_axis_pos, freq_axis_pos],
[-2, -1],
data_view, bl_axes = flatten_axes(
ss.datasets[self.dataset], [self.average_axis, "freq"]
)
data_view = data_view.reshape(-1, data_view.shape[-2], data_view.shape[-1])
data_view = mpiarray.MPIArray.wrap(data_view, axis=2, comm=ss.comm)
nbase = int(np.prod(data_view.shape[:-2]))
data_view = data_view.redistribute(axis=0)

# ... do the same for the weights, but we also need to make the weights full
# size
weight_full = np.zeros(
ss.datasets[self.dataset][:].shape, dtype=ss.weight.dtype
weight_view, _ = flatten_axes(
ss.weight,
[self.average_axis, "freq"],
match_dset=ss.datasets[self.dataset],
)
weight_full[:] = match_axes(ss.datasets[self.dataset], ss.weight)
weight_view = np.moveaxis(
weight_full, [average_axis_pos, freq_axis_pos], [-2, -1]
)
weight_view = weight_view.reshape(
-1, weight_view.shape[-2], weight_view.shape[-1]
)
weight_view = mpiarray.MPIArray.wrap(weight_view, axis=2, comm=ss.comm)
weight_view = weight_view.redistribute(axis=0)

# Use the "baselines" axis to generically represent all the other axes

# Initialise the spectrum container
nbase = data_view.global_shape[0]
if self.output_power_spectrum:
delay_spec = containers.DelaySpectrum(
baseline=nbase,
Expand All @@ -692,7 +670,6 @@ def _process_data(self, ss):
)
delay_spec.redistribute("baseline")
delay_spec.spectrum[:] = 0.0
bl_axes = [da for da in data_axes if da not in [self.average_axis, "freq"]]

# Copy the index maps for all the flattened axes into the output container, and
# write out their order into an attribute so we can reconstruct this easily
Expand Down Expand Up @@ -1458,6 +1435,73 @@ def match_axes(dset1, dset2):
return dset2[:][bcast_slice]


def flatten_axes(
dset: memh5.MemDatasetDistributed,
axes_to_keep: List[str],
match_dset: Optional[memh5.MemDatasetDistributed] = None,
) -> Tuple[mpiarray.MPIArray, List[str]]:
"""Move the specified axes of the dataset to the back, and flatten all others.

Optionally this will add length-1 axes to match the axes of another dataset.

Parameters
----------
dset
The dataset to reshape.
axes_to_keep
The names of the axes to keep.
match_dset
An optional dataset to match the shape of.

Returns
-------
flat_array
The MPIArray representing the re-arranged dataset. Distributed along the
flattened axis.
flat_axes
The names of the flattened axes from slowest to fastest varying.
"""
# Find the relevant axis positions
data_axes = list(dset.attrs["axis"])

# Check that the requested datasets actually exist
for axis in axes_to_keep:
if axis not in data_axes:
raise ValueError(f"Specified {axis=} not present in dataset.")

# If specified, add extra axes to match the shape of the given dataset
if match_dset and tuple(dset.attrs["axis"]) != tuple(match_dset.attrs["axis"]):
dset_full = np.empty_like(match_dset[:])
dset_full[:] = match_axes(match_dset, dset)

axes_ind = [data_axes.index(axis) for axis in axes_to_keep]

# Get an MPIArray and make sure it is distributed along one of the preserved axes
data_array = dset[:]
if data_array.axis not in axes_ind:
data_array = data_array.redistribute(axes_ind[0])

# Create a view of the dataset with the relevant axes at the back,
# and all others moved to the front (retaining their relative order)
other_axes = [ax for ax in range(len(data_axes)) if ax not in axes_ind]
data_array = data_array.transpose(other_axes + axes_ind)

# Get the explicit shape of the axes that will remain, but set the distributed one
# to None (as will be needed for MPIArray.reshape)
remaining_shape = list(data_array.shape)
remaining_shape[data_array.axis] = None
new_ax_len = np.prod(remaining_shape[: -len(axes_ind)])
remaining_shape = remaining_shape[-len(axes_ind) :]

# Reshape the MPIArray, and redistribute over the flattened axis
data_array = data_array.reshape((new_ax_len, *remaining_shape))
data_array = data_array.redistribute(axis=0)

other_axes_names = [data_axes[ax] for ax in other_axes]

return data_array, other_axes_names


def _move_front(arr: np.ndarray, axis: int, shape: tuple) -> np.ndarray:
# Move the specified axis to the front and flatten to give a 2D array
new_arr = np.moveaxis(arr, axis, 0)
Expand Down
32 changes: 24 additions & 8 deletions draco/analysis/ringmapmaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,17 @@ class MakeVisGrid(task.SingleTask):

This will fill out the visibilities in the half plane `x >= 0` where x is the EW
baseline separation.

Attributes
----------
centered : bool
If set, place the zero NS separation at the center of the y-axis with
the baselines given in ascending order. Otherwise the zero separation
is at position zero, and the baselines are in FFT order.
"""

centered = config.Property(proptype=bool, default=False)

def setup(self, tel):
"""Set the Telescope instance to use.

Expand Down Expand Up @@ -88,9 +97,16 @@ def process(self, sstream):

# Define several variables describing the baseline configuration.
nx = np.abs(xind).max() + 1
ny = 2 * np.abs(yind).max() + 1
max_yind = np.abs(yind).max()
ny = 2 * max_yind + 1
vis_pos_x = np.arange(nx) * min_xsep
vis_pos_y = np.fft.fftfreq(ny, d=(1.0 / (ny * min_ysep)))

if self.centered:
vis_pos_y = np.arange(-max_yind, max_yind + 1) * min_ysep
ns_offset = max_yind
else:
vis_pos_y = np.fft.fftfreq(ny, d=(1.0 / (ny * min_ysep)))
ns_offset = 0

# Extract the right ascension to initialise the new container with (or
# calculate from timestamp)
Expand Down Expand Up @@ -135,15 +151,15 @@ def process(self, sstream):
# Unpack visibilities into new array
for vis_ind, (p_ind, x_ind, y_ind) in enumerate(zip(pind, xind, yind)):
# Different behavior for intracylinder and intercylinder baselines.
gsv[p_ind, :, x_ind, y_ind, :] = ssv[:, vis_ind]
gsw[p_ind, :, x_ind, y_ind, :] = ssw[:, vis_ind]
gsr[p_ind, x_ind, y_ind, :] = redundancy[vis_ind]
gsv[p_ind, :, x_ind, ns_offset + y_ind, :] = ssv[:, vis_ind]
gsw[p_ind, :, x_ind, ns_offset + y_ind, :] = ssw[:, vis_ind]
gsr[p_ind, x_ind, ns_offset - y_ind, :] = redundancy[vis_ind]

if x_ind == 0:
pc_ind = pconjmap[p_ind]
gsv[pc_ind, :, x_ind, -y_ind, :] = ssv[:, vis_ind].conj()
gsw[pc_ind, :, x_ind, -y_ind, :] = ssw[:, vis_ind]
gsr[pc_ind, x_ind, -y_ind, :] = redundancy[vis_ind]
gsv[pc_ind, :, x_ind, ns_offset - y_ind, :] = ssv[:, vis_ind].conj()
gsw[pc_ind, :, x_ind, ns_offset - y_ind, :] = ssw[:, vis_ind]
gsr[pc_ind, x_ind, ns_offset - y_ind, :] = redundancy[vis_ind]

return grid

Expand Down
13 changes: 9 additions & 4 deletions draco/analysis/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,15 +976,20 @@ def process(self, polcont):
YY_ind = list(polcont.index_map["pol"]).index("YY")

for name, dset in polcont.datasets.items():
out_dset = outcont.datasets[name]
if "pol" not in dset.attrs["axis"]:
outcont.datasets[name][:] = dset[:]
out_dset[:] = dset[:]
else:
pol_axis_pos = list(dset.attrs["axis"]).index("pol")

sl = tuple([slice(None)] * pol_axis_pos)
outcont.datasets[name][(*sl, 0)] = dset[(*sl, XX_ind)]
outcont.datasets[name][(*sl, 0)] += dset[(*sl, YY_ind)]
outcont.datasets[name][:] *= 0.5
out_dset[(*sl, 0)] = dset[(*sl, XX_ind)]
out_dset[(*sl, 0)] += dset[(*sl, YY_ind)]

if np.issubdtype(out_dset.dtype, np.integer):
out_dset[:] //= 2
else:
out_dset[:] *= 0.5

return outcont

Expand Down
136 changes: 136 additions & 0 deletions draco/analysis/wavelet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Wavelet power spectrum estimation."""
import os

import numpy as np
import pywt
import scipy.fft as fft
import scipy.linalg as la
from caput import config, mpiutil

from ..core import containers, task
from ..util import _fast_tools
from .delay import flatten_axes


class WaveletSpectrumEstimator(task.SingleTask):
"""Estimate a continuous wavelet power spectrum from the data.

This requires the input of the underlying data, *and* an estimate of its delay
spectrum to allow us to in-fill any masked frequencies.
"""

dataset = config.Property(proptype=str, default="vis")
average_axis = config.Property(proptype=str)
ndelay = config.Property(proptype=int, default=128)
wavelet = config.Property(proptype=str, default="morl")
chunks = config.Property(proptype=int, default=4)

def process(
self,
data: containers.FreqContainer,
dspec: containers.DelaySpectrum,
) -> containers.WaveletSpectrum:
"""Estimate the wavelet power spectrum.

Parameters
----------
data
Must have a frequency axis, and the axis to average over. Any other axes
will be flattened in the output.
dspec
The delay spectrum. The flattened `baseline` axis must match the remaining
axes in `data`.

Returns
-------
wspec
The wavelet spectrum. The non-frequency and averaging axes have been
collapsed into `baseline`.
"""
workers = int(os.environ.get("OMP_NUM_THREADS", 1))

dset_view, bl_axes = flatten_axes(
data[self.dataset], [self.average_axis, "freq"]
)
weight_view, _ = flatten_axes(
data.weight,
[self.average_axis, "freq"],
match_dset=data[self.dataset],
)

nbase = dset_view.global_shape[0]

df = np.abs(data.freq[1] - data.freq[0])
delay_scales = np.arange(1, self.ndelay + 1) / (2 * df * self.ndelay)

wv_scales = pywt.frequency2scale(self.wavelet, delay_scales * df)

wspec = containers.WaveletSpectrum(
baseline=nbase,
axes_from=data,
attrs_from=data,
delay=delay_scales,
)
# Copy the index maps for all the flattened axes into the output container, and
# write out their order into an attribute so we can reconstruct this easily
# when loading in the spectrum
for ax in bl_axes:
wspec.create_index_map(ax, data.index_map[ax])
wspec.attrs["baseline_axes"] = bl_axes

wspec.redistribute("baseline")
dspec.redistribute("baseline")
ws = wspec.spectrum[:].local_array
ww = wspec.weight[:].local_array
ds = dspec.spectrum[:].local_array

# Construct the freq<->delay mapping Fourier matrix
F = np.exp(
-2.0j
* np.pi
* dspec.index_map["delay"][np.newaxis, :]
* data.freq[:, np.newaxis]
)

for ii in range(dset_view.shape[0]):
self.log.info(f"Transforming baseline {ii} of {dset_view.shape[0]}")
d = dset_view.local_array[ii]
w = weight_view.local_array[ii]

# Construct an averaged frequency mask and use it to set the output
# weights
Ni = w.mean(axis=0)
ww[ii] = Ni

# Construct a Wiener filter to in-fill the data
D = ds[ii]
Df = (F * D[np.newaxis, :]) @ F.T.conj()
iDf = la.inv(Df)
Ci = iDf + np.diag(Ni)

# Solve for the infilled data
d_infill = la.solve(
Ci,
Ni[:, np.newaxis] * d.T,
assume_a="pos",
overwrite_a=True,
overwrite_b=True,
).T

# Doing the cwt and calculating the variance can eat a bunch of
# memory. Break it up into chunks to try and control this
for _, s, e in mpiutil.split_m(wv_scales.shape[0], self.chunks).T:
with fft.set_workers(workers):
wd, _ = pywt.cwt(
d_infill,
scales=wv_scales[s:e],
wavelet=self.wavelet,
axis=-1,
sampling_period=df,
method="fft",
)

# Calculate and set the variance
_fast_tools._fast_var(wd, ws[ii, s:e])

return wspec
Loading
Loading