Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursively compute variables #263

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aospy/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def _get_pressure_vals(self, var, start_date, end_date):
try:
ps = self._ps_data
except AttributeError:
self._ps_data = self.data_loader.load_variable(
self._ps_data = self.data_loader.recursively_compute_variable(
self.ps, start_date, end_date, self.time_offset,
**self.data_loader_attrs)
name = self._ps_data.name
Expand Down Expand Up @@ -362,9 +362,9 @@ def _get_input_data(self, var, start_date, end_date):
cond_pfull = ((not hasattr(self, internal_names.PFULL_STR))
and var.def_vert and
self.dtype_in_vert == internal_names.ETA_STR)
data = self.data_loader.load_variable(var, start_date, end_date,
self.time_offset,
**self.data_loader_attrs)
data = self.data_loader.recursively_compute_variable(
var, start_date, end_date, self.time_offset,
**self.data_loader_attrs)
name = data.name
data = self._add_grid_attributes(data.to_dataset(name=data.name))
data = data[name]
Expand Down
36 changes: 36 additions & 0 deletions aospy/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,42 @@ def load_variable(self, var=None, start_date=None, end_date=None,
return times.sel_time(da, np.datetime64(start_date_xarray),
np.datetime64(end_date_xarray)).load()

def recursively_compute_variable(self, var, start_date=None, end_date=None,
time_offset=None, **DataAttrs):
"""Compute a variable recursively, loading data where needed

An obvious requirement here is that the variable must eventually be
able to be expressed in terms of model-native quantities; otherwise the
recursion will never stop.

Parameters
----------
var : Var
aospy Var object
start_date : datetime.datetime
start date for interval
end_date : datetime.datetime
end date for interval
time_offset : dict
Option to add a time offset to the time coordinate to correct for
incorrect metadata.
**DataAttrs
Attributes needed to identify a unique set of files to load from

Returns
-------
da : DataArray
DataArray for the specified variable, date range, and interval in
"""
if var.variables is None:
return self.load_variable(var, start_date, end_date, time_offset,
**DataAttrs)
else:
data = [self.recursively_compute_variable(
v, start_date, end_date, time_offset, **DataAttrs)
for v in var.variables]
return var.func(*data).rename(var.name)

@staticmethod
def _maybe_apply_time_shift(da, time_offset=None, **DataAttrs):
"""Apply specified time shift to DataArray"""
Expand Down
61 changes: 50 additions & 11 deletions aospy/test/test_calc_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import xarray as xr

from aospy import Var
from aospy.calc import Calc, _add_metadata_as_attrs
from .data.objects.examples import (
example_proj, example_model, example_run, var_not_time_defined,
condensation_rain, precip, sphum, globe, sahel
condensation_rain, convection_rain, precip, sphum, globe, sahel
)


Expand All @@ -38,18 +39,21 @@ def _test_files_and_attrs(calc, dtype_out):
_test_output_attrs(calc, dtype_out)


_BASIC_TEST_PARAMS = {
'proj': example_proj,
'model': example_model,
'run': example_run,
'var': condensation_rain,
'date_range': (datetime.datetime(4, 1, 1),
datetime.datetime(6, 12, 31)),
'intvl_in': 'monthly',
'dtype_in_time': 'ts'
}


class TestCalcBasic(unittest.TestCase):
def setUp(self):
self.test_params = {
'proj': example_proj,
'model': example_model,
'run': example_run,
'var': condensation_rain,
'date_range': (datetime.datetime(4, 1, 1),
datetime.datetime(6, 12, 31)),
'intvl_in': 'monthly',
'dtype_in_time': 'ts'
}
self.test_params = _BASIC_TEST_PARAMS.copy()

def tearDown(self):
for direc in [example_proj.direc_out, example_proj.tar_direc_out]:
Expand Down Expand Up @@ -207,5 +211,40 @@ def test_attrs(units, description, dtype_out_vert, expected_units,
assert expected_description == arr.attrs['description']


@pytest.fixture()
def recursive_test_params():
basic_params = _BASIC_TEST_PARAMS.copy()
recursive_params = basic_params.copy()

recursive_condensation_rain = Var(
name='recursive_condensation_rain',
variables=(precip, convection_rain), func=lambda x, y: x - y,
def_time=True)
recursive_params['var'] = recursive_condensation_rain

yield (basic_params, recursive_params)

for direc in [example_proj.direc_out, example_proj.tar_direc_out]:
shutil.rmtree(direc)


def test_recursive_calculation(recursive_test_params):
basic_params, recursive_params = recursive_test_params

calc = Calc(intvl_out='ann', dtype_out_time='av', **basic_params)
calc = calc.compute()
expected = xr.open_dataset(
calc.path_out['av'], autoclose=True)['condensation_rain']
_test_files_and_attrs(calc, 'av')

calc = Calc(intvl_out='ann', dtype_out_time='av', **recursive_params)
calc = calc.compute()
result = xr.open_dataset(
calc.path_out['av'], autoclose=True)['recursive_condensation_rain']
_test_files_and_attrs(calc, 'av')

xr.testing.assert_equal(expected, result)


if __name__ == '__main__':
unittest.main()
37 changes: 37 additions & 0 deletions aospy/test/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
import xarray as xr

from aospy import Var
from aospy.data_loader import (DataLoader, DictDataLoader, GFDLDataLoader,
NestedDictDataLoader, grid_attrs_to_aospy_names,
set_grid_attrs_as_coords, _sel_var,
Expand Down Expand Up @@ -609,6 +610,42 @@ def convert_all_to_missing_val(ds, **kwargs):
expected_num_non_missing = 0
self.assertEqual(num_non_missing, expected_num_non_missing)

def test_recursively_compute_variable_native(self):
result = self.data_loader.recursively_compute_variable(
condensation_rain, datetime(5, 1, 1), datetime(5, 12, 31),
intvl_in='monthly')
filepath = os.path.join(os.path.split(ROOT_PATH)[0], 'netcdf',
'00050101.precip_monthly.nc')
expected = _open_ds_catch_warnings(filepath)['condensation_rain']
np.testing.assert_array_equal(result.values, expected.values)

def test_recursively_compute_variable_one_level(self):
one_level = Var(
name='one_level', variables=(condensation_rain, condensation_rain),
func=lambda x, y: x + y)
result = self.data_loader.recursively_compute_variable(
one_level, datetime(5, 1, 1), datetime(5, 12, 31),
intvl_in='monthly')
filepath = os.path.join(os.path.split(ROOT_PATH)[0], 'netcdf',
'00050101.precip_monthly.nc')
expected = 2. * _open_ds_catch_warnings(filepath)['condensation_rain']
np.testing.assert_array_equal(result.values, expected.values)

def test_recursively_compute_variable_multi_level(self):
one_level = Var(
name='one_level', variables=(condensation_rain, condensation_rain),
func=lambda x, y: x + y)
multi_level = Var(
name='multi_level', variables=(one_level, condensation_rain),
func=lambda x, y: x + y)
result = self.data_loader.recursively_compute_variable(
multi_level, datetime(5, 1, 1), datetime(5, 12, 31),
intvl_in='monthly')
filepath = os.path.join(os.path.split(ROOT_PATH)[0], 'netcdf',
'00050101.precip_monthly.nc')
expected = 3. * _open_ds_catch_warnings(filepath)['condensation_rain']
np.testing.assert_array_equal(result.values, expected.values)


if __name__ == '__main__':
unittest.main()
11 changes: 8 additions & 3 deletions aospy/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,14 @@ def data_name_gfdl(name, domain, data_type, intvl_type, data_yr,

def dmget(files_list):
"""Call GFDL command 'dmget' to access archived files."""
if isinstance(files_list, str):
files_list = [files_list]

archive_files = []
for f in files_list:
if f.startswith('/archive'):
archive_files.append(f)
try:
if isinstance(files_list, str):
files_list = [files_list]
subprocess.call(['dmget'] + files_list)
subprocess.call(['dmget'] + archive_files)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was finding that the test suite would hang if dmget was called on an existing file that was not stored in /archive. I was able to reproduce this outside of aospy, so it has nothing to do with the changes I made to allow for the recursive calculation of variables; they must have made some updates to dmget. The logic here filters out files that aren't stored on archive before passing them to dmget.

except OSError:
logging.debug('dmget command not found in this machine')
21 changes: 21 additions & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,27 @@ prior ``Var`` constuctors. These signify the function to use and the
physical quantities to pass to that function in order to compute the
quantity.

As of aospy version 0.3, ``Var`` objects are computed
recursively; this means that as long as things eventually lead back to
model-native quantities, you can express a computed variable (i.e. one
with ``func`` and ``variables`` attributes) in terms of other computed
variables. For example we could equivalently express the
``precip_conv_frac`` more simply as the following:

.. ipython::

precip_conv_frac = Var(
name='precip_conv_frac',
def_time=True,
variables=(precip_convective, precip_total),
func=lambda conv, total: conv / total,
)

In this case, aospy will automatically know to load in
``precip_largescale`` and ``precip_convective`` in order to compute
``precip_total`` before passing it along to the function specified
in ``precip_conv_frac``. Any depth of recursion is supported.

.. note::

Although ``variables`` is passed a tuple of ``Var`` objects
Expand Down
3 changes: 3 additions & 0 deletions docs/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ Enhancements
<http://xarray.pydata.org/en/stable/generated/xarray.open_mfdataset.html>`_
for more information (closes :issue:`236` via :pull:`240`). By `Spencer
Clark <https://github.com/spencerkclark>`_.
- Allow for variables to be functions of other computed variables (closes
:issue:`3` via :pull:`263`). By `Spencer
Clark <https://github.com/spencerkclark>`_.

Bug Fixes
~~~~~~~~~
Expand Down