diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index c58e19dd..a893d5cb 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -22,6 +22,10 @@ dependencies: - shapely=2.0 - statsmodels=0.14 - xarray=2023.7 + - xarray-datatree=0.0.13 # for testing - pytest - pytest-cov + - pip + - pip: + - https://github.com/mathause/filefinder/archive/refs/tags/v0.3.0.zip diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 38e54524..9e75aa91 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -5,6 +5,7 @@ Required dependencies --------------------- - Python (3.10 or later) +- `filefinder `__ - `dask `__ - `joblib `__ - `netcdf4 `__ @@ -18,6 +19,7 @@ Required dependencies - `scipy `__ - `statsmodels `__ - `xarray `__ +- `xarray-datatree `__ Optional dependencies --------------------- diff --git a/environment.yml b/environment.yml index d93ecf55..1dc28afd 100644 --- a/environment.yml +++ b/environment.yml @@ -22,8 +22,12 @@ dependencies: - scipy - statsmodels>=0.13 - xarray>=2023.04 # because pandas 2 is required + - xarray-datatree # for testing - black!=23 - pytest - pytest-cov - ruff + - pip + - pip: + - https://github.com/mathause/filefinder/archive/refs/tags/v0.3.0.zip diff --git a/mesmer/__init__.py b/mesmer/__init__.py index a36559d8..f233bdb7 100644 --- a/mesmer/__init__.py +++ b/mesmer/__init__.py @@ -11,7 +11,7 @@ from . import calibrate_mesmer, core, create_emulations, io, stats, testing, utils from .core import _data as data -from .core import geospatial, grid, mask, volc, weighted +from .core import datatree, geospatial, grid, mask, volc, weighted # "legacy" modules __all__ = [ @@ -25,6 +25,7 @@ __all__ += [ "core", "data", + "datatree", "geospatial", "grid", "mask", diff --git a/mesmer/calibrate_mesmer/train_gv.py b/mesmer/calibrate_mesmer/train_gv.py index 5d24f24d..b052cc2b 100644 --- a/mesmer/calibrate_mesmer/train_gv.py +++ b/mesmer/calibrate_mesmer/train_gv.py @@ -11,7 +11,7 @@ import xarray as xr from mesmer.io.save_mesmer_bundle import save_mesmer_data -from mesmer.stats import _fit_auto_regression_scen_ens, _select_ar_order_scen_ens +from mesmer.stats import fit_auto_regression_scen_ens, select_ar_order_scen_ens def train_gv(gv, targ, esm, cfg, save_params=True, **kwargs): @@ -172,11 +172,11 @@ def train_gv_AR(params_gv, gv, max_lag, sel_crit): # create temporary DataArray objects data = [xr.DataArray(data, dims=["run", "time"]) for data in gv.values()] - AR_order = _select_ar_order_scen_ens( - *data, dim="time", ens_dim="run", maxlag=max_lag, ic=sel_crit + AR_order = select_ar_order_scen_ens( + data, dim="time", ens_dim="run", maxlag=max_lag, ic=sel_crit ) - params = _fit_auto_regression_scen_ens( - *data, dim="time", ens_dim="run", lags=AR_order + params = fit_auto_regression_scen_ens( + data, dim="time", ens_dim="run", lags=AR_order ) # TODO: remove np.float64(...) (only here so the tests pass) diff --git a/mesmer/calibrate_mesmer/train_lv.py b/mesmer/calibrate_mesmer/train_lv.py index fd1ca66b..44d6dae9 100644 --- a/mesmer/calibrate_mesmer/train_lv.py +++ b/mesmer/calibrate_mesmer/train_lv.py @@ -10,9 +10,9 @@ from mesmer.io.save_mesmer_bundle import save_mesmer_data from mesmer.stats import ( - _fit_auto_regression_scen_ens, adjust_covariance_ar1, find_localized_empirical_covariance, + fit_auto_regression_scen_ens, ) from .train_utils import get_scenario_weights, stack_predictors_and_targets @@ -233,7 +233,7 @@ def train_lv_AR1_sci(params_lv, targs, y, wgt_scen_eq, aux, cfg): dims = ("run", "time", "cell") data = [xr.DataArray(data, dims=dims) for data in targ.values()] - params = _fit_auto_regression_scen_ens(*data, dim="time", ens_dim="run", lags=1) + params = fit_auto_regression_scen_ens(data, dim="time", ens_dim="run", lags=1) params_lv["AR1_int"][targ_name] = params.intercept.values params_lv["AR1_coef"][targ_name] = params.coeffs.values.squeeze() diff --git a/mesmer/core/datatree.py b/mesmer/core/datatree.py new file mode 100644 index 00000000..76e895d8 --- /dev/null +++ b/mesmer/core/datatree.py @@ -0,0 +1,161 @@ +import xarray as xr +from datatree import DataTree + + +def _extract_single_dataarray_from_dt(dt: DataTree) -> xr.DataArray: + """ + Extract a single DataArray from a DataTree node, holding one ``Dataset`` with one ``DataArray``. + """ + # assert only one node in dt + if not len(list(dt.subtree)) == 1: + raise ValueError("DataTree must only contain one node.") + if not dt.has_data: + raise ValueError("DataTree must contain data.") + + ds = dt.to_dataset() + if len(ds.data_vars) != 1: + raise ValueError("DataTree must have exactly one data variable.") + + varname = list(ds.data_vars)[0] + da = ds.to_array().isel(variable=0).drop_vars("variable") + return da.rename(varname) + + +def collapse_datatree_into_dataset(dt: DataTree, dim: str) -> xr.Dataset: + """ + Take a ``DataTree`` and collapse **all subtrees** in it into a single ``xr.Dataset`` along dim. + All datasets in the ``DataTree`` must have the same dimensions and each dimension must have a coordinate. + + Parameters + ---------- + dt : DataTree + The DataTree to collapse. + dim : str + The dimension to concatenate the datasets along. + + Returns + ------- + xr.Dataset + The collapsed dataset. + + Raises + ------ + ValueError + If all datasets do not have the same dimensions. + If any dimension does not have a coordinate. + """ + # TODO: could potentially be replaced by DataTree.merge_child_nodes in the future? + datasets = [subtree.to_dataset() for subtree in dt.subtree if not subtree.is_empty] + + # Check if all datasets have the same dimensions + first_dims = set(datasets[0].dims) + if not all(set(ds.dims) == first_dims for ds in datasets): + raise ValueError("All datasets must have the same dimensions") + + # Check that all dimensions have coordinates + for ds in datasets: + for ds_dim in ds.dims: + if ds[ds_dim].coords == {}: + raise ValueError( + f"Dimension '{ds_dim}' must have a coordinate/coordinates." + ) + + # Concatenate datasets along the specified dimension + ds = xr.concat(datasets, dim=dim) + ds = ds.assign_coords( + {dim: [subtree.name for subtree in dt.subtree if not subtree.is_empty]} + ) + + return ds + + +def stack_linear_regression_datatrees( + predictors: DataTree, + target: DataTree, + weights: DataTree | None, + *, + stacking_dims: list[str], + collapse_dim: str = "scenario", + stacked_dim: str = "sample", +) -> tuple[DataTree, xr.Dataset, xr.Dataset | None]: + """ + prepares data for Linear Regression: + 1. Broadcasts predictors to target + 2. Collapses DataTrees into DataSets + 3. Stacks the DataSets along the stacking dimension(s) + + Parameters + ---------- + predictors : DataTree + A ``DataTree`` of ``xr.Dataset`` objects used as predictors. The ``DataTree`` + must have subtrees for each predictor each of which has to have at least one + leaf, holding a ``xr.Dataset`` representing a scenario. The subtrees of + different predictors must be isomorphic (i.e. have the save scenarios). The ``xr.Dataset`` + must at least contain `dim` and each ``xr.Dataset`` must only hold one data variable. + target : DataTree + A ``DataTree``holding the targets. Must be isomorphic to the predictor subtrees, i.e. + have the same scenarios. Each leaf must hold a ``xr.Dataset`` which must be at least 2D + and contain `dim`, but may also contain a dimension for ensemble members. + weights : DataTree, default: None. + Individual weights for each sample, must be isomorphic to target. Must at least contain + `dim`, and must have the ensemble member dimesnion if target has it. + stacking_dims : list[str] + Dimension(s) to stack. + collapse_dim : str, default: "scenario" + Dimension along which to collapse the DataTrees, will automatically be added to the + stacking dims. + stacked_dim : str, default: "sample" + Name of the stacked dimension. + + Returns + ------- + tuple + Tuple of the prepared predictors, target and weights, where the predictors and target are + stacked along the stacking dimensions and the weights are stacked along the stacking dimensions + and the ensemble member dimension. + + Notes + ----- + Dimensions which exist along the target but are not in the stacking_dims will be excluded from the + broadcasting of the predictors. + """ + + stacking_dims_all = stacking_dims + [collapse_dim] + + # exclude target dimensions from broadcasting which are not in the stacking_dims + exclude_dim = set(target.leaves[0].ds.dims) - set(stacking_dims) + + # predictors need to be + predictors_stacked = DataTree() + for key, subtree in predictors.items(): + # 1) broadcast to target + pred_broadcast = subtree.broadcast_like(target, exclude=exclude_dim) + # 2) collapsed into DataSets + predictor_ds = collapse_datatree_into_dataset(pred_broadcast, dim=collapse_dim) + # 3) stacked + predictors_stacked[key] = DataTree( + predictor_ds.stack( + {stacked_dim: stacking_dims_all}, create_index=False + ).dropna(dim=stacked_dim) + ) + + # target needs to be + # 1) collapsed into DataSet + target_ds = collapse_datatree_into_dataset(target, dim=collapse_dim) + # 2) stacked + target_stacked = target_ds.stack( + {stacked_dim: stacking_dims_all}, create_index=False + ).dropna(dim=stacked_dim) + + # weights need to be + if weights is not None: + # 1) collapsed into DataSet + weights_ds = collapse_datatree_into_dataset(weights, dim=collapse_dim) + # 2) stacked + weights_stacked = weights_ds.stack( + {stacked_dim: stacking_dims_all}, create_index=False + ).dropna(dim=stacked_dim) + else: + weights_stacked = None + + return predictors_stacked, target_stacked, weights_stacked diff --git a/mesmer/core/utils.py b/mesmer/core/utils.py index fcffe866..1ac37cca 100644 --- a/mesmer/core/utils.py +++ b/mesmer/core/utils.py @@ -1,4 +1,5 @@ import warnings +from collections.abc import Iterable import numpy as np import pandas as pd @@ -22,7 +23,7 @@ def create_equal_dim_names(dim, suffixes): return tuple(f"{dim}{suffix}" for suffix in suffixes) -def _minimize_local_discrete(func, sequence, **kwargs): +def _minimize_local_discrete(func, sequence: Iterable, **kwargs): """find the local minimum for a function that consumes discrete input Parameters @@ -150,7 +151,7 @@ def _check_dataset_form( obj, name: str = "obj", *, - required_vars: str | set[str] = set(), + required_vars: str | set[str] | None = set(), optional_vars: str | set[str] = set(), requires_other_vars: bool = False, ): @@ -199,9 +200,9 @@ def _check_dataarray_form( obj, name: str = "obj", *, - ndim: int = None, + ndim: int | tuple[int, ...] | None = None, required_dims: str | set[str] = set(), - shape=None, + shape: tuple[int, ...] | None = None, ): """check if a dataset conforms to some conditions @@ -209,7 +210,7 @@ def _check_dataarray_form( object to check. name : str, default: 'obj' Name to use in error messages. - ndim, int, optional + ndim : int | tuple of int, optional Number of required dimensions, can be a tuple of int if several are possible. required_dims: str, set of str, optional Names of dims that are required for obj diff --git a/mesmer/core/volc.py b/mesmer/core/volc.py index 28e693aa..3bed78ac 100644 --- a/mesmer/core/volc.py +++ b/mesmer/core/volc.py @@ -1,6 +1,7 @@ import warnings import xarray as xr +from datatree import DataTree from mesmer.core._data import load_stratospheric_aerosol_optical_depth_obs from mesmer.core.utils import _check_dataarray_form @@ -98,7 +99,7 @@ def fit_volcanic_influence(tas_residuals, hist_period, *, dim="time", version="2 # TODO: name of 'aod' lr.fit( - predictors={"aod": aod}, + predictors=DataTree(aod, name="aod"), target=tas_residuals, dim=dim, fit_intercept=False, diff --git a/mesmer/core/weighted.py b/mesmer/core/weighted.py index 4d3a9d46..32406c2f 100644 --- a/mesmer/core/weighted.py +++ b/mesmer/core/weighted.py @@ -1,7 +1,9 @@ import warnings +from collections.abc import Hashable, Iterable import numpy as np import xarray as xr +from datatree import DataTree, map_over_subtree def _weighted_if_dim(obj, weights, dims): @@ -106,3 +108,77 @@ def global_mean(data, weights=None, x_dim="lon", y_dim="lat"): weights = lat_weights(data[y_dim]) return weighted_mean(data, weights, [x_dim, y_dim]) + + +def create_equal_scenario_weights_from_datatree( + dt: DataTree, ens_dim: str = "member", exclude: Iterable[Hashable] | None = None +) -> DataTree: + """ + Create a DataTree isomorphic to ``dt`, holding the weights for each scenario to weight the ensemble members of each + scenario such that each scenario contributes equally to some fitting procedure. + The weight of each member = 1 / number of members in the scenario, so weights = 1 / ds[ens_dim].size. + + Thus, if all scenarios have equal amounts of members, all weights will be equal. + If one scenario has more members than the others, the weights will be smaller for each member of this scenario. + + Parameters: + ----------- + dt : DataTree + DataTree holding the ``xr.Datasets`` for which the weights should be created. Each dataset must have at least + ens_dim as a dimension, but can have more dimensions. + ens_dim : str + Name of the dimension along which the weights should be created. Default is "member". + exclude : Iterable[Hashable] | None + Name of one or several dimensions to exclude from the dataset before calculating the weights. Default is None. + Internally, these dimensions are dropped before calculating the weights. If None, the returned ``DataTree`` is + isomorphic to ``dt``. + + Returns: + -------- + DataTree + DataTree holding the weights for each scenario. + + Example: + -------- + dt = DataTree() + dt["ssp119"] = DataTree(xr.Dataset({"tas": xr.DataArray([1, 2, 3], dims="member")})) + dt["ssp585"] = DataTree(xr.Dataset({"tas": xr.DataArray([4, 5], dims="member")})) + create_equal_scenario_weights_from_datatree(dt) + # Output: + # DataTree({ + # "ssp119": DataTree({"weights": xr.DataArray([0.333333, 0.333333, 0.333333], dims="member")}), + # "ssp585": DataTree({"weights": xr.DataArray([0.5, 0.5], dims="member")}) + # }) + + """ + if dt.depth > 1: + raise ValueError(f"DataTree must have a depth of 1, not {dt.depth}.") + + def _create_weights(ds: xr.Dataset) -> xr.DataArray: + + if ens_dim not in ds.dims: + raise ValueError(f"Member dimension '{ens_dim}' not found in dataset.") + + data_vars = list(ds.keys()) + if len(data_vars) > 1: + raise ValueError("Dataset must have only one data variable.") + + # Get the dimensions to calculate the weights for and make sure they are in the right order + dims = [ + dim + for dim in ds[data_vars[0]].dims + if exclude is None or dim not in exclude + ] + + # Create a DataArray of ones with the remaining dimensions + shape = [ds.sizes[dim] for dim in dims] + coords = {dim: ds.coords[dim] for dim in dims} + ones = xr.DataArray(np.ones(shape), coords=coords, dims=list(dims)) + + weights = ones.rename("weights") + + return weights / ds[ens_dim].size + + weights = map_over_subtree(_create_weights)(dt) + + return weights diff --git a/mesmer/stats/__init__.py b/mesmer/stats/__init__.py index 8030e6f1..0b1c3488 100644 --- a/mesmer/stats/__init__.py +++ b/mesmer/stats/__init__.py @@ -1,12 +1,12 @@ from mesmer.stats._auto_regression import ( - _fit_auto_regression_scen_ens, - _select_ar_order_scen_ens, draw_auto_regression_correlated, draw_auto_regression_monthly, draw_auto_regression_uncorrelated, fit_auto_regression, fit_auto_regression_monthly, + fit_auto_regression_scen_ens, select_ar_order, + select_ar_order_scen_ens, ) from mesmer.stats._gaspari_cohn import gaspari_cohn, gaspari_cohn_correlation_matrices from mesmer.stats._harmonic_model import fit_harmonic_model, predict_harmonic_model @@ -27,8 +27,8 @@ __all__ = [ # auto regression - "_fit_auto_regression_scen_ens", - "_select_ar_order_scen_ens", + "fit_auto_regression_scen_ens", + "select_ar_order_scen_ens", "draw_auto_regression_correlated", "draw_auto_regression_uncorrelated", "fit_auto_regression", @@ -40,6 +40,7 @@ "gaspari_cohn", # linear regression "LinearRegression", + "prep_linear_regression_data", # localized covariance "adjust_covariance_ar1", "find_localized_empirical_covariance", diff --git a/mesmer/stats/_auto_regression.py b/mesmer/stats/_auto_regression.py index eb6c32cf..e6db1126 100644 --- a/mesmer/stats/_auto_regression.py +++ b/mesmer/stats/_auto_regression.py @@ -1,21 +1,80 @@ import warnings +from typing import Literal import numpy as np import pandas as pd import scipy import xarray as xr +from datatree import DataTree, map_over_subtree + +from mesmer.core.datatree import collapse_datatree_into_dataset +from mesmer.core.utils import ( + LinAlgWarning, + _check_dataarray_form, + _check_dataset_form, +) + + +def select_ar_order_scen_ens( + obs: list[xr.DataArray] | DataTree, + dim: str, + ens_dim: str | None, + maxlag: int, + ic: Literal["bic", "aic", "hqic"] = "bic", +) -> xr.DataArray: + """ + Select the order of an autoregressive process and potentially calculate the median + over ensemble members and scenarios + + Parameters + ---------- + objs : DataTree or iterable of DataArray + A list of ``xr.DataArray`` to estimate the auto regression order over. + dim : str + Dimension along which to determine the order. + ens_dim : str + Dimension name of the ensemble members. + maxlag : int + The maximum lag to consider. + ic : {'aic', 'hqic', 'bic'}, default 'bic' + The information criterion to use in the selection. -from mesmer.core.utils import LinAlgWarning, _check_dataarray_form, _check_dataset_form + Returns + ------- + selected_ar_order : DataArray + Array indicating the selected order with the same size as the input but ``dim`` + removed. + Notes + ----- + Calculates the median auto regression order, first over the ensemble members, + then over all scenarios. + """ -def _select_ar_order_scen_ens(*objs, dim, ens_dim, maxlag, ic="bic"): + if isinstance(obs, list): + warnings.warn( + "Passing a list of DataArrays will be deprecated in the future. Please use a DataTree instead.", + DeprecationWarning, + ) + return _select_ar_order_scen_ens_list(obs, dim, ens_dim, maxlag, ic) + elif isinstance(obs, DataTree): + return _select_ar_order_scen_ens_dt(obs, dim, ens_dim, maxlag, ic) + + +def _select_ar_order_scen_ens_list( + objs: list[xr.DataArray], + dim: str, + ens_dim: str | None, + maxlag: int, + ic: Literal["bic", "aic", "hqic"] = "bic", +) -> xr.DataArray: """ Select the order of an autoregressive process and potentially calculate the median over ensemble members and scenarios Parameters ---------- - *objs : iterable of DataArray + objs : iterable of DataArray A list of ``xr.DataArray`` to estimate the auto regression order over. dim : str Dimension along which to determine the order. @@ -41,7 +100,6 @@ def _select_ar_order_scen_ens(*objs, dim, ens_dim, maxlag, ic="bic"): ar_order_scen = list() for obj in objs: res = select_ar_order(obj, dim=dim, maxlag=maxlag, ic=ic) - if ens_dim in res.dims: res = res.quantile(dim=ens_dim, q=0.5, method="nearest") @@ -57,14 +115,140 @@ def _select_ar_order_scen_ens(*objs, dim, ens_dim, maxlag, ic="bic"): return ar_order -def _fit_auto_regression_scen_ens(*objs, dim, ens_dim, lags): +def _select_ar_order_scen_ens_dt( + dt: DataTree, + dim: str, + ens_dim: str | None, + maxlag: int, + ic: Literal["bic", "aic", "hqic"] = "bic", +) -> xr.DataArray: + """ + Select the order of an autoregressive process and potentially calculate the median + over ensemble members and scenarios + + Parameters + ---------- + dt : a DataTree + A DataTree holding one or several ``xr.Dataset`` to estimate the auto regression order over, + each representing one scenario, potentially with several ensemble members along `ens_dim`. + Each ``xr.DataSet`` should only hold one variable, the one for which to estimate the autoregression. + dim : str + Dimension along which to determine the order. + ens_dim : str + Dimension name of the ensemble members. Must be the same for all scenarios and have coordinates if not None. + maxlag : int + The maximum lag to consider. + ic : {'aic', 'hqic', 'bic'}, default 'bic' + The information criterion to use in the selection. + + Returns + ------- + selected_ar_order : DataArray + Array indicating the selected order with the same size as the input but ``dim`` + removed. + + Notes + ----- + Calculates the median auto regression order, first over the ensemble members, + then over all scenarios. + """ + + _select_ar_order_dt = map_over_subtree(_select_ar_order_ds) + + ar_order_scen = _select_ar_order_dt(dt, dim=dim, maxlag=maxlag, ic=ic) + + def ens_quantile(ds, ens_dim): + if ens_dim in ds.dims: + return ds.quantile(dim=ens_dim, q=0.5, method="nearest") + return ds + + ens_quantile_dt = map_over_subtree(ens_quantile) + ar_odrer_ens_median = ens_quantile_dt(ar_order_scen, ens_dim) + + ar_odrer_ens_median_ds = collapse_datatree_into_dataset( + ar_odrer_ens_median, dim="scen" + ) + + ar_order = ar_odrer_ens_median_ds.quantile( + dim="scen", q=0.5, method="nearest" + ).selected_order + + if not np.isnan(ar_order).any(): + ar_order = ar_order.astype(int) + + return ar_order + + +def _select_ar_order_ds( + ds: xr.Dataset, dim: str, maxlag: int, ic: Literal["aic", "bic", "hqic"] = "bic" +) -> xr.DataArray: + + data_vars = list(ds.keys()) + if len(data_vars) > 1: + raise ValueError("Dataset must have only one data variable.") + + res = ds.map(select_ar_order, args=(dim, maxlag, ic)) + res = res.rename({data_vars[0]: "selected_order"}) + + return res.selected_order + + +def fit_auto_regression_scen_ens( + obj: DataTree | list[xr.DataArray], + dim: str, + ens_dim: str | None, + lags: int | xr.DataArray, +) -> xr.Dataset: """ fit an auto regression and potentially calculate the mean over ensemble members and scenarios Parameters ---------- - *objs : iterable of DataArray + obj : a DataTree or list of ``xr.DataArray``s + A ``DataTree`` holding one or several ``xr.Dataset`` or a list of ``xr.DataArray``s to estimate the auto regression order over, + each representing one scenario, potentially with several ensemble members along `ens_dim`. + If a ``DataTree``, each ``xr.DataSet`` should only hold one variable, the one for which to estimate the autoregression. + dim : str + Dimension along which to fit the auto regression (often time). + ens_dim : str + Dimension name of the ensemble members, None if no ensemble is provided. Must be the same for all scenarios and have coordinates if not None. + lags : int + The number of lags to include in the model. + + Returns + ------- + :obj:`xr.Dataset` + Dataset containing the estimated parameters of the ``intercept``, the AR + ``coeffs`` and the ``variance`` of the residuals. + + Notes + ----- + If `ens_dim` is not `None`, calculates the mean auto regression first over all ensemble + members and then over scenarios. This is done to weight scenarios equally, consequently + ensemble members are not weighted equally, if the number of members differs between scenarios. + If no ensemble members are provided, the mean is calculated over scenarios only. + """ + if isinstance(obj, list): + warnings.warn( + "Passing a list of DataArrays will be deprecated in the future. Please use a DataTree instead.", + DeprecationWarning, + ) + return _fit_auto_regression_scen_ens_list(obj, dim, ens_dim, lags) + elif isinstance(obj, DataTree): + return _fit_auto_regression_scen_ens_dt(obj, dim, ens_dim, lags) + + +def _fit_auto_regression_scen_ens_list( + objs: list[xr.DataArray], dim: str, ens_dim: str | None, lags: int | xr.DataArray +) -> xr.Dataset: + """ + fit an auto regression and potentially calculate the mean over ensemble members + and scenarios + + Parameters + ---------- + objs : iterable of DataArray A list of ``xr.DataArray`` to estimate the auto regression over, each representing one scenario, potentially with several ensemble members along `ens_dim`. @@ -107,6 +291,74 @@ def _fit_auto_regression_scen_ens(*objs, dim, ens_dim, lags): return ar_params +def _fit_auto_regression_scen_ens_dt( + dt: DataTree, dim: str, ens_dim: str | None, lags: int | xr.DataArray +) -> xr.Dataset: + """ + fit an auto regression and potentially calculate the mean over ensemble members + and scenarios + + Parameters + ---------- + dt : a DataTree + A ``DataTree`` holding one or several ``xr.Dataset`` to estimate the auto regression order over, + each representing one scenario, potentially with several ensemble members along `ens_dim`. + Each ``xr.DataSet`` should only hold one variable, the one for which to estimate the autoregression. + dim : str + Dimension along which to fit the auto regression (often time). + ens_dim : str + Dimension name of the ensemble members, None if no ensemble is provided. Must be the same for all scenarios and have coordinates if not None. + lags : int + The number of lags to include in the model. + + Returns + ------- + :obj:`xr.Dataset` + Dataset containing the estimated parameters of the ``intercept``, the AR + ``coeffs`` and the ``variance`` of the residuals. + + Notes + ----- + If `ens_dim` is not `None`, calculates the mean auto regression first over all ensemble + members and then over scenarios. This is done to weight scenarios equally, consequently + ensemble members are not weighted equally, if the number of members differs between scenarios. + If no ensemble members are provided, the mean is calculated over scenarios only. + """ + _fit_auto_regression_dt = map_over_subtree(_fit_auto_regression_ds) + ar_params_scen = _fit_auto_regression_dt(dt, dim=dim, lags=int(lags)) + + # TODO: think about weighting! see https://github.com/MESMER-group/mesmer/issues/307 + def ens_mean(ds, ens_dim): + if ens_dim in ds.dims: + return ds.mean(ens_dim) + return ds + + ens_mean_dt = map_over_subtree(ens_mean) + ar_params_scen = ens_mean_dt(ar_params_scen, ens_dim) + + ar_params_scen = collapse_datatree_into_dataset(ar_params_scen, dim="scen") + + # return the mean over all scenarios + ar_params = ar_params_scen.mean("scen") + + return ar_params + + +def _fit_auto_regression_ds( + ds: xr.Dataset, + dim: str, + lags: int, +) -> xr.Dataset: + + data_vars = list(ds.keys()) + if len(data_vars) > 1: + raise ValueError("Dataset must have only one data variable.") + + res = fit_auto_regression(ds[data_vars[0]], dim, lags) + + return res + + # ====================================================================================== @@ -263,7 +515,7 @@ def draw_auto_regression_uncorrelated( # also to draw univariate realizations # check the input _check_dataset_form( - ar_params, "ar_params", required_vars=("intercept", "coeffs", "variance") + ar_params, "ar_params", required_vars={"intercept", "coeffs", "variance"} ) if ( @@ -278,6 +530,9 @@ def draw_auto_regression_uncorrelated( # _draw_ar_corr_xr_internal expects 2D arrays ar_params = ar_params.expand_dims("__gridpoint__") + if isinstance(seed, xr.Dataset): + seed = int(seed.seed.values) + result = _draw_ar_corr_xr_internal( intercept=ar_params.intercept, coeffs=ar_params.coeffs, @@ -293,7 +548,7 @@ def draw_auto_regression_uncorrelated( # remove the "__gridpoint__" dim again result = result.squeeze(dim="__gridpoint__", drop=True) - return result + return result.rename("samples") def draw_auto_regression_correlated( @@ -355,15 +610,18 @@ def draw_auto_regression_correlated( """ # check the input - _check_dataset_form(ar_params, "ar_params", required_vars=("intercept", "coeffs")) + _check_dataset_form(ar_params, "ar_params", required_vars={"intercept", "coeffs"}) _check_dataarray_form(ar_params.intercept, "intercept", ndim=1) (dim,), size = ar_params.intercept.dims, ar_params.intercept.size _check_dataarray_form( - ar_params.coeffs, "coeffs", ndim=2, required_dims=("lags", dim) + ar_params.coeffs, "coeffs", ndim=2, required_dims={"lags", dim} ) _check_dataarray_form(covariance, "covariance", ndim=2, shape=(size, size)) + if isinstance(seed, xr.Dataset): + seed = int(seed.seed.values) + result = _draw_ar_corr_xr_internal( intercept=ar_params.intercept, coeffs=ar_params.coeffs, @@ -376,7 +634,7 @@ def draw_auto_regression_correlated( realisation_dim=realisation_dim, ) - return result + return result.rename("samples") def _draw_ar_corr_xr_internal( @@ -521,7 +779,9 @@ def _draw_innovations_correlated_np( return innovations -def fit_auto_regression(data, dim, lags): +def fit_auto_regression( + data: xr.DataArray, dim: str, lags: int | list[int] +) -> xr.Dataset: """fit an auto regression Parameters @@ -530,8 +790,9 @@ def fit_auto_regression(data, dim, lags): A ``xr.DataArray`` to estimate the auto regression over. dim : str Dimension along which to fit the auto regression. - lags : int - The number of lags to include in the model. + lags : int | list + The number of lags or list of lags to include in the model. + If int, then all lags up to ``lags`` will be included. Returns ------- @@ -774,17 +1035,17 @@ def draw_auto_regression_monthly( """ # check input - _check_dataset_form(ar_params, "ar_params", required_vars=("intercept", "slope")) + _check_dataset_form(ar_params, "ar_params", required_vars={"intercept", "slope"}) month_dim, gridcell_dim = ar_params.intercept.dims n_months, size = ar_params.intercept.shape _check_dataarray_form( ar_params.intercept, "intercept", ndim=2, - required_dims=(month_dim, gridcell_dim), + required_dims={month_dim, gridcell_dim}, ) _check_dataarray_form( - ar_params.slope, "slope", ndim=2, required_dims=(month_dim, gridcell_dim) + ar_params.slope, "slope", ndim=2, required_dims={month_dim, gridcell_dim} ) _check_dataarray_form( covariance, "covariance", ndim=3, shape=(n_months, size, size) diff --git a/mesmer/stats/_linear_regression.py b/mesmer/stats/_linear_regression.py index efef8382..db60df05 100644 --- a/mesmer/stats/_linear_regression.py +++ b/mesmer/stats/_linear_regression.py @@ -1,9 +1,20 @@ -from collections.abc import Mapping +import warnings import numpy as np import xarray as xr +from datatree import DataTree -from mesmer.core.utils import _check_dataarray_form, _check_dataset_form, _to_set +from mesmer.core.datatree import ( + _extract_single_dataarray_from_dt, + collapse_datatree_into_dataset, +) +from mesmer.core.utils import ( + _check_dataarray_form, + _check_dataset_form, + _to_set, +) + +# TODO: deprecate predictor dicts? class LinearRegression: @@ -14,7 +25,7 @@ def __init__(self): def fit( self, - predictors: Mapping[str, xr.DataArray], + predictors: dict[str, xr.DataArray] | DataTree, target: xr.DataArray, dim: str, weights: xr.DataArray | None = None, @@ -25,9 +36,9 @@ def fit( Parameters ---------- - predictors : dict of xr.DataArray - A dict of DataArray objects used as predictors. Must be 1D and contain - `dim`. + predictors : dict of xr.DataArray | DataTree + A dict of DataArray objects used as predictors or a DataTree, holding each + predictor in a leaf. Each predictor must be 1D and contain `dim`. target : xr.DataArray Target DataArray. Must be 2D and contain `dim`. dim : str @@ -52,16 +63,17 @@ def fit( def predict( self, - predictors: Mapping[str, xr.DataArray], + predictors: dict[str, xr.DataArray] | DataTree, exclude=None, - ): + ) -> xr.DataArray: """ Predict using the linear model. Parameters ---------- - predictors : dict of xr.DataArray - A dict of DataArray objects used as predictors. Must be 1D and contain `dim`. + predictors : dict of xr.DataArray | DataTree + A dict of DataArray objects used as predictors or a DataTree, holding each + predictor in a leaf. Each predictor must be 1D and contain `dim`. exclude : str or set of str, default: None Set of variables to exclude in the prediction. May include ``"intercept"`` to initialize the prediction with 0. @@ -88,7 +100,7 @@ def predict( if available_predictors - required_predictors: superfluous = sorted(available_predictors - required_predictors) superfluous = "', '".join(superfluous) - raise ValueError(f"Superfluous predictors: '{superfluous}'") + warnings.warn(f"Superfluous predictors: '{superfluous}', will be ignored.") if "intercept" in exclude: prediction = xr.zeros_like(params.intercept) @@ -96,22 +108,29 @@ def predict( prediction = params.intercept for key in required_predictors: - prediction = prediction + predictors[key] * params[key] + prediction = predictors[key] * params[key] + prediction - return prediction + prediction = ( + _extract_single_dataarray_from_dt(prediction) + if isinstance(prediction, DataTree) + else prediction + ) + + return prediction.T.rename("prediction") def residuals( self, - predictors: Mapping[str, xr.DataArray], + predictors: dict[str, xr.DataArray] | DataTree, target: xr.DataArray, - ): + ) -> xr.DataArray: """ Calculate the residuals of the fitted linear model Parameters ---------- predictors : dict of xr.DataArray - A dict of DataArray objects used as predictors. Must be 1D and contain `dim`. + A dict of DataArray objects used as predictors or a DataTree, holding each + predictor in a leaf. Each predictor must be 1D and contain `dim`. target : xr.DataArray Target DataArray. Must be 2D and contain `dim`. @@ -126,7 +145,7 @@ def residuals( residuals = target - prediction - return residuals + return residuals.rename("residuals") @property def params(self): @@ -182,12 +201,12 @@ def to_netcdf(self, filename, **kwargs): Additional keyword arguments passed to ``xr.Dataset.to_netcf`` """ - params = self.params() + params = self.params params.to_netcdf(filename, **kwargs) def _fit_linear_regression_xr( - predictors: Mapping[str, xr.DataArray], + predictors: dict[str, xr.DataArray] | DataTree, target: xr.DataArray, dim: str, weights: xr.DataArray | None = None, @@ -198,8 +217,9 @@ def _fit_linear_regression_xr( Parameters ---------- - predictors : dict of xr.DataArray - A dict of DataArray objects used as predictors. Must be 1D and contain `dim`. + predictors : dict of xr.DataArray | DataTree + A dict of DataArray objects used as predictors or a DataTree, holding each + predictor in a leaf. Each predictor must be 1D and contain `dim`. target : xr.DataArray Target DataArray. Must be 2D and contain `dim`. dim : str @@ -217,8 +237,10 @@ def _fit_linear_regression_xr( individual DataArray. """ - if not isinstance(predictors, Mapping): - raise TypeError(f"predictors should be a dict, got {type(predictors)}.") + if not isinstance(predictors, dict | DataTree): + raise TypeError( + f"predictors should be a dict or DataTree, got {type(predictors)}." + ) if ("weights" in predictors) or ("intercept" in predictors): raise ValueError( @@ -229,14 +251,25 @@ def _fit_linear_regression_xr( raise ValueError("dim cannot currently be 'predictor'.") for key, pred in predictors.items(): + pred = ( + _extract_single_dataarray_from_dt(pred) + if isinstance(pred, DataTree) + else pred + ) _check_dataarray_form(pred, ndim=1, required_dims=dim, name=f"predictor: {key}") - predictors_concat = xr.concat( - tuple(predictors.values()), - dim="predictor", - join="exact", - coords="minimal", - ) + if isinstance(predictors, dict): + predictors_concat = xr.concat( + tuple(predictors.values()), + dim="predictor", + join="exact", + coords="minimal", + ) + else: + predictors_concat = collapse_datatree_into_dataset(predictors, dim="predictor") + predictors_concat = ( + predictors_concat.to_array().isel(variable=0).drop_vars("variable") + ) _check_dataarray_form(target, required_dims=dim, name="target") diff --git a/tests/integration/test_calibrate_mesmer.py b/tests/integration/test_calibrate_mesmer.py index b45ae474..1d302845 100644 --- a/tests/integration/test_calibrate_mesmer.py +++ b/tests/integration/test_calibrate_mesmer.py @@ -9,6 +9,7 @@ @pytest.mark.filterwarnings("ignore:No local minimum found") +@pytest.mark.filterwarnings("ignore:Passing a list of DataArrays will be deprecated") @pytest.mark.parametrize( "scenarios, use_tas2, use_hfds, outname", ( @@ -46,7 +47,7 @@ False, True, "tas_hfds/one_scen_one_ens", - marks=pytest.mark.slow, + # marks=pytest.mark.slow, ), # tas, tas**2, and hfds pytest.param( diff --git a/tests/integration/test_calibrate_mesmer_newcodepath.py b/tests/integration/test_calibrate_mesmer_newcodepath.py index 7e642c3c..05dc98d1 100644 --- a/tests/integration/test_calibrate_mesmer_newcodepath.py +++ b/tests/integration/test_calibrate_mesmer_newcodepath.py @@ -2,8 +2,11 @@ import joblib import numpy as np +import pandas import pytest import xarray as xr +from datatree import DataTree, map_over_subtree +from filefinder import FileContainer, FileFinder import mesmer @@ -19,57 +22,56 @@ False, "tas/one_scen_one_ens", ), - # TODO: Add the other test cases too - # pytest.param( - # ["h-ssp585"], - # False, - # False, - # "tas/one_scen_multi_ens", - # marks=pytest.mark.slow, - # ), - # pytest.param( - # ["h-ssp126", "h-ssp585"], - # False, - # False, - # "tas/multi_scen_multi_ens", - # ), - # # tas and tas**2 - # pytest.param( - # ["h-ssp126"], - # True, - # False, - # "tas_tas2/one_scen_one_ens", - # marks=pytest.mark.slow, - # ), - # # tas and hfds - # pytest.param( - # ["h-ssp126"], - # False, - # True, - # "tas_hfds/one_scen_one_ens", - # marks=pytest.mark.slow, - # ), - # # tas, tas**2, and hfds - # pytest.param( - # ["h-ssp126"], - # True, - # True, - # "tas_tas2_hfds/one_scen_one_ens", - # ), - # pytest.param( - # ["h-ssp585"], - # True, - # True, - # "tas_tas2_hfds/one_scen_multi_ens", - # marks=pytest.mark.slow, - # ), - # pytest.param( - # ["h-ssp126", "h-ssp585"], - # True, - # True, - # "tas_tas2_hfds/multi_scen_multi_ens", - # marks=pytest.mark.slow, - # ), + pytest.param( + ["ssp585"], + False, + False, + "tas/one_scen_multi_ens", + marks=pytest.mark.slow, + ), + pytest.param( + ["ssp126", "ssp585"], + False, + False, + "tas/multi_scen_multi_ens", + ), + # tas and tas**2 + pytest.param( + ["ssp126"], + True, + False, + "tas_tas2/one_scen_one_ens", + marks=pytest.mark.slow, + ), + # tas and hfds + pytest.param( + ["ssp126"], + False, + True, + "tas_hfds/one_scen_one_ens", + marks=pytest.mark.slow, + ), + # tas, tas**2, and hfds + pytest.param( + ["ssp126"], + True, + True, + "tas_tas2_hfds/one_scen_one_ens", + ), + pytest.param( + ["ssp585"], + True, + True, + "tas_tas2_hfds/one_scen_multi_ens", + marks=pytest.mark.slow, + ), + pytest.param( + ["ssp126", "ssp585"], + True, + True, + "tas_tas2_hfds/multi_scen_multi_ens", + marks=pytest.mark.slow, + ), ), ) def test_calibrate_mesmer( @@ -86,46 +88,102 @@ def test_calibrate_mesmer( REFERENCE_PERIOD = slice("1850", "1900") HIST_PERIOD = slice("1850", "2014") - PROJ_PERIOD = slice("2015", "2100") LOCALISATION_RADII = range(1750, 2001, 250) esm = "IPSL-CM6A-LR" - scenario = scenarios[0] test_cmip_generation = 6 # define paths and load data TEST_DATA_PATH = pathlib.Path(test_data_root_dir) - TEST_PATH = TEST_DATA_PATH / "output" / "tas" / "one_scen_one_ens" + TEST_PATH = TEST_DATA_PATH / "output" / outname cmip_data_path = ( TEST_DATA_PATH / "calibrate-coarse-grid" / f"cmip{test_cmip_generation}-ng" ) - path_tas = cmip_data_path / "tas" / "ann" / "g025" + CMIP_FILEFINDER = FileFinder( + path_pattern=cmip_data_path / "{variable}/{time_res}/{resolution}", + file_pattern="{variable}_{time_res}_{model}_{scenario}_{member}_{resolution}.nc", + ) + + fc_scens = CMIP_FILEFINDER.find_files( + variable="tas", scenario=scenarios, model=esm, resolution="g025", time_res="ann" + ) + + # only get the historical members that are also in the future scenarios, but only once + unique_scen_members = fc_scens.df.member.unique() + + fc_hist = CMIP_FILEFINDER.find_files( + variable="tas", + scenario="historical", + model=esm, + resolution="g025", + time_res="ann", + member=unique_scen_members, + ) - fN_hist = path_tas / f"tas_ann_{esm}_historical_r1i1p1f1_g025.nc" - fN_proj = path_tas / f"tas_ann_{esm}_{scenario}_r1i1p1f1_g025.nc" + fc_all = FileContainer(pandas.concat([fc_hist.df, fc_scens.df])) + + scenarios_whist = scenarios.copy() + scenarios_whist.append("historical") + + # load data for each scenario + dt = DataTree() + for scen in scenarios_whist: + files = fc_all.search(scenario=scen) + + # load all members for a scenario + members = [] + for fN, meta in files: + ds = xr.open_dataset(fN, use_cftime=True) + # drop unnecessary variables + ds = ds.drop_vars(["height", "time_bnds", "file_qf"], errors="ignore") + # assign member-ID as coordinate + ds = ds.assign_coords({"member": meta["member"]}) + members.append(ds) + + # create a Dataset that holds each member along the member dimension + scen_data = xr.concat(members, dim="member") + # put the scenario dataset into the DataTree + dt[f"{scen}"] = DataTree(scen_data) + + # load additional data + if use_hfds: + fc_hfds = CMIP_FILEFINDER.find_files( + variable="hfds", + scenario=scenarios_whist, + model=esm, + resolution="g025", + time_res="ann", + member=unique_scen_members, + ) - tas = xr.open_mfdataset( - [fN_hist, fN_proj], - combine="by_coords", - use_cftime=True, - combine_attrs="override", - data_vars="minimal", - compat="override", - coords="minimal", - drop_variables=["height", "file_qf"], - ).load() + dt_hfds = DataTree() + for scen in scenarios_whist: + files = fc_hfds.search(scenario=scen) + + members = [] + for fN, meta in files: + ds = xr.open_dataset(fN, use_cftime=True) + ds = ds.drop_vars( + ["height", "time_bnds", "file_qf", "area"], errors="ignore" + ) + ds = ds.assign_coords({"member": meta["member"]}) + members.append(ds) + + scen_data = xr.concat(members, dim="member") + dt_hfds[f"{scen}"] = DataTree(scen_data) # data preprocessing # create global mean tas anomlies timeseries - tas = mesmer.grid.wrap_to_180(tas) + dt = map_over_subtree(mesmer.grid.wrap_to_180)(dt) # convert the 0..360 grid to a -180..180 grid to be consistent with legacy code - ref = tas.sel(time=REFERENCE_PERIOD).mean("time", keep_attrs=True) - tas = tas - ref - tas_globmean = mesmer.weighted.global_mean(tas) + # calculate anomalies w.r.t. the reference period + ref = dt["historical"].sel(time=REFERENCE_PERIOD).mean("time") + tas_anoms = dt - ref.ds + tas_globmean = map_over_subtree(mesmer.weighted.global_mean)(tas_anoms) # create local gridded tas data def mask_and_stack(ds, threshold_land): @@ -134,95 +192,118 @@ def mask_and_stack(ds, threshold_land): ds = mesmer.grid.stack_lat_lon(ds) return ds - tas_stacked = mask_and_stack(tas, threshold_land=THRESHOLD_LAND) + tas_stacked = map_over_subtree(mask_and_stack)( + tas_anoms, threshold_land=THRESHOLD_LAND + ) # train global trend module - tas_globmean_lowess = mesmer.stats.lowess( - tas_globmean, "time", n_steps=50, use_coords=False + tas_globmean_smoothed = map_over_subtree(mesmer.stats.lowess)( + tas_globmean.mean(dim="member"), "time", n_steps=50, use_coords=False + ) + hist_lowess_residuals = ( + tas_globmean["historical"] - tas_globmean_smoothed["historical"] ) - tas_lowess_residuals = tas_globmean - tas_globmean_lowess volcanic_params = mesmer.volc.fit_volcanic_influence( - tas_lowess_residuals.tas, hist_period=HIST_PERIOD, dim="time" + hist_lowess_residuals.tas, hist_period=HIST_PERIOD, dim="time" ) - tas_globmean_volc = mesmer.volc.superimpose_volcanic_influence( - tas_globmean_lowess, volcanic_params, hist_period=HIST_PERIOD, dim="time" + tas_globmean_smoothed["historical"] = mesmer.volc.superimpose_volcanic_influence( + tas_globmean_smoothed["historical"], + volcanic_params, + hist_period=HIST_PERIOD, + dim="time", ) # train global variability module - def _split_hist_proj( - obj, dim="time", hist_period=HIST_PERIOD, proj_period=PROJ_PERIOD - ): - hist = obj.sel({dim: hist_period}) - proj = obj.sel({dim: proj_period}) - - return hist, proj - - tas_hist_globmean_smooth_volc, tas_proj_smooth = _split_hist_proj(tas_globmean_volc) - - tas_hist_resid_novolc = tas_globmean - tas_hist_globmean_smooth_volc - tas_proj_resid = tas_globmean - tas_proj_smooth - - data = (tas_hist_resid_novolc.tas, tas_proj_resid.tas) + tas_resid_novolc = tas_globmean - tas_globmean_smoothed - ar_order = mesmer.stats._select_ar_order_scen_ens( - *data, dim="time", ens_dim="ens", maxlag=12, ic="bic" + ar_order = mesmer.stats.select_ar_order_scen_ens( + tas_resid_novolc, dim="time", ens_dim="member", maxlag=12, ic="bic" ) - global_ar_params = mesmer.stats._fit_auto_regression_scen_ens( - *data, dim="time", ens_dim="ens", lags=ar_order + global_ar_params = mesmer.stats.fit_auto_regression_scen_ens( + tas_resid_novolc, dim="time", ens_dim="member", lags=ar_order ) + if use_hfds: + hfds_ref = dt_hfds["historical"].sel(time=REFERENCE_PERIOD).mean("time") + hfds_anoms = dt_hfds - hfds_ref.ds + hfds_globmean = map_over_subtree(mesmer.weighted.global_mean)(hfds_anoms) + hfds_globmean_smoothed = map_over_subtree(mesmer.stats.lowess)( + hfds_globmean.mean(dim="member"), "time", n_steps=50, use_coords=False + ) + # train local forced response module - predictors_split = { - "tas_globmean": [tas_hist_globmean_smooth_volc.tas, tas_proj_smooth.tas], - "tas_globmean_resid": [tas_hist_resid_novolc.tas, tas_proj_resid.tas], - } + # broadcast so all datasets have all the dimensions + # gridcell can be excluded because it will be mapped in the Linear Regression + target = tas_stacked + predictors = DataTree.from_dict( + {"tas": tas_globmean_smoothed, "tas_resids": tas_resid_novolc} + ) + if use_tas2: + predictors["tas2"] = tas_globmean_smoothed**2 + if use_hfds: + predictors["hfds"] = hfds_globmean_smoothed - predictors = dict() - for key, value in predictors_split.items(): - predictors[key] = xr.concat(value, dim="time") + weights = mesmer.weighted.create_equal_scenario_weights_from_datatree( + target, ens_dim="member", exclude="gridcell" + ) + + predictors_stacked, target_stacked, weights_stacked = ( + mesmer.core.datatree.stack_linear_regression_datatrees( + predictors, target, weights, stacking_dims=["member", "time"] + ) + ) local_forced_response_lr = mesmer.stats.LinearRegression() local_forced_response_lr.fit( - predictors=predictors, - target=tas_stacked.tas, - dim="time", # switch to sample? + predictors=predictors_stacked, + target=target_stacked.tas, + dim="sample", + weights=weights_stacked.weights, ) # train local variability module # train local AR process tas_stacked_residuals = local_forced_response_lr.residuals( - predictors=predictors, target=tas_stacked.tas - ) - - tas_stacked_residuals_hist, tas_stacked_residuals_proj = _split_hist_proj( - tas_stacked_residuals - ) + predictors=predictors_stacked, target=target_stacked.tas + ).T + + tas_un_stacked_residuals = tas_stacked_residuals.set_index( + sample=("time", "member", "scenario") + ).unstack("sample") + + dt_resids = DataTree() + for scenario in tas_un_stacked_residuals.scenario.values: + dt_resids[scenario] = DataTree( + tas_un_stacked_residuals.sel(scenario=scenario) + .dropna("member", how="all") + .dropna("time") + .drop_vars("scenario") + .rename("residuals") + ) - data = (tas_stacked_residuals_hist, tas_stacked_residuals_proj) - local_ar_params = mesmer.stats._fit_auto_regression_scen_ens( - *data, - ens_dim="none", + local_ar_params = mesmer.stats.fit_auto_regression_scen_ens( + dt_resids, + ens_dim="member", dim="time", lags=1, ) # train covariance - geodist = mesmer.geospatial.geodist_exact(tas_stacked.lon, tas_stacked.lat) + geodist = mesmer.geospatial.geodist_exact( + tas_stacked["historical"].ds.lon, tas_stacked["historical"].ds.lat + ) phi_gc_localizer = mesmer.stats.gaspari_cohn_correlation_matrices( geodist, localisation_radii=LOCALISATION_RADII ) - weights = xr.ones_like(tas_globmean.tas) # equal weights (for now?) - weights.name = "weights" - - dim = "time" # rename to "sample" + dim = "sample" k_folds = 30 localized_ecov = mesmer.stats.find_localized_empirical_covariance( - tas_stacked_residuals, weights, phi_gc_localizer, dim, k_folds + tas_stacked_residuals, weights_stacked.weights, phi_gc_localizer, dim, k_folds ) localized_ecov["localized_covariance_adjusted"] = ( @@ -251,7 +332,7 @@ def assert_params_allclose( fN_bundle = TEST_PATH / "test-mesmer-bundle.pkl" bundle = joblib.load(fN_bundle) - # TODO: Test volcanic influence params too + # TODO: Test volcanic influence params too (not in bundle) # global variability np.testing.assert_allclose( diff --git a/tests/unit/test_auto_regression.py b/tests/unit/test_auto_regression.py index 92d7c395..4793ccfa 100644 --- a/tests/unit/test_auto_regression.py +++ b/tests/unit/test_auto_regression.py @@ -4,6 +4,7 @@ import pandas as pd import pytest import xarray as xr +from datatree import DataTree, map_over_subtree from packaging.version import Version import mesmer @@ -172,6 +173,49 @@ def test_draw_auto_regression_uncorrelated( ) +def test_draw_auto_regression_uncorrelated_dt(ar_params_1D): + seeds = DataTree.from_dict( + { + "scen1": xr.DataArray(np.array([25])).rename("seed"), + "scen2": xr.DataArray(np.array([42])).rename("seed"), + } + ) + n_realization = 10 + n_ts = 20 + + result = map_over_subtree(mesmer.stats.draw_auto_regression_uncorrelated)( + ar_params_1D, + time=n_ts, + realisation=n_realization, + seed=seeds, + buffer=10, + time_dim="time", + realisation_dim="realisation", + ) + + assert result["scen1"].to_dataset().var() is not result["scen2"].to_dataset().var() + _check_dataset_form( + result["scen1"].to_dataset(), "result", required_vars={"samples"} + ) + _check_dataset_form( + result["scen2"].to_dataset(), "result", required_vars={"samples"} + ) + _check_dataarray_form( + result["scen1"].samples, + "samples", + ndim=2, + required_dims={"time", "realisation"}, + shape=(n_ts, n_realization), + ) + _check_dataarray_form( + result["scen2"].samples, + "samples", + ndim=2, + required_dims={"time", "realisation"}, + shape=(n_ts, n_realization), + ) + + @pytest.mark.parametrize("dim", ("time", "realisation")) @pytest.mark.parametrize("wrong_coords", (None, 2.0, np.array([1, 2]), xr.Dataset())) def test_draw_auto_regression_uncorrelated_wrong_coords( @@ -278,6 +322,50 @@ def test_draw_auto_regression_correlated( ) +def test_draw_auto_regression_correlated_dt(ar_params_2D, covariance): + seeds = DataTree.from_dict( + { + "scen1": xr.DataArray(np.array([25])).rename("seed"), + "scen2": xr.DataArray(np.array([42])).rename("seed"), + } + ) + n_realization = 10 + n_ts = 20 + + result = map_over_subtree(mesmer.stats.draw_auto_regression_correlated)( + ar_params_2D, + covariance, + time=n_ts, + realisation=n_realization, + seed=seeds, + buffer=10, + time_dim="time", + realisation_dim="realisation", + ) + + assert result["scen1"].to_dataset().var() is not result["scen2"].to_dataset().var() + _check_dataset_form( + result["scen1"].to_dataset(), "result", required_vars={"samples"} + ) + _check_dataset_form( + result["scen2"].to_dataset(), "result", required_vars={"samples"} + ) + _check_dataarray_form( + result["scen1"].samples, + "samples", + ndim=3, + required_dims={"time", "realisation", "gridcell"}, + shape=(n_ts, 2, n_realization), + ) + _check_dataarray_form( + result["scen2"].samples, + "samples", + ndim=3, + required_dims={"time", "realisation", "gridcell"}, + shape=(n_ts, 2, n_realization), + ) + + @pytest.mark.parametrize("dim", ("time", "realisation")) @pytest.mark.parametrize("wrong_coords", (None, 2.0, np.array([1, 2]), xr.Dataset())) def test_draw_auto_regression_correlated_wrong_coords( @@ -530,7 +618,7 @@ def test_fit_auto_regression_xr_1D(lags): _check_dataset_form( res, "_fit_auto_regression_result", - required_vars=["intercept", "coeffs", "variance"], + required_vars={"intercept", "coeffs", "variance"}, ) _check_dataarray_form(res.intercept, "intercept", ndim=0, shape=()) @@ -555,7 +643,7 @@ def test_fit_auto_regression_xr_2D(lags): _check_dataset_form( res, "_fit_auto_regression_result", - required_vars=["intercept", "coeffs", "variance"], + required_vars={"intercept", "coeffs", "variance"}, ) _check_dataarray_form(res.intercept, "intercept", ndim=1, shape=(n_cells,)) diff --git a/tests/unit/test_auto_regression_scen_ens.py b/tests/unit/test_auto_regression_scen_ens.py index 7dcc8bb0..c50791a0 100644 --- a/tests/unit/test_auto_regression_scen_ens.py +++ b/tests/unit/test_auto_regression_scen_ens.py @@ -1,6 +1,7 @@ import numpy as np import pytest import xarray as xr +from datatree import DataTree from statsmodels.tsa.arima_process import ArmaProcess import mesmer @@ -16,15 +17,15 @@ def generate_ar_samples(ar, std=1, n_timesteps=100, n_ens=4): da = xr.DataArray(data, dims=("time", "ens"), coords={"ens": ens}) - return da + return da.rename("data") def test_select_ar_order_scen_ens_one_scen(): - da = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=4) + dt = DataTree(generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=4)) - result = mesmer.stats._select_ar_order_scen_ens( - da, dim="time", ens_dim="ens", maxlag=5 + result = mesmer.stats.select_ar_order_scen_ens( + dt, dim="time", ens_dim="ens", maxlag=5 ) expected = xr.DataArray(3, coords={"quantile": 0.5}) @@ -37,8 +38,10 @@ def test_select_ar_order_scen_ens_multi_scen(): da1 = generate_ar_samples([1, 0.5, 0.3], n_timesteps=100, n_ens=4) da2 = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=4) - result = mesmer.stats._select_ar_order_scen_ens( - da1, da2, dim="time", ens_dim="ens", maxlag=5 + dt = DataTree.from_dict({"scen1": da1, "scen2": da2}) + + result = mesmer.stats.select_ar_order_scen_ens( + dt, dim="time", ens_dim="ens", maxlag=5 ) expected = xr.DataArray(2, coords={"quantile": 0.5}) @@ -48,10 +51,10 @@ def test_select_ar_order_scen_ens_multi_scen(): def test_select_ar_order_scen_ens_no_ens_dim(): - da = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=4) + dt = DataTree(generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=4)) - result = mesmer.stats._select_ar_order_scen_ens( - da, dim="time", ens_dim=None, maxlag=5 + result = mesmer.stats.select_ar_order_scen_ens( + dt, dim="time", ens_dim=None, maxlag=5 ) ens = [0, 1, 2, 3] @@ -68,8 +71,8 @@ def test_fit_auto_regression_scen_ens_one_scen(std): n_timesteps = 100 da = generate_ar_samples([1, 0.5, 0.3, 0.4], std, n_timesteps=n_timesteps, n_ens=4) - result = mesmer.stats._fit_auto_regression_scen_ens( - da, dim="time", ens_dim="ens", lags=3 + result = mesmer.stats.fit_auto_regression_scen_ens( + DataTree(da), dim="time", ens_dim="ens", lags=3 ) expected = mesmer.stats.fit_auto_regression(da, dim="time", lags=3) @@ -83,8 +86,10 @@ def test_fit_auto_regression_scen_ens_multi_scen(): da1 = generate_ar_samples([1, 0.5, 0.3], n_timesteps=100, n_ens=4) da2 = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=5) - result = mesmer.stats._fit_auto_regression_scen_ens( - da1, da2, dim="time", ens_dim="ens", lags=3 + dt = DataTree.from_dict({"scen1": da1, "scen2": da2}) + + result = mesmer.stats.fit_auto_regression_scen_ens( + dt, dim="time", ens_dim="ens", lags=3 ) da = xr.concat([da1, da2], dim="scen") @@ -101,10 +106,135 @@ def test_fit_auto_regression_scen_ens_no_ens_dim(): da = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=4) # simply fits each ens individually, no averaging - result = mesmer.stats._fit_auto_regression_scen_ens( - da, dim="time", ens_dim=None, lags=3 + result = mesmer.stats.fit_auto_regression_scen_ens( + DataTree(da), dim="time", ens_dim=None, lags=3 ) expected = mesmer.stats.fit_auto_regression(da, dim="time", lags=3) xr.testing.assert_allclose(result, expected) + + +def test_fit_auto_regression_scen_ens_no_ens_dim_multi_scen(): + + da1 = ( + generate_ar_samples([1, 0.5, 0.3], n_timesteps=100, n_ens=1) + .sel(ens=0) + .drop_vars("ens") + ) + da2 = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=5) + + dt = DataTree.from_dict({"scen1": da1, "scen2": da2}) + + with pytest.raises(ValueError, match="All datasets must have the same dimensions"): + mesmer.stats.fit_auto_regression_scen_ens(dt, dim="time", ens_dim=None, lags=3) + + +# =============================================================================== +# tests for list of scenarios +# =============================================================================== +@pytest.mark.filterwarnings("ignore:Passing a list of DataArrays will be deprecated") +def test_select_ar_order_scen_ens_one_scen_list(): + + da = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=4) + + result = mesmer.stats.select_ar_order_scen_ens( + [da], dim="time", ens_dim="ens", maxlag=5 + ) + + expected = xr.DataArray(3, coords={"quantile": 0.5}) + + xr.testing.assert_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:Passing a list of DataArrays will be deprecated") +def test_select_ar_order_scen_ens_multi_scen_list(): + + da1 = generate_ar_samples([1, 0.5, 0.3], n_timesteps=100, n_ens=4) + da2 = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=4) + + result = mesmer.stats.select_ar_order_scen_ens( + [da1, da2], dim="time", ens_dim="ens", maxlag=5 + ) + + expected = xr.DataArray(2, coords={"quantile": 0.5}) + + xr.testing.assert_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:Passing a list of DataArrays will be deprecated") +def test_select_ar_order_scen_ens_no_ens_dim_list(): + + da = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=4) + + result = mesmer.stats.select_ar_order_scen_ens( + [da], dim="time", ens_dim=None, maxlag=5 + ) + + ens = [0, 1, 2, 3] + expected = xr.DataArray( + [3, 1, 3, 3], dims="ens", coords={"quantile": 0.5, "ens": ens} + ) + + xr.testing.assert_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:Passing a list of DataArrays will be deprecated") +@pytest.mark.parametrize("std", [1, 0.1, 0.5]) +def test_fit_auto_regression_scen_ens_one_scen_list(std): + + n_timesteps = 100 + da = generate_ar_samples([1, 0.5, 0.3, 0.4], std, n_timesteps=n_timesteps, n_ens=4) + + result = mesmer.stats.fit_auto_regression_scen_ens( + [da], dim="time", ens_dim="ens", lags=3 + ) + + expected = mesmer.stats.fit_auto_regression(da, dim="time", lags=3) + expected = expected.mean("ens") + + xr.testing.assert_allclose(result, expected) + np.testing.assert_allclose(np.sqrt(result.variance), std, rtol=1e-1) + + +@pytest.mark.filterwarnings("ignore:Passing a list of DataArrays will be deprecated") +def test_fit_auto_regression_scen_ens_multi_scen_list(): + da1 = generate_ar_samples([1, 0.5, 0.3], n_timesteps=100, n_ens=4) + da2 = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=5) + + result = mesmer.stats.fit_auto_regression_scen_ens( + [da1, da2], dim="time", ens_dim="ens", lags=3 + ) + + da = xr.concat([da1, da2], dim="scen") + da = da.stack(scen_ens=("scen", "ens")).dropna("scen_ens") + expected = mesmer.stats.fit_auto_regression(da, dim="time", lags=3) + expected = expected.unstack("scen_ens") + expected = expected.mean("ens").mean("scen") + + xr.testing.assert_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:Passing a list of DataArrays will be deprecated") +def test_fit_auto_regression_scen_ens_no_ens_dim_list(): + + da = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=4) + + # simply fits each ens individually, no averaging + result = mesmer.stats.fit_auto_regression_scen_ens( + [da], dim="time", ens_dim=None, lags=3 + ) + + expected = mesmer.stats.fit_auto_regression(da, dim="time", lags=3) + xr.testing.assert_allclose(result, expected) + + +def test_fit_auto_regression_scen_ens_deprec_warning(): + da = generate_ar_samples([1, 0.5, 0.3, 0.4], n_timesteps=100, n_ens=4) + + with pytest.warns( + DeprecationWarning, match="Passing a list of DataArrays will be deprecated" + ): + mesmer.stats.fit_auto_regression_scen_ens( + [da], dim="time", ens_dim="ens", lags=3 + ) diff --git a/tests/unit/test_datatree.py b/tests/unit/test_datatree.py new file mode 100644 index 00000000..e47f3b35 --- /dev/null +++ b/tests/unit/test_datatree.py @@ -0,0 +1,261 @@ +import numpy as np +import pytest +import xarray as xr +from datatree import DataTree + +import mesmer +from mesmer.core.utils import _check_dataarray_form +from mesmer.testing import trend_data_1D, trend_data_2D + + +def test_collapse_datatree_into_dataset(): + n_ts = 30 + da1 = trend_data_1D(n_timesteps=n_ts).rename("tas") + da2 = da1 * 2 + da3 = da1 * 3 + + leaf1 = xr.concat([da1, da2, da3], dim="member").assign_coords( + {"member": np.arange(3)} + ) + leaf2 = xr.concat([da1, da2], dim="member").assign_coords({"member": np.arange(2)}) + + dt = DataTree.from_dict({"scen1": leaf1, "scen2": leaf2}) + + collapse_dim = "scenario" + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim=collapse_dim) + + assert isinstance(res, xr.Dataset) + assert collapse_dim in res.dims + assert (res[collapse_dim] == ["scen1", "scen2"]).all() + assert len(res.dims) == 3 + assert np.isnan(res.sel(scenario="scen2", member=2)).all() + + # error if data set has no coords along dim (bc then it is not concatenable if lengths differ) + leaf_missing_coords = leaf1.drop_vars("member") + dt = DataTree.from_dict({"scen1": leaf_missing_coords, "scen2": leaf2}) + with pytest.raises(ValueError, match="Dimension 'member' must have a coordinate"): + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim=collapse_dim) + + # Dimension along which to concatenate already exists + leaf1_scen = leaf1.assign_coords({"scenario": "scen1"}).expand_dims(collapse_dim) + leaf2_scen = leaf2.assign_coords({"scenario": "scen2"}).expand_dims(collapse_dim) + dt = DataTree.from_dict({"scen1": leaf1_scen, "scen2": leaf2_scen}) + + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim=collapse_dim) + assert isinstance(res, xr.Dataset) + + scen1 = res.sel(scenario="scen1").tas + xr.testing.assert_equal(scen1.drop_vars("scenario"), leaf1) + + # only one leaf works + dt = DataTree.from_dict({"scen1": leaf1}) + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim=collapse_dim) + + assert isinstance(res, xr.Dataset) + assert collapse_dim in res.dims + assert (res[collapse_dim] == ["scen1"]).all() + assert len(res.dims) == 3 + + xr.testing.assert_equal(scen1.drop_vars(collapse_dim), leaf1) + + # nested DataTree works + dt = DataTree() + dt["scen1/sub_scen1"] = DataTree(leaf1) + dt["scen1/sub_scen2"] = DataTree(leaf2) + dt["scen2"] = DataTree(leaf2) + + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim=collapse_dim) + assert isinstance(res, xr.Dataset) + assert collapse_dim in res.dims + assert len(res.dims) == 3 + assert (res[collapse_dim] == ["sub_scen1", "sub_scen2", "scen2"]).all() + + # more than one datavariable - works and fills with nans if necessary + ds = da3.rename("tas2") + + leaf3 = xr.merge( + [da1.assign_coords({"member": 1}), ds.assign_coords({"member": 1})] + ).expand_dims("member") + dt = DataTree.from_dict({"scen1": leaf1, "scen2": leaf2, "scen3": leaf3}) + + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim=collapse_dim) + assert isinstance(res, xr.Dataset) + assert collapse_dim in res.dims + assert len(res.dims) == 3 + assert (res[collapse_dim] == ["scen1", "scen2", "scen3"]).all() + assert len(res.data_vars) == 2 + assert np.isnan(res.sel(scenario="scen1").tas2).all() + + # two time dimensions that have different length fills missing values with nans + da_with_different_time = da1.shift(time=1) + + badleaf = da_with_different_time.assign_coords({"member": 0}).expand_dims("member") + dt = DataTree.from_dict({"scen1": leaf1, "scen2": badleaf}) + + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim=collapse_dim) + + assert np.isnan(res.sel(scenario="scen2", time=leaf1.time)).all() + + # missing dimension raises error + leaf_missing_dim = leaf1.sel(member=0).drop_vars("member") * 2 + + dt = DataTree.from_dict({"scen1": leaf1, "scen2": leaf_missing_dim}) + with pytest.raises(ValueError, match="All datasets must have the same dimensions"): + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim="scenario") + + # different dimensions raises error + leaf_diff_dim = leaf1.sel(member=0).rename({"member": "ens"}) * 2 + + dt = DataTree.from_dict({"scen1": leaf1, "scen2": leaf_diff_dim}) + with pytest.raises(ValueError, match="All datasets must have the same dimensions"): + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim="scenario") + + # make sure it also works with stacked dimension + # NOTE: only works if the stacked dimension has the same size on all datasets + n_lat, n_lon = 2, 3 + da1 = mesmer.testing.trend_data_2D(n_timesteps=n_ts, n_lat=n_lat, n_lon=n_lon) + da2 = mesmer.testing.trend_data_2D(n_timesteps=n_ts, n_lat=n_lat, n_lon=n_lon) + + dt = DataTree.from_dict({"mem1": da1, "mem2": da2}) + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim="members") + + # only one leaf also works + da = trend_data_1D(n_timesteps=n_ts).rename("tas") + ds = xr.Dataset({"tas": da}) + dt = DataTree.from_dict({"ds": ds}) + + collapse_dim = "scenario" + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim=collapse_dim) + + expected = ds.expand_dims(collapse_dim).assign_coords( + {collapse_dim: np.array(["ds"])} + ) + xr.testing.assert_equal(res, expected) + + # empty nodes are removed before concatenating + # NOTE: implicitly this is already there in the other tests, since the root node is always empty + # but it is nice to have it explicitly too + dt = DataTree.from_dict({"scen1": leaf1, "scen2": DataTree()}) + res = mesmer.datatree.collapse_datatree_into_dataset(dt, dim=collapse_dim) + expected = ( + xr.Dataset({"tas": leaf1}) + .expand_dims(collapse_dim) + .assign_coords({collapse_dim: np.array(["scen1"])}) + ) + xr.testing.assert_equal(res, expected) + + +def test_extract_single_dataarray_from_dt(): + da = trend_data_1D(n_timesteps=30).rename("tas") + dt = DataTree.from_dict({"/": xr.Dataset({"tas": da})}) + + res = mesmer.datatree._extract_single_dataarray_from_dt(dt) + xr.testing.assert_equal(res, da) + + dt = DataTree.from_dict({"/": xr.Dataset({"tas": da, "tas2": da})}) + with pytest.raises( + ValueError, match="DataTree must have exactly one data variable." + ): + res = mesmer.datatree._extract_single_dataarray_from_dt(dt) + + dt = DataTree.from_dict( + {"scen1": xr.Dataset({"tas": da, "tas2": da}), "scen2": xr.Dataset({"tas": da})} + ) + + with pytest.raises(ValueError, match="DataTree must only contain one node."): + mesmer.datatree._extract_single_dataarray_from_dt(dt) + + with pytest.raises(ValueError, match="DataTree must contain data."): + mesmer.datatree._extract_single_dataarray_from_dt(DataTree()) + + +def test_stack_linear_regression_datatrees(): + n_ts, n_lat, n_lon = 30, 2, 3 + member_dim = "member" + time_dim = "time" + stacking_dims = [time_dim, member_dim] + collapse_dim = "scenario" + stacked_dim = "sample" + + d2D_1 = trend_data_2D(n_timesteps=n_ts, n_lat=n_lat, n_lon=n_lon) + d2D_2 = d2D_1 * 2 + d2D_3 = d2D_1 * 3 + d2D_4 = d2D_1 * 4 + d2D_5 = d2D_1 * 5 + + leaf1 = xr.concat([d2D_1, d2D_2, d2D_3], dim=member_dim).assign_coords( + {member_dim: np.arange(3)} + ) + leaf2 = xr.concat([d2D_4, d2D_5], dim=member_dim).assign_coords( + {member_dim: np.arange(2)} + ) + + target = DataTree.from_dict({"scen1": leaf1, "scen2": leaf2}) + + d1D_1 = trend_data_1D(n_timesteps=n_ts) + d1D_2 = d1D_1 * 2 + d1D_3 = d1D_1 * 3 + d1D_4 = d1D_1 * 4 + predictors = DataTree.from_dict( + { + "pred1": DataTree.from_dict({"scen1": d1D_1, "scen2": d1D_2}), + "pred2": DataTree.from_dict({"scen1": d1D_3, "scen2": d1D_4}), + } + ) + + weights = mesmer.weighted.create_equal_scenario_weights_from_datatree( + target, exclude="cells" + ) + + predictors_stacked, target_stacked, weights_stacked = ( + mesmer.datatree.stack_linear_regression_datatrees( + predictors, + target, + weights, + stacking_dims=stacking_dims, + collapse_dim=collapse_dim, + stacked_dim=stacked_dim, + ) + ) + + n_samples = n_ts * (2 + 3) # 2 members for scen1, 3 members for scen2 + + for pred in predictors_stacked.children: + da = predictors_stacked[pred].to_dataset().data + _check_dataarray_form( + da, name="pred1", ndim=1, required_dims={"sample"}, shape=(n_samples,) + ) + + _check_dataarray_form( + target_stacked.data, + ndim=2, + required_dims={"cells", "sample"}, + shape=(n_lat * n_lon, n_samples), + ) + _check_dataarray_form( + weights_stacked.weights, ndim=1, required_dims={"sample"}, shape=(n_samples,) + ) + + # check if datasets align + assert xr.align( + target_stacked, predictors_stacked["pred1"].to_dataset(), join="exact" + ) == (target_stacked, predictors_stacked["pred1"].to_dataset()) + assert xr.align( + target_stacked, predictors_stacked["pred2"].to_dataset(), join="exact" + ) == (target_stacked, predictors_stacked["pred2"].to_dataset()) + assert xr.align(target_stacked, weights_stacked, join="exact") == ( + target_stacked, + weights_stacked, + ) + + predictors_stacked, target_stacked, weights_stacked = ( + mesmer.datatree.stack_linear_regression_datatrees( + predictors, + target, + None, + stacking_dims=stacking_dims, + collapse_dim=collapse_dim, + stacked_dim=stacked_dim, + ) + ) + assert weights_stacked is None, "Weights should be None if not provided" diff --git a/tests/unit/test_linear_regression.py b/tests/unit/test_linear_regression.py index 46bc43e0..7621abf3 100644 --- a/tests/unit/test_linear_regression.py +++ b/tests/unit/test_linear_regression.py @@ -4,6 +4,7 @@ import numpy.testing as npt import pytest import xarray as xr +from datatree import DataTree import mesmer from mesmer.testing import trend_data_1D, trend_data_2D @@ -87,10 +88,16 @@ def test_lr_predict(as_2D): ) lr.params = params if as_2D else params.squeeze() - tas = xr.DataArray([0, 1, 2], dims="time") + tas = xr.DataArray([0, 1, 2], dims="time").rename("tas") result = lr.predict({"tas": tas}) - expected = xr.DataArray([[5, 8, 11]], dims=("x", "time")) + expected = xr.DataArray([[5, 8, 11]], dims=("x", "time")).rename("tas") + expected = expected if as_2D else expected.squeeze() + + xr.testing.assert_equal(result, expected) + + result = lr.predict(DataTree.from_dict({"tas": tas})) + expected = xr.DataArray([[5, 8, 11]], dims=("x", "time")).rename("tas") expected = expected if as_2D else expected.squeeze() xr.testing.assert_equal(result, expected) @@ -115,11 +122,22 @@ def test_lr_predict_missing_superfluous(): with pytest.raises(ValueError, match="Missing predictors: 'tas'"): lr.predict({"tas2": None}) - with pytest.raises(ValueError, match="Superfluous predictors: 'something else'"): - lr.predict({"tas": None, "tas2": None, "something else": None}) + with pytest.warns(match="Superfluous predictors: 'something else'"): + lr.predict({"tas": 1, "tas2": 2, "something else": None}) + + with pytest.warns(match="Superfluous predictors: 'bar', 'foo'"): + lr.predict({"tas": 1, "tas2": 2, "foo": None, "bar": None}) + + with pytest.warns(match="Superfluous predictors: 'tas2'"): + lr.predict({"tas": 1, "tas2": 2}, exclude="tas2") - with pytest.raises(ValueError, match="Superfluous predictors: 'bar', 'foo'"): - lr.predict({"tas": None, "tas2": None, "foo": None, "bar": None}) + with pytest.warns(match="Superfluous predictors: 'tas2'"): + lr.predict( + DataTree.from_dict( + {"tas": xr.Dataset({"tas": 1}), "tas2": xr.Dataset({"tas2": 2})} + ), + exclude="tas2", + ) @pytest.mark.parametrize("as_2D", [True, False]) @@ -157,6 +175,41 @@ def test_lr_predict_exclude(as_2D): xr.testing.assert_equal(result, expected) +@pytest.mark.parametrize("as_2D", [True, False]) +def test_lr_predict_exclude_dt(as_2D): + lr = mesmer.stats.LinearRegression() + + params = xr.Dataset( + data_vars={ + "intercept": ("x", [5]), + "fit_intercept": True, + "tas": ("x", [3]), + "tas2": ("x", [1]), + } + ) + lr.params = params if as_2D else params.squeeze() + + tas = xr.DataArray([0, 1, 2], dims="time").rename("tas") + + result = lr.predict(DataTree.from_dict({"tas": tas}), exclude="tas2") + expected = xr.DataArray([[5, 8, 11]], dims=("x", "time")) + expected = expected if as_2D else expected.squeeze() + + xr.testing.assert_equal(result, expected) + + result = lr.predict(DataTree.from_dict({"tas": tas}), exclude={"tas2"}) + expected = xr.DataArray([[5, 8, 11]], dims=("x", "time")) + expected = expected if as_2D else expected.squeeze() + + xr.testing.assert_equal(result, expected) + + result = lr.predict(DataTree.from_dict({}), exclude={"tas", "tas2"}) + expected = xr.DataArray([5], dims="x") + expected = expected if as_2D else expected.squeeze() + + xr.testing.assert_equal(result, expected) + + @pytest.mark.parametrize("as_2D", [True, False]) def test_lr_predict_exclude_intercept(as_2D): lr = mesmer.stats.LinearRegression() @@ -206,6 +259,27 @@ def test_LR_residuals(as_2D): xr.testing.assert_equal(expected, result) +@pytest.mark.parametrize("as_2D", [True, False]) +def test_LR_residuals_dt(as_2D): + + lr = mesmer.stats.LinearRegression() + + params = xr.Dataset( + data_vars={"intercept": ("x", [5]), "fit_intercept": True, "tas": ("x", [0])} + ) + lr.params = params if as_2D else params.squeeze() + + tas = xr.DataArray([0, 1, 2], dims="time").rename("tas") + target = xr.DataArray([[5, 8, 0]], dims=("x", "time")) + target = target if as_2D else target.squeeze() + + result = lr.residuals(DataTree.from_dict({"tas": tas}), target) + expected = xr.DataArray([[0, 3, -5]], dims=("x", "time")) + expected = expected if as_2D else expected.squeeze() + + xr.testing.assert_equal(expected, result) + + # TEST XARRAY WRAPPER & LinearRegression().fit @pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION) def test_linear_regression_errors(lr_method_or_function): @@ -288,6 +362,20 @@ def test_missing_dim(pred0, pred1, tgt, weights, name): with pytest.raises(ValueError, match="dim cannot currently be 'predictor'."): lr_method_or_function({"pred0": pred0}, tgt, dim="predictor") + # test predictors have to be dict or DataTree + with pytest.raises( + TypeError, match="predictors should be a dict or DataTree, got ." + ): + lr_method_or_function([pred0, pred1], tgt, dim="time") + + # test DataTree depth + with pytest.raises(ValueError, match="DataTree must only contain one node."): + lr_method_or_function( + DataTree.from_dict({"scen0": DataTree.from_dict({"pred1": pred1})}), + tgt, + dim="time", + ) + @pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION) @pytest.mark.parametrize("intercept", (0, 3.14)) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index a47720cb..c0bd490f 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -7,7 +7,7 @@ import mesmer.core.utils -def make_dummy_yearly_data(freq, calendar="standard"): +def make_dummy_yearly_data(freq: str, calendar: str = "standard"): if freq == "YM": freq = "AS-JUL" if Version(pd.__version__) < Version("2.2") else "YS-JUL" time = xr.date_range( @@ -17,7 +17,7 @@ def make_dummy_yearly_data(freq, calendar="standard"): time = xr.date_range(start="2000", periods=5, freq=freq, calendar=calendar) data = xr.DataArray([1.0, 2.0, 3.0, 4.0, 5.0], dims=("time"), coords={"time": time}) - return data + return data.rename("tas") def make_dummy_monthly_data(freq, calendar="standard"): @@ -31,7 +31,7 @@ def make_dummy_monthly_data(freq, calendar="standard"): ) data = xr.DataArray(np.ones(5 * 12), dims=("time"), coords={"time": time}) - return data + return data.rename("tas") @pytest.mark.parametrize("freq_y", ["YM", "YS", "YE", "YS-JUL", "YS-NOV"]) diff --git a/tests/unit/test_weighted.py b/tests/unit/test_weighted.py index c0390938..efc045d2 100644 --- a/tests/unit/test_weighted.py +++ b/tests/unit/test_weighted.py @@ -1,6 +1,8 @@ +import datatree.testing import numpy as np import pytest import xarray as xr +from datatree import DataTree import mesmer @@ -170,3 +172,129 @@ def test_global_mean_weights_passed(as_dataset): expected = data.mean(("lat", "lon")) xr.testing.assert_allclose(result, expected) + + +def test_create_equal_sceanrio_weights_from_datatree(): + dt = DataTree() + + n_members_ssp119 = 3 + n_members_ssp585 = 2 + n_gridcells = 3 + n_ts = 30 + + dt["ssp119"] = DataTree( + xr.Dataset({"tas": xr.DataArray(np.arange(n_members_ssp119), dims="member")}) + ) + dt["ssp585"] = DataTree( + xr.Dataset({"tas": xr.DataArray(np.arange(n_members_ssp585), dims="member")}) + ) + result1 = mesmer.weighted.create_equal_scenario_weights_from_datatree(dt) + expected = DataTree.from_dict( + { + "ssp119": DataTree( + xr.DataArray( + np.ones(n_members_ssp119) / n_members_ssp119, + dims="member", + coords={"member": np.arange(n_members_ssp119)}, + ).rename("weights") + ), + "ssp585": DataTree( + xr.DataArray( + np.ones(n_members_ssp585) / n_members_ssp585, + dims="member", + coords={"member": np.arange(n_members_ssp585)}, + ).rename("weights") + ), + } + ) + + datatree.testing.assert_isomorphic(result1, expected) + datatree.testing.assert_equal(result1, expected) + + dt["ssp119"] = DataTree( + dt.ssp119.ds.expand_dims(gridcell=np.arange(n_gridcells), axis=1) + ) + dt["ssp585"] = DataTree( + dt.ssp585.ds.expand_dims(gridcell=np.arange(n_gridcells), axis=1) + ) + + result2 = mesmer.weighted.create_equal_scenario_weights_from_datatree( + dt, ens_dim="member", exclude={"gridcell"} + ) + datatree.testing.assert_equal(result2, expected) + + dt["ssp119"] = DataTree(dt.ssp119.ds.expand_dims(time=np.arange(n_ts), axis=1)) + dt["ssp585"] = DataTree(dt.ssp585.ds.expand_dims(time=np.arange(n_ts), axis=1)) + + result3 = mesmer.weighted.create_equal_scenario_weights_from_datatree( + dt, exclude={"gridcell"} + ) + expected = DataTree.from_dict( + { + "ssp119": DataTree( + xr.DataArray( + np.ones((n_members_ssp119, n_ts)) / n_members_ssp119, + dims=["member", "time"], + coords={ + "member": np.arange(n_members_ssp119), + "time": np.arange(n_ts), + }, + ).rename("weights") + ), + "ssp585": DataTree( + xr.DataArray( + np.ones((n_members_ssp585, n_ts)) / n_members_ssp585, + dims=["member", "time"], + coords={ + "member": np.arange(n_members_ssp585), + "time": np.arange(n_ts), + }, + ).rename("weights") + ), + } + ) + + # datatree.testing.assert_equal(result3, expected) + xr.testing.assert_equal(result3.ssp119.weights, expected.ssp119.weights) + xr.testing.assert_equal(result3.ssp585.weights, expected.ssp585.weights) + + result4 = mesmer.weighted.create_equal_scenario_weights_from_datatree( + dt, exclude={"time", "gridcell"} + ) + datatree.testing.assert_equal(result4, result1) + + +def test_create_equal_sceanrio_weights_from_datatree_checks(): + + dt = DataTree() + dt["ssp119"] = DataTree(xr.Dataset({"tas": xr.DataArray([1, 2, 3], dims="member")})) + dt["ssp585"] = DataTree(xr.Dataset({"tas": xr.DataArray([4, 5], dims="member")})) + + # too deep + dt_too_deep = dt.copy() + dt_too_deep["ssp585/1"] = DataTree( + xr.Dataset({"tas": xr.DataArray([4, 5], dims="member")}) + ) + with pytest.raises(ValueError, match="DataTree must have a depth of 1, not 2."): + mesmer.weighted.create_equal_scenario_weights_from_datatree(dt_too_deep) + + # missing member dimension + dt_no_member = dt.copy() + dt_no_member["ssp119"] = DataTree(dt_no_member.ssp119.ds.sel(member=1)) + with pytest.raises( + ValueError, match="Member dimension 'member' not found in dataset." + ): + mesmer.weighted.create_equal_scenario_weights_from_datatree(dt_no_member) + + # multiple data variables + dt_multiple_vars = dt.copy() + dt_multiple_vars["ssp119"] = DataTree( + xr.Dataset( + { + "tas": xr.DataArray([4, 5], dims="member"), + "tas2": xr.DataArray([4, 5], dims="member"), + } + ) + ) + with pytest.raises(ValueError, match="Dataset must have only one data variable."): + mesmer.weighted.create_equal_scenario_weights_from_datatree(dt_multiple_vars)