Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dask tokenise to make sure the .nc and .pklz belong together, #66 #109

Merged
merged 6 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/examples/01_basic_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
")\n",
"\n",
"if add_northsea:\n",
" ds = nlmod.mdims.add_northsea(ds)\n"
" ds = nlmod.mdims.add_northsea(ds, cachedir=cachedir)\n"
]
},
{
Expand Down Expand Up @@ -339,7 +339,7 @@
},
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3.9.7 ('artesia')",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -353,7 +353,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.9.4"
},
"vscode": {
"interpreter": {
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/02_surface_water.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@
"metadata": {},
"outputs": [],
"source": [
"sfw_grid = nlmod.mdims.gdf2grid(sfw, gwf)\n"
"sfw_grid = nlmod.mdims.gdf_to_grid(sfw, gwf)\n"
]
},
{
Expand Down Expand Up @@ -705,7 +705,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.7 ('artesia')",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -719,7 +719,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.9.4"
},
"vscode": {
"interpreter": {
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/03_local_grid_refinement.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
")\n",
"\n",
"if add_northsea:\n",
" ds = nlmod.mdims.add_northsea(ds)\n"
" ds = nlmod.mdims.add_northsea(ds, cachedir=cachedir)\n"
]
},
{
Expand Down Expand Up @@ -379,7 +379,7 @@
},
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3.9.7 ('artesia')",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -393,7 +393,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.9.4"
},
"vscode": {
"interpreter": {
Expand Down
28 changes: 14 additions & 14 deletions docs/examples/07_resampling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,9 @@
"source": [
"fig, axes = plt.subplots(ncols=4, figsize=(20,5))\n",
"\n",
"da1 = nlmod.mdims.gdf2data_array_struc(point_gdf, gwf, field='values', agg_method='max')\n",
"da2 = nlmod.mdims.gdf2data_array_struc(point_gdf, gwf, field='values', agg_method='mean')\n",
"da3 = nlmod.mdims.gdf2data_array_struc(point_gdf, gwf, field='values', agg_method='nearest')\n",
"da1 = nlmod.mdims.gdf_to_data_array_struc(point_gdf, gwf, field='values', agg_method='max')\n",
"da2 = nlmod.mdims.gdf_to_data_array_struc(point_gdf, gwf, field='values', agg_method='mean')\n",
"da3 = nlmod.mdims.gdf_to_data_array_struc(point_gdf, gwf, field='values', agg_method='nearest')\n",
"\n",
"vmin = min(da1.min(), da2.min(), da3.min())\n",
"vmax = max(da1.max(), da2.max(), da3.max())\n",
Expand Down Expand Up @@ -646,8 +646,8 @@
"source": [
"fig, axes = plt.subplots(ncols=3, figsize=(15,5))\n",
"\n",
"da1 = nlmod.mdims.gdf2data_array_struc(point_gdf, gwf, field='values', interp_method='nearest')\n",
"da2 = nlmod.mdims.gdf2data_array_struc(point_gdf, gwf, field='values', interp_method='linear')\n",
"da1 = nlmod.mdims.gdf_to_data_array_struc(point_gdf, gwf, field='values', interp_method='nearest')\n",
"da2 = nlmod.mdims.gdf_to_data_array_struc(point_gdf, gwf, field='values', interp_method='linear')\n",
"\n",
"vmin = min(da1.min(), da2.min())\n",
"vmax = max(da1.max(), da2.max())\n",
Expand Down Expand Up @@ -682,9 +682,9 @@
"source": [
"fig, axes = plt.subplots(ncols=4, figsize=(20,5))\n",
"\n",
"da1 = nlmod.mdims.gdf2data_array_struc(line_gdf, gwf, field='values', agg_method='max_length')\n",
"da2 = nlmod.mdims.gdf2data_array_struc(line_gdf, gwf, field='values', agg_method='length_weighted')\n",
"da3 = nlmod.mdims.gdf2data_array_struc(line_gdf, gwf, field='values', agg_method='nearest')\n",
"da1 = nlmod.mdims.gdf_to_data_array_struc(line_gdf, gwf, field='values', agg_method='max_length')\n",
"da2 = nlmod.mdims.gdf_to_data_array_struc(line_gdf, gwf, field='values', agg_method='length_weighted')\n",
"da3 = nlmod.mdims.gdf_to_data_array_struc(line_gdf, gwf, field='values', agg_method='nearest')\n",
"\n",
"vmin = min(da1.min(), da2.min(), da3.min())\n",
"vmax = max(da1.max(), da2.max(), da3.max())\n",
Expand Down Expand Up @@ -722,9 +722,9 @@
"source": [
"fig, axes = plt.subplots(ncols=4, figsize=(20,5))\n",
"\n",
"da1 = nlmod.mdims.gdf2data_array_struc(pol_gdf, gwf, field='values', agg_method='max_area')\n",
"da2 = nlmod.mdims.gdf2data_array_struc(pol_gdf, gwf, field='values', agg_method='area_weighted')\n",
"da3 = nlmod.mdims.gdf2data_array_struc(pol_gdf, gwf, field='values', agg_method='nearest')\n",
"da1 = nlmod.mdims.gdf_to_data_array_struc(pol_gdf, gwf, field='values', agg_method='max_area')\n",
"da2 = nlmod.mdims.gdf_to_data_array_struc(pol_gdf, gwf, field='values', agg_method='area_weighted')\n",
"da3 = nlmod.mdims.gdf_to_data_array_struc(pol_gdf, gwf, field='values', agg_method='nearest')\n",
"\n",
"vmin = min(da1.min(), da2.min(), da3.min())\n",
"vmax = max(da1.max(), da2.max(), da3.max())\n",
Expand Down Expand Up @@ -760,9 +760,9 @@
"metadata": {},
"outputs": [],
"source": [
"gdf_point_grid = nlmod.mdims.gdf2grid(point_gdf, gwf)\n",
"gdf_line_grid = nlmod.mdims.gdf2grid(line_gdf, gwf)\n",
"gdf_pol_grid = nlmod.mdims.gdf2grid(pol_gdf, gwf)"
"gdf_point_grid = nlmod.mdims.gdf_to_grid(point_gdf, gwf)\n",
"gdf_line_grid = nlmod.mdims.gdf_to_grid(line_gdf, gwf)\n",
"gdf_pol_grid = nlmod.mdims.gdf_to_grid(pol_gdf, gwf)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/09_schoonhoven.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@
"\n",
"mg = nlmod.mgrid.modelgrid_from_ds(ds)\n",
"gi = flopy.utils.GridIntersect(mg, method=\"vertex\")\n",
"bgt_grid = nlmod.mdims.gdf2grid(bgt, ix=gi).set_index(\"cellid\")\n",
"bgt_grid = nlmod.mdims.gdf_to_grid(bgt, ix=gi).set_index(\"cellid\")\n",
"\n",
"bgt_grid[\"cond\"] = bgt_grid.area / bed_resistance\n",
"mask = bgt_grid[\"bronhouder\"] == \"L0002\"\n",
Expand Down Expand Up @@ -719,7 +719,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.7 ('artesia')",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -733,7 +733,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.9.4"
},
"vscode": {
"interpreter": {
Expand Down
105 changes: 12 additions & 93 deletions nlmod/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pickle

import flopy
import dask
import numpy as np
import pandas as pd
import xarray as xr
Expand Down Expand Up @@ -131,14 +132,18 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs):
f"module of function {func.__name__} recently modified, not using cache"
)

cached_ds = xr.open_dataset(fname_cache, mask_and_scale=False)

# add netcdf hash to function arguments dic, see #66
func_args_dic["_nc_hash"] = dask.base.tokenize(cached_ds)

# check if cache was created with same function arguments as
# function call
argument_check = _same_function_arguments(
func_args_dic, func_args_dic_cache
)

if modification_check and argument_check:
cached_ds = xr.open_dataset(fname_cache, mask_and_scale=False)
if dataset is None:
logger.info(f"using cached data -> {cachename}")
return cached_ds
Expand All @@ -161,103 +166,17 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs):

# write netcdf cache
result.to_netcdf(fname_cache)
# pickle function arguments
with open(fname_pickle_cache, "wb") as fpklz:
pickle.dump(func_args_dic, fpklz)
else:
raise TypeError(f"expected xarray Dataset, got {type(result)} instead")

return result

return decorator
# add netcdf hash to function arguments dic, see #66
temp = xr.open_dataset(fname_cache, mask_and_scale=False)
func_args_dic["_nc_hash"] = dask.base.tokenize(temp)
temp.close()


def cache_pklz(func):
"""decorator to read/write the result of a function from/to a pklz file to
speed up function calls with the same arguments. Should only be applied to
functions that:

- return a dictionary
- have functions arguments of types that can be checked using the
_is_valid_cache functions

1. The directory and filename of the cache should be defined by the person
calling a function with this decorator. If not defined no cache is
created nor used.
2. Create a new cached file if it is impossible to check if the function
arguments used to create the cached file are the same as the current
function arguments. This can happen if one of the function arguments has a
type that cannot be checked using the _is_valid_cache function.
3. Function arguments are pickled together with the cache to check later
if the cache is valid.
4. This function uses `functools.wraps` and some home made
magic in _update_docstring_and_signature to add arguments of the decorator
to the decorated function. This assumes that the decorated function has a
docstring with a "Returns" heading. If this is not the case an error is
raised when trying to decorate the function.
"""

_update_docstring_and_signature(func)

@functools.wraps(func)
def decorator(*args, cachedir=None, cachename=None, **kwargs):

if cachedir is None or cachename is None:
return func(*args, **kwargs)

if not cachename.endswith(".pklz"):
cachename += ".pklz"

fname_cache = os.path.join(cachedir, cachename) # pklz file
fname_args_cache = fname_cache.replace(
".pklz", "_cache.pklz"
) # pickle with function arguments

# create dictionary with function arguments
func_args_dic = {f"arg{i}": args[i] for i in range(len(args))}
func_args_dic.update(kwargs)

# only use cache if the cache file and the pickled function arguments exist
if os.path.exists(fname_cache) and os.path.exists(fname_args_cache):
with open(fname_args_cache, "rb") as f:
func_args_dic_cache = pickle.load(f)

# check if the module where the function is defined was changed
# after the cache was created
time_mod_func = _get_modification_time(func)
time_mod_cache = os.path.getmtime(fname_cache)
modification_check = time_mod_cache > time_mod_func

if not modification_check:
logger.info(
f"module of function {func.__name__} recently modified, not using cache"
)

# check if cache was created with same function arguments as
# function call
argument_check = _same_function_arguments(
func_args_dic, func_args_dic_cache
)

if modification_check and argument_check:
with open(fname_cache, "rb") as f:
result = pickle.load(f)
logger.info(f"using cached data -> {cachename}")
return result

# create cache
result = func(*args, **kwargs)
logger.info(f"caching data -> {cachename}")

if isinstance(result, dict):
# write result
with open(fname_cache, "wb") as f:
pickle.dump(result, f)
# pickle function arguments
with open(fname_args_cache, "wb") as fpklz:
with open(fname_pickle_cache, "wb") as fpklz:
pickle.dump(func_args_dic, fpklz)
else:
raise TypeError(f"expected dictionary, got {type(result)} instead")
raise TypeError(f"expected xarray Dataset, got {type(result)} instead")

return result

Expand Down
2 changes: 1 addition & 1 deletion nlmod/gwf/horizontal_flow_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def line2hfb(gdf, gwf, prevent_rings=True, plot=False):
# for the idea, sea:
# https://gis.stackexchange.com/questions/188755/how-to-snap-a-road-network-to-a-hexagonal-grid-in-qgis

gdfg = mdims.gdf2grid(gdf, gwf)
gdfg = mdims.gdf_to_grid(gdf, gwf)

cell2d = pd.DataFrame(gwf.disv.cell2d.array).set_index("icell2d")
vertices = pd.DataFrame(gwf.disv.vertices.array).set_index("iv")
Expand Down
12 changes: 6 additions & 6 deletions nlmod/gwf/recharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def model_datasets_to_rch(gwf, ds, pname="rch", **kwargs):
recharge = "recharge"
else:
rch_name_arr, rch_unique_dic = _get_unique_series(ds, "recharge", pname)
ds['rch_name'] = ds['top'].dims, rch_name_arr
ds["rch_name"] = ds["top"].dims, rch_name_arr
mask = ds["rch_name"] != ""
recharge = "rch_name"

Expand Down Expand Up @@ -132,8 +132,8 @@ def model_datasets_to_evt(
rate = "evaporation"
else:
evt_name_arr, evt_unique_dic = _get_unique_series(ds, "evaporation", pname)
ds['evt_name'] = ds['top'].dims, evt_name_arr
ds["evt_name"] = ds["top"].dims, evt_name_arr

mask = ds["evt_name"] != ""
rate = "evt_name"

Expand Down Expand Up @@ -194,8 +194,8 @@ def _get_unique_series(ds, var, pname):
The values of each of the time series.

"""
rch_name_arr = np.empty_like(ds['top'].values, dtype='U13')
rch_name_arr = np.empty_like(ds["top"].values, dtype="U13")

# transient
if ds.gridtype == "structured":
if len(ds[var].dims) != 3:
Expand Down Expand Up @@ -224,7 +224,7 @@ def _get_unique_series(ds, var, pname):
mask = mask.reshape(rch_name_arr.shape)
rch_name_arr[mask] = f"{pname}_{i}"
rch_unique_dic[f"{pname}_{i}"] = unique_rch

return rch_name_arr, rch_unique_dic


Expand Down
6 changes: 3 additions & 3 deletions nlmod/gwf/surface_water.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from shapely.geometry import Polygon
import flopy

# from ..mdims.mgrid import gdf2grid
# from ..mdims.mgrid import gdf_to_grid
from ..read import bgt, waterboard
from ..mdims import resample, mgrid

Expand Down Expand Up @@ -631,7 +631,7 @@ def get_gdf(ds=None, extent=None, fname_ahn=None):
extent = [bs[0], bs[2], bs[1], bs[3]]
gdf = add_stages_from_waterboards(gdf, extent=extent)
if ds is not None:
return mgrid.gdf2grid(gdf, ds).set_index("cellid")
return mgrid.gdf_to_grid(gdf, ds).set_index("cellid")
return gdf


Expand Down Expand Up @@ -695,7 +695,7 @@ def gdf_to_seasonal_pkg(
"""
if gdf.index.name != "cellid":
# if "cellid" not in gdf:
# gdf = gdf2grid(gdf, gwf)
# gdf = gdf_to_grid(gdf, gwf)
gdf = gdf.set_index("cellid")
else:
# make sure changes to the DataFrame are temporarily
Expand Down
Loading