Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

199 nlmodgwfic slow for uniform starting head #200

Merged
merged 5 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions nlmod/gwf/gwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ def npf(
gwf,
pname=pname,
icelltype=icelltype,
k=k.data,
k33=k33.data,
k=k,
k33=k33,
save_flows=save_flows,
**kwargs,
)
Expand Down Expand Up @@ -381,7 +381,7 @@ def ghb(
bhead = f"{da_name}_peil"
cond = f"{da_name}_cond"

mask_arr = _get_value_from_ds_datavar(ds, "cond", cond)
mask_arr = _get_value_from_ds_datavar(ds, "cond", cond, return_da=True)
mask = mask_arr != 0

ghb_rec = grid.da_to_reclist(
Expand Down Expand Up @@ -466,7 +466,7 @@ def drn(
elev = f"{da_name}_peil"
cond = f"{da_name}_cond"

mask_arr = _get_value_from_ds_datavar(ds, "cond", cond)
mask_arr = _get_value_from_ds_datavar(ds, "cond", cond, return_da=True)
mask = mask_arr != 0

first_active_layer = layer is None
Expand Down Expand Up @@ -636,7 +636,7 @@ def chd(
)
mask = kwargs.pop("chd")

maskarr = _get_value_from_ds_datavar(ds, "mask", mask)
maskarr = _get_value_from_ds_datavar(ds, "mask", mask, return_da=True)
mask = maskarr != 0

# get the stress_period_data
Expand Down Expand Up @@ -693,7 +693,7 @@ def surface_drain_from_ds(ds, gwf, resistance, elev="ahn", pname="drn", **kwargs

ds.attrs["surface_drn_resistance"] = resistance

maskarr = _get_value_from_ds_datavar(ds, "elev", elev)
maskarr = _get_value_from_ds_datavar(ds, "elev", elev, return_da=True)
mask = maskarr.notnull()

drn_rec = grid.da_to_reclist(
Expand Down
10 changes: 9 additions & 1 deletion nlmod/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def _get_value_from_ds_attr(ds, varname, attr=None, value=None, warn=True):
return value


def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True):
def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True, return_da=False):
"""Internal function to get value from dataset data variables.

Parameters
Expand All @@ -551,6 +551,9 @@ def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True):
the same as varname. If not passed as string, it is treated as data
warn : bool, optional
log warning if value not found
return_da : bool, optional
if True a dataarray can be returned, if False a dataarray is always
converted to a numpy array before being returned. The default is False.

Returns
-------
Expand Down Expand Up @@ -597,4 +600,9 @@ def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True):
f"to function or check whether 'ds.{datavar}' was set correctly."
)
logger.warning(msg)

if not return_da:
if isinstance(value, xr.DataArray):
value = value.values

return value
10 changes: 6 additions & 4 deletions tests/test_003_mfpackages.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,22 @@ def get_value_from_ds_datavar():
ds["test_var"] = ("layer", "y", "x"), np.arange(np.product(shape)).reshape(shape)

# get value from ds
v0 = nlmod.util._get_value_from_ds_datavar(ds, "test_var", "test_var")
v0 = nlmod.util._get_value_from_ds_datavar(
ds, "test_var", "test_var", return_da=True
)
xr.testing.assert_equal(ds["test_var"], v0)

# get value from ds, variable and stored name are different
v1 = nlmod.util._get_value_from_ds_datavar(ds, "test", "test_var")
xr.testing.assert_equal(ds["test_var"], v1)
xr.testing.assert_equal(ds["test_var"].values, v1)

# do not get value from ds, value is Data Array, should log info msg
v2 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0)
v2 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0, return_da=True)
xr.testing.assert_equal(ds["test_var"], v2)

# do not get value from ds, value is Data Array, no msg
v0.name = "test2"
v3 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0)
v3 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0, return_da=True)
assert (v0 == v3).all()

# return None, value is str but not in dataset, should log warning
Expand Down