Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean res precip #2353

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions workflows/prognostic_c48_run/runtime/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,10 +601,11 @@ def _apply_reservoir_update_to_state(self) -> Diagnostics:
if self._reservoir_predict_stepper.is_diagnostic: # type: ignore
rename_diagnostics(diags, label="reservoir_predictor")

state_updates[TOTAL_PRECIP] = precipitation_sum(
self._state[TOTAL_PRECIP], net_moistening, self._timestep,
precip = self._reservoir_predict_stepper.update_precip( # type: ignore
self._state[TOTAL_PRECIP], net_moistening
)

diags.update(precip)
state_updates[TOTAL_PRECIP] = precip[TOTAL_PRECIP]
self._state.update_mass_conserving(state_updates)

diags.update({name: self._state[name] for name in self._states_to_output})
Expand All @@ -614,9 +615,7 @@ def _apply_reservoir_update_to_state(self) -> Diagnostics:
"cnvprcp_after_python": self._wrapper.get_diagnostic_by_name(
"cnvprcp"
).data_array,
TOTAL_PRECIP_RATE: precipitation_rate(
self._state[TOTAL_PRECIP], self._timestep
),
TOTAL_PRECIP_RATE: diags["total_precip_rate_res_interval_avg"],
}
)

Expand Down
107 changes: 102 additions & 5 deletions workflows/prognostic_c48_run/runtime/steppers/reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
cast,
Sequence,
Dict,
Union,
)

import fv3fit
from fv3fit._shared.halos import append_halos_using_mpi
from fv3fit.reservoir.adapters import ReservoirDatasetAdapter
from runtime.names import SST, SPHUM, TEMP
from runtime.names import SST, SPHUM, TEMP, PHYSICS_PRECIP_RATE, TOTAL_PRECIP
from runtime.tendency import add_tendency, tendencies_from_state_updates
from runtime.diagnostics import (
enforce_heating_and_moistening_tendency_constraints,
Expand Down Expand Up @@ -54,7 +55,7 @@ class ReservoirConfig:
limiter. Defaults to false.
"""

models: Mapping[int, str]
models: Mapping[Union[int, str], str]
synchronize_steps: int = 1
reservoir_timestep: str = "3h" # TODO: Could this be inferred?
time_average_inputs: bool = False
Expand All @@ -63,6 +64,22 @@ class ReservoirConfig:
rename_mapping: NameDict = dataclasses.field(default_factory=dict)
hydrostatic: bool = False
mse_conserving_limiter: bool = False
interval_average_precipitation: bool = False

def __post_init__(self):
# This handles cases in automatic config writing where json/yaml
# do not allow integer keys
_models = {}
for key, url in self.models.items():
try:
int_key = int(key)
_models[int_key] = url
except (ValueError) as e:
raise ValueError(
"Keys in reservoir_corrector.models must be integers "
"or string representation of integers."
) from e
self.models = _models


class _FiniteStateMachine:
Expand Down Expand Up @@ -104,6 +121,49 @@ def __call__(self, state: str):
)


class PrecipTracker:
def __init__(self, reservoir_timestep_seconds: float):
self.reservoir_timestep_seconds = reservoir_timestep_seconds
self.physics_precip_averager = TimeAverageInputs([PHYSICS_PRECIP_RATE])
self._air_temperature_at_previous_interval = None
self._specific_humidity_at_previous_interval = None

def increment_physics_precip_rate(self, physics_precip_rate):
self.physics_precip_averager.increment_running_average(
{PHYSICS_PRECIP_RATE: physics_precip_rate}
)

def interval_avg_precip_rates(self, net_moistening_due_to_reservoir):
physics_precip_rate = self.physics_precip_averager.get_averages()[
PHYSICS_PRECIP_RATE
]
total_precip_rate = physics_precip_rate - net_moistening_due_to_reservoir
total_precip_rate = total_precip_rate.where(total_precip_rate >= 0, 0)
reservoir_precip_rate = total_precip_rate - physics_precip_rate
return {
"total_precip_rate_res_interval_avg": total_precip_rate,
"physics_precip_rate_res_interval_avg": physics_precip_rate,
"reservoir_precip_rate_res_interval_avg": reservoir_precip_rate,
}

def accumulated_precip_update(
self,
physics_precip_total_over_model_timestep,
reservoir_precip_rate_over_res_interval,
reservoir_timestep,
):
# Since the reservoir correction is only applied every reservoir_timestep,
# all of the precip due to the reservoir is put into the accumulated precip
# in the model timestep at update time.
m_per_mm = 1 / 1000
reservoir_total_precip = (
reservoir_precip_rate_over_res_interval * reservoir_timestep * m_per_mm
)
total_precip = physics_precip_total_over_model_timestep + reservoir_total_precip
total_precip.attrs["units"] = "m"
return total_precip


class TimeAverageInputs:
"""
Copy of time averaging components from runtime.diagnostics.manager to
Expand Down Expand Up @@ -170,6 +230,7 @@ def __init__(
warm_start: bool = False,
hydrostatic: bool = False,
mse_conserving_limiter: bool = False,
precip_tracker: Optional[PrecipTracker] = None,
):
self.model = model
self.synchronize_steps = synchronize_steps
Expand All @@ -181,6 +242,7 @@ def __init__(
self.warm_start = warm_start
self.hydrostatic = hydrostatic
self.mse_conserving_limiter = mse_conserving_limiter
self.precip_tracker = precip_tracker

if state_machine is None:
state_machine = _FiniteStateMachine()
Expand Down Expand Up @@ -250,8 +312,8 @@ def _get_inputs_from_state(self, state):
)
except RuntimeError:
raise ValueError(
"MPI not available or tile dimension does not exist in state fields"
" during reservoir increment update"
"MPI not available or tile dimension does not exist in state "
"fields during reservoir increment update"
)
reservoir_inputs = rc_in_with_halos

Expand Down Expand Up @@ -360,6 +422,11 @@ def __call__(self, time, state):
if self.input_averager is not None:
self.input_averager.increment_running_average(inputs)

if self.precip_tracker is not None:
self.precip_tracker.increment_physics_precip_rate(
state[PHYSICS_PRECIP_RATE]
)

if self._is_rc_update_step(time):
logger.info(f"Reservoir model predict at time {time}")
if self.input_averager is not None:
Expand Down Expand Up @@ -407,6 +474,13 @@ def __call__(self, time, state):
tendencies=tendency_updates_from_constraints,
dt=self.model_timestep,
)
# Adjust corrective tendencies to be averages over
# the full reservoir timestep
for key in tendency_updates_from_constraints:
if key != "specific_humidity_limiter_active":
tendency_updates_from_constraints[key] *= (
self.model_timestep / self.timestep.total_seconds()
)
tendencies.update(tendency_updates_from_constraints)

