From 85321728ab2e136318ca736c70a8bc18fbe0012f Mon Sep 17 00:00:00 2001 From: Stephen G Yeager Date: Fri, 3 Mar 2023 16:54:10 -0700 Subject: [PATCH] Mods to eos.py to improve dask compatibility (#149) * eos() mods to improve dask functionality * correcting tab errors * Support scalar depth --------- Co-authored-by: dcherian --- pop_tools/eos.py | 13 ++----------- tests/test_eos.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pop_tools/eos.py b/pop_tools/eos.py index 943218ab..61e020ea 100644 --- a/pop_tools/eos.py +++ b/pop_tools/eos.py @@ -75,17 +75,11 @@ def eos(salt, temp, return_coefs=False, **kwargs): use_xarray = False if any(isinstance(arg, xr.DataArray) for arg in [salt, temp, d_or_p]): - if not all(isinstance(arg, xr.DataArray) for arg in [salt, temp, d_or_p]): - raise ValueError('cannot operate on mixed types') use_xarray = True # compute pressure if pressure is None: - if use_xarray: - pressure = xr.full_like(depth, fill_value=np.nan) - pressure[:] = 10.0 * compute_pressure(depth.data) # dbar - else: - pressure = 10.0 * compute_pressure(depth) # dbar + pressure = 10.0 * compute_pressure(depth) # dbar # enforce min/max values tmin = -2.0 @@ -96,10 +90,7 @@ def eos(salt, temp, return_coefs=False, **kwargs): if use_xarray: temp = temp.clip(tmin, tmax) salt = salt.clip(smin, smax) - - salt, temp, pressure = xr.broadcast(salt, temp, pressure) - if isinstance(salt.data, dask.array.Array): - pressure = pressure.chunk(salt.chunks) + salt, temp = xr.broadcast(salt, temp) if return_coefs: RHO, dRHOdS, dRHOdT = _compute_eos_coeffs(salt, temp, pressure) diff --git a/tests/test_eos.py b/tests/test_eos.py index 5eb1c5a1..62896baf 100644 --- a/tests/test_eos.py +++ b/tests/test_eos.py @@ -1,6 +1,7 @@ import os import numpy as np +import pytest import xarray as xr import pop_tools @@ -38,6 +39,16 @@ def test_eos_xarray_2(): assert isinstance(drhodT, xr.DataArray) +@pytest.mark.parametrize('depth', (xr.DataArray([2000]), 2000.0)) +def test_eos_xarray_3(depth): + fname = DATASETS.fetch('cesm_pop_monthly.T62_g17.nc') + ds = xr.open_dataset(fname, decode_times=False, decode_coords=False) + rho, drhodS, drhodT = pop_tools.eos(ds.SALT, ds.TEMP, depth=depth, return_coefs=True) + assert isinstance(rho, xr.DataArray) + assert isinstance(drhodS, xr.DataArray) + assert isinstance(drhodT, xr.DataArray) + + def test_eos_ds_dask(): fname = DATASETS.fetch('cesm_pop_monthly.T62_g17.nc') ds = xr.open_dataset(fname, decode_times=False, decode_coords=False, chunks={'z_t': 20})