Skip to content

Commit

Permalink
Merge pull request #21 from NOAA-GFDL/feature/stencil_basic_ops
Browse files Browse the repository at this point in the history
Added basic_operations.py from pyFV3/stencils to ndsl/stencils
  • Loading branch information
FlorianDeconinck authored Apr 1, 2024
2 parents 0f18047 + b086542 commit 85c2213
Show file tree
Hide file tree
Showing 2 changed files with 308 additions and 0 deletions.
96 changes: 96 additions & 0 deletions ndsl/stencils/basic_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import gt4py.cartesian.gtscript as gtscript
from gt4py.cartesian.gtscript import PARALLEL, computation, interval

from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ


def copy_defn(q_in: FloatField, q_out: FloatField):
"""
Copy q_in to q_out.
Args:
q_in: input field
q_out: output field
"""
with computation(PARALLEL), interval(...):
q_out = q_in


def adjustmentfactor_stencil_defn(adjustment: FloatFieldIJ, q_out: FloatField):
"""
Multiplies every element of q_out
by every element of the adjustment
field over the interval, replacing
the elements of q_out by the result
of the multiplication.
Args:
adjustment: adjustment field
q_out: output field
"""
with computation(PARALLEL), interval(...):
q_out = q_out * adjustment


def set_value_defn(q_out: FloatField, value: Float):
"""
Sets every element of q_out to the
value specified by value argument.
Args:
q_out: output field
value: NDSL Float type
"""
with computation(PARALLEL), interval(...):
q_out = value


def adjust_divide_stencil(adjustment: FloatField, q_out: FloatField):
"""
Divides every element of q_out
by every element of the adjustment
field over the interval, replacing
the elements of q_out by the result
of the multiplication.
Args:
adjustment: adjustment field
q_out: output field
"""
with computation(PARALLEL), interval(...):
q_out = q_out / adjustment


@gtscript.function
def sign(a, b):
"""
Defines asignb as the absolute value
of a, and checks if b is positive
or negative, assigning the analogus
sign value to asignb. asignb is returned
Args:
a: A number
b: A number
"""
asignb = abs(a)
if b > 0:
asignb = asignb
else:
asignb = -asignb
return asignb


@gtscript.function
def dim(a, b):
"""
Performs a check on the difference
between the values in arguments
a and b. The variable diff is set
to the difference between a and b
when the difference is positive,
otherwise it is set to zero. The
function returns the diff variable.
"""
diff = a - b if a - b > 0 else 0
return diff
212 changes: 212 additions & 0 deletions tests/test_basic_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import numpy as np

from ndsl import (
CompilationConfig,
DaceConfig,
DaCeOrchestration,
GridIndexing,
Quantity,
RunMode,
StencilConfig,
StencilFactory,
)
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ
from ndsl.stencils import basic_operations as basic


nx = 20
ny = 20
nz = 79
nhalo = 0
backend = "numpy"

dace_config = DaceConfig(
communicator=None, backend=backend, orchestration=DaCeOrchestration.Python
)

compilation_config = CompilationConfig(
backend=backend,
rebuild=True,
validate_args=True,
format_source=False,
device_sync=False,
run_mode=RunMode.BuildAndRun,
use_minimal_caching=False,
)

stencil_config = StencilConfig(
compare_to_numpy=False,
compilation_config=compilation_config,
dace_config=dace_config,
)

grid_indexing = GridIndexing(
domain=(nx, ny, nz),
n_halo=nhalo,
south_edge=True,
north_edge=True,
west_edge=True,
east_edge=True,
)

stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing)


class Copy:
def __init__(self, stencil_factory: StencilFactory):
grid_indexing = stencil_factory.grid_indexing
self._copy_stencil = stencil_factory.from_origin_domain(
basic.copy_defn,
origin=grid_indexing.origin_compute(),
domain=grid_indexing.domain_compute(),
)

def __call__(
self,
f_in: FloatField,
f_out: FloatField,
):
self._copy_stencil(f_in, f_out)


class AdjustmentFactor:
def __init__(self, stencil_factory: StencilFactory):
grid_indexing = stencil_factory.grid_indexing
self._adjustmentfactor_stencil = stencil_factory.from_origin_domain(
basic.adjustmentfactor_stencil_defn,
origin=grid_indexing.origin_compute(),
domain=grid_indexing.domain_compute(),
)

def __call__(
self,
factor: FloatFieldIJ,
f_out: FloatField,
):
self._adjustmentfactor_stencil(factor, f_out)


class SetValue:
def __init__(self, stencil_factory: StencilFactory):
grid_indexing = stencil_factory.grid_indexing
self._set_value_stencil = stencil_factory.from_origin_domain(
basic.set_value_defn,
origin=grid_indexing.origin_compute(),
domain=grid_indexing.domain_compute(),
)

def __call__(
self,
f_out: FloatField,
value: Float,
):
self._set_value_stencil(f_out, value)


class AdjustDivide:
def __init__(self, stencil_factory: StencilFactory):
grid_indexing = stencil_factory.grid_indexing
self._adjust_divide_stencil = stencil_factory.from_origin_domain(
basic.adjust_divide_stencil,
origin=grid_indexing.origin_compute(),
domain=grid_indexing.domain_compute(),
)

def __call__(
self,
factor: FloatField,
f_out: FloatField,
):
self._adjust_divide_stencil(factor, f_out)


def test_copy():
copy = Copy(stencil_factory)

infield = Quantity(
data=np.zeros([20, 20, 79]),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

outfield = Quantity(
data=np.ones([20, 20, 79]),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

copy(f_in=infield.data, f_out=outfield.data)

assert (infield.data == outfield.data).all()


def test_adjustmentfactor():
adfact = AdjustmentFactor(stencil_factory)

factorfield = Quantity(
data=np.full(shape=[20, 20], fill_value=2.0),
dims=[X_DIM, Y_DIM],
units="m",
)

outfield = Quantity(
data=np.full(shape=[20, 20, 79], fill_value=2.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

testfield = Quantity(
data=np.full(shape=[20, 20, 79], fill_value=4.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

adfact(factor=factorfield.data, f_out=outfield.data)
assert (outfield.data == testfield.data).all()


def test_setvalue():
setvalue = SetValue(stencil_factory)

outfield = Quantity(
data=np.zeros(shape=[20, 20, 79]),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

testfield = Quantity(
data=np.full(shape=[20, 20, 79], fill_value=2.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

setvalue(f_out=outfield.data, value=2.0)

assert (outfield.data == testfield.data).all()


def test_adjustdivide():
addiv = AdjustDivide(stencil_factory)

factorfield = Quantity(
data=np.full(shape=[20, 20, 79], fill_value=2.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

outfield = Quantity(
data=np.full(shape=[20, 20, 79], fill_value=2.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

testfield = Quantity(
data=np.full(shape=[20, 20, 79], fill_value=1.0),
dims=[X_DIM, Y_DIM, Z_DIM],
units="m",
)

addiv(factor=factorfield.data, f_out=outfield.data)

assert (outfield.data == testfield.data).all()

0 comments on commit 85c2213

Please sign in to comment.