Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved docs for NDSL examples #63

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
910 changes: 89 additions & 821 deletions examples/NDSL/01_gt4py_basics.ipynb
100755 → 100644

Large diffs are not rendered by default.

205 changes: 32 additions & 173 deletions examples/NDSL/02_NDSL_basics.ipynb

Large diffs are not rendered by default.

73 changes: 13 additions & 60 deletions examples/NDSL/03_orchestration_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://spcl.github.io/dace-webclient/dist/sdfv.js\"></script>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2024-05-29 13:16:33|INFO|rank 0|ndsl.logging:Constant selected: ConstantVersions.GFS\n"
]
}
],
"outputs": [],
"source": [
"import numpy as np\n",
"from gt4py.cartesian.gtscript import (\n",
Expand All @@ -57,8 +37,7 @@
")\n",
"from ndsl.constants import X_DIM, Y_DIM, Z_DIM\n",
"from ndsl.dsl.typing import FloatField, Float\n",
"\n",
"from orch_boilerplate import get_one_tile_factory_orchestrated"
"from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu"
]
},
{
Expand All @@ -70,14 +49,14 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def localsum_stencil(\n",
" field: FloatField, # type: ignore\n",
" field: FloatField, # type: ignore\n",
" result: FloatField, # type: ignore\n",
" weight: Float, # type: ignore\n",
" weight: Float, # type: ignore\n",
"):\n",
" with computation(PARALLEL), interval(...):\n",
" result = weight * (\n",
Expand All @@ -94,7 +73,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -119,8 +98,8 @@
" self._n_halo = quantity_factory.sizer.n_halo\n",
"\n",
" def __call__(self, in_field: FloatField, out_result: FloatField) -> None:\n",
" self._local_sum(in_field, out_result, 2.0) # GT4Py Stencil\n",
" tmp_field = out_result[:, :, :] + 2 # Regular Python code\n",
" self._local_sum(in_field, out_result, 2.0) # GT4Py Stencil\n",
" tmp_field = out_result[:, :, :] + 2 # Regular Python code\n",
" self._local_sum(tmp_field, out_result, 2.0) # GT4Py Stencil"
]
},
Expand All @@ -133,51 +112,25 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2024-05-29 13:16:33|INFO|rank 0|ndsl.logging:[DaCeOrchestration.BuildAndRun] Rank 0 reading/writing cache .gt_cache_FV3_A\n",
"2024-05-29 13:16:33|INFO|rank 0|ndsl.logging:Building DaCe orchestration\n",
"Inlined 2 SDFGs.\n",
"Fused 4 states.\n",
"Inferred 2 optional arrays.\n",
"SDFG 0: Eliminated 1 arrays: {'out_result_0'}.\n",
"Fused 2 states.\n",
"Inferred 4 optional arrays.\n",
"Inlined 2 SDFGs.\n",
"2024-05-29 13:16:34|INFO|rank 0|ndsl.logging:[DaCeOrchestration.BuildAndRun] LocalSum___call__:\n",
"StorageType.Default:\n",
" Alloc ref 0.01 mb\n",
" Alloc unref 0.00 mb\n",
" Pooled 0.00 mb\n",
" Top lvl alloc: 0.01mb\n",
"\n",
"[DaCe Config] Rank 0 loading SDFG /home/ckung/Documents/Code/SMT-Nebulae-Tutorial/tutorial/NDSL/.gt_cache_FV3_A/dacecache/LocalSum___call__\n"
]
}
],
"outputs": [],
"source": [
"# ----- Driver ----- #\n",
"\n",
"if __name__ == \"__main__\":\n",
" # Settings\n",
" backend = \"dace:cpu\"\n",
" dtype = np.float64\n",
" origin = (0, 0, 0)\n",
" rebuild = True\n",
" tile_size = (3, 3, 3)\n",
"\n",
" # Setup\n",
" stencil_factory, qty_factory = get_one_tile_factory_orchestrated(\n",
" stencil_factory, qty_factory = get_factories_single_tile_orchestrated_cpu(\n",
" nx=tile_size[0],\n",
" ny=tile_size[1],\n",
" nz=tile_size[2],\n",
" nhalo=2,\n",
" backend=backend,\n",
" )\n",
" local_sum = LocalSum(stencil_factory, qty_factory)\n",
"\n",
Expand Down Expand Up @@ -206,7 +159,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
14 changes: 14 additions & 0 deletions examples/NDSL/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# NDSL examples

This folder contains a couple of Jupyter notebooks with the following examples:

- Getting started with GT4Py: [GT4Py basics](./01_gt4py_basics.ipynb)
- Getting started with NDSL middleware: [NDSL basics](./02_NDSL_basics.ipynb)
- Combining stencil and non-stencil code with DaCe: [Orchestration basics](./03_orchestration_basics.ipynb)

## Quickstart

1. Make sure you fulfill the [requirements of NDSL](../../README.md#quickstart), e.g. python version.
2. Create a virtual environment `python -m venv .venv`, and activate it `source .venv/bin/activate`.
3. Install NDSL into your environment `pip install ../../[demos]`.
4. In VSCode, install the Jupyter extension, select your virtual environment as kernel and run the notebooks.
73 changes: 0 additions & 73 deletions examples/NDSL/basic_boilerplate.py

This file was deleted.

64 changes: 0 additions & 64 deletions examples/NDSL/orch_boilerplate.py

This file was deleted.

4 changes: 2 additions & 2 deletions ndsl/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _get_factories(
def get_factories_single_tile_orchestrated_cpu(
nx, ny, nz, nhalo
) -> Tuple[StencilFactory, QuantityFactory]:
"""Build a Stencil & Quantity factory for orchestrated CPU, on a single tile toplogy."""
"""Build a Stencil & Quantity factory for orchestrated CPU, on a single tile topology."""
return _get_factories(
nx=nx,
ny=ny,
Expand All @@ -97,7 +97,7 @@ def get_factories_single_tile_orchestrated_cpu(
def get_factories_single_tile_numpy(
nx, ny, nz, nhalo
) -> Tuple[StencilFactory, QuantityFactory]:
"""Build a Stencil & Quantity factory for Numpy, on a single tile toplogy."""
"""Build a Stencil & Quantity factory for Numpy, on a single tile topology."""
return _get_factories(
nx=nx,
ny=ny,
Expand Down
17 changes: 17 additions & 0 deletions ndsl/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union, cast

import matplotlib.pyplot as plt
import numpy as np

import ndsl.constants as constants
Expand Down Expand Up @@ -593,6 +594,22 @@ def transpose(
transposed._attrs = self._attrs
return transposed

def plot_k_level(self, k_index=0):
field = self.data
print(
"Min and max values:",
field[:, :, k_index].min(),
field[:, :, k_index].max(),
)
plt.xlabel("I")
plt.ylabel("J")

im = plt.imshow(field[:, :, k_index].transpose(), origin="lower")

plt.colorbar(im)
plt.title("Plot at K = " + str(k_index))
plt.show()


def transpose_sequence(sequence, order):
return sequence.__class__(sequence[i] for i in order)
Expand Down
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@ def local_pkg(name: str, relative_path: str) -> str:

test_requirements = ["pytest", "pytest-subtests", "coverage"]
develop_requirements = test_requirements + ["pre-commit"]
demos_requirements = ["ipython", "ipykernel"]

extras_requires = {"test": test_requirements, "develop": develop_requirements}
extras_requires = {
"test": test_requirements,
"develop": develop_requirements,
"demos": demos_requirements,
}

requirements: List[str] = [
local_pkg("gt4py", "external/gt4py"),
Expand All @@ -29,6 +34,7 @@ def local_pkg(name: str, relative_path: str) -> str:
"h5netcdf", # for xarray
"dask", # for xarray
"numpy==1.26.4",
"matplotlib", # for plotting in boilerplate
romanc marked this conversation as resolved.
Show resolved Hide resolved
]


Expand Down
Loading