diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 049354b0d3..2218477454 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -601,10 +601,11 @@ def _apply_reservoir_update_to_state(self) -> Diagnostics: if self._reservoir_predict_stepper.is_diagnostic: # type: ignore rename_diagnostics(diags, label="reservoir_predictor") - state_updates[TOTAL_PRECIP] = precipitation_sum( - self._state[TOTAL_PRECIP], net_moistening, self._timestep, + precip = self._reservoir_predict_stepper.update_precip( # type: ignore + self._state[TOTAL_PRECIP], net_moistening ) - + diags.update(precip) + state_updates[TOTAL_PRECIP] = precip[TOTAL_PRECIP] self._state.update_mass_conserving(state_updates) diags.update({name: self._state[name] for name in self._states_to_output}) @@ -614,9 +615,7 @@ def _apply_reservoir_update_to_state(self) -> Diagnostics: "cnvprcp_after_python": self._wrapper.get_diagnostic_by_name( "cnvprcp" ).data_array, - TOTAL_PRECIP_RATE: precipitation_rate( - self._state[TOTAL_PRECIP], self._timestep - ), + TOTAL_PRECIP_RATE: diags["total_precip_rate_res_interval_avg"], } ) diff --git a/workflows/prognostic_c48_run/runtime/steppers/reservoir.py b/workflows/prognostic_c48_run/runtime/steppers/reservoir.py index 6daba607a3..46a91799fa 100644 --- a/workflows/prognostic_c48_run/runtime/steppers/reservoir.py +++ b/workflows/prognostic_c48_run/runtime/steppers/reservoir.py @@ -12,12 +12,13 @@ cast, Sequence, Dict, + Union, ) import fv3fit from fv3fit._shared.halos import append_halos_using_mpi from fv3fit.reservoir.adapters import ReservoirDatasetAdapter -from runtime.names import SST, SPHUM, TEMP +from runtime.names import SST, SPHUM, TEMP, PHYSICS_PRECIP_RATE, TOTAL_PRECIP from runtime.tendency import add_tendency, tendencies_from_state_updates from runtime.diagnostics import ( enforce_heating_and_moistening_tendency_constraints, @@ -54,7 +55,7 @@ class ReservoirConfig: limiter. Defaults to false. """ - models: Mapping[int, str] + models: Mapping[Union[int, str], str] synchronize_steps: int = 1 reservoir_timestep: str = "3h" # TODO: Could this be inferred? time_average_inputs: bool = False @@ -63,6 +64,22 @@ class ReservoirConfig: rename_mapping: NameDict = dataclasses.field(default_factory=dict) hydrostatic: bool = False mse_conserving_limiter: bool = False + interval_average_precipitation: bool = False + + def __post_init__(self): + # This handles cases in automatic config writing where json/yaml + # do not allow integer keys + _models = {} + for key, url in self.models.items(): + try: + int_key = int(key) + _models[int_key] = url + except (ValueError) as e: + raise ValueError( + "Keys in reservoir_corrector.models must be integers " + "or string representation of integers." + ) from e + self.models = _models class _FiniteStateMachine: @@ -104,6 +121,49 @@ def __call__(self, state: str): ) +class PrecipTracker: + def __init__(self, reservoir_timestep_seconds: float): + self.reservoir_timestep_seconds = reservoir_timestep_seconds + self.physics_precip_averager = TimeAverageInputs([PHYSICS_PRECIP_RATE]) + self._air_temperature_at_previous_interval = None + self._specific_humidity_at_previous_interval = None + + def increment_physics_precip_rate(self, physics_precip_rate): + self.physics_precip_averager.increment_running_average( + {PHYSICS_PRECIP_RATE: physics_precip_rate} + ) + + def interval_avg_precip_rates(self, net_moistening_due_to_reservoir): + physics_precip_rate = self.physics_precip_averager.get_averages()[ + PHYSICS_PRECIP_RATE + ] + total_precip_rate = physics_precip_rate - net_moistening_due_to_reservoir + total_precip_rate = total_precip_rate.where(total_precip_rate >= 0, 0) + reservoir_precip_rate = total_precip_rate - physics_precip_rate + return { + "total_precip_rate_res_interval_avg": total_precip_rate, + "physics_precip_rate_res_interval_avg": physics_precip_rate, + "reservoir_precip_rate_res_interval_avg": reservoir_precip_rate, + } + + def accumulated_precip_update( + self, + physics_precip_total_over_model_timestep, + reservoir_precip_rate_over_res_interval, + reservoir_timestep, + ): + # Since the reservoir correction is only applied every reservoir_timestep, + # all of the precip due to the reservoir is put into the accumulated precip + # in the model timestep at update time. + m_per_mm = 1 / 1000 + reservoir_total_precip = ( + reservoir_precip_rate_over_res_interval * reservoir_timestep * m_per_mm + ) + total_precip = physics_precip_total_over_model_timestep + reservoir_total_precip + total_precip.attrs["units"] = "m" + return total_precip + + class TimeAverageInputs: """ Copy of time averaging components from runtime.diagnostics.manager to @@ -170,6 +230,7 @@ def __init__( warm_start: bool = False, hydrostatic: bool = False, mse_conserving_limiter: bool = False, + precip_tracker: Optional[PrecipTracker] = None, ): self.model = model self.synchronize_steps = synchronize_steps @@ -181,6 +242,7 @@ def __init__( self.warm_start = warm_start self.hydrostatic = hydrostatic self.mse_conserving_limiter = mse_conserving_limiter + self.precip_tracker = precip_tracker if state_machine is None: state_machine = _FiniteStateMachine() @@ -250,8 +312,8 @@ def _get_inputs_from_state(self, state): ) except RuntimeError: raise ValueError( - "MPI not available or tile dimension does not exist in state fields" - " during reservoir increment update" + "MPI not available or tile dimension does not exist in state " + "fields during reservoir increment update" ) reservoir_inputs = rc_in_with_halos @@ -360,6 +422,11 @@ def __call__(self, time, state): if self.input_averager is not None: self.input_averager.increment_running_average(inputs) + if self.precip_tracker is not None: + self.precip_tracker.increment_physics_precip_rate( + state[PHYSICS_PRECIP_RATE] + ) + if self._is_rc_update_step(time): logger.info(f"Reservoir model predict at time {time}") if self.input_averager is not None: @@ -407,6 +474,13 @@ def __call__(self, time, state): tendencies=tendency_updates_from_constraints, dt=self.model_timestep, ) + # Adjust corrective tendencies to be averages over + # the full reservoir timestep + for key in tendency_updates_from_constraints: + if key != "specific_humidity_limiter_active": + tendency_updates_from_constraints[key] *= ( + self.model_timestep / self.timestep.total_seconds() + ) tendencies.update(tendency_updates_from_constraints) else: @@ -418,6 +492,24 @@ def get_diagnostics(self, state, tendency): diags = compute_diagnostics(state, tendency, self.label, self.hydrostatic) return diags, diags[f"net_moistening_due_to_{self.label}"] + def update_precip( + self, physics_precip, net_moistening_due_to_reservoir, + ): + diags = {} + + # running average gets reset in this call + precip_rates = self.precip_tracker.interval_avg_precip_rates( + net_moistening_due_to_reservoir + ) + diags.update(precip_rates) + + diags[TOTAL_PRECIP] = self.precip_tracker.accumulated_precip_update( + physics_precip, + diags["reservoir_precip_rate_res_interval_avg"], + self.timestep.total_seconds(), + ) + return diags + def open_rc_model(path: str) -> ReservoirDatasetAdapter: return cast(ReservoirDatasetAdapter, fv3fit.load(path)) @@ -462,7 +554,11 @@ def get_reservoir_steppers( increment_averager, predict_averager = _get_time_averagers( model, config.time_average_inputs ) - + _precip_tracker_kwargs = {} + if config.interval_average_precipitation: + _precip_tracker_kwargs["precip_tracker"] = PrecipTracker( + reservoir_timestep_seconds=rc_tdelta.total_seconds(), + ) incrementer = ReservoirIncrementOnlyStepper( model, init_time, @@ -487,5 +583,6 @@ def get_reservoir_steppers( model_timestep=model_timestep, hydrostatic=config.hydrostatic, mse_conserving_limiter=config.mse_conserving_limiter, + **_precip_tracker_kwargs, ) return incrementer, predictor diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out index 65af8316d4..a705c72c36 100644 --- a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out @@ -436,6 +436,7 @@ radiation_scheme: null reservoir_corrector: diagnostic_only: false hydrostatic: false + interval_average_precipitation: false models: 0: gs://vcm-ml-scratch/rc-model-tile-0 1: gs://vcm-ml-scratch/rc-model-tile-1 diff --git a/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py b/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py index 900e50f02b..7d078f5ea5 100644 --- a/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py +++ b/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py @@ -307,3 +307,8 @@ def test_model_paths_and_rank_index_mismatch_on_load(): reservoir.get_reservoir_steppers( config, 1, datetime(2020, 1, 1), MODEL_TIMESTEP ) + + +def test_reservoir_config_raises_error_on_invalid_key(): + with pytest.raises(ValueError): + ReservoirConfig({"a": "model"}, 1, reservoir_timestep="10m")