Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DataTree as data structure #537

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
e1a2dc4
add datatree to dependencies
veni-vidi-vici-dormivi Sep 20, 2024
0fec170
add collapse datatree util
veni-vidi-vici-dormivi Sep 23, 2024
23b7a51
switch utils in init
veni-vidi-vici-dormivi Sep 23, 2024
4d253c9
nit
veni-vidi-vici-dormivi Sep 23, 2024
8fc314f
implement autoregression
veni-vidi-vici-dormivi Sep 30, 2024
de4850f
test new autoregression functionality
veni-vidi-vici-dormivi Sep 30, 2024
c80869d
implement scenario weights func
veni-vidi-vici-dormivi Sep 30, 2024
ce198ac
adapt collapse datatree
veni-vidi-vici-dormivi Sep 30, 2024
d99973e
linting
veni-vidi-vici-dormivi Sep 30, 2024
397cef6
add fielfinder dependency
veni-vidi-vici-dormivi Oct 1, 2024
bb81f54
change and add integration tests (not consistent)
veni-vidi-vici-dormivi Oct 1, 2024
353d561
adapt old codepath autoregression_scen_ens
veni-vidi-vici-dormivi Oct 2, 2024
025fba4
linting
veni-vidi-vici-dormivi Oct 2, 2024
83cf397
Merge branch 'main' into trees
veni-vidi-vici-dormivi Oct 2, 2024
109cf92
update dependencies filefinder
veni-vidi-vici-dormivi Oct 2, 2024
be4c96c
implement and deprecate list in autoregression_scen_ens
veni-vidi-vici-dormivi Oct 2, 2024
3c51385
fix autoregression call in test
veni-vidi-vici-dormivi Oct 2, 2024
66e15f0
downgrade datatree and add pip dep
veni-vidi-vici-dormivi Oct 8, 2024
5a29072
add activating env
veni-vidi-vici-dormivi Oct 8, 2024
147fdc7
remove activate again
veni-vidi-vici-dormivi Oct 8, 2024
c7a7f1c
Merge branch 'main' into trees
veni-vidi-vici-dormivi Oct 8, 2024
356cfa9
try verifying micromamba path
veni-vidi-vici-dormivi Oct 8, 2024
24a5691
forgot to add it
veni-vidi-vici-dormivi Oct 8, 2024
fc10078
revert changes in ci-workflow
veni-vidi-vici-dormivi Oct 8, 2024
0286046
expand weighting function
veni-vidi-vici-dormivi Oct 9, 2024
452e0b6
add todo
veni-vidi-vici-dormivi Oct 10, 2024
6943c62
expand collapse_dt tests
veni-vidi-vici-dormivi Oct 11, 2024
93b5b4c
refine collapse datatree
veni-vidi-vici-dormivi Oct 14, 2024
4e6063c
linting in weighted
veni-vidi-vici-dormivi Oct 14, 2024
876ffcb
add datatree to arraydict
veni-vidi-vici-dormivi Oct 14, 2024
03f0083
implement seed dict in autoregression
veni-vidi-vici-dormivi Oct 14, 2024
0137af4
adapt linear regression
veni-vidi-vici-dormivi Oct 15, 2024
156ede7
adapt volc
veni-vidi-vici-dormivi Oct 15, 2024
df8e814
init
veni-vidi-vici-dormivi Oct 15, 2024
e791502
linting
veni-vidi-vici-dormivi Oct 15, 2024
d37cbdd
add tas**2 test
veni-vidi-vici-dormivi Oct 15, 2024
5b8751f
add hfds tests
veni-vidi-vici-dormivi Oct 15, 2024
046347c
test stack_linear_regression_data
veni-vidi-vici-dormivi Oct 17, 2024
7f38ce3
nit
veni-vidi-vici-dormivi Oct 17, 2024
dd31eed
Update mesmer/core/weighted.py
veni-vidi-vici-dormivi Oct 17, 2024
18d8b07
Update mesmer/core/utils.py
veni-vidi-vici-dormivi Oct 17, 2024
a62c6a4
Update mesmer/core/utils.py
veni-vidi-vici-dormivi Oct 18, 2024
f13734c
Update mesmer/core/utils.py
veni-vidi-vici-dormivi Oct 18, 2024
b5be5fd
broadcast~ed~
veni-vidi-vici-dormivi Oct 17, 2024
5223e26
outsurce datatree utils
veni-vidi-vici-dormivi Oct 18, 2024
1f5be62
renaming
veni-vidi-vici-dormivi Oct 18, 2024
c6aee41
get rid of dt to arraydict
veni-vidi-vici-dormivi Oct 18, 2024
422347e
fixes
veni-vidi-vici-dormivi Oct 18, 2024
94c2e42
linting
veni-vidi-vici-dormivi Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ci/requirements/min-all-deps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ dependencies:
- shapely=2.0
- statsmodels=0.14
- xarray=2023.7
- xarray-datatree=0.0.13
# for testing
- pytest
- pytest-cov
- pip
- pip:
- https://github.com/mathause/filefinder/archive/refs/tags/v0.3.0.zip
2 changes: 2 additions & 0 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Required dependencies
---------------------