else:
Expand All @@ -418,6 +492,24 @@ def get_diagnostics(self, state, tendency):
diags = compute_diagnostics(state, tendency, self.label, self.hydrostatic)
return diags, diags[f"net_moistening_due_to_{self.label}"]

def update_precip(
self, physics_precip, net_moistening_due_to_reservoir,
):
diags = {}

# running average gets reset in this call
precip_rates = self.precip_tracker.interval_avg_precip_rates(
net_moistening_due_to_reservoir
)
diags.update(precip_rates)

diags[TOTAL_PRECIP] = self.precip_tracker.accumulated_precip_update(
physics_precip,
diags["reservoir_precip_rate_res_interval_avg"],
self.timestep.total_seconds(),
)
return diags


def open_rc_model(path: str) -> ReservoirDatasetAdapter:
return cast(ReservoirDatasetAdapter, fv3fit.load(path))
Expand Down Expand Up @@ -462,7 +554,11 @@ def get_reservoir_steppers(
increment_averager, predict_averager = _get_time_averagers(
model, config.time_average_inputs
)

_precip_tracker_kwargs = {}
if config.interval_average_precipitation:
_precip_tracker_kwargs["precip_tracker"] = PrecipTracker(
reservoir_timestep_seconds=rc_tdelta.total_seconds(),
)
incrementer = ReservoirIncrementOnlyStepper(
model,
init_time,
Expand All @@ -487,5 +583,6 @@ def get_reservoir_steppers(
model_timestep=model_timestep,
hydrostatic=config.hydrostatic,
mse_conserving_limiter=config.mse_conserving_limiter,
**_precip_tracker_kwargs,
)
return incrementer, predictor
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ radiation_scheme: null
reservoir_corrector:
diagnostic_only: false
hydrostatic: false
interval_average_precipitation: false
models:
0: gs://vcm-ml-scratch/rc-model-tile-0
1: gs://vcm-ml-scratch/rc-model-tile-1
Expand Down
5 changes: 5 additions & 0 deletions workflows/prognostic_c48_run/tests/test_reservoir_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,8 @@ def test_model_paths_and_rank_index_mismatch_on_load():
reservoir.get_reservoir_steppers(
config, 1, datetime(2020, 1, 1), MODEL_TIMESTEP
)


def test_reservoir_config_raises_error_on_invalid_key():
with pytest.raises(ValueError):
ReservoirConfig({"a": "model"}, 1, reservoir_timestep="10m")