Skip to content

Commit

Permalink
Add merge subdomain method (#2243)
Browse files Browse the repository at this point in the history
This adds a method to `fv3fit.reservoir.domain.RankDivider` that is used
to reshape reservoir model outputs from a flat array to the original x,
y, z dims.
 
The reshaping in the function is not very intuitive to follow in the
code, a more detailed breakdown of the reshaping is in [this
notebook](https://github.com/ai2cm/explore/blob/master/annak/2023-05-10-hybrid-reservoir/2023-06-15-subdomain-reshaping.ipynb).
The section "Breakdown of what's going on in the merge function"
demonstrates what the various reshaping steps are doing.




Added public API:
- `RankDivider.merge_subdomains`
 

- [x] Tests added

Resolves #<github issues> [JIRA-TAG]

Coverage reports (updated automatically):
  • Loading branch information
AnnaKwa authored Jun 21, 2023
1 parent aaf517b commit 9327855
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 1 deletion.
13 changes: 13 additions & 0 deletions external/fv3fit/fv3fit/reservoir/_reshaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,16 @@ def stack_data(tensor, keep_first_dim: bool):
return np.reshape(tensor, (n_samples, -1))
else:
return np.reshape(tensor, -1)


def split_1d_samples_into_2d_rows(
arr: np.ndarray, n_rows: int, keep_first_dim_shape: bool
) -> np.ndarray:
# Consecutive chunks of 1d array form rows of 2d array
# ex. 1d to 2d reshaping (8,) -> (2,4) for n_rows=2
# [1,2,3,4,5,6,7,8] -> [[1,2,3,4], [5,6,7,8]]
if keep_first_dim_shape is True:
time_dim_size = arr.shape[0]
return np.reshape(arr, (time_dim_size, n_rows, -1), order="C")
else:
return np.reshape(arr, (n_rows, -1), order="C")
36 changes: 35 additions & 1 deletion external/fv3fit/fv3fit/reservoir/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tensorflow as tf
from typing import Sequence, Iterable
import yaml
from ._reshaping import stack_data
from ._reshaping import stack_data, split_1d_samples_into_2d_rows
import pace.util


Expand Down Expand Up @@ -202,6 +202,40 @@ def load(cls, path):
metadata = yaml.safe_load(f)
return cls(**metadata)

def merge_subdomains(self, flat_prediction: np.ndarray):
# raw prediction from readout is a long 1D array consisting of concatenated
# flattened subdomain predictions

# separate the prediction into its constituent subdomains
subdomain_rows = split_1d_samples_into_2d_rows(
flat_prediction, n_rows=self.n_subdomains, keep_first_dim_shape=False,
)
subdomain_2d_predictions = []

# reshape each subdomain into (x, y, z) dims
for subdomain_row in subdomain_rows:
subdomain_2d_prediction = self.unstack_subdomain(
subdomain_row, with_overlap=False,
)
subdomain_2d_predictions.append(subdomain_2d_prediction)

subdomain_shape_without_overlap = (
self.subdomain_xy_size_without_overlap,
self.subdomain_xy_size_without_overlap,
)

# reshape the flat list of 3D subdomains into a single array that
# is a Xdomain, Ydomain grid with a (x, y, z) subdomain in each block
z_block_dims = (
*self.subdomain_layout,
*subdomain_shape_without_overlap,
self._n_features,
)
domain_z_blocks = np.array(subdomain_2d_predictions).reshape(*z_block_dims)

# Merge along Xdomain, Ydomain dims into a single array of dims (x, y, z)
return np.concatenate(np.concatenate(domain_z_blocks, axis=2), axis=0)


class TimeSeriesRankDivider(RankDivider):
def get_subdomain_tensor_slice(
Expand Down
21 changes: 21 additions & 0 deletions external/fv3fit/tests/reservoir/test__reshaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fv3fit.reservoir._reshaping import (
flatten_2d_keeping_columns_contiguous,
stack_data,
split_1d_samples_into_2d_rows,
)


Expand All @@ -27,3 +28,23 @@ def test_flatten_2d_keeping_columns_contiguous():
np.testing.assert_array_equal(
flatten_2d_keeping_columns_contiguous(x), np.array([1, 3, 5, 2, 4, 6])
)


def test_split_1d_samples_into_2d_rows():
x = np.arange(12)
x_2d = split_1d_samples_into_2d_rows(x, n_rows=4, keep_first_dim_shape=False)
np.testing.assert_array_equal(
x_2d, np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]])
)


def test_split_1d_samples_into_2d_rows_keep_first_dim_shape():
nt = 3
x = np.array([np.arange(12) for i in range(nt)])
x_2d = split_1d_samples_into_2d_rows(x, n_rows=4, keep_first_dim_shape=True)

expected = np.array(
[np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]) for i in range(nt)]
)

np.testing.assert_array_equal(x_2d, expected)
26 changes: 26 additions & 0 deletions external/fv3fit/tests/reservoir/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,29 @@ def test_RankDivider_subdomain_xy_size_without_overlap():
overlap=2,
)
assert divider.subdomain_xy_size_without_overlap == 2


def test_RankDivider_merge_subdomains():
# Original (x, y, z) dims are (4, 4, 2)
horizontal_array = np.arange(16).reshape(4, 4)
data_orig = np.stack([horizontal_array, -1.0 * horizontal_array], axis=-1)
rank_divider = RankDivider(
subdomain_layout=(2, 2),
rank_dims=["x", "y", "z"],
rank_extent=(4, 4, 2),
overlap=0,
)

# 'prediction' will just be the subdomains reshaped into columns and
# concatenated together. We want the `merge_subdomains` function to
# be able to take this 1D array and reshape it into the correct (x,y,z)
# dimensions matching the original data.
subdomain_columns = rank_divider.flatten_subdomains_to_columns(
data_orig, with_overlap=False
)
prediction = np.concatenate(
[subdomain_columns[:, s] for s in range(rank_divider.n_subdomains)], axis=-1
)

merged = rank_divider.merge_subdomains(prediction)
np.testing.assert_array_equal(merged, data_orig)

0 comments on commit 9327855

Please sign in to comment.