diff --git a/draco/analysis/delay.py b/draco/analysis/delay.py index ed6f91c9f..c4fb9c520 100644 --- a/draco/analysis/delay.py +++ b/draco/analysis/delay.py @@ -1,10 +1,10 @@ """Delay space spectrum estimation and filtering.""" -import typing +from typing import List, Optional, Tuple, TypeVar import numpy as np import scipy.linalg as la -from caput import config, mpiarray +from caput import config, memh5, mpiarray from cora.util import units from numpy.lib.recfunctions import structured_to_unstructured @@ -138,7 +138,7 @@ def process(self, ss): # A specific subclass of a FreqContainer -FreqContainerType = typing.TypeVar("FreqContainerType", bound=containers.FreqContainer) +FreqContainerType = TypeVar("FreqContainerType", bound=containers.FreqContainer) class DelayFilterBase(task.SingleTask): @@ -641,41 +641,19 @@ def _process_data(self, ss): ) # Find the relevant axis positions - data_axes = ss.datasets[self.dataset].attrs["axis"] - freq_axis_pos = list(data_axes).index("freq") - average_axis_pos = list(data_axes).index(self.average_axis) - - # Create a view of the dataset with the relevant axes at the back, - # and all other axes compressed. End result is packed as - # [baseline_axis, average_axis, freq_axis]. - data_view = np.moveaxis( - ss.datasets[self.dataset][:].local_array, - [average_axis_pos, freq_axis_pos], - [-2, -1], + data_view, bl_axes = flatten_axes( + ss.datasets[self.dataset], [self.average_axis, "freq"] ) - data_view = data_view.reshape(-1, data_view.shape[-2], data_view.shape[-1]) - data_view = mpiarray.MPIArray.wrap(data_view, axis=2, comm=ss.comm) - nbase = int(np.prod(data_view.shape[:-2])) - data_view = data_view.redistribute(axis=0) - - # ... do the same for the weights, but we also need to make the weights full - # size - weight_full = np.zeros( - ss.datasets[self.dataset][:].shape, dtype=ss.weight.dtype + weight_view, _ = flatten_axes( + ss.weight, + [self.average_axis, "freq"], + match_dset=ss.datasets[self.dataset], ) - weight_full[:] = match_axes(ss.datasets[self.dataset], ss.weight) - weight_view = np.moveaxis( - weight_full, [average_axis_pos, freq_axis_pos], [-2, -1] - ) - weight_view = weight_view.reshape( - -1, weight_view.shape[-2], weight_view.shape[-1] - ) - weight_view = mpiarray.MPIArray.wrap(weight_view, axis=2, comm=ss.comm) - weight_view = weight_view.redistribute(axis=0) # Use the "baselines" axis to generically represent all the other axes # Initialise the spectrum container + nbase = data_view.global_shape[0] if self.output_power_spectrum: delay_spec = containers.DelaySpectrum( baseline=nbase, @@ -692,7 +670,6 @@ def _process_data(self, ss): ) delay_spec.redistribute("baseline") delay_spec.spectrum[:] = 0.0 - bl_axes = [da for da in data_axes if da not in [self.average_axis, "freq"]] # Copy the index maps for all the flattened axes into the output container, and # write out their order into an attribute so we can reconstruct this easily @@ -1458,6 +1435,73 @@ def match_axes(dset1, dset2): return dset2[:][bcast_slice] +def flatten_axes( + dset: memh5.MemDatasetDistributed, + axes_to_keep: List[str], + match_dset: Optional[memh5.MemDatasetDistributed] = None, +) -> Tuple[mpiarray.MPIArray, List[str]]: + """Move the specified axes of the dataset to the back, and flatten all others. + + Optionally this will add length-1 axes to match the axes of another dataset. + + Parameters + ---------- + dset + The dataset to reshape. + axes_to_keep + The names of the axes to keep. + match_dset + An optional dataset to match the shape of. + + Returns + ------- + flat_array + The MPIArray representing the re-arranged dataset. Distributed along the + flattened axis. + flat_axes + The names of the flattened axes from slowest to fastest varying. + """ + # Find the relevant axis positions + data_axes = list(dset.attrs["axis"]) + + # Check that the requested datasets actually exist + for axis in axes_to_keep: + if axis not in data_axes: + raise ValueError(f"Specified {axis=} not present in dataset.") + + # If specified, add extra axes to match the shape of the given dataset + if match_dset and tuple(dset.attrs["axis"]) != tuple(match_dset.attrs["axis"]): + dset_full = np.empty_like(match_dset[:]) + dset_full[:] = match_axes(match_dset, dset) + + axes_ind = [data_axes.index(axis) for axis in axes_to_keep] + + # Get an MPIArray and make sure it is distributed along one of the preserved axes + data_array = dset[:] + if data_array.axis not in axes_ind: + data_array = data_array.redistribute(axes_ind[0]) + + # Create a view of the dataset with the relevant axes at the back, + # and all others moved to the front (retaining their relative order) + other_axes = [ax for ax in range(len(data_axes)) if ax not in axes_ind] + data_array = data_array.transpose(other_axes + axes_ind) + + # Get the explicit shape of the axes that will remain, but set the distributed one + # to None (as will be needed for MPIArray.reshape) + remaining_shape = list(data_array.shape) + remaining_shape[data_array.axis] = None + new_ax_len = np.prod(remaining_shape[: -len(axes_ind)]) + remaining_shape = remaining_shape[-len(axes_ind) :] + + # Reshape the MPIArray, and redistribute over the flattened axis + data_array = data_array.reshape((new_ax_len, *remaining_shape)) + data_array = data_array.redistribute(axis=0) + + other_axes_names = [data_axes[ax] for ax in other_axes] + + return data_array, other_axes_names + + def _move_front(arr: np.ndarray, axis: int, shape: tuple) -> np.ndarray: # Move the specified axis to the front and flatten to give a 2D array new_arr = np.moveaxis(arr, axis, 0) diff --git a/draco/analysis/ringmapmaker.py b/draco/analysis/ringmapmaker.py index 340357f66..927179098 100644 --- a/draco/analysis/ringmapmaker.py +++ b/draco/analysis/ringmapmaker.py @@ -36,8 +36,17 @@ class MakeVisGrid(task.SingleTask): This will fill out the visibilities in the half plane `x >= 0` where x is the EW baseline separation. + + Attributes + ---------- + centered : bool + If set, place the zero NS separation at the center of the y-axis with + the baselines given in ascending order. Otherwise the zero separation + is at position zero, and the baselines are in FFT order. """ + centered = config.Property(proptype=bool, default=False) + def setup(self, tel): """Set the Telescope instance to use. @@ -88,9 +97,16 @@ def process(self, sstream): # Define several variables describing the baseline configuration. nx = np.abs(xind).max() + 1 - ny = 2 * np.abs(yind).max() + 1 + max_yind = np.abs(yind).max() + ny = 2 * max_yind + 1 vis_pos_x = np.arange(nx) * min_xsep - vis_pos_y = np.fft.fftfreq(ny, d=(1.0 / (ny * min_ysep))) + + if self.centered: + vis_pos_y = np.arange(-max_yind, max_yind + 1) * min_ysep + ns_offset = max_yind + else: + vis_pos_y = np.fft.fftfreq(ny, d=(1.0 / (ny * min_ysep))) + ns_offset = 0 # Extract the right ascension to initialise the new container with (or # calculate from timestamp) @@ -135,15 +151,15 @@ def process(self, sstream): # Unpack visibilities into new array for vis_ind, (p_ind, x_ind, y_ind) in enumerate(zip(pind, xind, yind)): # Different behavior for intracylinder and intercylinder baselines. - gsv[p_ind, :, x_ind, y_ind, :] = ssv[:, vis_ind] - gsw[p_ind, :, x_ind, y_ind, :] = ssw[:, vis_ind] - gsr[p_ind, x_ind, y_ind, :] = redundancy[vis_ind] + gsv[p_ind, :, x_ind, ns_offset + y_ind, :] = ssv[:, vis_ind] + gsw[p_ind, :, x_ind, ns_offset + y_ind, :] = ssw[:, vis_ind] + gsr[p_ind, x_ind, ns_offset - y_ind, :] = redundancy[vis_ind] if x_ind == 0: pc_ind = pconjmap[p_ind] - gsv[pc_ind, :, x_ind, -y_ind, :] = ssv[:, vis_ind].conj() - gsw[pc_ind, :, x_ind, -y_ind, :] = ssw[:, vis_ind] - gsr[pc_ind, x_ind, -y_ind, :] = redundancy[vis_ind] + gsv[pc_ind, :, x_ind, ns_offset - y_ind, :] = ssv[:, vis_ind].conj() + gsw[pc_ind, :, x_ind, ns_offset - y_ind, :] = ssw[:, vis_ind] + gsr[pc_ind, x_ind, ns_offset - y_ind, :] = redundancy[vis_ind] return grid diff --git a/draco/analysis/transform.py b/draco/analysis/transform.py index 867f3d04b..1bbec5b81 100644 --- a/draco/analysis/transform.py +++ b/draco/analysis/transform.py @@ -976,15 +976,20 @@ def process(self, polcont): YY_ind = list(polcont.index_map["pol"]).index("YY") for name, dset in polcont.datasets.items(): + out_dset = outcont.datasets[name] if "pol" not in dset.attrs["axis"]: - outcont.datasets[name][:] = dset[:] + out_dset[:] = dset[:] else: pol_axis_pos = list(dset.attrs["axis"]).index("pol") sl = tuple([slice(None)] * pol_axis_pos) - outcont.datasets[name][(*sl, 0)] = dset[(*sl, XX_ind)] - outcont.datasets[name][(*sl, 0)] += dset[(*sl, YY_ind)] - outcont.datasets[name][:] *= 0.5 + out_dset[(*sl, 0)] = dset[(*sl, XX_ind)] + out_dset[(*sl, 0)] += dset[(*sl, YY_ind)] + + if np.issubdtype(out_dset.dtype, np.integer): + out_dset[:] //= 2 + else: + out_dset[:] *= 0.5 return outcont diff --git a/draco/analysis/wavelet.py b/draco/analysis/wavelet.py new file mode 100644 index 000000000..4720fb0da --- /dev/null +++ b/draco/analysis/wavelet.py @@ -0,0 +1,136 @@ +"""Wavelet power spectrum estimation.""" +import os + +import numpy as np +import pywt +import scipy.fft as fft +import scipy.linalg as la +from caput import config, mpiutil + +from ..core import containers, task +from ..util import _fast_tools +from .delay import flatten_axes + + +class WaveletSpectrumEstimator(task.SingleTask): + """Estimate a continuous wavelet power spectrum from the data. + + This requires the input of the underlying data, *and* an estimate of its delay + spectrum to allow us to in-fill any masked frequencies. + """ + + dataset = config.Property(proptype=str, default="vis") + average_axis = config.Property(proptype=str) + ndelay = config.Property(proptype=int, default=128) + wavelet = config.Property(proptype=str, default="morl") + chunks = config.Property(proptype=int, default=4) + + def process( + self, + data: containers.FreqContainer, + dspec: containers.DelaySpectrum, + ) -> containers.WaveletSpectrum: + """Estimate the wavelet power spectrum. + + Parameters + ---------- + data + Must have a frequency axis, and the axis to average over. Any other axes + will be flattened in the output. + dspec + The delay spectrum. The flattened `baseline` axis must match the remaining + axes in `data`. + + Returns + ------- + wspec + The wavelet spectrum. The non-frequency and averaging axes have been + collapsed into `baseline`. + """ + workers = int(os.environ.get("OMP_NUM_THREADS", 1)) + + dset_view, bl_axes = flatten_axes( + data[self.dataset], [self.average_axis, "freq"] + ) + weight_view, _ = flatten_axes( + data.weight, + [self.average_axis, "freq"], + match_dset=data[self.dataset], + ) + + nbase = dset_view.global_shape[0] + + df = np.abs(data.freq[1] - data.freq[0]) + delay_scales = np.arange(1, self.ndelay + 1) / (2 * df * self.ndelay) + + wv_scales = pywt.frequency2scale(self.wavelet, delay_scales * df) + + wspec = containers.WaveletSpectrum( + baseline=nbase, + axes_from=data, + attrs_from=data, + delay=delay_scales, + ) + # Copy the index maps for all the flattened axes into the output container, and + # write out their order into an attribute so we can reconstruct this easily + # when loading in the spectrum + for ax in bl_axes: + wspec.create_index_map(ax, data.index_map[ax]) + wspec.attrs["baseline_axes"] = bl_axes + + wspec.redistribute("baseline") + dspec.redistribute("baseline") + ws = wspec.spectrum[:].local_array + ww = wspec.weight[:].local_array + ds = dspec.spectrum[:].local_array + + # Construct the freq<->delay mapping Fourier matrix + F = np.exp( + -2.0j + * np.pi + * dspec.index_map["delay"][np.newaxis, :] + * data.freq[:, np.newaxis] + ) + + for ii in range(dset_view.shape[0]): + self.log.info(f"Transforming baseline {ii} of {dset_view.shape[0]}") + d = dset_view.local_array[ii] + w = weight_view.local_array[ii] + + # Construct an averaged frequency mask and use it to set the output + # weights + Ni = w.mean(axis=0) + ww[ii] = Ni + + # Construct a Wiener filter to in-fill the data + D = ds[ii] + Df = (F * D[np.newaxis, :]) @ F.T.conj() + iDf = la.inv(Df) + Ci = iDf + np.diag(Ni) + + # Solve for the infilled data + d_infill = la.solve( + Ci, + Ni[:, np.newaxis] * d.T, + assume_a="pos", + overwrite_a=True, + overwrite_b=True, + ).T + + # Doing the cwt and calculating the variance can eat a bunch of + # memory. Break it up into chunks to try and control this + for _, s, e in mpiutil.split_m(wv_scales.shape[0], self.chunks).T: + with fft.set_workers(workers): + wd, _ = pywt.cwt( + d_infill, + scales=wv_scales[s:e], + wavelet=self.wavelet, + axis=-1, + sampling_period=df, + method="fft", + ) + + # Calculate and set the variance + _fast_tools._fast_var(wd, ws[ii, s:e]) + + return wspec diff --git a/draco/core/containers.py b/draco/core/containers.py index bfe2d6e4e..a648e78f5 100644 --- a/draco/core/containers.py +++ b/draco/core/containers.py @@ -2408,6 +2408,39 @@ def freq(self): return self.attrs["freq"] +class WaveletSpectrum(ContainerBase): + """Container for a wavelet power spectrum.""" + + _axes = ("baseline", "freq", "delay") + + _dataset_spec = { + "spectrum": { + "axes": ["baseline", "delay", "freq"], + "dtype": np.float64, + "initialise": True, + "distributed": True, + "distributed_axis": "baseline", + }, + "weight": { + "axes": ["baseline", "freq"], + "dtype": np.float64, + "initialise": True, + "distributed": True, + "distributed_axis": "baseline", + }, + } + + @property + def spectrum(self): + """The wavelet spectrum.""" + return self.datasets["spectrum"] + + @property + def weight(self): + """The weights for the spectrum.""" + return self.datasets["weight"] + + class Powerspectrum2D(ContainerBase): """Container for a 2D cartesian power spectrum. diff --git a/draco/util/_fast_tools.pyx b/draco/util/_fast_tools.pyx index cd9d7907e..cc11b6063 100644 --- a/draco/util/_fast_tools.pyx +++ b/draco/util/_fast_tools.pyx @@ -1,16 +1,25 @@ """A few miscellaneous Cython routines to speed up critical operations.""" from cython.parallel import prange, parallel +from cython.view cimport array as cvarray cimport cython import numpy as np cimport numpy as np from libc.stdint cimport int16_t, int32_t +from libc.stdlib cimport malloc, free from libc.math cimport sin from libc.math cimport cos from libc.math cimport fabs +ctypedef double complex complex128 +ctypedef float complex complex64 + +cdef extern from "complex.h" nogil: + double creal(complex128) + double cimag(complex128) + cdef inline int int_max(int a, int b) nogil: return a if a >= b else b @@ -260,3 +269,70 @@ ctypedef fused real_or_complex: double complex float float complex + + +@cython.wraparound(False) +@cython.boundscheck(False) +@cython.cdivision(True) +cpdef _fast_var( + real_or_complex[:, :, ::1] arr, + double[:, ::1] out +): + """Fast parallel variance evaluation. + + This is specialised to dim=3 arrays with the variance taken along axis=1. This + implementation uses the Youngs and Cramer single pass algorithm. + + Parameters + ---------- + arr + The array to take the variance of. + out + The array to write the variance into. + """ + cdef int N1, N2, N3 + cdef int ii, jj, kk + cdef real_or_complex t + cdef double t2 + cdef real_or_complex* T + cdef int size = sizeof(T) + + N1 = arr.shape[0] + N2 = arr.shape[1] + N3 = arr.shape[2] + + if arr.shape[0] != out.shape[0] or arr.shape[2] != out.shape[1]: + raise ValueError("Input and output array shapes incompatible.") + + with nogil, parallel(): + + T = malloc(N3 * sizeof(real_or_complex)) + + for ii in prange(N1): + + for kk in range(N3): + out[ii, kk] = 0.0 + T[kk] = arr[ii, 0, kk] + + for jj in range(1, N2): + for kk in range(N3): + t = arr[ii, jj, kk] + T[kk] += t + t = (jj + 1) * t - T[kk] + if ( + (real_or_complex is cython.doublecomplex) or + (real_or_complex is cython.floatcomplex) + ): + t2 = creal(t)**2 + cimag(t)**2 + else: + t2 = t * t + out[ii, kk] += t2 / (jj * (jj + 1)) + + + for kk in range(N3): + out[ii, kk] /= N2 + + free(T) + + + diff --git a/requirements.txt b/requirements.txt index a547bcb17..3d5fa7aed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ mpi4py numpy>=1.17 scipy>=0.10 skyfield +pywavelets \ No newline at end of file