From 4fa8c00340580097f71fc33e05249752ce773a39 Mon Sep 17 00:00:00 2001 From: Richard Shaw Date: Tue, 19 Dec 2023 09:24:16 +0100 Subject: [PATCH] refactor(delay): use flatten_axes routine --- draco/analysis/delay.py | 37 +++++++------------------------------ 1 file changed, 7 insertions(+), 30 deletions(-) diff --git a/draco/analysis/delay.py b/draco/analysis/delay.py index 5a27ff55c..c4fb9c520 100644 --- a/draco/analysis/delay.py +++ b/draco/analysis/delay.py @@ -641,41 +641,19 @@ def _process_data(self, ss): ) # Find the relevant axis positions - data_axes = ss.datasets[self.dataset].attrs["axis"] - freq_axis_pos = list(data_axes).index("freq") - average_axis_pos = list(data_axes).index(self.average_axis) - - # Create a view of the dataset with the relevant axes at the back, - # and all other axes compressed. End result is packed as - # [baseline_axis, average_axis, freq_axis]. - data_view = np.moveaxis( - ss.datasets[self.dataset][:].local_array, - [average_axis_pos, freq_axis_pos], - [-2, -1], + data_view, bl_axes = flatten_axes( + ss.datasets[self.dataset], [self.average_axis, "freq"] ) - data_view = data_view.reshape(-1, data_view.shape[-2], data_view.shape[-1]) - data_view = mpiarray.MPIArray.wrap(data_view, axis=2, comm=ss.comm) - nbase = int(np.prod(data_view.shape[:-2])) - data_view = data_view.redistribute(axis=0) - - # ... do the same for the weights, but we also need to make the weights full - # size - weight_full = np.zeros( - ss.datasets[self.dataset][:].shape, dtype=ss.weight.dtype + weight_view, _ = flatten_axes( + ss.weight, + [self.average_axis, "freq"], + match_dset=ss.datasets[self.dataset], ) - weight_full[:] = match_axes(ss.datasets[self.dataset], ss.weight) - weight_view = np.moveaxis( - weight_full, [average_axis_pos, freq_axis_pos], [-2, -1] - ) - weight_view = weight_view.reshape( - -1, weight_view.shape[-2], weight_view.shape[-1] - ) - weight_view = mpiarray.MPIArray.wrap(weight_view, axis=2, comm=ss.comm) - weight_view = weight_view.redistribute(axis=0) # Use the "baselines" axis to generically represent all the other axes # Initialise the spectrum container + nbase = data_view.global_shape[0] if self.output_power_spectrum: delay_spec = containers.DelaySpectrum( baseline=nbase, @@ -692,7 +670,6 @@ def _process_data(self, ss): ) delay_spec.redistribute("baseline") delay_spec.spectrum[:] = 0.0 - bl_axes = [da for da in data_axes if da not in [self.average_axis, "freq"]] # Copy the index maps for all the flattened axes into the output container, and # write out their order into an attribute so we can reconstruct this easily