Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev branch for reservoir model training features #2300

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
20 changes: 12 additions & 8 deletions external/fv3fit/fv3fit/reservoir/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ReservoirTrainingConfig(Hyperparameters):
output_variables: time series variables, must be subset of input_variables
reservoir_hyperparameters: hyperparameters for reservoir
readout_hyperparameters: hyperparameters for readout
n_batches_burn: number of training batches at start of time series to use
n_timesteps_synchronize: number of timesteps at start of time series to use
for synchronizaton. This data is used to update the reservoir state
but is not included in training.
input_noise: stddev of normal distribution which is sampled to add input
Expand All @@ -88,21 +88,23 @@ class ReservoirTrainingConfig(Hyperparameters):
subdomain: CubedsphereSubdomainConfig
reservoir_hyperparameters: ReservoirHyperparameters
readout_hyperparameters: BatchLinearRegressorHyperparameters
n_batches_burn: int
n_timesteps_synchronize: int
input_noise: float
seed: int = 0
n_jobs: Optional[int] = 1
square_half_hidden_state: bool = False
autoencoder_path: Optional[str] = None
hybrid_autoencoder_path: Optional[str] = None
hybrid_variables: Optional[Sequence[str]] = None
_METADATA_NAME = "reservoir_training_config.yaml"

