Skip to content

Commit

Permalink
Use boilerplate exposed by ndsl package
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Cattaneo committed Aug 6, 2024
1 parent 3a250a0 commit 97ee945
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 138 deletions.
2 changes: 1 addition & 1 deletion examples/NDSL/01_gt4py_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"from ndsl.dsl.typing import FloatField\n",
"from ndsl.quantity import Quantity\n",
"import numpy as np\n",
"from basic_boilerplate import plot_field_at_kN"
"from ndsl.boilerplate import plot_field_at_kN"
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions examples/NDSL/02_NDSL_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@
"metadata": {},
"outputs": [],
"source": [
"from basic_boilerplate import get_one_tile_factory, plot_field_at_kN\n",
"from ndsl import StencilFactory\n",
"from ndsl.boilerplate import get_factories_single_tile_numpy, plot_field_at_kN\n",
"\n",
"nx = 6\n",
"ny = 6\n",
"nz = 1\n",
"nhalo = 1\n",
"backend=\"numpy\"\n",
"\n",
"stencil_factory: StencilFactory = get_one_tile_factory(nx, ny, nz, nhalo, backend)"
"stencil_factory, _ = get_factories_single_tile_numpy(nx, ny, nz, nhalo)"
]
},
{
Expand Down Expand Up @@ -81,7 +80,7 @@
"source": [
"### **Creating a class that performs a stencil computation**\n",
"\n",
"Using the `StencilFactory` object created earlier, the code will now create a class `CopyField` that takes `copy_field_stencil` and defines the computation domain from the parameters `origin` and `domain` within `__init__`. `origin` indicates the \"starting\" point of the stencil calculation, and `domain` indicates the extent of the stencil calculation in the three dimensions. Note that when creating `stencil_factory`, a 6 by 6 by 1 sized domain surrounded with a halo layer of size 1 was defined (see the initialization of `grid_indexing` at [basic_boilerplate.py](./basic_boilerplate.py#get_one_tile_factory)). Thus, whenever a `CopyField` object is created, it will perform calculations within the 6 by 6 by 1 domain (specified by `domain=grid_indexing.domain_compute()`), and the `origin` will start at the `[0,0,0]` location of the 6 by 6 by 1 grid (specified by `origin=grid_indexing.origin_compute()`)."
"Using the `StencilFactory` object created earlier, the code will now create a class `CopyField` that takes `copy_field_stencil` and defines the computation domain from the parameters `origin` and `domain` within `__init__`. `origin` indicates the \"starting\" point of the stencil calculation, and `domain` indicates the extent of the stencil calculation in the three dimensions. Note that when creating `stencil_factory`, a 6 by 6 by 1 sized domain surrounded with a halo layer of size 1 was defined. Thus, whenever a `CopyField` object is created, it will perform calculations within the 6 by 6 by 1 domain (specified by `domain=grid_indexing.domain_compute()`), and the `origin` will start at the `[0,0,0]` location of the 6 by 6 by 1 grid (specified by `origin=grid_indexing.origin_compute()`)."
]
},
{
Expand Down Expand Up @@ -129,6 +128,7 @@
"from ndsl.quantity import Quantity\n",
"import numpy as np\n",
"\n",
"backend = stencil_factory.backend\n",
"size = (nx + 2 * nhalo) * (ny + 2 * nhalo) * nz\n",
"shape = (nx + 2 * nhalo, ny + 2 * nhalo, nz)\n",
"\n",
Expand Down
7 changes: 2 additions & 5 deletions examples/NDSL/03_orchestration_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
")\n",
"from ndsl.constants import X_DIM, Y_DIM, Z_DIM\n",
"from ndsl.dsl.typing import FloatField, Float\n",
"\n",
"from orch_boilerplate import get_one_tile_factory_orchestrated"
"from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu"
]
},
{
Expand Down Expand Up @@ -121,19 +120,17 @@
"\n",
"if __name__ == \"__main__\":\n",
" # Settings\n",
" backend = \"dace:cpu\"\n",
" dtype = np.float64\n",
" origin = (0, 0, 0)\n",
" rebuild = True\n",
" tile_size = (3, 3, 3)\n",
"\n",
" # Setup\n",
" stencil_factory, qty_factory = get_one_tile_factory_orchestrated(\n",
" stencil_factory, qty_factory = get_factories_single_tile_orchestrated_cpu(\n",
" nx=tile_size[0],\n",
" ny=tile_size[1],\n",
" nz=tile_size[2],\n",
" nhalo=2,\n",
" backend=backend,\n",
" )\n",
" local_sum = LocalSum(stencil_factory, qty_factory)\n",
"\n",
Expand Down
8 changes: 3 additions & 5 deletions examples/NDSL/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

This folder contains a couple of Jupyter notebooks with the following examples:

- [GT4Py basics](./01_gt4py_basics.ipynb)
- [NDSL basics](./02_NDSL_basics.ipynb)
- [Orchestration basics](./03_orchestration_basics.ipynb)

wich are supported by `*_boilerplate.py` code.
- Getting started with GT4Py: [GT4Py basics](./01_gt4py_basics.ipynb)
- Getting started with NDSL middleware: [NDSL basics](./02_NDSL_basics.ipynb)
- Combining stencil and non-stencil code with DaCe: [Orchestration basics](./03_orchestration_basics.ipynb)

## Quickstart

Expand Down
58 changes: 0 additions & 58 deletions examples/NDSL/basic_boilerplate.py

This file was deleted.

64 changes: 0 additions & 64 deletions examples/NDSL/orch_boilerplate.py

This file was deleted.

14 changes: 14 additions & 0 deletions ndsl/boilerplate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np

from ndsl import (
Expand Down Expand Up @@ -107,3 +108,16 @@ def get_factories_single_tile_numpy(
orchestration=DaCeOrchestration.Python,
topology="tile",
)


def plot_field_at_kN(field, k_index=0):

print("Min and max values:", field[:, :, k_index].min(), field[:, :, k_index].max())
plt.xlabel("I")
plt.ylabel("J")

im = plt.imshow(field[:, :, k_index].transpose(), origin="lower")

plt.colorbar(im)
plt.title("Plot at K = " + str(k_index))
plt.show()
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ def local_pkg(name: str, relative_path: str) -> str:
develop_requirements = test_requirements + ["pre-commit"]
demos_requirements = ["ipython", "ipykernel"]

extras_requires = {"test": test_requirements, "develop": develop_requirements, "demos": demos_requirements}
extras_requires = {
"test": test_requirements,
"develop": develop_requirements,
"demos": demos_requirements,
}

requirements: List[str] = [
local_pkg("gt4py", "external/gt4py"),
Expand All @@ -30,6 +34,7 @@ def local_pkg(name: str, relative_path: str) -> str:
"h5netcdf", # for xarray
"dask", # for xarray
"numpy==1.26.4",
"matplotlib", # for plotting in boilerplate
]


Expand Down

0 comments on commit 97ee945

Please sign in to comment.