From 17d17fbea1126cc3c6e32f761a1598d47741dde8 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 6 Mar 2024 15:48:50 -0500 Subject: [PATCH 1/7] Added basic_operations.py from pyFV3 to ndsl/stencils --- ndsl/stencils/basic_operations.py | 46 +++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 ndsl/stencils/basic_operations.py diff --git a/ndsl/stencils/basic_operations.py b/ndsl/stencils/basic_operations.py new file mode 100644 index 0000000..55307c5 --- /dev/null +++ b/ndsl/stencils/basic_operations.py @@ -0,0 +1,46 @@ +import gt4py.cartesian.gtscript as gtscript +from gt4py.cartesian.gtscript import PARALLEL, computation, interval + +from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ + + +def copy_defn(q_in: FloatField, q_out: FloatField): + """Copy q_in to q_out. + + Args: + q_in: input field + q_out: output field + """ + with computation(PARALLEL), interval(...): + q_out = q_in + + +def adjustmentfactor_stencil_defn(adjustment: FloatFieldIJ, q_out: FloatField): + with computation(PARALLEL), interval(...): + q_out = q_out * adjustment + + +def set_value_defn(q_out: FloatField, value: Float): + with computation(PARALLEL), interval(...): + q_out = value + + +def adjust_divide_stencil(adjustment: FloatField, q_out: FloatField): + with computation(PARALLEL), interval(...): + q_out = q_out / adjustment + + +@gtscript.function +def sign(a, b): + asignb = abs(a) + if b > 0: + asignb = asignb + else: + asignb = -asignb + return asignb + + +@gtscript.function +def dim(a, b): + diff = a - b if a - b > 0 else 0 + return diff From 9f9e25c4df67fcacce435dc40663f91c2eba47d4 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 20 Mar 2024 16:07:22 -0400 Subject: [PATCH 2/7] Added rough draft of unit tests --- tests/test_basic_operations.py | 205 +++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 tests/test_basic_operations.py diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py new file mode 100644 index 0000000..83b9f81 --- /dev/null +++ b/tests/test_basic_operations.py @@ -0,0 +1,205 @@ +from gt4py.storage import full, ones, zeros + +from ndsl import ( + CompilationConfig, + DaceConfig, + DaCeOrchestration, + GridIndexing, + RunMode, + StencilConfig, + StencilFactory, +) +from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ +from ndsl.stencils import basic_operations as basic + + +nx = 20 +ny = 20 +nz = 79 +nhalo = 3 +backend = "numpy" + +dace_config = DaceConfig( + communicator=None, backend=backend, orchestration=DaCeOrchestration.Python +) + +compilation_config = CompilationConfig( + backend=backend, + rebuild=True, + validate_args=True, + format_source=False, + device_sync=False, + run_mode=RunMode.BuildAndRun, + use_minimal_caching=False, +) + +stencil_config = StencilConfig( + compare_to_numpy=False, + compilation_config=compilation_config, + dace_config=dace_config, +) + +grid_indexing = GridIndexing( + domain=(nx, ny, nz), + n_halo=nhalo, + south_edge=True, + north_edge=True, + west_edge=True, + east_edge=True, +) + +stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing) + + +class Copy: + def __init__(self, stencil_factory: StencilFactory): + grid_indexing = stencil_factory.grid_indexing + self._copy_stencil = stencil_factory.from_origin_domain( + basic.copy_defn, + origin=grid_indexing.origin_compute(), + domain=grid_indexing.domain_compute(), + ) + + def __call__( + self, + f_in: FloatField, + f_out: FloatField, + ): + self._copy_stencil(f_in, f_out) + + +class AdjustmentFactor: + def __init__(self, stencil_factory: StencilFactory): + grid_indexing = stencil_factory.grid_indexing + self._adjustmentfactor_stencil = stencil_factory.from_origin_domain( + basic.adjustmentfactor_stencil_defn, + origin=grid_indexing.origin_compute(), + domain=grid_indexing.domain_compute(), + ) + + def __call__( + self, + factor: FloatFieldIJ, + f_out: FloatField, + ): + self._adjustmentfactor_stencil(factor, f_out) + + +class SetValue: + def __init__(self, stencil_factory: StencilFactory): + grid_indexing = stencil_factory.grid_indexing + self._set_value_stencil = stencil_factory.from_origin_domain( + basic.set_value_defn, + origin=grid_indexing.origin_compute(), + domain=grid_indexing.domain_compute(), + ) + + def __call__( + self, + f_out: FloatField, + value: Float, + ): + self._set_value_stencil(f_out, value) + + +class AdjustDivide: + def __init__(self, stencil_factory: StencilFactory): + grid_indexing = stencil_factory.grid_indexing + self._adjust_divide_stencil = stencil_factory.from_origin_domain( + basic.adjust_divide_stencil, + origin=grid_indexing.origin_compute(), + domain=grid_indexing.domain_compute(), + ) + + def __call__( + self, + factor: FloatField, + f_out: FloatField, + ): + self._adjust_divide_stencil(factor, f_out) + + +def test_copy(): + copy = Copy(stencil_factory) + + infield = zeros( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + ) + + outfield = ones( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + ) + + copy(f_in=infield, f_out=outfield) + + assert infield.all() == outfield.all() + + +def test_adjustmentfactor(): + adfact = AdjustmentFactor(stencil_factory) + + factorfield = ones( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo) + ) + + outfield = ones( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + ) + + testfield = full( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo), + fill_value=26.0, + ) + + adfact(factor=factorfield, f_out=outfield) + assert outfield.any() == testfield.any() + + +def test_setvalue(): + setvalue = SetValue(stencil_factory) + + outfield = zeros( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + ) + + testfield = full( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + fill_value=2.0, + ) + + setvalue(f_out=outfield, value=2.0) + + assert outfield.any() == testfield.any() + + +def test_adjustdivide(): + addiv = AdjustDivide(stencil_factory) + + factorfield = full( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + fill_value=2.0, + ) + + outfield = ones( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + ) + + testfield = full( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo), + fill_value=13.0, + ) + + addiv(factor=factorfield, f_out=outfield) + assert outfield.any() == testfield.any() From a15accc632a0150aa1f38bb2f724ca874d4e1aa1 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 27 Mar 2024 15:17:35 -0400 Subject: [PATCH 3/7] Updated unit test --- tests/test_basic_operations.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index 83b9f81..f90d794 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -13,10 +13,10 @@ from ndsl.stencils import basic_operations as basic -nx = 20 -ny = 20 -nz = 79 -nhalo = 3 +nx = 5 +ny = 5 +nz = 1 +nhalo = 0 backend = "numpy" dace_config = DaceConfig( @@ -122,35 +122,35 @@ def __call__( def test_copy(): copy = Copy(stencil_factory) - infield = zeros( + infield = ones( backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) ) - outfield = ones( + outfield = zeros( backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) ) copy(f_in=infield, f_out=outfield) - assert infield.all() == outfield.all() + assert infield.any() == outfield.any() def test_adjustmentfactor(): adfact = AdjustmentFactor(stencil_factory) - factorfield = ones( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo) + factorfield = full( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo), fill_value=2.0 ) - outfield = ones( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + outfield = full( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), fill_value=2.0 ) testfield = full( backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo), - fill_value=26.0, + fill_value=4.0, ) adfact(factor=factorfield, f_out=outfield) @@ -188,18 +188,20 @@ def test_adjustdivide(): fill_value=2.0, ) - outfield = ones( + outfield = full( backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + fill_value=4.0, ) testfield = full( backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo), - fill_value=13.0, + fill_value=2.0, ) - + addiv(factor=factorfield, f_out=outfield) + assert outfield.any() == testfield.any() From 330e52d974cd5fe449fc4c693a8228a840994408 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 27 Mar 2024 15:19:38 -0400 Subject: [PATCH 4/7] Linting --- tests/test_basic_operations.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index f90d794..1ca1dd4 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -139,11 +139,17 @@ def test_adjustmentfactor(): adfact = AdjustmentFactor(stencil_factory) factorfield = full( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo), fill_value=2.0 + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo), + fill_value=2.0, ) outfield = full( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), fill_value=2.0 + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + fill_value=2.0, ) testfield = full( @@ -201,7 +207,7 @@ def test_adjustdivide(): shape=(nx + 2 * nhalo, ny + 2 * nhalo), fill_value=2.0, ) - + addiv(factor=factorfield, f_out=outfield) assert outfield.any() == testfield.any() From 5e4daf4a59b463c6e112244bc02c2d56fc5972f8 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Thu, 28 Mar 2024 15:35:33 -0400 Subject: [PATCH 5/7] Updated docstrings of basic_operations and using Quantity instead of gtstorage for testing --- ndsl/stencils/basic_operations.py | 52 ++++++++++++++- tests/test_basic_operations.py | 103 +++++++++++++++--------------- 2 files changed, 102 insertions(+), 53 deletions(-) diff --git a/ndsl/stencils/basic_operations.py b/ndsl/stencils/basic_operations.py index 55307c5..b46123a 100644 --- a/ndsl/stencils/basic_operations.py +++ b/ndsl/stencils/basic_operations.py @@ -5,7 +5,8 @@ def copy_defn(q_in: FloatField, q_out: FloatField): - """Copy q_in to q_out. + """ + Copy q_in to q_out. Args: q_in: input field @@ -16,22 +17,62 @@ def copy_defn(q_in: FloatField, q_out: FloatField): def adjustmentfactor_stencil_defn(adjustment: FloatFieldIJ, q_out: FloatField): + """ + Multiplies every element of q_out + by every element of the adjustment + field over the interval, replacing + the elements of q_out by the result + of the multiplication. + + Args: + adjustment: adjustment field + q_out: output field + """ with computation(PARALLEL), interval(...): q_out = q_out * adjustment def set_value_defn(q_out: FloatField, value: Float): + """ + Sets every element of q_out to the + value specified by value argument. + + Args: + q_out: output field + value: NDSL Float type + """ with computation(PARALLEL), interval(...): q_out = value def adjust_divide_stencil(adjustment: FloatField, q_out: FloatField): + """ + Divides every element of q_out + by every element of the adjustment + field over the interval, replacing + the elements of q_out by the result + of the multiplication. + + Args: + adjustment: adjustment field + q_out: output field + """ with computation(PARALLEL), interval(...): q_out = q_out / adjustment @gtscript.function def sign(a, b): + """ + Defines asignb as the absolute value + of a, and checks if b is positive + or negative, assigning the analogus + sign value to asignb. asignb is returned + + Args: + a: A number + b: A number + """ asignb = abs(a) if b > 0: asignb = asignb @@ -42,5 +83,14 @@ def sign(a, b): @gtscript.function def dim(a, b): + """ + Performs a check on the difference + between the values in arguments + a and b. The variable diff is set + to the difference between a and b + when the difference is positive, + otherwise it is set to zero. The + function returns the diff variable. + """ diff = a - b if a - b > 0 else 0 return diff diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index 1ca1dd4..8fd87ad 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -1,14 +1,16 @@ -from gt4py.storage import full, ones, zeros +import numpy as np from ndsl import ( CompilationConfig, DaceConfig, DaCeOrchestration, GridIndexing, + Quantity, RunMode, StencilConfig, StencilFactory, ) +from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ from ndsl.stencils import basic_operations as basic @@ -122,92 +124,89 @@ def __call__( def test_copy(): copy = Copy(stencil_factory) - infield = ones( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + infield = Quantity( + data=np.zeros([5, 5, 1]), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - outfield = zeros( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + outfield = Quantity( + data=np.ones([5, 5, 1]), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - copy(f_in=infield, f_out=outfield) + copy(f_in=infield.data, f_out=outfield.data) - assert infield.any() == outfield.any() + assert infield.data.any() == outfield.data.any() def test_adjustmentfactor(): adfact = AdjustmentFactor(stencil_factory) - factorfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo), - fill_value=2.0, + factorfield = Quantity( + data=np.full(shape=[5, 5], fill_value=2.0), + dims=[X_DIM, Y_DIM], + units="m", ) - outfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), - fill_value=2.0, + outfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=2.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - testfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo), - fill_value=4.0, + testfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=4.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - adfact(factor=factorfield, f_out=outfield) - assert outfield.any() == testfield.any() + adfact(factor=factorfield.data, f_out=outfield.data) + assert outfield.data.any() == testfield.data.any() def test_setvalue(): setvalue = SetValue(stencil_factory) - outfield = zeros( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + outfield = Quantity( + data=np.zeros(shape=[5, 5, 1]), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - testfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), - fill_value=2.0, + testfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=2.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - setvalue(f_out=outfield, value=2.0) + setvalue(f_out=outfield.data, value=2.0) - assert outfield.any() == testfield.any() + assert outfield.data.any() == testfield.data.any() def test_adjustdivide(): addiv = AdjustDivide(stencil_factory) - factorfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), - fill_value=2.0, + factorfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=2.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - outfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), - fill_value=4.0, + outfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=2.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - testfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo), - fill_value=2.0, + testfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=1.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - addiv(factor=factorfield, f_out=outfield) + addiv(factor=factorfield.data, f_out=outfield.data) - assert outfield.any() == testfield.any() + assert outfield.data.any() == testfield.data.any() From 7f0d0f0c260eb15f353f970098803c33651d133a Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Fri, 29 Mar 2024 13:46:32 -0400 Subject: [PATCH 6/7] Updated logic of unit tests in test_basic_operations.py --- tests/test_basic_operations.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index 8fd87ad..e761b8b 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -15,9 +15,9 @@ from ndsl.stencils import basic_operations as basic -nx = 5 -ny = 5 -nz = 1 +nx = 20 +ny = 20 +nz = 79 nhalo = 0 backend = "numpy" @@ -125,88 +125,88 @@ def test_copy(): copy = Copy(stencil_factory) infield = Quantity( - data=np.zeros([5, 5, 1]), + data=np.zeros([20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) outfield = Quantity( - data=np.ones([5, 5, 1]), + data=np.ones([20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) copy(f_in=infield.data, f_out=outfield.data) - assert infield.data.any() == outfield.data.any() + assert (infield.data == outfield.data).any() def test_adjustmentfactor(): adfact = AdjustmentFactor(stencil_factory) factorfield = Quantity( - data=np.full(shape=[5, 5], fill_value=2.0), + data=np.full(shape=[20, 20], fill_value=2.0), dims=[X_DIM, Y_DIM], units="m", ) outfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=2.0), + data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) testfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=4.0), + data=np.full(shape=[20, 20, 79], fill_value=4.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) adfact(factor=factorfield.data, f_out=outfield.data) - assert outfield.data.any() == testfield.data.any() + assert (outfield.data == testfield.data).any() def test_setvalue(): setvalue = SetValue(stencil_factory) outfield = Quantity( - data=np.zeros(shape=[5, 5, 1]), + data=np.zeros(shape=[20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) testfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=2.0), + data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) setvalue(f_out=outfield.data, value=2.0) - assert outfield.data.any() == testfield.data.any() + assert (outfield.data == testfield.data).any() def test_adjustdivide(): addiv = AdjustDivide(stencil_factory) factorfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=2.0), + data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) outfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=2.0), + data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) testfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=1.0), + data=np.full(shape=[20, 20, 79], fill_value=1.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) addiv(factor=factorfield.data, f_out=outfield.data) - assert outfield.data.any() == testfield.data.any() + assert (outfield.data == testfield.data).any() From b08654231a233571c37fd247db1d68cb927168b2 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Fri, 29 Mar 2024 15:05:01 -0400 Subject: [PATCH 7/7] Using .all() instead of .any() --- tests/test_basic_operations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index e761b8b..0d70724 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -138,7 +138,7 @@ def test_copy(): copy(f_in=infield.data, f_out=outfield.data) - assert (infield.data == outfield.data).any() + assert (infield.data == outfield.data).all() def test_adjustmentfactor(): @@ -163,7 +163,7 @@ def test_adjustmentfactor(): ) adfact(factor=factorfield.data, f_out=outfield.data) - assert (outfield.data == testfield.data).any() + assert (outfield.data == testfield.data).all() def test_setvalue(): @@ -183,7 +183,7 @@ def test_setvalue(): setvalue(f_out=outfield.data, value=2.0) - assert (outfield.data == testfield.data).any() + assert (outfield.data == testfield.data).all() def test_adjustdivide(): @@ -209,4 +209,4 @@ def test_adjustdivide(): addiv(factor=factorfield.data, f_out=outfield.data) - assert (outfield.data == testfield.data).any() + assert (outfield.data == testfield.data).all()