def __post_init__(self):
if set(self.output_variables).issubset(self.input_variables) is False:
raise ValueError(
f"Output variables {self.output_variables} must be a subset of "
f"input variables {self.input_variables}."
)
if len(set(self.output_variables).intersection(self.input_variables)) > 0:
raise ValueError(
f"Output variables {self.output_variables} must either be a subset "
f"of input variables {self.input_variables} or mutually exclusive."
)
if self.hybrid_variables is not None:
hybrid_and_input_vars_intersection = set(
self.hybrid_variables
Expand All @@ -119,7 +121,9 @@ def variables(self) -> Set[str]:
hybrid_vars = list(self.hybrid_variables) # type: ignore
else:
hybrid_vars = []
return set(list(self.input_variables) + hybrid_vars)
return set(
list(self.input_variables) + list(self.output_variables) + hybrid_vars
)

@classmethod
def from_dict(cls, kwargs) -> "ReservoirTrainingConfig":
Expand Down Expand Up @@ -148,7 +152,7 @@ def from_dict(cls, kwargs) -> "ReservoirTrainingConfig":

def dump(self, path: str):
metadata = {
"n_batches_burn": self.n_batches_burn,
"n_timesteps_synchronize": self.n_timesteps_synchronize,
"input_noise": self.input_noise,
"seed": self.seed,
"n_jobs": self.n_jobs,
Expand Down
2 changes: 1 addition & 1 deletion external/fv3fit/fv3fit/reservoir/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def dump(self, path):
metadata = {
"subdomain_layout": self.subdomain_layout,
"rank_dims": self.rank_dims,
"rank_extent": self.rank_extent,
"rank_extent": list(self.rank_extent),
"overlap": self.overlap,
}
with fsspec.open(path, "w") as f:
Expand Down
66 changes: 55 additions & 11 deletions external/fv3fit/fv3fit/reservoir/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import fsspec
import numpy as np
import os
from typing import Iterable, Hashable, Sequence, cast
from typing import Iterable, Hashable, Sequence, cast, Optional
import xarray as xr
import yaml

Expand Down Expand Up @@ -35,6 +35,7 @@ def _transpose_xy_dims(ds: xr.Dataset, rank_dims: Sequence[str]):
@io.register("hybrid-reservoir")
class HybridReservoirComputingModel(Predictor):
_HYBRID_VARIABLES_NAME = "hybrid_variables.yaml"
_AUTOENCODER_SUBDIR = "autoencoder"

def __init__(
self,
Expand All @@ -45,6 +46,8 @@ def __init__(
readout: ReservoirComputingReadout,
rank_divider: RankDivider,
autoencoder: ReloadableTransfomer,
output_autoencoder: Optional[ReloadableTransfomer] = None,
hybrid_autoencoder: Optional[ReloadableTransfomer] = None,
square_half_hidden_state: bool = False,
):
self.reservoir_model = ReservoirComputingModel(
Expand All @@ -55,20 +58,24 @@ def __init__(
square_half_hidden_state=square_half_hidden_state,
rank_divider=rank_divider,
autoencoder=autoencoder,
output_autoencoder=output_autoencoder,
)
self.input_variables = input_variables
self.hybrid_variables = hybrid_variables
self.output_variables = output_variables
self.readout = readout
self.square_half_hidden_state = square_half_hidden_state
self.rank_divider = rank_divider
self.autoencoder = autoencoder
self.input_autoencoder = autoencoder
self.output_autoencoder = output_autoencoder
self.hybrid_autoencoder = hybrid_autoencoder

def predict(self, hybrid_input: Sequence[np.ndarray]):
# hybrid input is assumed to be in original spatial xy dims
# (x, y, feature) and does not include overlaps.
hybrid_autoencoder = self.hybrid_autoencoder or self.input_autoencoder
encoded_hybrid_input = encode_columns(
input_arrs=hybrid_input, transformer=self.autoencoder
input_arrs=hybrid_input, transformer=hybrid_autoencoder
)
if encoded_hybrid_input.shape[:2] != tuple(
self.rank_divider.rank_extent_without_overlap
Expand All @@ -93,9 +100,11 @@ def predict(self, hybrid_input: Sequence[np.ndarray]):
)
flat_prediction = self.readout.predict(flattened_readout_input).reshape(-1)
prediction = self.rank_divider.merge_subdomains(flat_prediction)

output_autoencoder = self.output_autoencoder or self.input_autoencoder
decoded_prediction = decode_columns(
encoded_output=prediction,
transformer=self.autoencoder,
transformer=output_autoencoder,
xy_shape=self.rank_divider.rank_extent_without_overlap,
)
return decoded_prediction
Expand All @@ -122,20 +131,36 @@ def dump(self, path: str) -> None:
self.reservoir_model.dump(path)
with fsspec.open(os.path.join(path, self._HYBRID_VARIABLES_NAME), "w") as f:
f.write(yaml.dump({"hybrid_variables": self.hybrid_variables}))
if self.hybrid_autoencoder is not None:
fv3fit.dump(
self.hybrid_autoencoder,
os.path.join(path, self._AUTOENCODER_SUBDIR, "hybrid"),
)

@classmethod
def load(cls, path: str) -> "HybridReservoirComputingModel":
pure_reservoir_model = ReservoirComputingModel.load(path)
with fsspec.open(os.path.join(path, cls._HYBRID_VARIABLES_NAME), "r") as f:
hybrid_variables = yaml.safe_load(f)["hybrid_variables"]

try:
hybrid_autoencoder = cast(
ReloadableTransfomer,
fv3fit.load(os.path.join(path, cls._AUTOENCODER_SUBDIR, "hybrid")),
)
except (KeyError):
hybrid_autoencoder = None # type: ignore

return cls(
input_variables=pure_reservoir_model.input_variables,
output_variables=pure_reservoir_model.output_variables,
reservoir=pure_reservoir_model.reservoir,
readout=pure_reservoir_model.readout,
square_half_hidden_state=pure_reservoir_model.square_half_hidden_state,
rank_divider=pure_reservoir_model.rank_divider,
autoencoder=pure_reservoir_model.autoencoder,
autoencoder=pure_reservoir_model.input_autoencoder,
output_autoencoder=pure_reservoir_model.output_autoencoder,
hybrid_autoencoder=hybrid_autoencoder,
hybrid_variables=hybrid_variables,
)

Expand Down Expand Up @@ -200,6 +225,7 @@ def __init__(
readout: ReservoirComputingReadout,
rank_divider: RankDivider,
autoencoder: ReloadableTransfomer,
output_autoencoder: Optional[ReloadableTransfomer] = None,
square_half_hidden_state: bool = False,
):
"""_summary_
Expand All @@ -219,7 +245,8 @@ def __init__(
self.readout = readout
self.square_half_hidden_state = square_half_hidden_state
self.rank_divider = rank_divider
self.autoencoder = autoencoder
self.input_autoencoder = autoencoder
self.output_autoencoder = output_autoencoder

def process_state_to_readout_input(self):
if self.square_half_hidden_state is True:
Expand All @@ -236,9 +263,10 @@ def predict(self):
readout_input = self.process_state_to_readout_input()
flat_prediction = self.readout.predict(readout_input).reshape(-1)
prediction = self.rank_divider.merge_subdomains(flat_prediction)
output_autoencoder = self.output_autoencoder or self.input_autoencoder
decoded_prediction = decode_columns(
encoded_output=prediction,
transformer=self.autoencoder,
transformer=output_autoencoder,
xy_shape=self.rank_divider.rank_extent_without_overlap,
)
return decoded_prediction
Expand All @@ -253,7 +281,7 @@ def reset_state(self):
def increment_state(self, prediction_with_overlap: Sequence[np.ndarray]) -> None:
# input array is in native x, y, z_feature coordinates
encoded_xy_input_arrs = encode_columns(
prediction_with_overlap, self.autoencoder
prediction_with_overlap, self.input_autoencoder
)
encoded_flattened_subdomains = self.rank_divider.flatten_subdomains_to_columns(
encoded_xy_input_arrs, with_overlap=True
Expand Down Expand Up @@ -281,8 +309,15 @@ def dump(self, path: str) -> None:
f.write(yaml.dump(metadata))

self.rank_divider.dump(os.path.join(path, self._RANK_DIVIDER_NAME))
if self.autoencoder is not None:
fv3fit.dump(self.autoencoder, os.path.join(path, self._AUTOENCODER_SUBDIR))
fv3fit.dump(
self.input_autoencoder,
os.path.join(path, self._AUTOENCODER_SUBDIR, "input"),
)
if self.output_autoencoder is not None:
fv3fit.dump(
self.output_autoencoder,
os.path.join(path, self._AUTOENCODER_SUBDIR, "output"),
)

@classmethod
def load(cls, path: str) -> "ReservoirComputingModel":
Expand All @@ -298,8 +333,16 @@ def load(cls, path: str) -> "ReservoirComputingModel":

autoencoder = cast(
ReloadableTransfomer,
fv3fit.load(os.path.join(path, cls._AUTOENCODER_SUBDIR)),
fv3fit.load(os.path.join(path, cls._AUTOENCODER_SUBDIR, "input")),
)
try:
output_autoencoder = cast(
ReloadableTransfomer,
fv3fit.load(os.path.join(path, cls._AUTOENCODER_SUBDIR, "output")),
)
except (KeyError):
output_autoencoder = None # type: ignore

return cls(
input_variables=metadata["input_variables"],
output_variables=metadata["output_variables"],
Expand All @@ -308,4 +351,5 @@ def load(cls, path: str) -> "ReservoirComputingModel":
square_half_hidden_state=metadata["square_half_hidden_state"],
rank_divider=rank_divider,
autoencoder=autoencoder,
output_autoencoder=output_autoencoder,
)
Loading