Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursively compute variables #263

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aospy/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def _get_pressure_vals(self, var, start_date, end_date):
try:
ps = self._ps_data
except AttributeError:
self._ps_data = self.data_loader.load_variable(
self._ps_data = self.data_loader.recursively_compute_variable(
self.ps, start_date, end_date, self.time_offset,
**self.data_loader_attrs)
name = self._ps_data.name
Expand Down Expand Up @@ -362,9 +362,9 @@ def _get_input_data(self, var, start_date, end_date):
cond_pfull = ((not hasattr(self, internal_names.PFULL_STR))
and var.def_vert and
self.dtype_in_vert == internal_names.ETA_STR)
data = self.data_loader.load_variable(var, start_date, end_date,
self.time_offset,
**self.data_loader_attrs)
data = self.data_loader.recursively_compute_variable(
var, start_date, end_date, self.time_offset,
**self.data_loader_attrs)
name = data.name
data = self._add_grid_attributes(data.to_dataset(name=data.name))
data = data[name]
Expand Down
12 changes: 12 additions & 0 deletions aospy/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,18 @@ def load_variable(self, var=None, start_date=None, end_date=None,
return times.sel_time(da, np.datetime64(start_date_xarray),
np.datetime64(end_date_xarray)).load()

def recursively_compute_variable(self, var, start_date=None, end_date=None,
time_offset=None, **DataAttrs):
"""Compute a variable recursively, loading data where needed"""
if var.variables is None:
return self.load_variable(var, start_date, end_date, time_offset,
**DataAttrs)
else:
data = [self.recursively_compute_variable(
v, start_date, end_date, time_offset, **DataAttrs)
for v in var.variables]
return var.func(*data).rename(var.name)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, good suggestion to add a test for this. One minor change and things work smoothly!


@staticmethod
def _maybe_apply_time_shift(da, time_offset=None, **DataAttrs):
"""Apply specified time shift to DataArray"""
Expand Down
53 changes: 52 additions & 1 deletion aospy/test/test_calc_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import xarray as xr

from aospy import Var
from aospy.calc import Calc, _add_metadata_as_attrs
from .data.objects.examples import (
example_proj, example_model, example_run, var_not_time_defined,
condensation_rain, precip, sphum, globe, sahel
condensation_rain, convection_rain, precip, sphum, globe, sahel
)


Expand Down Expand Up @@ -207,5 +208,55 @@ def test_attrs(units, description, dtype_out_vert, expected_units,
assert expected_description == arr.attrs['description']


@pytest.fixture()
def recursive_test_params():
basic_params = {
'proj': example_proj,
'model': example_model,
'run': example_run,
'var': condensation_rain,
'date_range': (datetime.datetime(4, 1, 1),
datetime.datetime(6, 12, 31)),
'intvl_in': 'monthly',
'dtype_in_time': 'ts'
}

recursive_condensation_rain = Var(
name='recursive_condensation_rain',
variables=(precip, convection_rain), func=lambda x, y: x - y,
def_time=True)
recursive_params = {
'proj': example_proj,
'model': example_model,
'run': example_run,
'var': recursive_condensation_rain,
'date_range': (datetime.datetime(4, 1, 1),
datetime.datetime(6, 12, 31)),
'intvl_in': 'monthly',
'dtype_in_time': 'ts'
}
yield (basic_params, recursive_params)
for direc in [example_proj.direc_out, example_proj.tar_direc_out]:
shutil.rmtree(direc)


def test_recursive_calculation(recursive_test_params):
basic_params, recursive_params = recursive_test_params

calc = Calc(intvl_out='ann', dtype_out_time='av', **basic_params)
calc = calc.compute()
expected = xr.open_dataset(calc.path_out['av'],
autoclose=True)['condensation_rain']
_test_files_and_attrs(calc, 'av')

calc = Calc(intvl_out='ann', dtype_out_time='av', **recursive_params)
calc = calc.compute()
result = xr.open_dataset(
calc.path_out['av'], autoclose=True)['recursive_condensation_rain']
_test_files_and_attrs(calc, 'av')

xr.testing.assert_equal(expected, result)


if __name__ == '__main__':
unittest.main()
37 changes: 37 additions & 0 deletions aospy/test/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
import xarray as xr

from aospy import Var
from aospy.data_loader import (DataLoader, DictDataLoader, GFDLDataLoader,
NestedDictDataLoader, grid_attrs_to_aospy_names,
set_grid_attrs_as_coords, _sel_var,
Expand Down Expand Up @@ -609,6 +610,42 @@ def convert_all_to_missing_val(ds, **kwargs):
expected_num_non_missing = 0
self.assertEqual(num_non_missing, expected_num_non_missing)

def test_recursively_compute_variable_native(self):
result = self.data_loader.recursively_compute_variable(
condensation_rain, datetime(5, 1, 1), datetime(5, 12, 31),
intvl_in='monthly')
filepath = os.path.join(os.path.split(ROOT_PATH)[0], 'netcdf',
'00050101.precip_monthly.nc')
expected = _open_ds_catch_warnings(filepath)['condensation_rain']
np.testing.assert_array_equal(result.values, expected.values)

def test_recursively_compute_variable_one_level(self):
one_level = Var(
name='one_level', variables=(condensation_rain, condensation_rain),
func=lambda x, y: x + y)
result = self.data_loader.recursively_compute_variable(
one_level, datetime(5, 1, 1), datetime(5, 12, 31),
intvl_in='monthly')
filepath = os.path.join(os.path.split(ROOT_PATH)[0], 'netcdf',
'00050101.precip_monthly.nc')
expected = 2. * _open_ds_catch_warnings(filepath)['condensation_rain']
np.testing.assert_array_equal(result.values, expected.values)

def test_recursively_compute_variable_multi_level(self):
one_level = Var(
name='one_level', variables=(condensation_rain, condensation_rain),
func=lambda x, y: x + y)
multi_level = Var(
name='multi_level', variables=(one_level, condensation_rain),
func=lambda x, y: x + y)
result = self.data_loader.recursively_compute_variable(
multi_level, datetime(5, 1, 1), datetime(5, 12, 31),
intvl_in='monthly')
filepath = os.path.join(os.path.split(ROOT_PATH)[0], 'netcdf',
'00050101.precip_monthly.nc')
expected = 3. * _open_ds_catch_warnings(filepath)['condensation_rain']
np.testing.assert_array_equal(result.values, expected.values)


if __name__ == '__main__':
unittest.main()
11 changes: 8 additions & 3 deletions aospy/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,14 @@ def data_name_gfdl(name, domain, data_type, intvl_type, data_yr,

def dmget(files_list):
"""Call GFDL command 'dmget' to access archived files."""
if isinstance(files_list, str):
files_list = [files_list]

archive_files = []
for f in files_list:
if f.startswith('/archive'):
archive_files.append(f)
try:
if isinstance(files_list, str):
files_list = [files_list]
subprocess.call(['dmget'] + files_list)
subprocess.call(['dmget'] + archive_files)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was finding that the test suite would hang if dmget was called on an existing file that was not stored in /archive. I was able to reproduce this outside of aospy, so it has nothing to do with the changes I made to allow for the recursive calculation of variables; they must have made some updates to dmget. The logic here filters out files that aren't stored on archive before passing them to dmget.

except OSError:
logging.debug('dmget command not found in this machine')
3 changes: 3 additions & 0 deletions docs/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ Enhancements
<http://xarray.pydata.org/en/stable/generated/xarray.open_mfdataset.html>`_
for more information (closes :issue:`236` via :pull:`240`). By `Spencer
Clark <https://github.com/spencerkclark>`_.
- Allow for variables to be functions of other computed variables (closes
:issue:`3` via :pull:`263`). By `Spencer
Clark <https://github.com/spencerkclark>`_.

Bug Fixes
~~~~~~~~~
Expand Down