From 003c6e7aa4ac60e4d0ab8db81b17c821e4eba750 Mon Sep 17 00:00:00 2001 From: Bas des Tombe Date: Sun, 28 Apr 2024 09:28:03 +0200 Subject: [PATCH 01/24] Creating new cache used full ds instead of reduced ds Required to ensure that the cached function does not use data_vars that not explicitly required --- nlmod/cache.py | 134 ++++++++++++++++++++++++++++--------------------- 1 file changed, 78 insertions(+), 56 deletions(-) diff --git a/nlmod/cache.py b/nlmod/cache.py index d20a3aeb..dc1d11b5 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -54,7 +54,7 @@ def clear_cache(cachedir): def cache_netcdf(coords_2d=False, coords_3d=False, coords_time=False, datavars=None, coords=None, attrs=None): - """decorator to read/write the result of a function from/to a file to speed + """Decorator to read/write the result of a function from/to a file to speed up function calls with the same arguments. Should only be applied to functions that: @@ -118,27 +118,54 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs): fname_cache = os.path.join(cachedir, cachename) # netcdf file fname_pickle_cache = fname_cache.replace(".nc", ".pklz") - # create dictionary with function arguments - func_args_dic = {f"arg{i}": args[i] for i in range(len(args))} - func_args_dic.update(kwargs) + # adjust args and kwargs with minimal dataset + args_adj = [] + kwargs_adj = {} - # remove xarray dataset from function arguments - dataset = None - for key in list(func_args_dic.keys()): - if isinstance(func_args_dic[key], xr.Dataset): - if dataset is not None: - raise TypeError( - "Function was called with multiple xarray dataset arguments. Currently unsupported." - ) - dataset_received = func_args_dic.pop(key) - dataset = ds_contains( - dataset_received, + datasets = [] + func_args_dic = {} + + for i, arg in enumerate(args): + if isinstance(arg, xr.Dataset): + arg_adj = ds_contains( + arg, coords_2d=coords_2d, coords_3d=coords_3d, coords_time=coords_time, datavars=datavars, coords=coords, attrs=attrs) + args_adj.append(arg_adj) + datasets.append(arg_adj) + else: + args_adj.append(arg) + func_args_dic[f"arg{i}"] = arg + + for key, arg in kwargs.items(): + if isinstance(arg, xr.Dataset): + arg_adj = ds_contains( + arg, + coords_2d=coords_2d, + coords_3d=coords_3d, + coords_time=coords_time, + datavars=datavars, + coords=coords, + attrs=attrs) + kwargs_adj[key] = arg_adj + datasets.append(arg_adj) + else: + kwargs_adj[key] = arg + func_args_dic[key] = arg + + if len(datasets) == 0: + dataset = None + elif len(datasets) == 1: + dataset = datasets[0] + else: + msg = "Function was called with multiple xarray dataset arguments. Currently unsupported." + raise NotImplementedError( + msg + ) # only use cache if the cache file and the pickled function arguments exist if os.path.exists(fname_cache) and os.path.exists(fname_pickle_cache): @@ -190,7 +217,7 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs): return cached_ds # create cache - result = func(*args, **kwargs) + result = func(*args_adj, **kwargs_adj) logger.info(f"caching data -> {cachename}") if isinstance(result, xr.DataArray): @@ -205,7 +232,7 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs): # write netcdf cache # check if dataset is chunked for writing with dask.delayed - first_data_var = list(result.data_vars.keys())[0] + first_data_var = next(iter(result.data_vars.keys())) if result[first_data_var].chunks: delayed = result.to_netcdf(fname_cache, compute=False) with ProgressBar(): @@ -230,16 +257,16 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs): with open(fname_pickle_cache, "wb") as fpklz: pickle.dump(func_args_dic, fpklz) else: - raise TypeError(f"expected xarray Dataset, got {type(result)} instead") - result = _check_for_data_array(result) - return result + msg = f"expected xarray Dataset, got {type(result)} instead" + raise TypeError(msg) + return _check_for_data_array(result) return wrapper return decorator def cache_pickle(func): - """decorator to read/write the result of a function from/to a file to speed + """Decorator to read/write the result of a function from/to a file to speed up function calls with the same arguments. Should only be applied to functions that: @@ -262,7 +289,6 @@ def cache_pickle(func): docstring with a "Returns" heading. If this is not the case an error is raised when trying to decorate the function. """ - # add cachedir and cachename to docstring _update_docstring_and_signature(func) @@ -346,14 +372,15 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs): with open(fname_pickle_cache, "wb") as fpklz: pickle.dump(func_args_dic, fpklz) else: - raise TypeError(f"expected DataFrame, got {type(result)} instead") + msg = f"expected DataFrame, got {type(result)} instead" + raise TypeError(msg) return result return decorator def _same_function_arguments(func_args_dic, func_args_dic_cache): - """checks if two dictionaries with function arguments are identical by + """Checks if two dictionaries with function arguments are identical by checking: 1. if they have the same keys 2. if the items have the same type @@ -361,7 +388,7 @@ def _same_function_arguments(func_args_dic, func_args_dic_cache): float, bool, str, bytes, list, tuple, dict, np.ndarray, xr.DataArray, - flopy.mf6.ModflowGwf) + flopy.mf6.ModflowGwf). Parameters ---------- @@ -381,7 +408,7 @@ def _same_function_arguments(func_args_dic, func_args_dic_cache): """ for key, item in func_args_dic.items(): # check if cache and function call have same argument names - if key not in func_args_dic_cache.keys(): + if key not in func_args_dic_cache: logger.info( "cache was created using different function arguments, do not use cached data" ) @@ -510,16 +537,9 @@ def _update_docstring_and_signature(func): cur_param = cur_param[:-1] else: add_kwargs = None - new_param = cur_param + ( - inspect.Parameter( - "cachedir", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None - ), - inspect.Parameter( - "cachename", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None - ), - ) + new_param = (*cur_param, inspect.Parameter("cachedir", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None), inspect.Parameter("cachename", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None)) if add_kwargs is not None: - new_param = new_param + (add_kwargs,) + new_param = (*new_param, add_kwargs) sig = sig.replace(parameters=new_param) func.__signature__ = sig @@ -541,7 +561,7 @@ def _update_docstring_and_signature(func): " filename of netcdf cache. If None no cache is used." " Default is None.\n\n Returns" ) - new_doc = "".join((mod_before, after)) + new_doc = f"{mod_before}{after}" func.__doc__ = new_doc return @@ -569,10 +589,7 @@ def _check_for_data_array(ds): """ if "__xarray_dataarray_variable__" in ds: - if "spatial_ref" in ds: - spatial_ref = ds.spatial_ref - else: - spatial_ref = None + spatial_ref = ds.spatial_ref if "spatial_ref" in ds else None # the method returns a DataArray, so we return only this DataArray ds = ds["__xarray_dataarray_variable__"] if spatial_ref is not None: @@ -611,17 +628,17 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar """ # Return the full dataset if not configured if ds is None: - raise ValueError("No dataset provided") - elif not coords_2d and not coords_3d and not datavars and not coords and not attrs: + msg = "No dataset provided" + raise ValueError(msg) + if not coords_2d and not coords_3d and not datavars and not coords and not attrs: return ds - else: - # Initialize lists - if datavars is None: - datavars = [] - if coords is None: - coords = [] - if attrs is None: - attrs = [] + # Initialize lists + if datavars is None: + datavars = [] + if coords is None: + coords = [] + if attrs is None: + attrs = [] # Add coords, datavars and attrs via shorthands if coords_2d or coords_3d: @@ -629,7 +646,7 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar coords.append("y") datavars.append("area") attrs.append("extent") - + if "gridtype" in ds.attrs: attrs.append("gridtype") @@ -651,23 +668,28 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar # User-friendly error messages if "northsea" in datavars and "northsea" not in ds.data_vars: - raise ValueError("Northsea not in dataset. Run nlmod.read.rws.add_northsea() first.") + msg = "Northsea not in dataset. Run nlmod.read.rws.add_northsea() first." + raise ValueError(msg) if "time" in coords and "time" not in ds.coords: - raise ValueError("time not in dataset. Run nlmod.time.set_ds_time() first.") + msg = "time not in dataset. Run nlmod.time.set_ds_time() first." + raise ValueError(msg) # User-unfriendly error messages for datavar in datavars: if datavar not in ds.data_vars: - raise ValueError(f"{datavar} not in dataset.data_vars") + msg = f"{datavar} not in dataset.data_vars" + raise ValueError(msg) for coord in coords: if coord not in ds.coords: - raise ValueError(f"{coord} not in dataset.coords") + msg = f"{coord} not in dataset.coords" + raise ValueError(msg) for attr in attrs: if attr not in ds.attrs: - raise ValueError(f"{attr} not in dataset.attrs") + msg = f"{attr} not in dataset.attrs" + raise ValueError(msg) # Return only the required data return xr.Dataset( From 5268ca834bc4ea70ab76b33b7ce7ba1b1a9258ba Mon Sep 17 00:00:00 2001 From: Bas des Tombe Date: Sun, 28 Apr 2024 09:33:07 +0200 Subject: [PATCH 02/24] Format cache to please codacy --- nlmod/cache.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nlmod/cache.py b/nlmod/cache.py index dc1d11b5..e4a958c7 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -537,7 +537,11 @@ def _update_docstring_and_signature(func): cur_param = cur_param[:-1] else: add_kwargs = None - new_param = (*cur_param, inspect.Parameter("cachedir", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None), inspect.Parameter("cachename", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None)) + new_param = ( + *cur_param, + inspect.Parameter("cachedir", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None), + inspect.Parameter("cachename", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None) + ) if add_kwargs is not None: new_param = (*new_param, add_kwargs) sig = sig.replace(parameters=new_param) From d0cd821a95e0102a688aea7a048bb69c0de227d4 Mon Sep 17 00:00:00 2001 From: Bas des Tombe Date: Wed, 1 May 2024 12:43:25 +0200 Subject: [PATCH 03/24] Better support for vertex grids --- nlmod/cache.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/nlmod/cache.py b/nlmod/cache.py index e4a958c7..bd06d299 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -636,6 +636,9 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar raise ValueError(msg) if not coords_2d and not coords_3d and not datavars and not coords and not attrs: return ds + + isvertex = ds.attrs["gridtype"] == "vertex" + # Initialize lists if datavars is None: datavars = [] @@ -650,9 +653,15 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar coords.append("y") datavars.append("area") attrs.append("extent") - - if "gridtype" in ds.attrs: - attrs.append("gridtype") + attrs.append("gridtype") + + if isvertex: + coords.append("icell2d") + coords.append("iv") + coords.append("icv") + datavars.append("xv") + datavars.append("yv") + datavars.append("icvert") if "angrot" in ds.attrs: attrs.append("angrot") From bb66042d4f8e463f217cab38622a064763e02046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:22:11 +0200 Subject: [PATCH 04/24] remove option to store delr/delc as attrs --- nlmod/dims/base.py | 17 ++++++----------- pyproject.toml | 4 ++++ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/nlmod/dims/base.py b/nlmod/dims/base.py index c652d246..3ec3f946 100644 --- a/nlmod/dims/base.py +++ b/nlmod/dims/base.py @@ -15,7 +15,7 @@ def set_ds_attrs(ds, model_name, model_ws, mfversion="mf6", exe_name=None): - """set the attribute of a model dataset. + """Set the attribute of a model dataset. Parameters ---------- @@ -161,8 +161,9 @@ def to_model_ds( if delc is None: delc = delr if isinstance(delr, (numbers.Number)) and isinstance(delc, (numbers.Number)): - ds["area"] = ("y", "x"), ds.delr * ds.delc * np.ones( - (ds.sizes["y"], ds.sizes["x"]) + ds["area"] = ( + ("y", "x"), + delr * delc * np.ones((ds.sizes["y"], ds.sizes["x"])), ) elif isinstance(delr, np.ndarray) and isinstance(delc, np.ndarray): ds["area"] = ("y", "x"), np.outer(delc, delr) @@ -367,15 +368,9 @@ def _get_structured_grid_ds( ) # set delr and delc delr = np.diff(xedges) - if len(np.unique(delr)) == 1: - ds.attrs["delr"] = np.unique(delr)[0] - else: - ds["delr"] = ("x"), delr + ds["delr"] = ("x"), delr delc = -np.diff(yedges) - if len(np.unique(delc)) == 1: - ds.attrs["delc"] = np.unique(delc)[0] - else: - ds["delc"] = ("y"), delc + ds["delc"] = ("y"), delc if crs is not None: ds.rio.set_crs(crs) diff --git a/pyproject.toml b/pyproject.toml index 6729fa58..dd9eaf64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,10 @@ line-length = 88 [tool.isort] profile = "black" +[tool.ruff] +line-length = 88 +extend-include = ["*.ipynb"] + [tool.pytest.ini_options] addopts = "--strict-markers --durations=0 --cov-report xml:coverage.xml --cov nlmod -v" markers = ["notebooks: run notebooks", "slow: slow tests", "skip: skip tests"] From 21ac2267f88524a310fcb6dfd000d3a7babb7392 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:22:34 +0200 Subject: [PATCH 05/24] remove option to store delr/delc as attrs --- nlmod/dims/grid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nlmod/dims/grid.py b/nlmod/dims/grid.py index 133a7538..ce528a42 100644 --- a/nlmod/dims/grid.py +++ b/nlmod/dims/grid.py @@ -214,11 +214,11 @@ def modelgrid_from_ds(ds, rotated=True, nlay=None, top=None, botm=None, **kwargs if "delc" in ds: delc = ds["delc"].values else: - delc = np.array([ds.delc] * ds.sizes["y"]) + raise KeyError("delc not in dataset") if "delr" in ds: delr = ds["delr"].values else: - delr = np.array([ds.delr] * ds.sizes["x"]) + raise KeyError("delr not in dataset") modelgrid = StructuredGrid( delc=delc, delr=delr, From 1d51ae9703a3354a6cd0531f12dc49a1112ef397 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:23:39 +0200 Subject: [PATCH 06/24] update polygons_from_model_ds to to use new delr/delc datavars --- nlmod/dims/grid.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/nlmod/dims/grid.py b/nlmod/dims/grid.py index ce528a42..68a4c9b9 100644 --- a/nlmod/dims/grid.py +++ b/nlmod/dims/grid.py @@ -1921,7 +1921,7 @@ def mask_model_edge(ds, idomain=None): def polygons_from_model_ds(model_ds): - """create polygons of each cell in a model dataset. + """Create polygons of each cell in a model dataset. Parameters ---------- @@ -1940,19 +1940,33 @@ def polygons_from_model_ds(model_ds): """ if model_ds.gridtype == "structured": - # check if coördinates are consistent with delr/delc values + # TODO: update with new delr/delc calculation + # check if coordinates are consistent with delr/delc values delr_x = np.unique(model_ds.x.values[1:] - model_ds.x.values[:-1]) delc_y = np.unique(model_ds.y.values[:-1] - model_ds.y.values[1:]) - if not ((delr_x == model_ds.delr) and (delc_y == model_ds.delc)): + + delr = np.unique(model_ds.delr) + if len(delr) > 1: + raise ValueError("delr is variable") + else: + delr = delr.item() + + delc = np.unique(model_ds.delc) + if len(delc) > 1: + raise ValueError("delc is variable") + else: + delc = delc.item() + + if not ((delr_x == delr) and (delc_y == delc)): raise ValueError( "delr and delc attributes of model_ds inconsistent " "with x and y coordinates" ) - xmins = model_ds.x - (model_ds.delr * 0.5) - xmaxs = model_ds.x + (model_ds.delr * 0.5) - ymins = model_ds.y - (model_ds.delc * 0.5) - ymaxs = model_ds.y + (model_ds.delc * 0.5) + xmins = model_ds.x - (delr * 0.5) + xmaxs = model_ds.x + (delr * 0.5) + ymins = model_ds.y - (delc * 0.5) + ymaxs = model_ds.y + (delc * 0.5) polygons = [ Polygon( [ From 0896d04e7a5feb5f938b110cfa2279eecd6f4b0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:25:38 +0200 Subject: [PATCH 07/24] add delr and delc to datavars in ds_to_structured_grid --- nlmod/dims/resample.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/nlmod/dims/resample.py b/nlmod/dims/resample.py index c5b536af..d956f9ae 100644 --- a/nlmod/dims/resample.py +++ b/nlmod/dims/resample.py @@ -153,20 +153,27 @@ def ds_to_structured_grid( # add new attributes attrs["gridtype"] = "structured" + if isinstance(delr, numbers.Number) and isinstance(delc, numbers.Number): - attrs["delr"] = delr - attrs["delc"] = delc + delr = np.full_like(x, delr) + delc = np.full_like(y, delc) if method in ["nearest", "linear"] and angrot == 0.0: ds_out = ds_in.interp( x=x, y=y, method=method, kwargs={"fill_value": "extrapolate"} ) ds_out.attrs = attrs + # ds_out = ds_out.expand_dims({"ncol": len(x), "nrow": len(y)}) + ds_out["delr"] = ("ncol",), delr + ds_out["delc"] = ("nrow",), delc return ds_out ds_out = xr.Dataset(coords={"y": y, "x": x, "layer": ds_in.layer.data}, attrs=attrs) for var in ds_in.data_vars: ds_out[var] = structured_da_to_ds(ds_in[var], ds_out, method=method) + # ds_out = ds_out.expand_dims({"ncol": len(x), "nrow": len(y)}) + ds_out["delr"] = ("x",), delr + ds_out["delc"] = ("y",), delc return ds_out From 724a3bf62708aba5f9c130b8cec277edfbe44efc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:26:28 +0200 Subject: [PATCH 08/24] remove delr/delc attr check in structured_da_to_ds --- nlmod/dims/resample.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/nlmod/dims/resample.py b/nlmod/dims/resample.py index d956f9ae..73a13427 100644 --- a/nlmod/dims/resample.py +++ b/nlmod/dims/resample.py @@ -526,9 +526,6 @@ def structured_da_to_ds(da, ds, method="average", nodata=np.NaN): # xmin, xmax, ymin, ymax dx = (ds.attrs["extent"][1] - ds.attrs["extent"][0]) / len(ds.x) dy = (ds.attrs["extent"][3] - ds.attrs["extent"][2]) / len(ds.y) - elif "delr" in ds.attrs and "delc" in ds.attrs: - dx = ds.attrs["delr"] - dy = ds.attrs["delc"] else: raise ValueError( "No extent or delr and delc in ds. Cannot determine affine." From c86820d2a0a5ecb06118a95dc04a0261649ffdb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:28:28 +0200 Subject: [PATCH 09/24] get_recharge uses coords3d to determine first active layer --- nlmod/read/knmi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nlmod/read/knmi.py b/nlmod/read/knmi.py index 9fcc6527..b142848c 100644 --- a/nlmod/read/knmi.py +++ b/nlmod/read/knmi.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -@cache.cache_netcdf(coords_2d=True, coords_time=True) +@cache.cache_netcdf(coords_3d=True, coords_time=True) def get_recharge(ds, method="linear", most_common_station=False): """add multiple recharge packages to the groundwater flow model with knmi data by following these steps: From 9f5a5dd0d5f0ffb33ad980f70f52ff6c8502d2d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:29:09 +0200 Subject: [PATCH 10/24] add delr/delc to datavars for surface water functions --- nlmod/read/rws.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nlmod/read/rws.py b/nlmod/read/rws.py index 01d5b4e2..4b00e3c3 100644 --- a/nlmod/read/rws.py +++ b/nlmod/read/rws.py @@ -37,7 +37,8 @@ def get_gdf_surface_water(ds): return gdf_swater -@cache.cache_netcdf(coords_2d=True) +# TODO: temporary fix until delr/delc are removed completely +@cache.cache_netcdf(coords_3d=True, datavars=["delr", "delc"]) def get_surface_water(ds, da_basename): """create 3 data-arrays from the shapefile with surface water: @@ -91,7 +92,8 @@ def get_surface_water(ds, da_basename): return ds_out -@cache.cache_netcdf(coords_2d=True) +# TODO: temporary fix until delr/delc are removed completely +@cache.cache_netcdf(coords_2d=True, datavars=["delc", "delr"]) def get_northsea(ds, da_name="northsea"): """Get Dataset which is 1 at the northsea and 0 everywhere else. Sea is defined by rws surface water shapefile. From a04f7ab1cbcb28877bf8d2501cc3c2ea0003be8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:29:46 +0200 Subject: [PATCH 11/24] formatting --- nlmod/cache.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/nlmod/cache.py b/nlmod/cache.py index bd06d299..eeae9d37 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -53,7 +53,14 @@ def clear_cache(cachedir): logger.info(f"removed {fname_nc}") -def cache_netcdf(coords_2d=False, coords_3d=False, coords_time=False, datavars=None, coords=None, attrs=None): +def cache_netcdf( + coords_2d=False, + coords_3d=False, + coords_time=False, + datavars=None, + coords=None, + attrs=None, +): """Decorator to read/write the result of a function from/to a file to speed up function calls with the same arguments. Should only be applied to functions that: From d11ea6b7932a92bfd1f318005bbc16997356e492 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:30:00 +0200 Subject: [PATCH 12/24] formatting --- nlmod/cache.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nlmod/cache.py b/nlmod/cache.py index eeae9d37..852ed57b 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -141,7 +141,8 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs): coords_time=coords_time, datavars=datavars, coords=coords, - attrs=attrs) + attrs=attrs, + ) args_adj.append(arg_adj) datasets.append(arg_adj) else: @@ -157,7 +158,8 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs): coords_time=coords_time, datavars=datavars, coords=coords, - attrs=attrs) + attrs=attrs, + ) kwargs_adj[key] = arg_adj datasets.append(arg_adj) else: @@ -170,9 +172,7 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs): dataset = datasets[0] else: msg = "Function was called with multiple xarray dataset arguments. Currently unsupported." - raise NotImplementedError( - msg - ) + raise NotImplementedError(msg) # only use cache if the cache file and the pickled function arguments exist if os.path.exists(fname_cache) and os.path.exists(fname_pickle_cache): From 530eed486fa369485b2e1f055db9b2eb4f9ea38b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:31:10 +0200 Subject: [PATCH 13/24] remove coords that number from 0 to N (these are not listed in coords) --- nlmod/cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nlmod/cache.py b/nlmod/cache.py index 852ed57b..0fdfa9b5 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -663,9 +663,9 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar attrs.append("gridtype") if isvertex: - coords.append("icell2d") - coords.append("iv") - coords.append("icv") + # coords.append("icell2d") # not in coords, runs from 0 to N + # coords.append("iv") # not in coords, runs from 0 to N + # coords.append("icv") # not in coords, runs from 0 to N datavars.append("xv") datavars.append("yv") datavars.append("icvert") From 6ef29f6b8ec424960c931a64bd13ed6cba43c738 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:31:29 +0200 Subject: [PATCH 14/24] remove delr and delc from datavars is vertexgrid --- nlmod/cache.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nlmod/cache.py b/nlmod/cache.py index 0fdfa9b5..bada4fb8 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -669,6 +669,11 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar datavars.append("xv") datavars.append("yv") datavars.append("icvert") + # TODO: temporary fix until delr/delc are completely removed + if "delr" in datavars: + datavars.remove("delr") + if "delc" in datavars: + datavars.remove("delc") if "angrot" in ds.attrs: attrs.append("angrot") From 901e52d159b3ec88c641bda120d5be6f27802615 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:31:53 +0200 Subject: [PATCH 15/24] do not check time attrs start and time_units --- nlmod/cache.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nlmod/cache.py b/nlmod/cache.py index bada4fb8..a02cafab 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -688,8 +688,9 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar datavars.append("steady") datavars.append("nstp") datavars.append("tsmult") - attrs.append("start") - attrs.append("time_units") + # time_attrs = [] + # time_attrs.append("start") + # time_attrs.append("time_units") # User-friendly error messages if "northsea" in datavars and "northsea" not in ds.data_vars: From f89ea5765dfd44ada5fd956e1c84ca44f99ad75f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:32:18 +0200 Subject: [PATCH 16/24] do not check time coord attributes --- nlmod/cache.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nlmod/cache.py b/nlmod/cache.py index a02cafab..138a6aba 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -717,8 +717,14 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar msg = f"{attr} not in dataset.attrs" raise ValueError(msg) + # if coords_time: + # for t_attr in time_attrs: + # if t_attr not in ds["time"].attrs: + # raise ValueError(f'{t_attr} not in dataset["time"].attrs') + # Return only the required data return xr.Dataset( data_vars={k: ds.data_vars[k] for k in datavars}, coords={k: ds.coords[k] for k in coords}, - attrs={k: ds.attrs[k] for k in attrs}) + attrs={k: ds.attrs[k] for k in attrs}, + ) From b43926032d1f8b1c044959257fae47272b7959fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:32:26 +0200 Subject: [PATCH 17/24] formatting --- nlmod/cache.py | 61 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/nlmod/cache.py b/nlmod/cache.py index 138a6aba..e005ce97 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -207,10 +207,14 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs): if dataset is not None: # Check the coords of the dataset argument - func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(dict(dataset.coords)) + func_args_dic["_dataset_coords_hash"] = dask.base.tokenize( + dict(dataset.coords) + ) # Check the data_vars of the dataset argument - func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(dict(dataset.data_vars)) + func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize( + dict(dataset.data_vars) + ) # check if cache was created with same function arguments as # function call @@ -257,8 +261,12 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs): # Add dataset argument hash to pickle if dataset is not None: - func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(dict(dataset.coords)) - func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(dict(dataset.data_vars)) + func_args_dic["_dataset_coords_hash"] = dask.base.tokenize( + dict(dataset.coords) + ) + func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize( + dict(dataset.data_vars) + ) # pickle function arguments with open(fname_pickle_cache, "wb") as fpklz: @@ -267,6 +275,7 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs): msg = f"expected xarray Dataset, got {type(result)} instead" raise TypeError(msg) return _check_for_data_array(result) + return wrapper return decorator @@ -473,15 +482,23 @@ def _same_function_arguments(func_args_dic, func_args_dic_cache): mfgrid1 = {k: v for k, v in item.mfgrid.__dict__.items() if k not in excl} mfgrid2 = {k: v for k, v in i2.mfgrid.__dict__.items() if k not in excl} - is_same_length_props = all(np.all(np.size(v) == np.size(mfgrid2[k])) for k, v in mfgrid1.items()) + is_same_length_props = all( + np.all(np.size(v) == np.size(mfgrid2[k])) for k, v in mfgrid1.items() + ) - if not is_method_equal or mfgrid1.keys() != mfgrid2.keys() or not is_same_length_props: + if ( + not is_method_equal + or mfgrid1.keys() != mfgrid2.keys() + or not is_same_length_props + ): logger.info( "cache was created using different gridintersect, do not use cached data" ) return False - is_other_props_equal = all(np.all(v == mfgrid2[k]) for k, v in mfgrid1.items()) + is_other_props_equal = all( + np.all(v == mfgrid2[k]) for k, v in mfgrid1.items() + ) if not is_other_props_equal: logger.info( @@ -545,9 +562,13 @@ def _update_docstring_and_signature(func): else: add_kwargs = None new_param = ( - *cur_param, - inspect.Parameter("cachedir", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None), - inspect.Parameter("cachename", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None) + *cur_param, + inspect.Parameter( + "cachedir", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None + ), + inspect.Parameter( + "cachename", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None + ), ) if add_kwargs is not None: new_param = (*new_param, add_kwargs) @@ -578,9 +599,8 @@ def _update_docstring_and_signature(func): def _check_for_data_array(ds): - """ - Check if the saved NetCDF-file represents a DataArray or a Dataset, and return this - data-variable. + """Check if the saved NetCDF-file represents a DataArray or a Dataset, and return + this data-variable. The file contains a DataArray when a variable called "__xarray_dataarray_variable__" is present in the Dataset. If so, return a DataArray, otherwise return the Dataset. @@ -597,7 +617,6 @@ def _check_for_data_array(ds): ------- ds : xr.Dataset or xr.DataArray A Dataset or DataArray containing the cached data. - """ if "__xarray_dataarray_variable__" in ds: spatial_ref = ds.spatial_ref if "spatial_ref" in ds else None @@ -608,9 +627,16 @@ def _check_for_data_array(ds): return ds -def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavars=None, coords=None, attrs=None): - """ - Returns a Dataset containing only the required data. +def ds_contains( + ds, + coords_2d=False, + coords_3d=False, + coords_time=False, + datavars=None, + coords=None, + attrs=None, +): + """Returns a Dataset containing only the required data. If all kwargs are left to their defaults, the function returns the full dataset. @@ -635,7 +661,6 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar ------- ds : xr.Dataset A Dataset containing only the required data. - """ # Return the full dataset if not configured if ds is None: From 6ac9638064b46df111bef099827e7366db2bdae9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:33:14 +0200 Subject: [PATCH 18/24] skip delr/delc when writing files --- nlmod/gis.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/nlmod/gis.py b/nlmod/gis.py index 17d4709c..3998791f 100644 --- a/nlmod/gis.py +++ b/nlmod/gis.py @@ -143,6 +143,9 @@ def struc_da_to_gdf(model_ds, data_variables, polygons=None, dealing_with_time=" raise ValueError( f"expected dimensions ('layer', 'y', 'x'), got {da.dims}" ) + # TODO: remove when delr/delc are removed as data vars + elif da_name in ["delr", "delc"]: + continue else: raise NotImplementedError( f"expected two or three dimensions got {no_dims} for data variable {da_name}" @@ -286,7 +289,9 @@ def ds_to_vector_file( if model_ds.gridtype == "structured": gdf = struc_da_to_gdf(model_ds, (da_name,), polygons=polygons) elif model_ds.gridtype == "vertex": - gdf = vertex_da_to_gdf(model_ds, (da_name,), polygons=polygons) + # TODO: remove when delr/dec are removed + if da_name not in ["delr", "delc"]: + gdf = vertex_da_to_gdf(model_ds, (da_name,), polygons=polygons) if driver == "GPKG": gdf.to_file(fname_gpkg, layer=da_name, driver=driver) else: From 0a20adad5ba60b1badec1c5e15770ad6fec995b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:33:40 +0200 Subject: [PATCH 19/24] rename bathymetry tests --- tests/test_004_northsea.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_004_northsea.py b/tests/test_004_northsea.py index 2e49171d..d71ccea4 100644 --- a/tests/test_004_northsea.py +++ b/tests/test_004_northsea.py @@ -60,7 +60,7 @@ def test_get_bathymetry_seamodel(): assert (~ds_bathymetry.bathymetry.isnull()).sum() > 0 -def test_get_bathymetrie_nosea(): +def test_get_bathymetry_nosea(): # model without sea ds = test_001_model.get_ds_from_cache("small_model") ds.update(nlmod.read.rws.get_northsea(ds)) @@ -69,7 +69,7 @@ def test_get_bathymetrie_nosea(): assert (~ds_bathymetry.bathymetry.isnull()).sum() == 0 -def test_add_bathymetrie_to_top_bot_kh_kv_seamodel(): +def test_add_bathymetry_to_top_bot_kh_kv_seamodel(): # model with sea ds = test_001_model.get_ds_from_cache("basic_sea_model") ds.update(nlmod.read.rws.get_northsea(ds)) From 92362c5c16e6a85671835aeb9a78208211d7ff37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:34:08 +0200 Subject: [PATCH 20/24] comment still in dutch: bathymetry --- nlmod/read/jarkus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nlmod/read/jarkus.py b/nlmod/read/jarkus.py index 6654352d..25e70331 100644 --- a/nlmod/read/jarkus.py +++ b/nlmod/read/jarkus.py @@ -75,7 +75,7 @@ def get_bathymetry(ds, northsea, kind="jarkus", method="average"): # fill nan values in bathymetry da_bathymetry_filled = fillnan_da(da_bathymetry_raw) - # bathymetrie mag nooit groter zijn dan NAP 0.0 + # bathymetry can never be larger than NAP 0.0 da_bathymetry_filled = xr.where(da_bathymetry_filled > 0, 0, da_bathymetry_filled) # bathymetry projected on model grid From eefdc2cfa05f6caa1b24e1217aeebb08c1c5339d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:34:31 +0200 Subject: [PATCH 21/24] comment still in dutch: bathymetry --- nlmod/read/jarkus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nlmod/read/jarkus.py b/nlmod/read/jarkus.py index 25e70331..5d962973 100644 --- a/nlmod/read/jarkus.py +++ b/nlmod/read/jarkus.py @@ -272,7 +272,7 @@ def add_bathymetry_to_top_bot_kh_kv(ds, bathymetry, fill_mask, kh_sea=10, kv_sea ds["kv"][lay] = xr.where(fill_mask, kv_sea, ds["kv"][lay]) - # reset bot for all layers based on bathymetrie + # reset bot for all layers based on bathymetry for lay in range(1, ds.sizes["layer"]): ds["botm"][lay] = np.where( ds["botm"][lay] > ds["botm"][lay - 1], From 75b5ffbbf1d20cef31a15de216ad9e169c110db1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:35:29 +0200 Subject: [PATCH 22/24] use last delr to set extent --- docs/examples/02_surface_water.ipynb | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/examples/02_surface_water.ipynb b/docs/examples/02_surface_water.ipynb index f5d8cb8ae..c0a7033b 100644 --- a/docs/examples/02_surface_water.ipynb +++ b/docs/examples/02_surface_water.ipynb @@ -490,7 +490,7 @@ "xlim = ax.get_xlim()\n", "ylim = ax.get_ylim()\n", "gwf.modelgrid.plot(ax=ax)\n", - "ax.set_xlim(xlim[0], xlim[0] + ds.delr * 1.1)\n", + "ax.set_xlim(xlim[0], xlim[0] + ds.delr[-1] * 1.1)\n", "ax.set_ylim(ylim)\n", "ax.set_title(f\"Surface water shapes in cell: {cid}\")" ] @@ -788,6 +788,14 @@ "cbar = fig.colorbar(qm, shrink=1.0)\n", "cbar.set_label(\"head [m NAP]\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1cbc42bf", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From a5c32647cbbcfb6c7de2176c314230c0739fdf9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=ADd=20Brakenhoff?= Date: Mon, 13 May 2024 16:36:04 +0200 Subject: [PATCH 23/24] use brackets (was failing for me without) --- docs/examples/cache_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/cache_example.py b/docs/examples/cache_example.py index d1ded5d6..5d33d7b0 100644 --- a/docs/examples/cache_example.py +++ b/docs/examples/cache_example.py @@ -3,7 +3,7 @@ import xarray as xr -@nlmod.cache.cache_netcdf +@nlmod.cache.cache_netcdf() def func_to_create_a_dataset(number): """create a dataarray as an example for the caching method. From f51a15cd76f7b1b1ee73e555079185db46f55a4d Mon Sep 17 00:00:00 2001 From: Bas des Tombe Date: Tue, 14 May 2024 09:45:57 +0200 Subject: [PATCH 24/24] Cache: Implemented inline suggestions made by David --- nlmod/cache.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/nlmod/cache.py b/nlmod/cache.py index e005ce97..6ccd717f 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -688,12 +688,10 @@ def ds_contains( attrs.append("gridtype") if isvertex: - # coords.append("icell2d") # not in coords, runs from 0 to N - # coords.append("iv") # not in coords, runs from 0 to N - # coords.append("icv") # not in coords, runs from 0 to N datavars.append("xv") datavars.append("yv") datavars.append("icvert") + # TODO: temporary fix until delr/delc are completely removed if "delr" in datavars: datavars.remove("delr") @@ -713,18 +711,25 @@ def ds_contains( datavars.append("steady") datavars.append("nstp") datavars.append("tsmult") - # time_attrs = [] - # time_attrs.append("start") - # time_attrs.append("time_units") - # User-friendly error messages + # User-friendly error messages if missing from ds if "northsea" in datavars and "northsea" not in ds.data_vars: msg = "Northsea not in dataset. Run nlmod.read.rws.add_northsea() first." raise ValueError(msg) - if "time" in coords and "time" not in ds.coords: - msg = "time not in dataset. Run nlmod.time.set_ds_time() first." - raise ValueError(msg) + if coords_time: + if "time" not in ds.coords: + msg = "time not in dataset. Run nlmod.time.set_ds_time() first." + raise ValueError(msg) + + # Check if time-coord is complete + time_attrs_required = ["start", "time_units"] + + for t_attr in time_attrs_required: + if t_attr not in ds["time"].attrs: + msg = f"{t_attr} not in dataset['time'].attrs. " +\ + "Run nlmod.time.set_ds_time() to set time." + raise ValueError(msg) # User-unfriendly error messages for datavar in datavars: @@ -742,11 +747,6 @@ def ds_contains( msg = f"{attr} not in dataset.attrs" raise ValueError(msg) - # if coords_time: - # for t_attr in time_attrs: - # if t_attr not in ds["time"].attrs: - # raise ValueError(f'{t_attr} not in dataset["time"].attrs') - # Return only the required data return xr.Dataset( data_vars={k: ds.data_vars[k] for k in datavars},