Skip to content

Commit

Permalink
Get params from storage not from config
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Sep 28, 2023
1 parent e89768d commit deed3d7
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 77 deletions.
50 changes: 18 additions & 32 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
if TYPE_CHECKING:
import numpy.typing as npt

from ert.config import AnalysisConfig, AnalysisModule, EnkfObs, EnsembleConfig
from ert.config import AnalysisConfig, AnalysisModule, EnkfObs
from ert.enkf_main import EnKFMain
from ert.storage import EnsembleAccessor, EnsembleReader

Expand Down Expand Up @@ -213,12 +213,11 @@ def _save_to_temp_storage(

def _save_temp_storage_to_disk(
target_fs: EnsembleAccessor,
ensemble_config: "EnsembleConfig",
temp_storage: TempStorage,
iens_active_index: npt.NDArray[np.int_],
) -> None:
for key, matrix in temp_storage.items():
config_node = ensemble_config.parameter_configs[key]
config_node = target_fs.experiment.parameter_configuration[key]
for i, realization in enumerate(iens_active_index):
if isinstance(config_node, GenKwConfig):
assert isinstance(matrix, np.ndarray)
Expand All @@ -244,32 +243,33 @@ def _save_temp_storage_to_disk(

def _create_temporary_parameter_storage(
source_fs: EnsembleReader,
ensemble_config: EnsembleConfig,
iens_active_index: npt.NDArray[np.int_],
) -> TempStorage:
temp_storage = TempStorage()
t_genkw = 0.0
t_surface = 0.0
t_field = 0.0
_logger.debug("_create_temporary_parameter_storage() - start")
for key in ensemble_config.parameters:
config_node = ensemble_config.parameter_configs[key]
for (
param_group,
config_node,
) in source_fs.experiment.parameter_configuration.items():
matrix: Union[npt.NDArray[np.double], xr.DataArray]
if isinstance(config_node, GenKwConfig):
t = time.perf_counter()
matrix = source_fs.load_parameters(key, iens_active_index).values.T
matrix = source_fs.load_parameters(param_group, iens_active_index).values.T
t_genkw += time.perf_counter() - t
elif isinstance(config_node, SurfaceConfig):
t = time.perf_counter()
matrix = source_fs.load_parameters(key, iens_active_index)
matrix = source_fs.load_parameters(param_group, iens_active_index)
t_surface += time.perf_counter() - t
elif isinstance(config_node, Field):
t = time.perf_counter()
matrix = source_fs.load_parameters(key, iens_active_index)
matrix = source_fs.load_parameters(param_group, iens_active_index)
t_field += time.perf_counter() - t
else:
raise NotImplementedError(f"{type(config_node)} is not supported")
temp_storage[key] = matrix
temp_storage[param_group] = matrix
_logger.debug(
f"_create_temporary_parameter_storage() time_used gen_kw={t_genkw:.4f}s, \
surface={t_surface:.4f}s, field={t_field:.4f}s"
Expand Down Expand Up @@ -378,26 +378,23 @@ def _load_observations_and_responses(


def analysis_ES(
updatestep: "UpdateConfiguration",
updatestep: UpdateConfiguration,
obs: EnkfObs,
rng: np.random.Generator,
module: "AnalysisModule",
module: AnalysisModule,
alpha: float,
std_cutoff: float,
global_scaling: float,
smoother_snapshot: SmootherSnapshot,
ens_mask: npt.NDArray[np.bool_],
ensemble_config: "EnsembleConfig",
source_fs: EnsembleReader,
target_fs: EnsembleAccessor,
progress_callback: ProgressCallback,
) -> None:
iens_active_index = np.flatnonzero(ens_mask)

progress_callback(Progress(Task("Loading data", 1, 3), None))
temp_storage = _create_temporary_parameter_storage(
source_fs, ensemble_config, iens_active_index
)
temp_storage = _create_temporary_parameter_storage(source_fs, iens_active_index)

ensemble_size = ens_mask.sum()
param_ensemble = _param_ensemble_for_projection(
Expand Down Expand Up @@ -470,22 +467,19 @@ def analysis_ES(
_save_to_temp_storage(temp_storage, [row_scaling_parameter], A)

progress_callback(Progress(Task("Storing data", 3, 3), None))
_save_temp_storage_to_disk(
target_fs, ensemble_config, temp_storage, iens_active_index
)
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)


def analysis_IES(
updatestep: "UpdateConfiguration",
updatestep: UpdateConfiguration,
obs: EnkfObs,
rng: np.random.Generator,
module: "AnalysisModule",
module: AnalysisModule,
alpha: float,
std_cutoff: float,
global_scaling: float,
smoother_snapshot: SmootherSnapshot,
ens_mask: npt.NDArray[np.bool_],
ensemble_config: "EnsembleConfig",
source_fs: EnsembleReader,
target_fs: EnsembleAccessor,
iterative_ensemble_smoother: ies.SIES,
Expand All @@ -494,9 +488,7 @@ def analysis_IES(
iens_active_index = np.flatnonzero(ens_mask)

progress_callback(Progress(Task("Loading data", 1, 3), None))
temp_storage = _create_temporary_parameter_storage(
source_fs, ensemble_config, iens_active_index
)
temp_storage = _create_temporary_parameter_storage(source_fs, iens_active_index)
progress_callback(Progress(Task("Updating data", 2, 3), None))

ensemble_size = ens_mask.sum()
Expand Down Expand Up @@ -553,9 +545,7 @@ def analysis_IES(
)

progress_callback(Progress(Task("Storing data", 3, 3), None))
_save_temp_storage_to_disk(
target_fs, ensemble_config, temp_storage, iens_active_index
)
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)


def _write_update_report(
Expand Down Expand Up @@ -648,7 +638,6 @@ def smootherUpdate(

analysis_config = self.ert.analysisConfig()
obs = self.ert.getObservations()
ensemble_config = self.ert.ensembleConfig()

alpha = analysis_config.enkf_alpha
std_cutoff = analysis_config.std_cutoff
Expand All @@ -671,7 +660,6 @@ def smootherUpdate(
global_scaling,
smoother_snapshot,
ens_mask,
ensemble_config,
prior_storage,
posterior_storage,
progress_callback,
Expand Down Expand Up @@ -706,7 +694,6 @@ def iterative_smoother_update(
analysis_config = self.ert.analysisConfig()

obs = self.ert.getObservations()
ensemble_config = self.ert.ensembleConfig()

alpha = analysis_config.enkf_alpha
std_cutoff = analysis_config.std_cutoff
Expand All @@ -730,7 +717,6 @@ def iterative_smoother_update(
1.0,
smoother_snapshot,
ens_mask,
ensemble_config,
prior_storage,
posterior_storage,
w_container,
Expand Down
13 changes: 8 additions & 5 deletions tests/performance_tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,19 @@ def poly_template(monkeypatch):
def test_memory_smoothing(poly_template):
ert_config = ErtConfig.from_file("poly.ert")
ert = EnKFMain(ert_config)
tgt = mock_target_accessor()
src = make_source_accessor(poly_template, ert)
with open_storage(poly_template, mode="w") as storage:
tgt = storage.create_ensemble(
src.experiment_id,
src.ensemble_size,
iteration=1,
name="tgt",
prior_ensemble=src,
)
smoother = ESUpdate(ert)
smoother.smootherUpdate(src, tgt, str(uuid.uuid4()))


def mock_target_accessor() -> EnsembleAccessor:
return Mock(spec=EnsembleAccessor)


def make_source_accessor(path: Path, ert: EnKFMain) -> EnsembleReader:
path = Path(path) / "ensembles"
with open_storage(path, mode="w") as storage:
Expand Down
Loading

0 comments on commit deed3d7

Please sign in to comment.