- Python (3.10 or later)
- `filefinder <https://github.com/mathause/filefinder>`__
- `dask <https://dask.org/>`__
- `joblib <https://joblib.readthedocs.io/en/latest/>`__
- `netcdf4 <https://unidata.github.io/netcdf4-python/>`__
Expand All @@ -18,6 +19,7 @@ Required dependencies
- `scipy <https://scipy.org/>`__
- `statsmodels <https://www.statsmodels.org/stable/index.html>`__
- `xarray <http://xarray.pydata.org/>`__
- `xarray-datatree <https://xarray-datatree.readthedocs.io/en/latest/index.html>`__

Optional dependencies
---------------------
Expand Down
4 changes: 4 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ dependencies:
- scipy
- statsmodels>=0.13
- xarray>=2023.04 # because pandas 2 is required
- xarray-datatree
# for testing
- black!=23
- pytest
- pytest-cov
- ruff
- pip
- pip:
- https://github.com/mathause/filefinder/archive/refs/tags/v0.3.0.zip
3 changes: 2 additions & 1 deletion mesmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import calibrate_mesmer, core, create_emulations, io, stats, testing, utils
from .core import _data as data
from .core import geospatial, grid, mask, volc, weighted
from .core import datatree, geospatial, grid, mask, volc, weighted

# "legacy" modules
__all__ = [
Expand All @@ -25,6 +25,7 @@
__all__ += [
"core",
"data",
"datatree",
"geospatial",
"grid",
"mask",
Expand Down
10 changes: 5 additions & 5 deletions mesmer/calibrate_mesmer/train_gv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import xarray as xr

from mesmer.io.save_mesmer_bundle import save_mesmer_data
from mesmer.stats import _fit_auto_regression_scen_ens, _select_ar_order_scen_ens
from mesmer.stats import fit_auto_regression_scen_ens, select_ar_order_scen_ens


def train_gv(gv, targ, esm, cfg, save_params=True, **kwargs):
Expand Down Expand Up @@ -172,11 +172,11 @@ def train_gv_AR(params_gv, gv, max_lag, sel_crit):
# create temporary DataArray objects
data = [xr.DataArray(data, dims=["run", "time"]) for data in gv.values()]

AR_order = _select_ar_order_scen_ens(
*data, dim="time", ens_dim="run", maxlag=max_lag, ic=sel_crit
AR_order = select_ar_order_scen_ens(
data, dim="time", ens_dim="run", maxlag=max_lag, ic=sel_crit
)
params = _fit_auto_regression_scen_ens(
*data, dim="time", ens_dim="run", lags=AR_order
params = fit_auto_regression_scen_ens(
data, dim="time", ens_dim="run", lags=AR_order
)

# TODO: remove np.float64(...) (only here so the tests pass)
Expand Down
4 changes: 2 additions & 2 deletions mesmer/calibrate_mesmer/train_lv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

from mesmer.io.save_mesmer_bundle import save_mesmer_data
from mesmer.stats import (
_fit_auto_regression_scen_ens,
adjust_covariance_ar1,
find_localized_empirical_covariance,
fit_auto_regression_scen_ens,
)

from .train_utils import get_scenario_weights, stack_predictors_and_targets
Expand Down Expand Up @@ -233,7 +233,7 @@ def train_lv_AR1_sci(params_lv, targs, y, wgt_scen_eq, aux, cfg):
dims = ("run", "time", "cell")
data = [xr.DataArray(data, dims=dims) for data in targ.values()]

params = _fit_auto_regression_scen_ens(*data, dim="time", ens_dim="run", lags=1)
params = fit_auto_regression_scen_ens(data, dim="time", ens_dim="run", lags=1)

params_lv["AR1_int"][targ_name] = params.intercept.values
params_lv["AR1_coef"][targ_name] = params.coeffs.values.squeeze()
Expand Down
161 changes: 161 additions & 0 deletions mesmer/core/datatree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import xarray as xr
from datatree import DataTree


def _extract_single_dataarray_from_dt(dt: DataTree) -> xr.DataArray:
"""
Extract a single DataArray from a DataTree node, holding one ``Dataset`` with one ``DataArray``.
"""
# assert only one node in dt
if not len(list(dt.subtree)) == 1:
raise ValueError("DataTree must only contain one node.")
if not dt.has_data:
raise ValueError("DataTree must contain data.")

ds = dt.to_dataset()
if len(ds.data_vars) != 1:
raise ValueError("DataTree must have exactly one data variable.")

varname = list(ds.data_vars)[0]
da = ds.to_array().isel(variable=0).drop_vars("variable")
return da.rename(varname)


def collapse_datatree_into_dataset(dt: DataTree, dim: str) -> xr.Dataset:
"""
Take a ``DataTree`` and collapse **all subtrees** in it into a single ``xr.Dataset`` along dim.
All datasets in the ``DataTree`` must have the same dimensions and each dimension must have a coordinate.

Parameters
----------
dt : DataTree
The DataTree to collapse.
dim : str
The dimension to concatenate the datasets along.

Returns
-------
xr.Dataset
The collapsed dataset.

Raises
------
ValueError
If all datasets do not have the same dimensions.
If any dimension does not have a coordinate.
"""
# TODO: could potentially be replaced by DataTree.merge_child_nodes in the future?
datasets = [subtree.to_dataset() for subtree in dt.subtree if not subtree.is_empty]

# Check if all datasets have the same dimensions
first_dims = set(datasets[0].dims)
if not all(set(ds.dims) == first_dims for ds in datasets):
raise ValueError("All datasets must have the same dimensions")

# Check that all dimensions have coordinates
for ds in datasets:
for ds_dim in ds.dims:
if ds[ds_dim].coords == {}:
raise ValueError(
f"Dimension '{ds_dim}' must have a coordinate/coordinates."
)

# Concatenate datasets along the specified dimension
ds = xr.concat(datasets, dim=dim)
ds = ds.assign_coords(
{dim: [subtree.name for subtree in dt.subtree if not subtree.is_empty]}
)

return ds


def stack_linear_regression_datatrees(
predictors: DataTree,
target: DataTree,
weights: DataTree | None,
*,
stacking_dims: list[str],
collapse_dim: str = "scenario",
stacked_dim: str = "sample",
) -> tuple[DataTree, xr.Dataset, xr.Dataset | None]:
"""
prepares data for Linear Regression:
1. Broadcasts predictors to target
2. Collapses DataTrees into DataSets
3. Stacks the DataSets along the stacking dimension(s)

Parameters
----------
predictors : DataTree
A ``DataTree`` of ``xr.Dataset`` objects used as predictors. The ``DataTree``
must have subtrees for each predictor each of which has to have at least one
leaf, holding a ``xr.Dataset`` representing a scenario. The subtrees of
different predictors must be isomorphic (i.e. have the save scenarios). The ``xr.Dataset``
must at least contain `dim` and each ``xr.Dataset`` must only hold one data variable.
target : DataTree
A ``DataTree``holding the targets. Must be isomorphic to the predictor subtrees, i.e.
have the same scenarios. Each leaf must hold a ``xr.Dataset`` which must be at least 2D
and contain `dim`, but may also contain a dimension for ensemble members.
weights : DataTree, default: None.
Individual weights for each sample, must be isomorphic to target. Must at least contain
`dim`, and must have the ensemble member dimesnion if target has it.
stacking_dims : list[str]
Dimension(s) to stack.
collapse_dim : str, default: "scenario"
Dimension along which to collapse the DataTrees, will automatically be added to the
stacking dims.
stacked_dim : str, default: "sample"
Name of the stacked dimension.

Returns
-------
tuple
Tuple of the prepared predictors, target and weights, where the predictors and target are
stacked along the stacking dimensions and the weights are stacked along the stacking dimensions
and the ensemble member dimension.

Notes
-----
Dimensions which exist along the target but are not in the stacking_dims will be excluded from the
broadcasting of the predictors.
"""

stacking_dims_all = stacking_dims + [collapse_dim]

# exclude target dimensions from broadcasting which are not in the stacking_dims
exclude_dim = set(target.leaves[0].ds.dims) - set(stacking_dims)

# predictors need to be
predictors_stacked = DataTree()
for key, subtree in predictors.items():
# 1) broadcast to target
pred_broadcast = subtree.broadcast_like(target, exclude=exclude_dim)
# 2) collapsed into DataSets
predictor_ds = collapse_datatree_into_dataset(pred_broadcast, dim=collapse_dim)
# 3) stacked
predictors_stacked[key] = DataTree(
predictor_ds.stack(
{stacked_dim: stacking_dims_all}, create_index=False
).dropna(dim=stacked_dim)
)

# target needs to be
# 1) collapsed into DataSet
target_ds = collapse_datatree_into_dataset(target, dim=collapse_dim)
# 2) stacked
target_stacked = target_ds.stack(
{stacked_dim: stacking_dims_all}, create_index=False
).dropna(dim=stacked_dim)

# weights need to be
if weights is not None:
# 1) collapsed into DataSet
weights_ds = collapse_datatree_into_dataset(weights, dim=collapse_dim)
# 2) stacked
weights_stacked = weights_ds.stack(
{stacked_dim: stacking_dims_all}, create_index=False
).dropna(dim=stacked_dim)
else:
weights_stacked = None

return predictors_stacked, target_stacked, weights_stacked
11 changes: 6 additions & 5 deletions mesmer/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from collections.abc import Iterable

import numpy as np
import pandas as pd
Expand All @@ -22,7 +23,7 @@ def create_equal_dim_names(dim, suffixes):
return tuple(f"{dim}{suffix}" for suffix in suffixes)


def _minimize_local_discrete(func, sequence, **kwargs):
def _minimize_local_discrete(func, sequence: Iterable, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These typing fixes are good - could extract them (but also ok to keep them if that's too annoying)

Just FYI - the list of which 'protocol' needs to have which methods: https://docs.python.org/3/library/collections.abc.html

"""find the local minimum for a function that consumes discrete input

Parameters
Expand Down Expand Up @@ -150,7 +151,7 @@ def _check_dataset_form(
obj,
name: str = "obj",
*,
required_vars: str | set[str] = set(),
required_vars: str | set[str] | None = set(),
optional_vars: str | set[str] = set(),
requires_other_vars: bool = False,
):
Expand Down Expand Up @@ -199,17 +200,17 @@ def _check_dataarray_form(
obj,
name: str = "obj",
*,
ndim: int = None,
ndim: int | tuple[int, ...] | None = None,
required_dims: str | set[str] = set(),
shape=None,
shape: tuple[int, ...] | None = None,
):
"""check if a dataset conforms to some conditions

obj: Any
object to check.
name : str, default: 'obj'
Name to use in error messages.
ndim, int, optional
ndim : int | tuple of int, optional
Number of required dimensions, can be a tuple of int if several are possible.
required_dims: str, set of str, optional
Names of dims that are required for obj
Expand Down
3 changes: 2 additions & 1 deletion mesmer/core/volc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings

import xarray as xr
from datatree import DataTree

from mesmer.core._data import load_stratospheric_aerosol_optical_depth_obs
from mesmer.core.utils import _check_dataarray_form
Expand Down Expand Up @@ -98,7 +99,7 @@ def fit_volcanic_influence(tas_residuals, hist_period, *, dim="time", version="2

# TODO: name of 'aod'
lr.fit(
predictors={"aod": aod},
predictors=DataTree(aod, name="aod"),
target=tas_residuals,
dim=dim,
fit_intercept=False,
Expand Down
Loading
Loading