From b83f8828d3ecaf16f45a93f38a53c32a797f496c Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 14 Feb 2024 12:43:02 -0500 Subject: [PATCH 01/14] Adding and updating __init__ files for modules --- ndsl/__init__.py | 6 ++++++ ndsl/buffer.py | 3 ++- ndsl/comm/__init__.py | 15 +++++++++++++++ ndsl/dsl/__init__.py | 5 +++-- ndsl/dsl/caches/__init__.py | 1 + ndsl/dsl/dace/__init__.py | 4 ++-- ndsl/halo/__init__.py | 6 ++++++ ndsl/performance/__init__.py | 7 +++++++ ndsl/stencils/__init__.py | 4 ++++ ndsl/stencils/testing/__init__.py | 1 + tests/checkpointer/__init__.py | 0 tests/dsl/__init__.py | 2 ++ tests/mpi/__init__.py | 0 tests/mpi/test_mpi_halo_update.py | 3 ++- tests/mpi/test_mpi_mock.py | 3 ++- tests/quantity/__init__.py | 0 16 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 ndsl/dsl/caches/__init__.py create mode 100644 tests/checkpointer/__init__.py create mode 100644 tests/dsl/__init__.py create mode 100644 tests/mpi/__init__.py create mode 100644 tests/quantity/__init__.py diff --git a/ndsl/__init__.py b/ndsl/__init__.py index e023b27..89d35b8 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1 +1,7 @@ +from .constants import ConstantVersions +from .exceptions import OutOfBoundsError from .logging import ndsl_log +from .optional_imports import RaiseWhenAccessed +from .types import Allocator, AsyncRequest, NumpyModule +from .units import UnitsError +from .utils import MetaEnumStr diff --git a/ndsl/buffer.py b/ndsl/buffer.py index 05cd643..bc829e7 100644 --- a/ndsl/buffer.py +++ b/ndsl/buffer.py @@ -4,7 +4,6 @@ import numpy as np from numpy.lib.index_tricks import IndexExpression -from ndsl.performance.timer import NullTimer, Timer from ndsl.types import Allocator from ndsl.utils import ( device_synchronize, @@ -13,6 +12,8 @@ safe_mpi_allocate, ) +from .performance.timer import NullTimer, Timer + BufferKey = Tuple[Callable, Iterable[int], type] BUFFER_CACHE: Dict[BufferKey, List["Buffer"]] = {} diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index e69de29..92675c0 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -0,0 +1,15 @@ +from .boundary import SimpleBoundary +from .caching_comm import ( + CachingCommData, + CachingCommReader, + CachingCommWriter, + CachingRequestReader, + CachingRequestWriter, + NullRequest, +) +from .comm_abc import Comm, Request +from .communicator import CubedSphereCommunicator, TileCommunicator +from .local_comm import AsyncResult, ConcurrencyError, LocalComm +from .mpi import MPIComm +from .null_comm import NullAsyncResult, NullComm +from .partitioner import CubedSpherePartitioner, TilePartitioner diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index 5dd32b2..c013460 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -6,12 +6,13 @@ from .dace.dace_config import DaceConfig, DaCeOrchestration from .dace.orchestration import orchestrate, orchestrate_function from .stencil import ( - CompilationConfig, + CompareToNumpyStencil, FrozenStencil, GridIndexing, - StencilConfig, StencilFactory, + TimingCollector, ) +from .stencil_config import CompilationConfig, RunMode, StencilConfig if MPI is not None: diff --git a/ndsl/dsl/caches/__init__.py b/ndsl/dsl/caches/__init__.py new file mode 100644 index 0000000..3417ff0 --- /dev/null +++ b/ndsl/dsl/caches/__init__.py @@ -0,0 +1 @@ +from .codepath import FV3CodePath \ No newline at end of file diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index bcae0c4..d2c64b0 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -1,2 +1,2 @@ -from ndsl.dsl.dace.dace_config import DaceConfig -from ndsl.dsl.dace.orchestration import orchestrate +from .dace_config import DaceConfig +from .orchestration import orchestrate diff --git a/ndsl/halo/__init__.py b/ndsl/halo/__init__.py index e69de29..b32a693 100644 --- a/ndsl/halo/__init__.py +++ b/ndsl/halo/__init__.py @@ -0,0 +1,6 @@ +from .data_transformer import ( + HaloDataTransformerCPU, + HaloDataTransformerGPU, + HaloExchangeSpec, +) +from .updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater diff --git a/ndsl/performance/__init__.py b/ndsl/performance/__init__.py index 28e03bc..fd79608 100644 --- a/ndsl/performance/__init__.py +++ b/ndsl/performance/__init__.py @@ -1,2 +1,9 @@ +from .collector import ( + AbstractPerformanceCollector, + NullPerformanceCollector, + PerformanceCollector, +) from .config import PerformanceConfig +from .profiler import NullProfiler, Profiler +from .report import Experiment, Report, TimeReport from .timer import NullTimer, Timer diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index d3ec452..641e032 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1 +1,5 @@ +from .c2l_ord import CubedToLatLon +from .corners import CopyCorners, CopyCornersXY, FillCornersBGrid + + __version__ = "0.2.0" diff --git a/ndsl/stencils/testing/__init__.py b/ndsl/stencils/testing/__init__.py index d676e87..d66176c 100644 --- a/ndsl/stencils/testing/__init__.py +++ b/ndsl/stencils/testing/__init__.py @@ -4,6 +4,7 @@ ParallelTranslate2Py, ParallelTranslate2PyState, ParallelTranslateBaseSlicing, + ParallelTranslateGrid, ) from .savepoint import SavepointCase, Translate, dataset_to_dict from .temporaries import assert_same_temporaries, copy_temporaries diff --git a/tests/checkpointer/__init__.py b/tests/checkpointer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/dsl/__init__.py b/tests/dsl/__init__.py new file mode 100644 index 0000000..10487c1 --- /dev/null +++ b/tests/dsl/__init__.py @@ -0,0 +1,2 @@ +from .test_stencil_wrapper import MockFieldInfo +from .test_caches import OrchestratedProgam \ No newline at end of file diff --git a/tests/mpi/__init__.py b/tests/mpi/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index 4d5133d..7343bf8 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -1,7 +1,6 @@ import copy import pytest -from mpi_comm import MPI from ndsl.comm._boundary_utils import get_boundary_slice from ndsl.comm.communicator import CubedSphereCommunicator @@ -25,6 +24,8 @@ ) from ndsl.quantity import Quantity +from .mpi_comm import MPI + @pytest.fixture def dtype(numpy): diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index c9d3d61..4d4d24a 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -1,10 +1,11 @@ import numpy as np import pytest -from mpi_comm import MPI from ndsl.comm.communicator import recv_buffer from ndsl.testing import ConcurrencyError, DummyComm +from .mpi_comm import MPI + worker_function_list = [] diff --git a/tests/quantity/__init__.py b/tests/quantity/__init__.py new file mode 100644 index 0000000..e69de29 From 2b1cb3ae3d0a21b972a5adaa8472e5c47d149318 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Mon, 26 Feb 2024 15:45:32 -0500 Subject: [PATCH 02/14] Changes suggested from PR and updates to what is exposed --- examples/mpi/global_timings.py | 2 +- examples/mpi/zarr_monitor.py | 11 ++++--- ndsl/__init__.py | 34 ++++++++++++++++++--- ndsl/buffer.py | 3 +- ndsl/checkpointer/__init__.py | 9 ------ ndsl/checkpointer/null.py | 2 +- ndsl/checkpointer/snapshots.py | 3 +- ndsl/checkpointer/thresholds.py | 4 +-- ndsl/checkpointer/validation.py | 10 ++++-- ndsl/comm/__init__.py | 16 ++-------- ndsl/comm/boundary.py | 4 +-- ndsl/dsl/__init__.py | 11 ++----- ndsl/dsl/caches/__init__.py | 2 +- ndsl/dsl/dace/__init__.py | 11 +++++-- ndsl/dsl/stencil.py | 2 +- ndsl/grid/__init__.py | 10 ------ ndsl/grid/generation.py | 25 ++++++++------- ndsl/grid/geometry.py | 5 ++- ndsl/grid/global_setup.py | 7 ++--- ndsl/grid/helper.py | 3 +- ndsl/halo/__init__.py | 8 ++--- ndsl/initialization/__init__.py | 2 +- ndsl/initialization/allocator.py | 8 ++--- ndsl/monitor/__init__.py | 1 - ndsl/monitor/convert.py | 2 +- ndsl/monitor/netcdf_monitor.py | 9 +++--- ndsl/performance/__init__.py | 8 ----- ndsl/performance/collector.py | 3 +- ndsl/performance/config.py | 5 ++- ndsl/stencils/__init__.py | 4 --- ndsl/stencils/c2l_ord.py | 2 +- ndsl/stencils/testing/__init__.py | 16 ++-------- ndsl/stencils/testing/conftest.py | 3 +- ndsl/stencils/testing/parallel_translate.py | 6 ++-- ndsl/stencils/testing/savepoint.py | 2 +- tests/checkpointer/test_snapshot.py | 2 +- tests/checkpointer/test_thresholds.py | 2 +- tests/checkpointer/test_validation.py | 7 +++-- tests/dsl/__init__.py | 4 +-- tests/dsl/test_caches.py | 9 +++--- tests/dsl/test_compilation_config.py | 10 ++++-- tests/dsl/test_dace_config.py | 10 ++---- tests/dsl/test_skip_passes.py | 6 ++-- tests/dsl/test_stencil.py | 7 +---- tests/dsl/test_stencil_config.py | 3 +- tests/dsl/test_stencil_factory.py | 13 ++++---- tests/dsl/test_stencil_wrapper.py | 17 +++++++---- tests/mpi/test_mpi_halo_update.py | 12 +++++--- tests/mpi/test_mpi_mock.py | 3 +- tests/quantity/test_boundary.py | 2 +- tests/quantity/test_deepcopy.py | 2 +- tests/quantity/test_quantity.py | 2 +- tests/quantity/test_storage.py | 2 +- tests/quantity/test_transpose.py | 2 +- tests/quantity/test_view.py | 2 +- tests/test_caching_comm.py | 16 ++++++---- tests/test_cube_scatter_gather.py | 13 +++++--- tests/test_decomposition.py | 5 ++- tests/test_dimension_sizer.py | 3 +- tests/test_g2g_communication.py | 13 +++++--- tests/test_halo_data_transformer.py | 3 +- tests/test_halo_update.py | 19 +++++++----- tests/test_halo_update_ranks.py | 13 +++++--- tests/test_legacy_restart.py | 11 ++++--- tests/test_local_comm.py | 2 +- tests/test_netcdf_monitor.py | 13 +++++--- tests/test_null_comm.py | 9 ++++-- tests/test_partitioner.py | 4 +-- tests/test_partitioner_boundaries.py | 7 ++--- tests/test_sync_shared_boundary.py | 13 +++++--- tests/test_tile_scatter.py | 5 +-- tests/test_tile_scatter_gather.py | 5 +-- tests/test_timer.py | 2 +- tests/test_zarr_monitor.py | 11 ++++--- 74 files changed, 264 insertions(+), 273 deletions(-) diff --git a/examples/mpi/global_timings.py b/examples/mpi/global_timings.py index 0921acd..9e3ecdb 100644 --- a/examples/mpi/global_timings.py +++ b/examples/mpi/global_timings.py @@ -3,7 +3,7 @@ import numpy as np from mpi4py import MPI -from ndsl.performance.timer import Timer +from ndsl import Timer @contextlib.contextmanager diff --git a/examples/mpi/zarr_monitor.py b/examples/mpi/zarr_monitor.py index 20d418f..0c089af 100644 --- a/examples/mpi/zarr_monitor.py +++ b/examples/mpi/zarr_monitor.py @@ -5,11 +5,14 @@ import zarr from mpi4py import MPI -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSpherePartitioner, + QuantityFactory, + SubtileGridSizer, + TilePartitioner, + ZarrMonitor, +) from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.initialization.allocator import QuantityFactory -from ndsl.initialization.sizer import SubtileGridSizer -from ndsl.monitor import ZarrMonitor OUTPUT_PATH = "output/zarr_monitor.zarr" diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 89d35b8..06a77ac 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,7 +1,31 @@ -from .constants import ConstantVersions +from .checkpointer import SnapshotCheckpointer +from .comm import ( + CachingCommReader, + CachingCommWriter, + ConcurrencyError, + CubedSphereCommunicator, + CubedSpherePartitioner, + LocalComm, + NullComm, + TileCommunicator, + TilePartitioner, +) +from .dsl import ( + CompareToNumpyStencil, + CompilationConfig, + DaceConfig, + DaCeOrchestration, + FrozenStencil, + GridIndexing, + RunMode, + StencilConfig, + StencilFactory, +) from .exceptions import OutOfBoundsError +from .halo import HaloDataTransformer, HaloExchangeSpec, HaloUpdater +from .initialization import QuantityFactory, SubtileGridSizer from .logging import ndsl_log -from .optional_imports import RaiseWhenAccessed -from .types import Allocator, AsyncRequest, NumpyModule -from .units import UnitsError -from .utils import MetaEnumStr +from .monitor import NetCDFMonitor, ZarrMonitor +from .performance import NullTimer, Timer +from .quantity import Quantity, QuantityHaloSpec +from .testing import DummyComm diff --git a/ndsl/buffer.py b/ndsl/buffer.py index bc829e7..05cd643 100644 --- a/ndsl/buffer.py +++ b/ndsl/buffer.py @@ -4,6 +4,7 @@ import numpy as np from numpy.lib.index_tricks import IndexExpression +from ndsl.performance.timer import NullTimer, Timer from ndsl.types import Allocator from ndsl.utils import ( device_synchronize, @@ -12,8 +13,6 @@ safe_mpi_allocate, ) -from .performance.timer import NullTimer, Timer - BufferKey = Tuple[Callable, Iterable[int], type] BUFFER_CACHE: Dict[BufferKey, List["Buffer"]] = {} diff --git a/ndsl/checkpointer/__init__.py b/ndsl/checkpointer/__init__.py index a51a4d9..46d32a6 100644 --- a/ndsl/checkpointer/__init__.py +++ b/ndsl/checkpointer/__init__.py @@ -1,10 +1 @@ -from .base import Checkpointer -from .null import NullCheckpointer from .snapshots import SnapshotCheckpointer -from .thresholds import ( - InsufficientTrialsError, - SavepointThresholds, - Threshold, - ThresholdCalibrationCheckpointer, -) -from .validation import ValidationCheckpointer diff --git a/ndsl/checkpointer/null.py b/ndsl/checkpointer/null.py index e707d58..fbc7875 100644 --- a/ndsl/checkpointer/null.py +++ b/ndsl/checkpointer/null.py @@ -1,4 +1,4 @@ -from .base import Checkpointer +from ndsl.checkpointer.base import Checkpointer class NullCheckpointer(Checkpointer): diff --git a/ndsl/checkpointer/snapshots.py b/ndsl/checkpointer/snapshots.py index 11b7b89..aa806b2 100644 --- a/ndsl/checkpointer/snapshots.py +++ b/ndsl/checkpointer/snapshots.py @@ -2,11 +2,10 @@ import numpy as np +from ndsl.checkpointer.base import Checkpointer from ndsl.optional_imports import cupy as cp from ndsl.optional_imports import xarray as xr -from .base import Checkpointer - def make_dims(savepoint_dim, label, data_list): """ diff --git a/ndsl/checkpointer/thresholds.py b/ndsl/checkpointer/thresholds.py index 86133a8..ded73b3 100644 --- a/ndsl/checkpointer/thresholds.py +++ b/ndsl/checkpointer/thresholds.py @@ -5,8 +5,8 @@ import numpy as np -from ..quantity import Quantity -from .base import Checkpointer +from ndsl.checkpointer.base import Checkpointer +from ndsl.quantity import Quantity try: diff --git a/ndsl/checkpointer/validation.py b/ndsl/checkpointer/validation.py index 360c6a1..8af1131 100644 --- a/ndsl/checkpointer/validation.py +++ b/ndsl/checkpointer/validation.py @@ -5,11 +5,15 @@ import numpy as np +from ndsl.checkpointer.base import Checkpointer +from ndsl.checkpointer.thresholds import ( + ArrayLike, + SavepointName, + SavepointThresholds, + cast_to_ndarray, +) from ndsl.optional_imports import xarray as xr -from .base import Checkpointer -from .thresholds import ArrayLike, SavepointName, SavepointThresholds, cast_to_ndarray - def _clip_pace_array_to_target( array: np.ndarray, target_shape: Tuple[int, ...] diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index 92675c0..0e86fe0 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -1,15 +1,5 @@ -from .boundary import SimpleBoundary -from .caching_comm import ( - CachingCommData, - CachingCommReader, - CachingCommWriter, - CachingRequestReader, - CachingRequestWriter, - NullRequest, -) -from .comm_abc import Comm, Request +from .caching_comm import CachingCommReader, CachingCommWriter from .communicator import CubedSphereCommunicator, TileCommunicator -from .local_comm import AsyncResult, ConcurrencyError, LocalComm -from .mpi import MPIComm -from .null_comm import NullAsyncResult, NullComm +from .local_comm import ConcurrencyError, LocalComm +from .null_comm import NullComm from .partitioner import CubedSpherePartitioner, TilePartitioner diff --git a/ndsl/comm/boundary.py b/ndsl/comm/boundary.py index 540f025..020798c 100644 --- a/ndsl/comm/boundary.py +++ b/ndsl/comm/boundary.py @@ -1,8 +1,8 @@ import dataclasses from typing import Tuple -from ..quantity import Quantity, QuantityHaloSpec -from ._boundary_utils import get_boundary_slice +from ndsl.comm._boundary_utils import get_boundary_slice +from ndsl.quantity import Quantity, QuantityHaloSpec @dataclasses.dataclass diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index c013460..1331294 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -3,15 +3,8 @@ from ndsl.comm.mpi import MPI from . import dace -from .dace.dace_config import DaceConfig, DaCeOrchestration -from .dace.orchestration import orchestrate, orchestrate_function -from .stencil import ( - CompareToNumpyStencil, - FrozenStencil, - GridIndexing, - StencilFactory, - TimingCollector, -) +from .dace import DaceConfig, DaCeOrchestration, orchestrate, orchestrate_function +from .stencil import CompareToNumpyStencil, FrozenStencil, GridIndexing, StencilFactory from .stencil_config import CompilationConfig, RunMode, StencilConfig diff --git a/ndsl/dsl/caches/__init__.py b/ndsl/dsl/caches/__init__.py index 3417ff0..4fbb20e 100644 --- a/ndsl/dsl/caches/__init__.py +++ b/ndsl/dsl/caches/__init__.py @@ -1 +1 @@ -from .codepath import FV3CodePath \ No newline at end of file +from .codepath import FV3CodePath diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index d2c64b0..0f1edcb 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -1,2 +1,9 @@ -from .dace_config import DaceConfig -from .orchestration import orchestrate +from .dace_config import DaceConfig, DaCeOrchestration, FrozenCompiledSDFG +from .orchestration import ( + _LazyComputepathFunction, + _LazyComputepathMethod, + orchestrate, + orchestrate_function, +) +from .utils import ArrayReport, DaCeProgress, MaxBandwithBenchmarkProgram, StorageReport +from .wrapped_halo_exchange import WrappedHaloUpdater diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index a7a4f94..77efd67 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -31,7 +31,7 @@ from ndsl.dsl.dace.orchestration import SDFGConvertible from ndsl.dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from ndsl.dsl.typing import Float, Index3D, cast_to_index3d -from ndsl.initialization import GridSizer, SubtileGridSizer +from ndsl.initialization.sizer import GridSizer, SubtileGridSizer from ndsl.quantity import Quantity diff --git a/ndsl/grid/__init__.py b/ndsl/grid/__init__.py index 5e48874..a7692a8 100644 --- a/ndsl/grid/__init__.py +++ b/ndsl/grid/__init__.py @@ -1,7 +1,6 @@ # flake8: noqa: F401 from .eta import set_hybrid_pressure_coefficients -from .generation import GridDefinitions, MetricTerms from .gnomonic import ( great_circle_distance_along_axis, great_circle_distance_lon_lat, @@ -11,13 +10,4 @@ xyz_midpoint, xyz_to_lon_lat, ) -from .helper import ( - AngleGridData, - ContravariantGridData, - DampingCoefficients, - DriverGridData, - GridData, - HorizontalGridData, - VerticalGridData, -) from .stretch_transformation import direct_transform diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index b38dbf2..12275d7 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -21,17 +21,7 @@ from ndsl.dsl.stencil import GridIndexing from ndsl.dsl.typing import Float from ndsl.grid import eta -from ndsl.initialization.allocator import QuantityFactory -from ndsl.initialization.sizer import SubtileGridSizer -from ndsl.quantity import Quantity -from ndsl.stencils.corners import ( - fill_corners_2d, - fill_corners_agrid, - fill_corners_cgrid, - fill_corners_dgrid, -) - -from .geometry import ( +from ndsl.grid.geometry import ( calc_unit_vector_south, calc_unit_vector_west, calculate_divg_del6, @@ -47,7 +37,7 @@ supergrid_corner_fix, unit_vector_lonlat, ) -from .gnomonic import ( +from ndsl.grid.gnomonic import ( get_area, great_circle_distance_along_axis, local_gnomonic_ed, @@ -59,7 +49,16 @@ set_tile_border_dxc, set_tile_border_dyc, ) -from .mirror import mirror_grid +from ndsl.grid.mirror import mirror_grid +from ndsl.initialization.allocator import QuantityFactory +from ndsl.initialization.sizer import SubtileGridSizer +from ndsl.quantity import Quantity +from ndsl.stencils.corners import ( + fill_corners_2d, + fill_corners_agrid, + fill_corners_cgrid, + fill_corners_dgrid, +) # TODO: when every environment in python3.8, remove diff --git a/ndsl/grid/geometry.py b/ndsl/grid/geometry.py index 5b2ec02..804be0f 100644 --- a/ndsl/grid/geometry.py +++ b/ndsl/grid/geometry.py @@ -1,7 +1,5 @@ from ndsl.comm.partitioner import TilePartitioner -from ndsl.quantity import Quantity - -from .gnomonic import ( +from ndsl.grid.gnomonic import ( get_lonlat_vect, get_unit_vector_direction, great_circle_distance_lon_lat, @@ -10,6 +8,7 @@ spherical_cos, xyz_midpoint, ) +from ndsl.quantity import Quantity def get_center_vector( diff --git a/ndsl/grid/global_setup.py b/ndsl/grid/global_setup.py index 46c0c90..a0237ec 100644 --- a/ndsl/grid/global_setup.py +++ b/ndsl/grid/global_setup.py @@ -1,16 +1,15 @@ import math from ndsl.constants import PI, RADIUS - -from .generation import MetricTerms -from .gnomonic import ( +from ndsl.grid.generation import MetricTerms +from ndsl.grid.gnomonic import ( _cart_to_latlon, _check_shapes, _latlon2xyz, _mirror_latlon, symm_ed, ) -from .mirror import _rot_3d +from ndsl.grid.mirror import _rot_3d def gnomonic_grid(grid_type: int, lon, lat, np): diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index 89a8c0e..ee97a6b 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -13,11 +13,10 @@ import ndsl.constants as constants from ndsl.constants import Z_DIM, Z_INTERFACE_DIM from ndsl.filesystem import get_fs +from ndsl.grid.generation import MetricTerms from ndsl.initialization import QuantityFactory from ndsl.quantity import Quantity -from .generation import MetricTerms - @dataclasses.dataclass(frozen=True) class DampingCoefficients: diff --git a/ndsl/halo/__init__.py b/ndsl/halo/__init__.py index b32a693..823bd22 100644 --- a/ndsl/halo/__init__.py +++ b/ndsl/halo/__init__.py @@ -1,6 +1,2 @@ -from .data_transformer import ( - HaloDataTransformerCPU, - HaloDataTransformerGPU, - HaloExchangeSpec, -) -from .updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater +from .data_transformer import HaloDataTransformer, HaloExchangeSpec +from .updater import HaloUpdater diff --git a/ndsl/initialization/__init__.py b/ndsl/initialization/__init__.py index fe15db8..50fd2f8 100644 --- a/ndsl/initialization/__init__.py +++ b/ndsl/initialization/__init__.py @@ -1,2 +1,2 @@ from .allocator import QuantityFactory -from .sizer import GridSizer, SubtileGridSizer +from .sizer import SubtileGridSizer diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index cbbd78d..5320e4c 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -2,10 +2,10 @@ import numpy as np -from ..constants import SPATIAL_DIMS -from ..optional_imports import gt4py -from ..quantity import Quantity, QuantityHaloSpec -from .sizer import GridSizer +from ndsl.constants import SPATIAL_DIMS +from ndsl.initialization.sizer import GridSizer +from ndsl.optional_imports import gt4py +from ndsl.quantity import Quantity, QuantityHaloSpec class StorageNumpy: diff --git a/ndsl/monitor/__init__.py b/ndsl/monitor/__init__.py index a0c7e03..26b38cc 100644 --- a/ndsl/monitor/__init__.py +++ b/ndsl/monitor/__init__.py @@ -1,3 +1,2 @@ from .netcdf_monitor import NetCDFMonitor -from .protocol import Monitor from .zarr_monitor import ZarrMonitor diff --git a/ndsl/monitor/convert.py b/ndsl/monitor/convert.py index ad05b27..a62af01 100644 --- a/ndsl/monitor/convert.py +++ b/ndsl/monitor/convert.py @@ -1,6 +1,6 @@ import numpy as np -from ..optional_imports import cupy +from ndsl.optional_imports import cupy def to_numpy(array, dtype=None) -> np.ndarray: diff --git a/ndsl/monitor/netcdf_monitor.py b/ndsl/monitor/netcdf_monitor.py index 3d950ae..8a0b96f 100644 --- a/ndsl/monitor/netcdf_monitor.py +++ b/ndsl/monitor/netcdf_monitor.py @@ -6,12 +6,11 @@ import numpy as np from ndsl.comm.communicator import Communicator +from ndsl.filesystem import get_fs +from ndsl.logging import ndsl_log +from ndsl.monitor.convert import to_numpy from ndsl.optional_imports import xarray as xr - -from ..filesystem import get_fs -from ..logging import ndsl_log -from ..quantity import Quantity -from .convert import to_numpy +from ndsl.quantity import Quantity class _TimeChunkedVariable: diff --git a/ndsl/performance/__init__.py b/ndsl/performance/__init__.py index fd79608..b0a65f3 100644 --- a/ndsl/performance/__init__.py +++ b/ndsl/performance/__init__.py @@ -1,9 +1 @@ -from .collector import ( - AbstractPerformanceCollector, - NullPerformanceCollector, - PerformanceCollector, -) -from .config import PerformanceConfig -from .profiler import NullProfiler, Profiler -from .report import Experiment, Report, TimeReport from .timer import NullTimer, Timer diff --git a/ndsl/performance/collector.py b/ndsl/performance/collector.py index 4df0440..8ec7a81 100644 --- a/ndsl/performance/collector.py +++ b/ndsl/performance/collector.py @@ -11,6 +11,7 @@ from ndsl.performance.report import ( Report, TimeReport, + collect_data_and_write_to_file, collect_keys_from_data, gather_hit_counts, get_experiment_info, @@ -19,8 +20,6 @@ from ndsl.performance.timer import NullTimer, Timer from ndsl.utils import GPU_AVAILABLE -from .report import collect_data_and_write_to_file - class AbstractPerformanceCollector(Protocol): total_timer: Timer diff --git a/ndsl/performance/config.py b/ndsl/performance/config.py index fa1ce8e..99e6109 100644 --- a/ndsl/performance/config.py +++ b/ndsl/performance/config.py @@ -1,13 +1,12 @@ import dataclasses from ndsl.comm.comm_abc import Comm -from ndsl.performance.profiler import NullProfiler, Profiler - -from .collector import ( +from ndsl.performance.collector import ( AbstractPerformanceCollector, NullPerformanceCollector, PerformanceCollector, ) +from ndsl.performance.profiler import NullProfiler, Profiler @dataclasses.dataclass diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index 641e032..d3ec452 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1,5 +1 @@ -from .c2l_ord import CubedToLatLon -from .corners import CopyCorners, CopyCornersXY, FillCornersBGrid - - __version__ = "0.2.0" diff --git a/ndsl/stencils/c2l_ord.py b/ndsl/stencils/c2l_ord.py index 4e18c1f..67f2b5a 100644 --- a/ndsl/stencils/c2l_ord.py +++ b/ndsl/stencils/c2l_ord.py @@ -13,7 +13,7 @@ from ndsl.dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater from ndsl.dsl.stencil import StencilFactory from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ -from ndsl.grid import GridData +from ndsl.grid.helper import GridData from ndsl.initialization.allocator import QuantityFactory diff --git a/ndsl/stencils/testing/__init__.py b/ndsl/stencils/testing/__init__.py index d66176c..61a6483 100644 --- a/ndsl/stencils/testing/__init__.py +++ b/ndsl/stencils/testing/__init__.py @@ -1,16 +1,4 @@ from . import parallel_translate, translate -from .parallel_translate import ( - ParallelTranslate, - ParallelTranslate2Py, - ParallelTranslate2PyState, - ParallelTranslateBaseSlicing, - ParallelTranslateGrid, -) -from .savepoint import SavepointCase, Translate, dataset_to_dict +from .savepoint import dataset_to_dict from .temporaries import assert_same_temporaries, copy_temporaries -from .translate import ( - TranslateFortranData2Py, - TranslateGrid, - pad_field_in_j, - read_serialized_data, -) +from .translate import pad_field_in_j, read_serialized_data diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index 65ac8c0..d000e1f 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -17,8 +17,9 @@ from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.namelist import Namelist -from ndsl.stencils.testing import ParallelTranslate, TranslateGrid +from ndsl.stencils.testing.parallel_translate import ParallelTranslate from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict +from ndsl.stencils.testing.translate import TranslateGrid @pytest.fixture() diff --git a/ndsl/stencils/testing/parallel_translate.py b/ndsl/stencils/testing/parallel_translate.py index 7481a41..e066999 100644 --- a/ndsl/stencils/testing/parallel_translate.py +++ b/ndsl/stencils/testing/parallel_translate.py @@ -8,8 +8,10 @@ from ndsl.constants import HORIZONTAL_DIMS, N_HALO_DEFAULT, X_DIMS, Y_DIMS from ndsl.dsl import gt4py_utils as utils from ndsl.quantity import Quantity - -from .translate import TranslateFortranData2Py, read_serialized_data +from ndsl.stencils.testing.translate import ( + TranslateFortranData2Py, + read_serialized_data, +) class ParallelTranslate: diff --git a/ndsl/stencils/testing/savepoint.py b/ndsl/stencils/testing/savepoint.py index 04d01e2..77d7191 100644 --- a/ndsl/stencils/testing/savepoint.py +++ b/ndsl/stencils/testing/savepoint.py @@ -5,7 +5,7 @@ import numpy as np import xarray as xr -from .grid import Grid # type: ignore +from ndsl.stencils.testing.grid import Grid # type: ignore def dataset_to_dict(ds: xr.Dataset) -> Dict[str, Union[np.ndarray, float, int]]: diff --git a/tests/checkpointer/test_snapshot.py b/tests/checkpointer/test_snapshot.py index 89d368e..797b470 100644 --- a/tests/checkpointer/test_snapshot.py +++ b/tests/checkpointer/test_snapshot.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ndsl.checkpointer import SnapshotCheckpointer +from ndsl.checkpointer.snapshots import SnapshotCheckpointer from ndsl.optional_imports import xarray as xr diff --git a/tests/checkpointer/test_thresholds.py b/tests/checkpointer/test_thresholds.py index 8bf70b0..851a966 100644 --- a/tests/checkpointer/test_thresholds.py +++ b/tests/checkpointer/test_thresholds.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ndsl.checkpointer import ( +from ndsl.checkpointer.thresholds import ( InsufficientTrialsError, Threshold, ThresholdCalibrationCheckpointer, diff --git a/tests/checkpointer/test_validation.py b/tests/checkpointer/test_validation.py index 0c08d52..b696aca 100644 --- a/tests/checkpointer/test_validation.py +++ b/tests/checkpointer/test_validation.py @@ -4,8 +4,11 @@ import numpy as np import pytest -from ndsl.checkpointer import SavepointThresholds, Threshold, ValidationCheckpointer -from ndsl.checkpointer.validation import _clip_pace_array_to_target +from ndsl.checkpointer.thresholds import SavepointThresholds, Threshold +from ndsl.checkpointer.validation import ( + ValidationCheckpointer, + _clip_pace_array_to_target, +) from ndsl.optional_imports import xarray as xr diff --git a/tests/dsl/__init__.py b/tests/dsl/__init__.py index 10487c1..d5af6cf 100644 --- a/tests/dsl/__init__.py +++ b/tests/dsl/__init__.py @@ -1,2 +1,2 @@ -from .test_stencil_wrapper import MockFieldInfo -from .test_caches import OrchestratedProgam \ No newline at end of file +# from .test_caches import OrchestratedProgam +# from .test_stencil_wrapper import MockFieldInfo diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index a7218b0..f90a182 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -2,15 +2,16 @@ from gt4py.cartesian.gtscript import PARALLEL, Field, computation, interval from gt4py.storage import empty, ones -from ndsl.comm.mpi import MPI -from ndsl.dsl.dace import orchestrate -from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration -from ndsl.dsl.stencil import ( +from ndsl import ( CompilationConfig, + DaceConfig, + DaCeOrchestration, GridIndexing, StencilConfig, StencilFactory, ) +from ndsl.comm.mpi import MPI +from ndsl.dsl.dace import orchestrate def _make_storage( diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index 14b240a..62049d9 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -3,9 +3,13 @@ import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner -from ndsl.dsl.stencil import CompilationConfig, RunMode +from ndsl import ( + CompilationConfig, + CubedSphereCommunicator, + CubedSpherePartitioner, + RunMode, + TilePartitioner, +) def test_safety_checks(): diff --git a/tests/dsl/test_dace_config.py b/tests/dsl/test_dace_config.py index 0aca176..c044cb1 100644 --- a/tests/dsl/test_dace_config.py +++ b/tests/dsl/test_dace_config.py @@ -1,12 +1,8 @@ import unittest.mock -from ndsl.comm.communicator import CubedSpherePartitioner, TilePartitioner -from ndsl.dsl.dace.dace_config import DaceConfig, _determine_compiling_ranks -from ndsl.dsl.dace.orchestration import ( - DaCeOrchestration, - orchestrate, - orchestrate_function, -) +from ndsl import CubedSpherePartitioner, DaceConfig, DaCeOrchestration, TilePartitioner +from ndsl.dsl.dace.dace_config import _determine_compiling_ranks +from ndsl.dsl.dace.orchestration import orchestrate, orchestrate_function """ diff --git a/tests/dsl/test_skip_passes.py b/tests/dsl/test_skip_passes.py index c1f3a71..e0173b7 100644 --- a/tests/dsl/test_skip_passes.py +++ b/tests/dsl/test_skip_passes.py @@ -7,14 +7,14 @@ from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline from gt4py.cartesian.gtscript import PARALLEL, computation, interval -from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.dsl.dace.dace_config import DaceConfig -from ndsl.dsl.stencil import ( +from ndsl import ( CompilationConfig, + DaceConfig, GridIndexing, StencilConfig, StencilFactory, ) +from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.typing import FloatField diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index 18cb99a..180b7ba 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -1,12 +1,7 @@ from gt4py.cartesian.gtscript import PARALLEL, Field, computation, interval from gt4py.storage import empty, ones -from ndsl.dsl.stencil import ( - CompilationConfig, - GridIndexing, - StencilConfig, - StencilFactory, -) +from ndsl import CompilationConfig, GridIndexing, StencilConfig, StencilFactory def _make_storage( diff --git a/tests/dsl/test_stencil_config.py b/tests/dsl/test_stencil_config.py index 45891df..7e6b4da 100644 --- a/tests/dsl/test_stencil_config.py +++ b/tests/dsl/test_stencil_config.py @@ -1,7 +1,6 @@ import pytest -from ndsl.dsl.dace.dace_config import DaceConfig -from ndsl.dsl.stencil import CompilationConfig, StencilConfig +from ndsl import CompilationConfig, DaceConfig, StencilConfig @pytest.mark.parametrize("validate_args", [True, False]) diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index 364e5a3..756de95 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -2,17 +2,18 @@ import pytest from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region -from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.dsl.dace.dace_config import DaceConfig -from ndsl.dsl.gt4py_utils import make_storage_from_shape -from ndsl.dsl.stencil import ( +from ndsl import ( CompareToNumpyStencil, + CompilationConfig, + DaceConfig, FrozenStencil, GridIndexing, + StencilConfig, StencilFactory, - get_stencils_with_varied_bounds, ) -from ndsl.dsl.stencil_config import CompilationConfig, StencilConfig +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py_utils import make_storage_from_shape +from ndsl.dsl.stencil import get_stencils_with_varied_bounds from ndsl.dsl.typing import FloatField diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index ba3da53..cfe56de 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -6,12 +6,17 @@ import pytest from gt4py.cartesian.gtscript import PARALLEL, computation, interval -from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration +from ndsl import ( + CompilationConfig, + DaceConfig, + DaCeOrchestration, + FrozenStencil, + Quantity, + StencilConfig, +) from ndsl.dsl.gt4py_utils import make_storage_from_shape -from ndsl.dsl.stencil import FrozenStencil, _convert_quantities_to_storage -from ndsl.dsl.stencil_config import CompilationConfig, StencilConfig +from ndsl.dsl.stencil import _convert_quantities_to_storage from ndsl.dsl.typing import Float, FloatField -from ndsl.quantity import Quantity def get_stencil_config( @@ -280,14 +285,14 @@ def test_backend_options( "backend": "numpy", "rebuild": True, "format_source": False, - "name": "test_stencil_wrapper.copy_stencil", + "name": "tests.dsl.test_stencil_wrapper.copy_stencil", }, "cuda": { "backend": "cuda", "rebuild": True, "device_sync": False, "format_source": False, - "name": "test_stencil_wrapper.copy_stencil", + "name": "tests.dsl.test_stencil_wrapper.copy_stencil", }, } diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index 7343bf8..ab11b16 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -2,9 +2,13 @@ import pytest +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + Quantity, + TilePartitioner, +) from ndsl.comm._boundary_utils import get_boundary_slice -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.constants import ( BOUNDARY_TYPES, EDGE_BOUNDARY_TYPES, @@ -22,9 +26,7 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.quantity import Quantity - -from .mpi_comm import MPI +from tests.mpi.mpi_comm import MPI @pytest.fixture diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 4d4d24a..d099d76 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -3,8 +3,7 @@ from ndsl.comm.communicator import recv_buffer from ndsl.testing import ConcurrencyError, DummyComm - -from .mpi_comm import MPI +from tests.mpi.mpi_comm import MPI worker_function_list = [] diff --git a/tests/quantity/test_boundary.py b/tests/quantity/test_boundary.py index 42db16a..a4f8e81 100644 --- a/tests/quantity/test_boundary.py +++ b/tests/quantity/test_boundary.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from ndsl import Quantity from ndsl.comm._boundary_utils import _shift_boundary_slice, get_boundary_slice from ndsl.constants import ( EAST, @@ -12,7 +13,6 @@ Y_DIM, Z_DIM, ) -from ndsl.quantity import Quantity def boundary_data(quantity, boundary_type, n_points, interior=True): diff --git a/tests/quantity/test_deepcopy.py b/tests/quantity/test_deepcopy.py index c44ea39..a7b1564 100644 --- a/tests/quantity/test_deepcopy.py +++ b/tests/quantity/test_deepcopy.py @@ -3,7 +3,7 @@ import numpy as np -from ndsl.quantity import Quantity +from ndsl import Quantity def test_deepcopy_copy_is_editable_by_view(): diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 7d0d75f..a6de628 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -2,7 +2,7 @@ import pytest import ndsl.quantity as qty -from ndsl.quantity import Quantity +from ndsl import Quantity try: diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index 172d78d..2cdb8d4 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ndsl.quantity import Quantity +from ndsl import Quantity try: diff --git a/tests/quantity/test_transpose.py b/tests/quantity/test_transpose.py index be1569a..5e52727 100644 --- a/tests/quantity/test_transpose.py +++ b/tests/quantity/test_transpose.py @@ -1,5 +1,6 @@ import pytest +from ndsl import Quantity from ndsl.constants import ( X_DIM, X_DIMS, @@ -10,7 +11,6 @@ Z_DIM, Z_DIMS, ) -from ndsl.quantity import Quantity @pytest.fixture diff --git a/tests/quantity/test_view.py b/tests/quantity/test_view.py index a1ba5e5..7324509 100644 --- a/tests/quantity/test_view.py +++ b/tests/quantity/test_view.py @@ -1,8 +1,8 @@ import numpy as np import pytest +from ndsl import Quantity from ndsl.constants import X_DIM, Y_DIM -from ndsl.quantity import Quantity @pytest.fixture diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index 6481315..b28eba1 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -4,13 +4,17 @@ import numpy as np -from ndsl.comm.caching_comm import CachingCommReader, CachingCommWriter -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.local_comm import LocalComm -from ndsl.comm.null_comm import NullComm -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CachingCommReader, + CachingCommWriter, + CubedSphereCommunicator, + CubedSpherePartitioner, + LocalComm, + NullComm, + Quantity, + TilePartitioner, +) from ndsl.constants import X_DIM, Y_DIM -from ndsl.quantity import Quantity def test_halo_update_integration(): diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index 855422b..7966142 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -3,8 +3,14 @@ import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, + Timer, +) from ndsl.constants import ( HORIZONTAL_DIMS, TILE_DIM, @@ -15,9 +21,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm try: diff --git a/tests/test_decomposition.py b/tests/test_decomposition.py index de4d40c..bf7363e 100644 --- a/tests/test_decomposition.py +++ b/tests/test_decomposition.py @@ -4,6 +4,7 @@ import pytest +from ndsl import CubedSpherePartitioner, TilePartitioner from ndsl.comm.decomposition import ( block_waiting_for_compilation, build_cache_path, @@ -11,9 +12,7 @@ determine_rank_is_compiling, unblock_waiting_tiles, ) -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner - -from .mpi.mpi_comm import MPI +from tests.mpi.mpi_comm import MPI @pytest.mark.parametrize( diff --git a/tests/test_dimension_sizer.py b/tests/test_dimension_sizer.py index 3f2cdde..a401e69 100644 --- a/tests/test_dimension_sizer.py +++ b/tests/test_dimension_sizer.py @@ -2,6 +2,7 @@ import pytest +from ndsl import QuantityFactory, SubtileGridSizer from ndsl.constants import ( N_HALO_DEFAULT, X_DIM, @@ -11,8 +12,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.initialization.allocator import QuantityFactory -from ndsl.initialization.sizer import SubtileGridSizer @pytest.fixture(params=[48, 96]) diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index a05b0fb..28f1af7 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -8,12 +8,15 @@ import numpy as np import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, + Timer, +) from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm try: diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index 10d7f99..7e1b9f0 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from ndsl import HaloDataTransformer, HaloExchangeSpec, Quantity, QuantityHaloSpec from ndsl.buffer import Buffer from ndsl.comm import _boundary_utils from ndsl.constants import ( @@ -22,9 +23,7 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.halo.data_transformer import HaloDataTransformer, HaloExchangeSpec from ndsl.halo.rotate import rotate_scalar_data, rotate_vector_data -from ndsl.quantity import Quantity, QuantityHaloSpec @pytest.fixture diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index c17903c..d0536b2 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -3,10 +3,20 @@ import pytest +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + HaloUpdater, + OutOfBoundsError, + Quantity, + QuantityHaloSpec, + TileCommunicator, + TilePartitioner, + Timer, +) from ndsl.buffer import BUFFER_CACHE from ndsl.comm._boundary_utils import get_boundary_slice -from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.constants import ( BOUNDARY_TYPES, EDGE_BOUNDARY_TYPES, @@ -24,11 +34,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.exceptions import OutOfBoundsError -from ndsl.halo.updater import HaloUpdater -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity, QuantityHaloSpec -from ndsl.testing import DummyComm @pytest.fixture diff --git a/tests/test_halo_update_ranks.py b/tests/test_halo_update_ranks.py index e33f0d6..8ec77cc 100644 --- a/tests/test_halo_update_ranks.py +++ b/tests/test_halo_update_ranks.py @@ -1,7 +1,13 @@ import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, + Timer, +) from ndsl.constants import ( X_DIM, X_INTERFACE_DIM, @@ -10,9 +16,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm @pytest.fixture diff --git a/tests/test_legacy_restart.py b/tests/test_legacy_restart.py index 3728b94..2034c04 100644 --- a/tests/test_legacy_restart.py +++ b/tests/test_legacy_restart.py @@ -12,17 +12,20 @@ import pytest import ndsl.io as io -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, +) from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM -from ndsl.quantity import Quantity from ndsl.restart._legacy_restart import ( _apply_dims, get_rank_suffix, map_keys, open_restart, ) -from ndsl.testing import DummyComm requires_xarray = pytest.mark.skipif(xr is None, reason="xarray is not installed") diff --git a/tests/test_local_comm.py b/tests/test_local_comm.py index 0b8072a..c549ee2 100644 --- a/tests/test_local_comm.py +++ b/tests/test_local_comm.py @@ -1,7 +1,7 @@ import numpy import pytest -from ndsl.comm.local_comm import LocalComm +from ndsl import LocalComm @pytest.fixture diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index 6e20537..7a21dd7 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -6,12 +6,15 @@ import numpy as np import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner -from ndsl.monitor import NetCDFMonitor +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + NetCDFMonitor, + Quantity, + TilePartitioner, +) from ndsl.optional_imports import xarray as xr -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm requires_xarray = pytest.mark.skipif(xr is None, reason="xarray is not installed") diff --git a/tests/test_null_comm.py b/tests/test_null_comm.py index 0a38476..74065f6 100644 --- a/tests/test_null_comm.py +++ b/tests/test_null_comm.py @@ -1,6 +1,9 @@ -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.null_comm import NullComm -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + NullComm, + TilePartitioner, +) def test_can_create_cube_communicator(): diff --git a/tests/test_partitioner.py b/tests/test_partitioner.py index 99f1fb6..6bd15ed 100644 --- a/tests/test_partitioner.py +++ b/tests/test_partitioner.py @@ -1,9 +1,8 @@ import numpy as np import pytest +from ndsl import CubedSpherePartitioner, Quantity, TilePartitioner from ndsl.comm.partitioner import ( - CubedSpherePartitioner, - TilePartitioner, _subtile_extents_from_tile_metadata, get_tile_index, get_tile_number, @@ -18,7 +17,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.quantity import Quantity rank_list = [] diff --git a/tests/test_partitioner_boundaries.py b/tests/test_partitioner_boundaries.py index 71574ee..eff528f 100644 --- a/tests/test_partitioner_boundaries.py +++ b/tests/test_partitioner_boundaries.py @@ -1,10 +1,7 @@ import pytest -from ndsl.comm.partitioner import ( - CubedSpherePartitioner, - TilePartitioner, - rotate_subtile_rank, -) +from ndsl import CubedSpherePartitioner, TilePartitioner +from ndsl.comm.partitioner import rotate_subtile_rank from ndsl.constants import ( BOUNDARY_TYPES, CORNER_BOUNDARY_TYPES, diff --git a/tests/test_sync_shared_boundary.py b/tests/test_sync_shared_boundary.py index 3771100..3e0930a 100644 --- a/tests/test_sync_shared_boundary.py +++ b/tests/test_sync_shared_boundary.py @@ -1,11 +1,14 @@ import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, + Timer, +) from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm @pytest.fixture diff --git a/tests/test_tile_scatter.py b/tests/test_tile_scatter.py index 26aa3d0..d768bb1 100644 --- a/tests/test_tile_scatter.py +++ b/tests/test_tile_scatter.py @@ -1,10 +1,7 @@ import pytest -from ndsl.comm.communicator import TileCommunicator -from ndsl.comm.partitioner import TilePartitioner +from ndsl import DummyComm, Quantity, TileCommunicator, TilePartitioner from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm def rank_scatter_results(communicator_list, quantity): diff --git a/tests/test_tile_scatter_gather.py b/tests/test_tile_scatter_gather.py index 2669f5f..6d56dd6 100644 --- a/tests/test_tile_scatter_gather.py +++ b/tests/test_tile_scatter_gather.py @@ -3,8 +3,7 @@ import pytest -from ndsl.comm.communicator import TileCommunicator -from ndsl.comm.partitioner import TilePartitioner +from ndsl import DummyComm, Quantity, TileCommunicator, TilePartitioner from ndsl.constants import ( HORIZONTAL_DIMS, X_DIM, @@ -14,8 +13,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm try: diff --git a/tests/test_timer.py b/tests/test_timer.py index 0b0cd4b..213a487 100644 --- a/tests/test_timer.py +++ b/tests/test_timer.py @@ -2,7 +2,7 @@ import pytest -from ndsl.performance.timer import NullTimer, Timer +from ndsl import NullTimer, Timer @pytest.fixture diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index fbe9040..b608ec0 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -12,7 +12,13 @@ import cftime import pytest -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, + ZarrMonitor, +) from ndsl.constants import ( X_DIM, X_DIMS, @@ -22,11 +28,8 @@ Y_INTERFACE_DIM, Z_DIM, ) -from ndsl.monitor import ZarrMonitor from ndsl.monitor.zarr_monitor import array_chunks, get_calendar from ndsl.optional_imports import xarray as xr -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm requires_zarr = pytest.mark.skipif(zarr is None, reason="zarr is not installed") From 7d333b55d0dba54f2b5295f58c8b1608646c3d0d Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Tue, 27 Feb 2024 12:06:01 -0500 Subject: [PATCH 03/14] Changes to missed exposed modules and clean-up of comments --- ndsl/dsl/caches/__init__.py | 1 - ndsl/dsl/dace/__init__.py | 11 ++--------- tests/dsl/__init__.py | 2 -- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/ndsl/dsl/caches/__init__.py b/ndsl/dsl/caches/__init__.py index 4fbb20e..e69de29 100644 --- a/ndsl/dsl/caches/__init__.py +++ b/ndsl/dsl/caches/__init__.py @@ -1 +0,0 @@ -from .codepath import FV3CodePath diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index 0f1edcb..aa19a3d 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -1,9 +1,2 @@ -from .dace_config import DaceConfig, DaCeOrchestration, FrozenCompiledSDFG -from .orchestration import ( - _LazyComputepathFunction, - _LazyComputepathMethod, - orchestrate, - orchestrate_function, -) -from .utils import ArrayReport, DaCeProgress, MaxBandwithBenchmarkProgram, StorageReport -from .wrapped_halo_exchange import WrappedHaloUpdater +from .dace_config import DaceConfig, DaCeOrchestration +from .orchestration import orchestrate, orchestrate_function diff --git a/tests/dsl/__init__.py b/tests/dsl/__init__.py index d5af6cf..e69de29 100644 --- a/tests/dsl/__init__.py +++ b/tests/dsl/__init__.py @@ -1,2 +0,0 @@ -# from .test_caches import OrchestratedProgam -# from .test_stencil_wrapper import MockFieldInfo From fba0d95674fc9b1849bd91af5ea0e249bd3f8847 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Tue, 27 Feb 2024 15:47:23 -0500 Subject: [PATCH 04/14] Adding exposure for modules needed by external modules pyFV3 and pySHiELD --- ndsl/__init__.py | 22 +++++++++++++++++++--- ndsl/checkpointer/__init__.py | 2 ++ ndsl/comm/__init__.py | 4 +++- ndsl/dsl/__init__.py | 8 +++++++- ndsl/dsl/dace/__init__.py | 1 + ndsl/initialization/__init__.py | 2 +- ndsl/performance/__init__.py | 1 + ndsl/stencils/__init__.py | 13 +++++++++++++ ndsl/stencils/testing/__init__.py | 15 ++++++++++++++- ndsl/stencils/testing/grid.py | 4 ++-- tests/checkpointer/test_snapshot.py | 2 +- 11 files changed, 64 insertions(+), 10 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 06a77ac..221cf1b 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,11 +1,14 @@ -from .checkpointer import SnapshotCheckpointer +from .checkpointer import Checkpointer, NullCheckpointer, SnapshotCheckpointer from .comm import ( CachingCommReader, CachingCommWriter, + Comm, + Communicator, ConcurrencyError, CubedSphereCommunicator, CubedSpherePartitioner, LocalComm, + MPIComm, NullComm, TileCommunicator, TilePartitioner, @@ -20,12 +23,25 @@ RunMode, StencilConfig, StencilFactory, + WrappedHaloUpdater, ) from .exceptions import OutOfBoundsError from .halo import HaloDataTransformer, HaloExchangeSpec, HaloUpdater -from .initialization import QuantityFactory, SubtileGridSizer +from .initialization import GridSizer, QuantityFactory, SubtileGridSizer from .logging import ndsl_log from .monitor import NetCDFMonitor, ZarrMonitor -from .performance import NullTimer, Timer +from .performance import NullTimer, PerformanceCollector, Timer from .quantity import Quantity, QuantityHaloSpec +from .stencils import ( + CubedToLatLon, + Grid, + ParallelTranslate, + ParallelTranslate2Py, + ParallelTranslate2PyState, + ParallelTranslateBaseSlicing, + ParallelTranslateGrid, + TranslateFortranData2Py, + TranslateGrid, +) from .testing import DummyComm +from .utils import MetaEnumStr diff --git a/ndsl/checkpointer/__init__.py b/ndsl/checkpointer/__init__.py index 46d32a6..d24936c 100644 --- a/ndsl/checkpointer/__init__.py +++ b/ndsl/checkpointer/__init__.py @@ -1 +1,3 @@ +from .base import Checkpointer +from .null import NullCheckpointer from .snapshots import SnapshotCheckpointer diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index 0e86fe0..31319c7 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -1,5 +1,7 @@ from .caching_comm import CachingCommReader, CachingCommWriter -from .communicator import CubedSphereCommunicator, TileCommunicator +from .comm_abc import Comm +from .communicator import Communicator, CubedSphereCommunicator, TileCommunicator from .local_comm import ConcurrencyError, LocalComm +from .mpi import MPIComm from .null_comm import NullComm from .partitioner import CubedSpherePartitioner, TilePartitioner diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index 1331294..269ae95 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -3,7 +3,13 @@ from ndsl.comm.mpi import MPI from . import dace -from .dace import DaceConfig, DaCeOrchestration, orchestrate, orchestrate_function +from .dace import ( + DaceConfig, + DaCeOrchestration, + WrappedHaloUpdater, + orchestrate, + orchestrate_function, +) from .stencil import CompareToNumpyStencil, FrozenStencil, GridIndexing, StencilFactory from .stencil_config import CompilationConfig, RunMode, StencilConfig diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index aa19a3d..c1386ad 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -1,2 +1,3 @@ from .dace_config import DaceConfig, DaCeOrchestration from .orchestration import orchestrate, orchestrate_function +from .wrapped_halo_exchange import WrappedHaloUpdater diff --git a/ndsl/initialization/__init__.py b/ndsl/initialization/__init__.py index 50fd2f8..fe15db8 100644 --- a/ndsl/initialization/__init__.py +++ b/ndsl/initialization/__init__.py @@ -1,2 +1,2 @@ from .allocator import QuantityFactory -from .sizer import SubtileGridSizer +from .sizer import GridSizer, SubtileGridSizer diff --git a/ndsl/performance/__init__.py b/ndsl/performance/__init__.py index b0a65f3..51e9bf8 100644 --- a/ndsl/performance/__init__.py +++ b/ndsl/performance/__init__.py @@ -1 +1,2 @@ +from .collector import PerformanceCollector from .timer import NullTimer, Timer diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index d3ec452..6083272 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1 +1,14 @@ +from .c2l_ord import CubedToLatLon +from .testing import ( + Grid, + ParallelTranslate, + ParallelTranslate2Py, + ParallelTranslate2PyState, + ParallelTranslateBaseSlicing, + ParallelTranslateGrid, + TranslateFortranData2Py, + TranslateGrid, +) + + __version__ = "0.2.0" diff --git a/ndsl/stencils/testing/__init__.py b/ndsl/stencils/testing/__init__.py index 61a6483..3ad9ef9 100644 --- a/ndsl/stencils/testing/__init__.py +++ b/ndsl/stencils/testing/__init__.py @@ -1,4 +1,17 @@ from . import parallel_translate, translate +from .grid import Grid # type: ignore +from .parallel_translate import ( + ParallelTranslate, + ParallelTranslate2Py, + ParallelTranslate2PyState, + ParallelTranslateBaseSlicing, + ParallelTranslateGrid, +) from .savepoint import dataset_to_dict from .temporaries import assert_same_temporaries, copy_temporaries -from .translate import pad_field_in_j, read_serialized_data +from .translate import ( + TranslateFortranData2Py, + TranslateGrid, + pad_field_in_j, + read_serialized_data, +) diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py index b6a5513..273a0f3 100644 --- a/ndsl/stencils/testing/grid.py +++ b/ndsl/stencils/testing/grid.py @@ -8,13 +8,13 @@ from ndsl.dsl import gt4py_utils as utils from ndsl.dsl.stencil import GridIndexing from ndsl.dsl.typing import Float -from ndsl.grid import ( +from ndsl.grid.generation import GridDefinitions +from ndsl.grid.helper import ( AngleGridData, ContravariantGridData, DampingCoefficients, DriverGridData, GridData, - GridDefinitions, HorizontalGridData, MetricTerms, VerticalGridData, diff --git a/tests/checkpointer/test_snapshot.py b/tests/checkpointer/test_snapshot.py index 797b470..a8dd538 100644 --- a/tests/checkpointer/test_snapshot.py +++ b/tests/checkpointer/test_snapshot.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ndsl.checkpointer.snapshots import SnapshotCheckpointer +from ndsl import SnapshotCheckpointer from ndsl.optional_imports import xarray as xr From 8400c83b55fe044fd7285f3b80407c584c127bd7 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 28 Feb 2024 16:00:39 -0500 Subject: [PATCH 05/14] Updated import method to mimic other package styles --- ndsl/__init__.py | 121 ++++++++++++++++++------ ndsl/checkpointer/__init__.py | 3 - ndsl/comm/__init__.py | 7 -- ndsl/dsl/__init__.py | 11 --- ndsl/dsl/dace/__init__.py | 3 - ndsl/dsl/stencil.py | 6 +- ndsl/grid/__init__.py | 13 --- ndsl/grid/helper.py | 2 +- ndsl/halo/__init__.py | 2 - ndsl/initialization/__init__.py | 2 - ndsl/monitor/__init__.py | 2 - ndsl/performance/__init__.py | 2 - ndsl/stencils/__init__.py | 13 --- ndsl/stencils/testing/__init__.py | 17 ---- ndsl/stencils/testing/test_translate.py | 5 +- ndsl/testing/__init__.py | 3 - ndsl/testing/dummy_comm.py | 1 - setup.py | 1 + tests/checkpointer/test_thresholds.py | 6 +- tests/checkpointer/test_validation.py | 7 +- tests/dsl/test_caches.py | 2 +- tests/mpi/test_mpi_mock.py | 2 +- tests/test_halo_data_transformer.py | 9 +- 23 files changed, 113 insertions(+), 127 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 221cf1b..6967d8d 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,47 +1,110 @@ -from .checkpointer import Checkpointer, NullCheckpointer, SnapshotCheckpointer -from .comm import ( +from .buffer import Buffer +from .checkpointer.base import Checkpointer +from .checkpointer.null import NullCheckpointer +from .checkpointer.snapshots import SnapshotCheckpointer, _Snapshots +from .checkpointer.thresholds import ( + InsufficientTrialsError, + SavepointThresholds, + Threshold, + ThresholdCalibrationCheckpointer, +) +from .checkpointer.validation import ValidationCheckpointer +from .comm.boundary import Boundary, SimpleBoundary +from .comm.caching_comm import ( + CachingCommData, CachingCommReader, CachingCommWriter, - Comm, - Communicator, - ConcurrencyError, - CubedSphereCommunicator, - CubedSpherePartitioner, - LocalComm, - MPIComm, - NullComm, - TileCommunicator, - TilePartitioner, -) -from .dsl import ( + CachingRequestReader, + CachingRequestWriter, + NullRequest, +) +from .comm.comm_abc import Comm, Request +from .comm.communicator import Communicator, CubedSphereCommunicator, TileCommunicator +from .comm.local_comm import AsyncResult, ConcurrencyError, LocalComm +from .comm.mpi import MPIComm +from .comm.null_comm import NullAsyncResult, NullComm +from .comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner +from .constants import ConstantVersions +from .dsl.caches.codepath import FV3CodePath +from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration, FrozenCompiledSDFG +from .dsl.dace.orchestration import orchestrate, orchestrate_function +from .dsl.dace.utils import ( + ArrayReport, + DaCeProgress, + MaxBandwithBenchmarkProgram, + StorageReport, +) +from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater +from .dsl.stencil import ( CompareToNumpyStencil, - CompilationConfig, - DaceConfig, - DaCeOrchestration, FrozenStencil, GridIndexing, - RunMode, - StencilConfig, StencilFactory, - WrappedHaloUpdater, + TimingCollector, ) +from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from .exceptions import OutOfBoundsError -from .halo import HaloDataTransformer, HaloExchangeSpec, HaloUpdater -from .initialization import GridSizer, QuantityFactory, SubtileGridSizer +from .grid.eta import HybridPressureCoefficients +from .grid.generation import GridDefinition, GridDefinitions, MetricTerms +from .grid.helper import ( + AngleGridData, + ContravariantGridData, + DampingCoefficients, + DriverGridData, + GridData, + HorizontalGridData, + VerticalGridData, +) +from .halo.data_transformer import ( + HaloDataTransformer, + HaloDataTransformerCPU, + HaloDataTransformerGPU, + HaloExchangeSpec, +) +from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater +from .initialization.allocator import QuantityFactory, StorageNumpy +from .initialization.sizer import GridSizer, SubtileGridSizer from .logging import ndsl_log -from .monitor import NetCDFMonitor, ZarrMonitor -from .performance import NullTimer, PerformanceCollector, Timer -from .quantity import Quantity, QuantityHaloSpec -from .stencils import ( - CubedToLatLon, - Grid, +from .monitor.netcdf_monitor import NetCDFMonitor +from .monitor.protocol import Protocol +from .monitor.zarr_monitor import ZarrMonitor +from .namelist import Namelist +from .optional_imports import RaiseWhenAccessed +from .performance.collector import ( + AbstractPerformanceCollector, + NullPerformanceCollector, + PerformanceCollector, +) +from .performance.config import PerformanceConfig +from .performance.profiler import NullProfiler, Profiler +from .performance.report import Experiment, Report, TimeReport +from .performance.timer import NullTimer, Timer +from .quantity import ( + BoundaryArrayView, + BoundedArrayView, + Quantity, + QuantityHaloSpec, + QuantityMetadata, +) +from .stencils.c2l_ord import CubedToLatLon +from .stencils.corners import CopyCorners, CopyCornersXY, FillCornersBGrid +from .stencils.testing.grid import Grid # type: ignore +from .stencils.testing.parallel_translate import ( ParallelTranslate, ParallelTranslate2Py, ParallelTranslate2PyState, ParallelTranslateBaseSlicing, ParallelTranslateGrid, +) +from .stencils.testing.savepoint import SavepointCase, Translate, dataset_to_dict +from .stencils.testing.temporaries import assert_same_temporaries, copy_temporaries +from .stencils.testing.translate import ( TranslateFortranData2Py, TranslateGrid, + pad_field_in_j, + read_serialized_data, ) -from .testing import DummyComm +from .testing.dummy_comm import DummyComm +from .types import Allocator, AsyncRequest, NumpyModule +from .units import UnitsError from .utils import MetaEnumStr diff --git a/ndsl/checkpointer/__init__.py b/ndsl/checkpointer/__init__.py index d24936c..e69de29 100644 --- a/ndsl/checkpointer/__init__.py +++ b/ndsl/checkpointer/__init__.py @@ -1,3 +0,0 @@ -from .base import Checkpointer -from .null import NullCheckpointer -from .snapshots import SnapshotCheckpointer diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index 31319c7..e69de29 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -1,7 +0,0 @@ -from .caching_comm import CachingCommReader, CachingCommWriter -from .comm_abc import Comm -from .communicator import Communicator, CubedSphereCommunicator, TileCommunicator -from .local_comm import ConcurrencyError, LocalComm -from .mpi import MPIComm -from .null_comm import NullComm -from .partitioner import CubedSpherePartitioner, TilePartitioner diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index 269ae95..ed44420 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -2,17 +2,6 @@ from ndsl.comm.mpi import MPI -from . import dace -from .dace import ( - DaceConfig, - DaCeOrchestration, - WrappedHaloUpdater, - orchestrate, - orchestrate_function, -) -from .stencil import CompareToNumpyStencil, FrozenStencil, GridIndexing, StencilFactory -from .stencil_config import CompilationConfig, RunMode, StencilConfig - if MPI is not None: import os diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index c1386ad..e69de29 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -1,3 +0,0 @@ -from .dace_config import DaceConfig, DaCeOrchestration -from .orchestration import orchestrate, orchestrate_function -from .wrapped_halo_exchange import WrappedHaloUpdater diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 77efd67..b831672 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -22,7 +22,6 @@ from gt4py.cartesian import gtscript from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline -from ndsl import testing from ndsl.comm.comm_abc import Comm from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles @@ -34,6 +33,9 @@ from ndsl.initialization.sizer import GridSizer, SubtileGridSizer from ndsl.quantity import Quantity +# from ndsl import testing +from ndsl.testing import comparison + try: import cupy as cp @@ -68,7 +70,7 @@ def report_difference(args, kwargs, args_copy, kwargs_copy, function_name, gt_id def report_diff(arg: np.ndarray, numpy_arg: np.ndarray, label) -> str: - metric_err = testing.compare_arr(arg, numpy_arg) + metric_err = comparison.compare_arr(arg, numpy_arg) nans_match = np.logical_and(np.isnan(arg), np.isnan(numpy_arg)) n_points = np.product(arg.shape) failures_14 = n_points - np.sum( diff --git a/ndsl/grid/__init__.py b/ndsl/grid/__init__.py index a7692a8..e69de29 100644 --- a/ndsl/grid/__init__.py +++ b/ndsl/grid/__init__.py @@ -1,13 +0,0 @@ -# flake8: noqa: F401 - -from .eta import set_hybrid_pressure_coefficients -from .gnomonic import ( - great_circle_distance_along_axis, - great_circle_distance_lon_lat, - lon_lat_corner_to_cell_center, - lon_lat_midpoint, - lon_lat_to_xyz, - xyz_midpoint, - xyz_to_lon_lat, -) -from .stretch_transformation import direct_transform diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index ee97a6b..fd62d77 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -14,7 +14,7 @@ from ndsl.constants import Z_DIM, Z_INTERFACE_DIM from ndsl.filesystem import get_fs from ndsl.grid.generation import MetricTerms -from ndsl.initialization import QuantityFactory +from ndsl.initialization.allocator import QuantityFactory from ndsl.quantity import Quantity diff --git a/ndsl/halo/__init__.py b/ndsl/halo/__init__.py index 823bd22..e69de29 100644 --- a/ndsl/halo/__init__.py +++ b/ndsl/halo/__init__.py @@ -1,2 +0,0 @@ -from .data_transformer import HaloDataTransformer, HaloExchangeSpec -from .updater import HaloUpdater diff --git a/ndsl/initialization/__init__.py b/ndsl/initialization/__init__.py index fe15db8..e69de29 100644 --- a/ndsl/initialization/__init__.py +++ b/ndsl/initialization/__init__.py @@ -1,2 +0,0 @@ -from .allocator import QuantityFactory -from .sizer import GridSizer, SubtileGridSizer diff --git a/ndsl/monitor/__init__.py b/ndsl/monitor/__init__.py index 26b38cc..e69de29 100644 --- a/ndsl/monitor/__init__.py +++ b/ndsl/monitor/__init__.py @@ -1,2 +0,0 @@ -from .netcdf_monitor import NetCDFMonitor -from .zarr_monitor import ZarrMonitor diff --git a/ndsl/performance/__init__.py b/ndsl/performance/__init__.py index 51e9bf8..e69de29 100644 --- a/ndsl/performance/__init__.py +++ b/ndsl/performance/__init__.py @@ -1,2 +0,0 @@ -from .collector import PerformanceCollector -from .timer import NullTimer, Timer diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index 6083272..d3ec452 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1,14 +1 @@ -from .c2l_ord import CubedToLatLon -from .testing import ( - Grid, - ParallelTranslate, - ParallelTranslate2Py, - ParallelTranslate2PyState, - ParallelTranslateBaseSlicing, - ParallelTranslateGrid, - TranslateFortranData2Py, - TranslateGrid, -) - - __version__ = "0.2.0" diff --git a/ndsl/stencils/testing/__init__.py b/ndsl/stencils/testing/__init__.py index 3ad9ef9..e69de29 100644 --- a/ndsl/stencils/testing/__init__.py +++ b/ndsl/stencils/testing/__init__.py @@ -1,17 +0,0 @@ -from . import parallel_translate, translate -from .grid import Grid # type: ignore -from .parallel_translate import ( - ParallelTranslate, - ParallelTranslate2Py, - ParallelTranslate2PyState, - ParallelTranslateBaseSlicing, - ParallelTranslateGrid, -) -from .savepoint import dataset_to_dict -from .temporaries import assert_same_temporaries, copy_temporaries -from .translate import ( - TranslateFortranData2Py, - TranslateGrid, - pad_field_in_j, - read_serialized_data, -) diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index 2e42e27..29c4ed6 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -14,8 +14,9 @@ from ndsl.dsl.stencil import CompilationConfig, StencilConfig from ndsl.quantity import Quantity from ndsl.restart._legacy_restart import RESTART_PROPERTIES -from ndsl.stencils.testing import SavepointCase, dataset_to_dict -from ndsl.testing import compare_scalar, perturb, success, success_array +from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict +from ndsl.testing.comparison import compare_scalar, success, success_array +from ndsl.testing.perturbation import perturb # this only matters for manually-added print statements diff --git a/ndsl/testing/__init__.py b/ndsl/testing/__init__.py index a1c927e..e69de29 100644 --- a/ndsl/testing/__init__.py +++ b/ndsl/testing/__init__.py @@ -1,3 +0,0 @@ -from .comparison import compare_arr, compare_scalar, success, success_array -from .dummy_comm import ConcurrencyError, DummyComm -from .perturbation import perturb diff --git a/ndsl/testing/dummy_comm.py b/ndsl/testing/dummy_comm.py index b4df234..f3e9381 100644 --- a/ndsl/testing/dummy_comm.py +++ b/ndsl/testing/dummy_comm.py @@ -1,2 +1 @@ -from ndsl.comm.local_comm import ConcurrencyError # noqa from ndsl.comm.local_comm import LocalComm as DummyComm # noqa diff --git a/setup.py b/setup.py index 73ec210..c0d7181 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ def local_pkg(name: str, relative_path: str) -> str: "mpi4py", "cftime", "xarray", + "f90nml>=1.1.0", "fsspec", "netcdf4", "scipy", # restart capacities only diff --git a/tests/checkpointer/test_thresholds.py b/tests/checkpointer/test_thresholds.py index 851a966..90d1f8f 100644 --- a/tests/checkpointer/test_thresholds.py +++ b/tests/checkpointer/test_thresholds.py @@ -1,11 +1,7 @@ import numpy as np import pytest -from ndsl.checkpointer.thresholds import ( - InsufficientTrialsError, - Threshold, - ThresholdCalibrationCheckpointer, -) +from ndsl import InsufficientTrialsError, Threshold, ThresholdCalibrationCheckpointer def test_thresholds_no_trials(): diff --git a/tests/checkpointer/test_validation.py b/tests/checkpointer/test_validation.py index b696aca..091bb7c 100644 --- a/tests/checkpointer/test_validation.py +++ b/tests/checkpointer/test_validation.py @@ -4,11 +4,8 @@ import numpy as np import pytest -from ndsl.checkpointer.thresholds import SavepointThresholds, Threshold -from ndsl.checkpointer.validation import ( - ValidationCheckpointer, - _clip_pace_array_to_target, -) +from ndsl import SavepointThresholds, Threshold, ValidationCheckpointer +from ndsl.checkpointer.validation import _clip_pace_array_to_target from ndsl.optional_imports import xarray as xr diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index f90a182..893fb89 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -11,7 +11,7 @@ StencilFactory, ) from ndsl.comm.mpi import MPI -from ndsl.dsl.dace import orchestrate +from ndsl.dsl.dace.orchestration import orchestrate def _make_storage( diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index d099d76..def0d34 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -1,8 +1,8 @@ import numpy as np import pytest +from ndsl import ConcurrencyError, DummyComm from ndsl.comm.communicator import recv_buffer -from ndsl.testing import ConcurrencyError, DummyComm from tests.mpi.mpi_comm import MPI diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index 7e1b9f0..ec986f8 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -4,8 +4,13 @@ import numpy as np import pytest -from ndsl import HaloDataTransformer, HaloExchangeSpec, Quantity, QuantityHaloSpec -from ndsl.buffer import Buffer +from ndsl import ( + Buffer, + HaloDataTransformer, + HaloExchangeSpec, + Quantity, + QuantityHaloSpec, +) from ndsl.comm import _boundary_utils from ndsl.constants import ( EAST, From ac8be60a6c92e400b9e05987b3741f1eeaaba11f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 1 Mar 2024 12:36:48 -0500 Subject: [PATCH 06/14] Move `DaCe` to top of `master` as of March 1st. --- .gitmodules | 3 +-- external/dace | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index 3ef1bc9..43838b0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,5 +3,4 @@ url = https://github.com/GridTools/gt4py.git [submodule "external/dace"] path = external/dace - url = https://github.com/FlorianDeconinck/dace.git - branch = fix/gcc_dies_on_dacecpu + url = https://github.com/spcl/dace.git diff --git a/external/dace b/external/dace index 22982af..b1a7f8a 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 22982afe133bccd906d5eeee448092f5f065ff6a +Subproject commit b1a7f8a6ea76f913a0bf8b32de5bc416697218fd From 1c8b4b9c09fcfa98a83311358ddd0ea7e46f5552 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Tue, 5 Mar 2024 16:28:43 -0500 Subject: [PATCH 07/14] Exposure changes and creation of ndsl.typing module --- ndsl/__init__.py | 52 +- ndsl/checkpointer/__init__.py | 9 + ndsl/checkpointer/base.py | 7 - ndsl/checkpointer/null.py | 2 +- ndsl/checkpointer/snapshots.py | 2 +- ndsl/checkpointer/thresholds.py | 2 +- ndsl/checkpointer/validation.py | 2 +- ndsl/comm/__init__.py | 9 + ndsl/comm/communicator.py | 568 +--------------------- ndsl/comm/partitioner.py | 79 +-- ndsl/dsl/caches/cache_location.py | 2 +- ndsl/dsl/dace/dace_config.py | 2 +- ndsl/dsl/dace/wrapped_halo_exchange.py | 2 +- ndsl/dsl/stencil.py | 2 +- ndsl/dsl/stencil_config.py | 3 +- ndsl/grid/__init__.py | 11 + ndsl/grid/generation.py | 2 +- ndsl/halo/updater.py | 2 +- ndsl/monitor/netcdf_monitor.py | 2 +- ndsl/monitor/zarr_monitor.py | 3 +- ndsl/restart/_legacy_restart.py | 2 +- ndsl/stencils/__init__.py | 20 + ndsl/stencils/c2l_ord.py | 2 +- ndsl/stencils/testing/conftest.py | 7 +- ndsl/typing.py | 648 +++++++++++++++++++++++++ tests/checkpointer/test_snapshot.py | 2 +- tests/checkpointer/test_thresholds.py | 6 +- tests/checkpointer/test_validation.py | 2 +- tests/mpi/test_mpi_mock.py | 2 +- tests/test_caching_comm.py | 3 +- 30 files changed, 733 insertions(+), 724 deletions(-) delete mode 100644 ndsl/checkpointer/base.py create mode 100644 ndsl/typing.py diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 6967d8d..a507302 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,29 +1,10 @@ from .buffer import Buffer -from .checkpointer.base import Checkpointer -from .checkpointer.null import NullCheckpointer -from .checkpointer.snapshots import SnapshotCheckpointer, _Snapshots -from .checkpointer.thresholds import ( - InsufficientTrialsError, - SavepointThresholds, - Threshold, - ThresholdCalibrationCheckpointer, -) -from .checkpointer.validation import ValidationCheckpointer from .comm.boundary import Boundary, SimpleBoundary -from .comm.caching_comm import ( - CachingCommData, - CachingCommReader, - CachingCommWriter, - CachingRequestReader, - CachingRequestWriter, - NullRequest, -) -from .comm.comm_abc import Comm, Request -from .comm.communicator import Communicator, CubedSphereCommunicator, TileCommunicator +from .comm.communicator import CubedSphereCommunicator, TileCommunicator from .comm.local_comm import AsyncResult, ConcurrencyError, LocalComm from .comm.mpi import MPIComm from .comm.null_comm import NullAsyncResult, NullComm -from .comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner +from .comm.partitioner import CubedSpherePartitioner, TilePartitioner from .constants import ConstantVersions from .dsl.caches.codepath import FV3CodePath from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration, FrozenCompiledSDFG @@ -44,17 +25,6 @@ ) from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from .exceptions import OutOfBoundsError -from .grid.eta import HybridPressureCoefficients -from .grid.generation import GridDefinition, GridDefinitions, MetricTerms -from .grid.helper import ( - AngleGridData, - ContravariantGridData, - DampingCoefficients, - DriverGridData, - GridData, - HorizontalGridData, - VerticalGridData, -) from .halo.data_transformer import ( HaloDataTransformer, HaloDataTransformerCPU, @@ -86,24 +56,6 @@ QuantityHaloSpec, QuantityMetadata, ) -from .stencils.c2l_ord import CubedToLatLon -from .stencils.corners import CopyCorners, CopyCornersXY, FillCornersBGrid -from .stencils.testing.grid import Grid # type: ignore -from .stencils.testing.parallel_translate import ( - ParallelTranslate, - ParallelTranslate2Py, - ParallelTranslate2PyState, - ParallelTranslateBaseSlicing, - ParallelTranslateGrid, -) -from .stencils.testing.savepoint import SavepointCase, Translate, dataset_to_dict -from .stencils.testing.temporaries import assert_same_temporaries, copy_temporaries -from .stencils.testing.translate import ( - TranslateFortranData2Py, - TranslateGrid, - pad_field_in_j, - read_serialized_data, -) from .testing.dummy_comm import DummyComm from .types import Allocator, AsyncRequest, NumpyModule from .units import UnitsError diff --git a/ndsl/checkpointer/__init__.py b/ndsl/checkpointer/__init__.py index e69de29..6486d96 100644 --- a/ndsl/checkpointer/__init__.py +++ b/ndsl/checkpointer/__init__.py @@ -0,0 +1,9 @@ +from .null import NullCheckpointer +from .snapshots import SnapshotCheckpointer, _Snapshots +from .thresholds import ( + InsufficientTrialsError, + SavepointThresholds, + Threshold, + ThresholdCalibrationCheckpointer, +) +from .validation import ValidationCheckpointer diff --git a/ndsl/checkpointer/base.py b/ndsl/checkpointer/base.py deleted file mode 100644 index 8218bbf..0000000 --- a/ndsl/checkpointer/base.py +++ /dev/null @@ -1,7 +0,0 @@ -import abc - - -class Checkpointer(abc.ABC): - @abc.abstractmethod - def __call__(self, savepoint_name, **kwargs): - ... diff --git a/ndsl/checkpointer/null.py b/ndsl/checkpointer/null.py index fbc7875..448b3a6 100644 --- a/ndsl/checkpointer/null.py +++ b/ndsl/checkpointer/null.py @@ -1,4 +1,4 @@ -from ndsl.checkpointer.base import Checkpointer +from ndsl.typing import Checkpointer class NullCheckpointer(Checkpointer): diff --git a/ndsl/checkpointer/snapshots.py b/ndsl/checkpointer/snapshots.py index aa806b2..573701a 100644 --- a/ndsl/checkpointer/snapshots.py +++ b/ndsl/checkpointer/snapshots.py @@ -2,9 +2,9 @@ import numpy as np -from ndsl.checkpointer.base import Checkpointer from ndsl.optional_imports import cupy as cp from ndsl.optional_imports import xarray as xr +from ndsl.typing import Checkpointer def make_dims(savepoint_dim, label, data_list): diff --git a/ndsl/checkpointer/thresholds.py b/ndsl/checkpointer/thresholds.py index ded73b3..2f1af55 100644 --- a/ndsl/checkpointer/thresholds.py +++ b/ndsl/checkpointer/thresholds.py @@ -5,8 +5,8 @@ import numpy as np -from ndsl.checkpointer.base import Checkpointer from ndsl.quantity import Quantity +from ndsl.typing import Checkpointer try: diff --git a/ndsl/checkpointer/validation.py b/ndsl/checkpointer/validation.py index 8af1131..12146a5 100644 --- a/ndsl/checkpointer/validation.py +++ b/ndsl/checkpointer/validation.py @@ -5,7 +5,6 @@ import numpy as np -from ndsl.checkpointer.base import Checkpointer from ndsl.checkpointer.thresholds import ( ArrayLike, SavepointName, @@ -13,6 +12,7 @@ cast_to_ndarray, ) from ndsl.optional_imports import xarray as xr +from ndsl.typing import Checkpointer def _clip_pace_array_to_target( diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index e69de29..289e641 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -0,0 +1,9 @@ +from .caching_comm import ( + CachingCommData, + CachingCommReader, + CachingCommWriter, + CachingRequestReader, + CachingRequestWriter, + NullRequest, +) +from .comm_abc import Comm, Request diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 9149c1e..3f21ee2 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -1,17 +1,11 @@ -import abc -from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast - -import numpy as np +from typing import List, Optional, Sequence, Tuple, Union, cast import ndsl.constants as constants -from ndsl.buffer import array_buffer, recv_buffer, send_buffer -from ndsl.comm.boundary import Boundary -from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner -from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater -from ndsl.performance.timer import NullTimer, Timer -from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata -from ndsl.types import NumpyModule -from ndsl.utils import device_synchronize +from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest +from ndsl.performance.timer import Timer +from ndsl.quantity import Quantity, QuantityMetadata +from ndsl.typing import Communicator try: @@ -20,29 +14,6 @@ cupy = None -def to_numpy(array, dtype=None) -> np.ndarray: - """ - Input array can be a numpy array or a cupy array. Returns numpy array. - """ - try: - output = np.asarray(array) - except ValueError as err: - if err.args[0] == "object __array__ method not producing an array": - output = cupy.asnumpy(array) - else: - raise err - except TypeError as err: - if err.args[0].startswith( - "Implicit conversion to a NumPy array is not allowed." - ): - output = cupy.asnumpy(array) - else: - raise err - if dtype: - output = output.astype(dtype=dtype) - return output - - def bcast_metadata_list(comm, quantity_list): is_root = comm.Get_rank() == constants.ROOT_RANK if is_root: @@ -58,533 +29,6 @@ def bcast_metadata(comm, array): return bcast_metadata_list(comm, [array])[0] -class Communicator(abc.ABC): - def __init__( - self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None - ): - self.comm = comm - self.partitioner: Partitioner = partitioner - self._force_cpu = force_cpu - self._boundaries: Optional[Mapping[int, Boundary]] = None - self._last_halo_tag = 0 - self.timer: Timer = timer if timer is not None else NullTimer() - - @abc.abstractproperty - def tile(self) -> "TileCommunicator": - pass - - @classmethod - @abc.abstractmethod - def from_layout( - cls, - comm, - layout: Tuple[int, int], - force_cpu: bool = False, - timer: Optional[Timer] = None, - ): - pass - - @property - def rank(self) -> int: - """rank of the current process within this communicator""" - return self.comm.Get_rank() - - @property - def size(self) -> int: - """Total number of ranks in this communicator""" - return self.comm.Get_size() - - def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: - """ - Get a numpy-like module depending on configuration and - Quantity original allocator. - """ - if self._force_cpu: - return np - return module - - @staticmethod - def _device_synchronize(): - """Wait for all work that could be in-flight to finish.""" - # this is a method so we can profile it separately from other device syncs - device_synchronize() - - def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: - self.comm.Scatter(send, recv, **kwargs) - - def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: - self.comm.Gather(send, recv, **kwargs) - - def scatter( - self, - send_quantity: Optional[Quantity] = None, - recv_quantity: Optional[Quantity] = None, - ) -> Quantity: - """Transfer subtile regions of a full-tile quantity - from the tile root rank to all subtiles. - - Args: - send_quantity: quantity to send, only required/used on the tile root rank - recv_quantity: if provided, assign received data into this Quantity. - Returns: - recv_quantity - """ - if self.rank == constants.ROOT_RANK and send_quantity is None: - raise TypeError("send_quantity is a required argument on the root rank") - if self.rank == constants.ROOT_RANK: - send_quantity = cast(Quantity, send_quantity) - metadata = self.comm.bcast(send_quantity.metadata, root=constants.ROOT_RANK) - else: - metadata = self.comm.bcast(None, root=constants.ROOT_RANK) - shape = self.partitioner.subtile_extent(metadata, self.rank) - if recv_quantity is None: - recv_quantity = self._get_scatter_recv_quantity(shape, metadata) - if self.rank == constants.ROOT_RANK: - send_quantity = cast(Quantity, send_quantity) - with array_buffer( - self._maybe_force_cpu(metadata.np).zeros, - (self.partitioner.total_ranks,) + shape, - dtype=metadata.dtype, - ) as sendbuf: - for rank in range(0, self.partitioner.total_ranks): - subtile_slice = self.partitioner.subtile_slice( - rank=rank, - global_dims=metadata.dims, - global_extent=metadata.extent, - overlap=True, - ) - sendbuf.assign_from( - send_quantity.view[subtile_slice], - buffer_slice=np.index_exp[rank, :], - ) - self._Scatter( - metadata.np, - sendbuf.array, - recv_quantity.view[:], - root=constants.ROOT_RANK, - ) - else: - self._Scatter( - metadata.np, - None, - recv_quantity.view[:], - root=constants.ROOT_RANK, - ) - return recv_quantity - - def _get_gather_recv_quantity( - self, global_extent: Sequence[int], send_metadata: QuantityMetadata - ) -> Quantity: - """Initialize a Quantity for use when receiving global data during gather""" - recv_quantity = Quantity( - send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype), - dims=send_metadata.dims, - units=send_metadata.units, - origin=tuple([0 for dim in send_metadata.dims]), - extent=global_extent, - gt4py_backend=send_metadata.gt4py_backend, - allow_mismatch_float_precision=True, - ) - return recv_quantity - - def _get_scatter_recv_quantity( - self, shape: Sequence[int], send_metadata: QuantityMetadata - ) -> Quantity: - """Initialize a Quantity for use when receiving subtile data during scatter""" - recv_quantity = Quantity( - send_metadata.np.zeros(shape, dtype=send_metadata.dtype), - dims=send_metadata.dims, - units=send_metadata.units, - gt4py_backend=send_metadata.gt4py_backend, - allow_mismatch_float_precision=True, - ) - return recv_quantity - - def gather( - self, send_quantity: Quantity, recv_quantity: Quantity = None - ) -> Optional[Quantity]: - """Transfer subtile regions of a full-tile quantity - from each rank to the tile root rank. - - Args: - send_quantity: quantity to send - recv_quantity: if provided, assign received data into this Quantity (only - used on the tile root rank) - Returns: - recv_quantity: quantity if on root rank, otherwise None - """ - result: Optional[Quantity] - if self.rank == constants.ROOT_RANK: - with array_buffer( - send_quantity.np.zeros, - (self.partitioner.total_ranks,) + tuple(send_quantity.extent), - dtype=send_quantity.data.dtype, - ) as recvbuf: - self._Gather( - send_quantity.np, - send_quantity.view[:], - recvbuf.array, - root=constants.ROOT_RANK, - ) - if recv_quantity is None: - global_extent = self.partitioner.global_extent( - send_quantity.metadata - ) - recv_quantity = self._get_gather_recv_quantity( - global_extent, send_quantity.metadata - ) - for rank in range(self.partitioner.total_ranks): - to_slice = self.partitioner.subtile_slice( - rank=rank, - global_dims=recv_quantity.dims, - global_extent=recv_quantity.extent, - overlap=True, - ) - recvbuf.assign_to( - recv_quantity.view[to_slice], buffer_slice=np.index_exp[rank, :] - ) - result = recv_quantity - else: - self._Gather( - send_quantity.np, - send_quantity.view[:], - None, - root=constants.ROOT_RANK, - ) - result = None - return result - - def gather_state(self, send_state=None, recv_state=None, transfer_type=None): - """Transfer a state dictionary from subtile ranks to the tile root rank. - - 'time' is assumed to be the same on all ranks, and its value will be set - to the value from the root rank. - - Args: - send_state: the model state to be sent containing the subtile data - recv_state: the pre-allocated state in which to recieve the full tile - state. Only variables which are scattered will be written to. - Returns: - recv_state: on the root rank, the state containing the entire tile - """ - if self.rank == constants.ROOT_RANK and recv_state is None: - recv_state = {} - for name, quantity in send_state.items(): - if name == "time": - if self.rank == constants.ROOT_RANK: - recv_state["time"] = send_state["time"] - else: - gather_value = to_numpy(quantity.view[:], dtype=transfer_type) - gather_quantity = Quantity( - data=gather_value, - dims=quantity.dims, - units=quantity.units, - allow_mismatch_float_precision=True, - ) - if recv_state is not None and name in recv_state: - tile_quantity = self.gather( - gather_quantity, recv_quantity=recv_state[name] - ) - else: - tile_quantity = self.gather(gather_quantity) - if self.rank == constants.ROOT_RANK: - recv_state[name] = tile_quantity - del gather_quantity - return recv_state - - def scatter_state(self, send_state=None, recv_state=None): - """Transfer a state dictionary from the tile root rank to all subtiles. - - Args: - send_state: the model state to be sent containing the entire tile, - required only from the root rank - recv_state: the pre-allocated state in which to recieve the scattered - state. Only variables which are scattered will be written to. - Returns: - rank_state: the state corresponding to this rank's subdomain - """ - - def scatter_root(): - if send_state is None: - raise TypeError("send_state is a required argument on the root rank") - name_list = list(send_state.keys()) - while "time" in name_list: - name_list.remove("time") - name_list = self.comm.bcast(name_list, root=constants.ROOT_RANK) - array_list = [send_state[name] for name in name_list] - for name, array in zip(name_list, array_list): - if name in recv_state: - self.scatter(send_quantity=array, recv_quantity=recv_state[name]) - else: - recv_state[name] = self.scatter(send_quantity=array) - recv_state["time"] = self.comm.bcast( - send_state.get("time", None), root=constants.ROOT_RANK - ) - - def scatter_client(): - name_list = self.comm.bcast(None, root=constants.ROOT_RANK) - for name in name_list: - if name in recv_state: - self.scatter(recv_quantity=recv_state[name]) - else: - recv_state[name] = self.scatter() - recv_state["time"] = self.comm.bcast(None, root=constants.ROOT_RANK) - - if recv_state is None: - recv_state = {} - if self.rank == constants.ROOT_RANK: - scatter_root() - else: - scatter_client() - if recv_state["time"] is None: - recv_state.pop("time") - return recv_state - - def halo_update(self, quantity: Union[Quantity, List[Quantity]], n_points: int): - """Perform a halo update on a quantity or quantities - - Args: - quantity: the quantity to be updated - n_points: how many halo points to update, starting from the interior - """ - if isinstance(quantity, Quantity): - quantities = [quantity] - else: - quantities = quantity - - halo_updater = self.start_halo_update(quantities, n_points) - halo_updater.wait() - - def start_halo_update( - self, quantity: Union[Quantity, List[Quantity]], n_points: int - ) -> HaloUpdater: - """Start an asynchronous halo update on a quantity. - - Args: - quantity: the quantity to be updated - n_points: how many halo points to update, starting from the interior - - Returns: - request: an asynchronous request object with a .wait() method - """ - if isinstance(quantity, Quantity): - quantities = [quantity] - else: - quantities = quantity - - specifications = [] - for quantity in quantities: - specification = QuantityHaloSpec( - n_points=n_points, - shape=quantity.data.shape, - strides=quantity.data.strides, - itemsize=quantity.data.itemsize, - origin=quantity.origin, - extent=quantity.extent, - dims=quantity.dims, - numpy_module=self._maybe_force_cpu(quantity.np), - dtype=quantity.metadata.dtype, - ) - specifications.append(specification) - - halo_updater = self.get_scalar_halo_updater(specifications) - halo_updater.force_finalize_on_wait() - halo_updater.start(quantities) - return halo_updater - - def vector_halo_update( - self, - x_quantity: Union[Quantity, List[Quantity]], - y_quantity: Union[Quantity, List[Quantity]], - n_points: int, - ): - """Perform a halo update of a horizontal vector quantity or quantities. - - Assumes the x and y dimension indices are the same between the two quantities. - - Args: - x_quantity: the x-component quantity to be halo updated - y_quantity: the y-component quantity to be halo updated - n_points: how many halo points to update, starting at the interior - """ - if isinstance(x_quantity, Quantity): - x_quantities = [x_quantity] - else: - x_quantities = x_quantity - if isinstance(y_quantity, Quantity): - y_quantities = [y_quantity] - else: - y_quantities = y_quantity - - halo_updater = self.start_vector_halo_update( - x_quantities, y_quantities, n_points - ) - halo_updater.wait() - - def start_vector_halo_update( - self, - x_quantity: Union[Quantity, List[Quantity]], - y_quantity: Union[Quantity, List[Quantity]], - n_points: int, - ) -> HaloUpdater: - """Start an asynchronous halo update of a horizontal vector quantity. - - Assumes the x and y dimension indices are the same between the two quantities. - - Args: - x_quantity: the x-component quantity to be halo updated - y_quantity: the y-component quantity to be halo updated - n_points: how many halo points to update, starting at the interior - - Returns: - request: an asynchronous request object with a .wait() method - """ - if isinstance(x_quantity, Quantity): - x_quantities = [x_quantity] - else: - x_quantities = x_quantity - if isinstance(y_quantity, Quantity): - y_quantities = [y_quantity] - else: - y_quantities = y_quantity - - x_specifications = [] - y_specifications = [] - for x_quantity, y_quantity in zip(x_quantities, y_quantities): - x_specification = QuantityHaloSpec( - n_points=n_points, - shape=x_quantity.data.shape, - strides=x_quantity.data.strides, - itemsize=x_quantity.data.itemsize, - origin=x_quantity.metadata.origin, - extent=x_quantity.metadata.extent, - dims=x_quantity.metadata.dims, - numpy_module=self._maybe_force_cpu(x_quantity.np), - dtype=x_quantity.metadata.dtype, - ) - x_specifications.append(x_specification) - y_specification = QuantityHaloSpec( - n_points=n_points, - shape=y_quantity.data.shape, - strides=y_quantity.data.strides, - itemsize=y_quantity.data.itemsize, - origin=y_quantity.metadata.origin, - extent=y_quantity.metadata.extent, - dims=y_quantity.metadata.dims, - numpy_module=self._maybe_force_cpu(y_quantity.np), - dtype=y_quantity.metadata.dtype, - ) - y_specifications.append(y_specification) - - halo_updater = self.get_vector_halo_updater(x_specifications, y_specifications) - halo_updater.force_finalize_on_wait() - halo_updater.start(x_quantities, y_quantities) - return halo_updater - - def synchronize_vector_interfaces(self, x_quantity: Quantity, y_quantity: Quantity): - """ - Synchronize shared points at the edges of a vector interface variable. - - Sends the values on the south and west edges to overwrite the values on adjacent - subtiles. Vector must be defined on the Arakawa C grid. - - For interface variables, the edges of the tile are computed on both ranks - bordering that edge. This routine copies values across those shared edges - so that both ranks have the same value for that edge. It also handles any - rotation of vector quantities needed to move data across the edge. - - Args: - x_quantity: the x-component quantity to be synchronized - y_quantity: the y-component quantity to be synchronized - """ - req = self.start_synchronize_vector_interfaces(x_quantity, y_quantity) - req.wait() - - def start_synchronize_vector_interfaces( - self, x_quantity: Quantity, y_quantity: Quantity - ) -> HaloUpdateRequest: - """ - Synchronize shared points at the edges of a vector interface variable. - - Sends the values on the south and west edges to overwrite the values on adjacent - subtiles. Vector must be defined on the Arakawa C grid. - - For interface variables, the edges of the tile are computed on both ranks - bordering that edge. This routine copies values across those shared edges - so that both ranks have the same value for that edge. It also handles any - rotation of vector quantities needed to move data across the edge. - - Args: - x_quantity: the x-component quantity to be synchronized - y_quantity: the y-component quantity to be synchronized - - Returns: - request: an asynchronous request object with a .wait() method - """ - halo_updater = VectorInterfaceHaloUpdater( - comm=self.comm, - boundaries=self.boundaries, - force_cpu=self._force_cpu, - timer=self.timer, - ) - req = halo_updater.start_synchronize_vector_interfaces(x_quantity, y_quantity) - return req - - def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): - if len(specifications) == 0: - raise RuntimeError("Cannot create updater with specifications list") - if specifications[0].n_points == 0: - raise ValueError("cannot perform a halo update on zero halo points") - return HaloUpdater.from_scalar_specifications( - self, - self._maybe_force_cpu(specifications[0].numpy_module), - specifications, - self.boundaries.values(), - self._get_halo_tag(), - self.timer, - ) - - def get_vector_halo_updater( - self, - specifications_x: List[QuantityHaloSpec], - specifications_y: List[QuantityHaloSpec], - ): - if len(specifications_x) == 0 and len(specifications_y) == 0: - raise RuntimeError("Cannot create updater with empty specifications list") - if specifications_x[0].n_points == 0 and specifications_y[0].n_points == 0: - raise ValueError("Cannot perform a halo update on zero halo points") - return HaloUpdater.from_vector_specifications( - self, - self._maybe_force_cpu(specifications_x[0].numpy_module), - specifications_x, - specifications_y, - self.boundaries.values(), - self._get_halo_tag(), - self.timer, - ) - - def _get_halo_tag(self) -> int: - self._last_halo_tag += 1 - return self._last_halo_tag - - @property - def boundaries(self) -> Mapping[int, Boundary]: - """boundaries of this tile with neighboring tiles""" - if self._boundaries is None: - self._boundaries = {} - for boundary_type in constants.BOUNDARY_TYPES: - boundary = self.partitioner.boundary(boundary_type, self.rank) - if boundary is not None: - self._boundaries[boundary_type] = boundary - return self._boundaries - - class TileCommunicator(Communicator): """Performs communications within a single tile or region of a tile""" diff --git a/ndsl/comm/partitioner.py b/ndsl/comm/partitioner.py index 6b8750a..e3b2e02 100644 --- a/ndsl/comm/partitioner.py +++ b/ndsl/comm/partitioner.py @@ -1,4 +1,3 @@ -import abc import copy import functools from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union, cast @@ -18,6 +17,7 @@ WEST, ) from ndsl.quantity import Quantity, QuantityMetadata +from ndsl.typing import Partitioner from ndsl.utils import list_by_dims @@ -54,83 +54,6 @@ def get_tile_number(tile_rank: int, total_ranks: int) -> int: return tile_rank // ranks_per_tile + 1 -class Partitioner(abc.ABC): - @abc.abstractmethod - def __init__(self): - self.tile = None - self.layout = None - - @abc.abstractmethod - def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: - ... - - @abc.abstractmethod - def tile_index(self, rank: int): - pass - - @abc.abstractmethod - def global_extent(self, rank_metadata: QuantityMetadata) -> Tuple[int, ...]: - """Return the shape of a full tile representation for the given dimensions. - - Args: - metadata: quantity metadata - - Returns: - extent: shape of full tile representation - """ - pass - - @abc.abstractmethod - def subtile_slice( - self, - rank: int, - global_dims: Sequence[str], - global_extent: Sequence[int], - overlap: bool = False, - ) -> Tuple[Union[int, slice], ...]: - """Return the subtile slice of a given rank on an array. - - Global refers to the domain being partitioned. For example, for a partitioning - of a tile, the tile would be the "global" domain. - - Args: - rank: the rank of the process - global_dims: dimensions of the global quantity being partitioned - global_extent: extent of the global quantity being partitioned - overlap (optional): if True, for interface variables include the part - of the array shared by adjacent ranks in both ranks. If False, ensure - only one of those ranks (the greater rank) is assigned the overlapping - section. Default is False. - - Returns: - subtile_slice: the slice of the global compute domain corresponding - to the subtile compute domain - """ - pass - - @abc.abstractmethod - def subtile_extent( - self, - global_metadata: QuantityMetadata, - rank: int, - ) -> Tuple[int, ...]: - """Return the shape of a single rank representation for the given dimensions. - - Args: - global_metadata: quantity metadata. - rank: rank of the process. - - Returns: - extent: shape of a single rank representation for the given dimensions. - """ - pass - - @property - @abc.abstractmethod - def total_ranks(self) -> int: - pass - - class TilePartitioner(Partitioner): def __init__( self, diff --git a/ndsl/dsl/caches/cache_location.py b/ndsl/dsl/caches/cache_location.py index edf563b..2d973f7 100644 --- a/ndsl/dsl/caches/cache_location.py +++ b/ndsl/dsl/caches/cache_location.py @@ -1,5 +1,5 @@ -from ndsl.comm.partitioner import Partitioner from ndsl.dsl.caches.codepath import FV3CodePath +from ndsl.typing import Partitioner def identify_code_path( diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 7671b46..f93d2ba 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -6,12 +6,12 @@ from dace.codegen.compiled_sdfg import CompiledSDFG from dace.frontend.python.parser import DaceProgram -from ndsl.comm.communicator import Communicator, Partitioner from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath from ndsl.dsl.gt4py_utils import is_gpu_backend from ndsl.dsl.typing import floating_point_precision from ndsl.optional_imports import cupy as cp +from ndsl.typing import Communicator, Partitioner # This can be turned on to revert compilation for orchestration diff --git a/ndsl/dsl/dace/wrapped_halo_exchange.py b/ndsl/dsl/dace/wrapped_halo_exchange.py index 78a68fa..ca36f3a 100644 --- a/ndsl/dsl/dace/wrapped_halo_exchange.py +++ b/ndsl/dsl/dace/wrapped_halo_exchange.py @@ -1,9 +1,9 @@ import dataclasses from typing import Any, List, Optional -from ndsl.comm.communicator import Communicator from ndsl.dsl.dace.orchestration import dace_inhibitor from ndsl.halo.updater import HaloUpdater +from ndsl.typing import Communicator class WrappedHaloUpdater: diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index b831672..f57c139 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -23,7 +23,6 @@ from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline from ndsl.comm.comm_abc import Comm -from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles from ndsl.comm.mpi import MPI from ndsl.constants import X_DIM, X_DIMS, Y_DIM, Y_DIMS, Z_DIM, Z_DIMS @@ -35,6 +34,7 @@ # from ndsl import testing from ndsl.testing import comparison +from ndsl.typing import Communicator try: diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index 6b8f75e..e1e233b 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -5,11 +5,10 @@ from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline -from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import determine_rank_is_compiling, set_distributed_caches -from ndsl.comm.partitioner import Partitioner from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration from ndsl.dsl.gt4py_utils import is_gpu_backend +from ndsl.typing import Communicator, Partitioner class RunMode(enum.Enum): diff --git a/ndsl/grid/__init__.py b/ndsl/grid/__init__.py index e69de29..49eccf0 100644 --- a/ndsl/grid/__init__.py +++ b/ndsl/grid/__init__.py @@ -0,0 +1,11 @@ +from .eta import HybridPressureCoefficients +from .generation import GridDefinition, GridDefinitions, MetricTerms +from .helper import ( + AngleGridData, + ContravariantGridData, + DampingCoefficients, + DriverGridData, + GridData, + HorizontalGridData, + VerticalGridData, +) diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 12275d7..2d6450a 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -5,7 +5,6 @@ import numpy as np -from ndsl.comm.communicator import Communicator from ndsl.constants import ( N_HALO_DEFAULT, PI, @@ -59,6 +58,7 @@ fill_corners_cgrid, fill_corners_dgrid, ) +from ndsl.typing import Communicator # TODO: when every environment in python3.8, remove diff --git a/ndsl/halo/updater.py b/ndsl/halo/updater.py index 665d0b9..7684c56 100644 --- a/ndsl/halo/updater.py +++ b/ndsl/halo/updater.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: - from ndsl.comm.communicator import Communicator + from ndsl.typing import Communicator _HaloSendTuple = Tuple[AsyncRequest, Buffer] _HaloRecvTuple = Tuple[AsyncRequest, Buffer, np.ndarray] diff --git a/ndsl/monitor/netcdf_monitor.py b/ndsl/monitor/netcdf_monitor.py index 8a0b96f..3073109 100644 --- a/ndsl/monitor/netcdf_monitor.py +++ b/ndsl/monitor/netcdf_monitor.py @@ -5,12 +5,12 @@ import fsspec import numpy as np -from ndsl.comm.communicator import Communicator from ndsl.filesystem import get_fs from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy from ndsl.optional_imports import xarray as xr from ndsl.quantity import Quantity +from ndsl.typing import Communicator class _TimeChunkedVariable: diff --git a/ndsl/monitor/zarr_monitor.py b/ndsl/monitor/zarr_monitor.py index 214171b..85e3722 100644 --- a/ndsl/monitor/zarr_monitor.py +++ b/ndsl/monitor/zarr_monitor.py @@ -4,12 +4,13 @@ import cftime import ndsl.constants as constants -from ndsl.comm.partitioner import Partitioner, subtile_slice +from ndsl.comm.partitioner import subtile_slice from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy from ndsl.optional_imports import cupy from ndsl.optional_imports import xarray as xr from ndsl.optional_imports import zarr +from ndsl.typing import Partitioner from ndsl.utils import list_by_dims diff --git a/ndsl/restart/_legacy_restart.py b/ndsl/restart/_legacy_restart.py index 01f9bdb..afa4d52 100644 --- a/ndsl/restart/_legacy_restart.py +++ b/ndsl/restart/_legacy_restart.py @@ -5,11 +5,11 @@ import ndsl.constants as constants import ndsl.filesystem as filesystem import ndsl.io as io -from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import get_tile_index from ndsl.optional_imports import xarray as xr from ndsl.quantity import Quantity from ndsl.restart._properties import RESTART_PROPERTIES, RestartProperties +from ndsl.typing import Communicator __all__ = ["open_restart"] diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index d3ec452..0fe1672 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1 +1,21 @@ +from .c2l_ord import CubedToLatLon +from .corners import CopyCorners, CopyCornersXY, FillCornersBGrid +from .testing.grid import Grid # type: ignore +from .testing.parallel_translate import ( + ParallelTranslate, + ParallelTranslate2Py, + ParallelTranslate2PyState, + ParallelTranslateBaseSlicing, + ParallelTranslateGrid, +) +from .testing.savepoint import SavepointCase, Translate, dataset_to_dict +from .testing.temporaries import assert_same_temporaries, copy_temporaries +from .testing.translate import ( + TranslateFortranData2Py, + TranslateGrid, + pad_field_in_j, + read_serialized_data, +) + + __version__ = "0.2.0" diff --git a/ndsl/stencils/c2l_ord.py b/ndsl/stencils/c2l_ord.py index 67f2b5a..87d59f2 100644 --- a/ndsl/stencils/c2l_ord.py +++ b/ndsl/stencils/c2l_ord.py @@ -8,13 +8,13 @@ ) import ndsl.dsl.gt4py_utils as utils -from ndsl.comm.communicator import Communicator from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM from ndsl.dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater from ndsl.dsl.stencil import StencilFactory from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ from ndsl.grid.helper import GridData from ndsl.initialization.allocator import QuantityFactory +from ndsl.typing import Communicator A1 = 0.5625 diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index d000e1f..b3a3a7e 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -8,11 +8,7 @@ import yaml import ndsl.dsl -from ndsl.comm.communicator import ( - Communicator, - CubedSphereCommunicator, - TileCommunicator, -) +from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator from ndsl.comm.mpi import MPI from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig @@ -20,6 +16,7 @@ from ndsl.stencils.testing.parallel_translate import ParallelTranslate from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict from ndsl.stencils.testing.translate import TranslateGrid +from ndsl.typing import Communicator @pytest.fixture() diff --git a/ndsl/typing.py b/ndsl/typing.py new file mode 100644 index 0000000..2f815d0 --- /dev/null +++ b/ndsl/typing.py @@ -0,0 +1,648 @@ +import abc +from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast + +import numpy as np + +import ndsl.constants as constants +from ndsl.buffer import array_buffer, recv_buffer, send_buffer +from ndsl.comm.boundary import Boundary +from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater +from ndsl.performance.timer import NullTimer, Timer +from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata +from ndsl.types import NumpyModule +from ndsl.utils import device_synchronize +from ndsl.comm import boundary as bd + +try: + import cupy +except ImportError: + cupy = None + +def to_numpy(array, dtype=None) -> np.ndarray: + """ + Input array can be a numpy array or a cupy array. Returns numpy array. + """ + try: + output = np.asarray(array) + except ValueError as err: + if err.args[0] == "object __array__ method not producing an array": + output = cupy.asnumpy(array) + else: + raise err + except TypeError as err: + if err.args[0].startswith( + "Implicit conversion to a NumPy array is not allowed." + ): + output = cupy.asnumpy(array) + else: + raise err + if dtype: + output = output.astype(dtype=dtype) + return output + +class Checkpointer(abc.ABC): + @abc.abstractmethod + def __call__(self, savepoint_name, **kwargs): + ... + +class Communicator(abc.ABC): + def __init__( + self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None + ): + self.comm = comm + self.partitioner: Partitioner = partitioner + self._force_cpu = force_cpu + self._boundaries: Optional[Mapping[int, Boundary]] = None + self._last_halo_tag = 0 + self.timer: Timer = timer if timer is not None else NullTimer() + + @abc.abstractproperty + def tile(self): + pass + + @classmethod + @abc.abstractmethod + def from_layout( + cls, + comm, + layout: Tuple[int, int], + force_cpu: bool = False, + timer: Optional[Timer] = None, + ): + pass + + @property + def rank(self) -> int: + """rank of the current process within this communicator""" + return self.comm.Get_rank() + + @property + def size(self) -> int: + """Total number of ranks in this communicator""" + return self.comm.Get_size() + + def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: + """ + Get a numpy-like module depending on configuration and + Quantity original allocator. + """ + if self._force_cpu: + return np + return module + + @staticmethod + def _device_synchronize(): + """Wait for all work that could be in-flight to finish.""" + # this is a method so we can profile it separately from other device syncs + device_synchronize() + + def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): + with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( + numpy_module.zeros, recvbuf + ) as recv: + self.comm.Scatter(send, recv, **kwargs) + + def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): + with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( + numpy_module.zeros, recvbuf + ) as recv: + self.comm.Gather(send, recv, **kwargs) + + def scatter( + self, + send_quantity: Optional[Quantity] = None, + recv_quantity: Optional[Quantity] = None, + ) -> Quantity: + """Transfer subtile regions of a full-tile quantity + from the tile root rank to all subtiles. + + Args: + send_quantity: quantity to send, only required/used on the tile root rank + recv_quantity: if provided, assign received data into this Quantity. + Returns: + recv_quantity + """ + if self.rank == constants.ROOT_RANK and send_quantity is None: + raise TypeError("send_quantity is a required argument on the root rank") + if self.rank == constants.ROOT_RANK: + send_quantity = cast(Quantity, send_quantity) + metadata = self.comm.bcast(send_quantity.metadata, root=constants.ROOT_RANK) + else: + metadata = self.comm.bcast(None, root=constants.ROOT_RANK) + shape = self.partitioner.subtile_extent(metadata, self.rank) + if recv_quantity is None: + recv_quantity = self._get_scatter_recv_quantity(shape, metadata) + if self.rank == constants.ROOT_RANK: + send_quantity = cast(Quantity, send_quantity) + with array_buffer( + self._maybe_force_cpu(metadata.np).zeros, + (self.partitioner.total_ranks,) + shape, + dtype=metadata.dtype, + ) as sendbuf: + for rank in range(0, self.partitioner.total_ranks): + subtile_slice = self.partitioner.subtile_slice( + rank=rank, + global_dims=metadata.dims, + global_extent=metadata.extent, + overlap=True, + ) + sendbuf.assign_from( + send_quantity.view[subtile_slice], + buffer_slice=np.index_exp[rank, :], + ) + self._Scatter( + metadata.np, + sendbuf.array, + recv_quantity.view[:], + root=constants.ROOT_RANK, + ) + else: + self._Scatter( + metadata.np, + None, + recv_quantity.view[:], + root=constants.ROOT_RANK, + ) + return recv_quantity + + def _get_gather_recv_quantity( + self, global_extent: Sequence[int], send_metadata: QuantityMetadata + ) -> Quantity: + """Initialize a Quantity for use when receiving global data during gather""" + recv_quantity = Quantity( + send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype), + dims=send_metadata.dims, + units=send_metadata.units, + origin=tuple([0 for dim in send_metadata.dims]), + extent=global_extent, + gt4py_backend=send_metadata.gt4py_backend, + allow_mismatch_float_precision=True, + ) + return recv_quantity + + def _get_scatter_recv_quantity( + self, shape: Sequence[int], send_metadata: QuantityMetadata + ) -> Quantity: + """Initialize a Quantity for use when receiving subtile data during scatter""" + recv_quantity = Quantity( + send_metadata.np.zeros(shape, dtype=send_metadata.dtype), + dims=send_metadata.dims, + units=send_metadata.units, + gt4py_backend=send_metadata.gt4py_backend, + allow_mismatch_float_precision=True, + ) + return recv_quantity + + def gather( + self, send_quantity: Quantity, recv_quantity: Quantity = None + ) -> Optional[Quantity]: + """Transfer subtile regions of a full-tile quantity + from each rank to the tile root rank. + + Args: + send_quantity: quantity to send + recv_quantity: if provided, assign received data into this Quantity (only + used on the tile root rank) + Returns: + recv_quantity: quantity if on root rank, otherwise None + """ + result: Optional[Quantity] + if self.rank == constants.ROOT_RANK: + with array_buffer( + send_quantity.np.zeros, + (self.partitioner.total_ranks,) + tuple(send_quantity.extent), + dtype=send_quantity.data.dtype, + ) as recvbuf: + self._Gather( + send_quantity.np, + send_quantity.view[:], + recvbuf.array, + root=constants.ROOT_RANK, + ) + if recv_quantity is None: + global_extent = self.partitioner.global_extent( + send_quantity.metadata + ) + recv_quantity = self._get_gather_recv_quantity( + global_extent, send_quantity.metadata + ) + for rank in range(self.partitioner.total_ranks): + to_slice = self.partitioner.subtile_slice( + rank=rank, + global_dims=recv_quantity.dims, + global_extent=recv_quantity.extent, + overlap=True, + ) + recvbuf.assign_to( + recv_quantity.view[to_slice], buffer_slice=np.index_exp[rank, :] + ) + result = recv_quantity + else: + self._Gather( + send_quantity.np, + send_quantity.view[:], + None, + root=constants.ROOT_RANK, + ) + result = None + return result + + def gather_state(self, send_state=None, recv_state=None, transfer_type=None): + """Transfer a state dictionary from subtile ranks to the tile root rank. + + 'time' is assumed to be the same on all ranks, and its value will be set + to the value from the root rank. + + Args: + send_state: the model state to be sent containing the subtile data + recv_state: the pre-allocated state in which to recieve the full tile + state. Only variables which are scattered will be written to. + Returns: + recv_state: on the root rank, the state containing the entire tile + """ + if self.rank == constants.ROOT_RANK and recv_state is None: + recv_state = {} + for name, quantity in send_state.items(): + if name == "time": + if self.rank == constants.ROOT_RANK: + recv_state["time"] = send_state["time"] + else: + gather_value = to_numpy(quantity.view[:], dtype=transfer_type) + gather_quantity = Quantity( + data=gather_value, + dims=quantity.dims, + units=quantity.units, + allow_mismatch_float_precision=True, + ) + if recv_state is not None and name in recv_state: + tile_quantity = self.gather( + gather_quantity, recv_quantity=recv_state[name] + ) + else: + tile_quantity = self.gather(gather_quantity) + if self.rank == constants.ROOT_RANK: + recv_state[name] = tile_quantity + del gather_quantity + return recv_state + + def scatter_state(self, send_state=None, recv_state=None): + """Transfer a state dictionary from the tile root rank to all subtiles. + + Args: + send_state: the model state to be sent containing the entire tile, + required only from the root rank + recv_state: the pre-allocated state in which to recieve the scattered + state. Only variables which are scattered will be written to. + Returns: + rank_state: the state corresponding to this rank's subdomain + """ + + def scatter_root(): + if send_state is None: + raise TypeError("send_state is a required argument on the root rank") + name_list = list(send_state.keys()) + while "time" in name_list: + name_list.remove("time") + name_list = self.comm.bcast(name_list, root=constants.ROOT_RANK) + array_list = [send_state[name] for name in name_list] + for name, array in zip(name_list, array_list): + if name in recv_state: + self.scatter(send_quantity=array, recv_quantity=recv_state[name]) + else: + recv_state[name] = self.scatter(send_quantity=array) + recv_state["time"] = self.comm.bcast( + send_state.get("time", None), root=constants.ROOT_RANK + ) + + def scatter_client(): + name_list = self.comm.bcast(None, root=constants.ROOT_RANK) + for name in name_list: + if name in recv_state: + self.scatter(recv_quantity=recv_state[name]) + else: + recv_state[name] = self.scatter() + recv_state["time"] = self.comm.bcast(None, root=constants.ROOT_RANK) + + if recv_state is None: + recv_state = {} + if self.rank == constants.ROOT_RANK: + scatter_root() + else: + scatter_client() + if recv_state["time"] is None: + recv_state.pop("time") + return recv_state + + def halo_update(self, quantity: Union[Quantity, List[Quantity]], n_points: int): + """Perform a halo update on a quantity or quantities + + Args: + quantity: the quantity to be updated + n_points: how many halo points to update, starting from the interior + """ + if isinstance(quantity, Quantity): + quantities = [quantity] + else: + quantities = quantity + + halo_updater = self.start_halo_update(quantities, n_points) + halo_updater.wait() + + def start_halo_update( + self, quantity: Union[Quantity, List[Quantity]], n_points: int + ) -> HaloUpdater: + """Start an asynchronous halo update on a quantity. + + Args: + quantity: the quantity to be updated + n_points: how many halo points to update, starting from the interior + + Returns: + request: an asynchronous request object with a .wait() method + """ + if isinstance(quantity, Quantity): + quantities = [quantity] + else: + quantities = quantity + + specifications = [] + for quantity in quantities: + specification = QuantityHaloSpec( + n_points=n_points, + shape=quantity.data.shape, + strides=quantity.data.strides, + itemsize=quantity.data.itemsize, + origin=quantity.origin, + extent=quantity.extent, + dims=quantity.dims, + numpy_module=self._maybe_force_cpu(quantity.np), + dtype=quantity.metadata.dtype, + ) + specifications.append(specification) + + halo_updater = self.get_scalar_halo_updater(specifications) + halo_updater.force_finalize_on_wait() + halo_updater.start(quantities) + return halo_updater + + def vector_halo_update( + self, + x_quantity: Union[Quantity, List[Quantity]], + y_quantity: Union[Quantity, List[Quantity]], + n_points: int, + ): + """Perform a halo update of a horizontal vector quantity or quantities. + + Assumes the x and y dimension indices are the same between the two quantities. + + Args: + x_quantity: the x-component quantity to be halo updated + y_quantity: the y-component quantity to be halo updated + n_points: how many halo points to update, starting at the interior + """ + if isinstance(x_quantity, Quantity): + x_quantities = [x_quantity] + else: + x_quantities = x_quantity + if isinstance(y_quantity, Quantity): + y_quantities = [y_quantity] + else: + y_quantities = y_quantity + + halo_updater = self.start_vector_halo_update( + x_quantities, y_quantities, n_points + ) + halo_updater.wait() + + def start_vector_halo_update( + self, + x_quantity: Union[Quantity, List[Quantity]], + y_quantity: Union[Quantity, List[Quantity]], + n_points: int, + ) -> HaloUpdater: + """Start an asynchronous halo update of a horizontal vector quantity. + + Assumes the x and y dimension indices are the same between the two quantities. + + Args: + x_quantity: the x-component quantity to be halo updated + y_quantity: the y-component quantity to be halo updated + n_points: how many halo points to update, starting at the interior + + Returns: + request: an asynchronous request object with a .wait() method + """ + if isinstance(x_quantity, Quantity): + x_quantities = [x_quantity] + else: + x_quantities = x_quantity + if isinstance(y_quantity, Quantity): + y_quantities = [y_quantity] + else: + y_quantities = y_quantity + + x_specifications = [] + y_specifications = [] + for x_quantity, y_quantity in zip(x_quantities, y_quantities): + x_specification = QuantityHaloSpec( + n_points=n_points, + shape=x_quantity.data.shape, + strides=x_quantity.data.strides, + itemsize=x_quantity.data.itemsize, + origin=x_quantity.metadata.origin, + extent=x_quantity.metadata.extent, + dims=x_quantity.metadata.dims, + numpy_module=self._maybe_force_cpu(x_quantity.np), + dtype=x_quantity.metadata.dtype, + ) + x_specifications.append(x_specification) + y_specification = QuantityHaloSpec( + n_points=n_points, + shape=y_quantity.data.shape, + strides=y_quantity.data.strides, + itemsize=y_quantity.data.itemsize, + origin=y_quantity.metadata.origin, + extent=y_quantity.metadata.extent, + dims=y_quantity.metadata.dims, + numpy_module=self._maybe_force_cpu(y_quantity.np), + dtype=y_quantity.metadata.dtype, + ) + y_specifications.append(y_specification) + + halo_updater = self.get_vector_halo_updater(x_specifications, y_specifications) + halo_updater.force_finalize_on_wait() + halo_updater.start(x_quantities, y_quantities) + return halo_updater + + def synchronize_vector_interfaces(self, x_quantity: Quantity, y_quantity: Quantity): + """ + Synchronize shared points at the edges of a vector interface variable. + + Sends the values on the south and west edges to overwrite the values on adjacent + subtiles. Vector must be defined on the Arakawa C grid. + + For interface variables, the edges of the tile are computed on both ranks + bordering that edge. This routine copies values across those shared edges + so that both ranks have the same value for that edge. It also handles any + rotation of vector quantities needed to move data across the edge. + + Args: + x_quantity: the x-component quantity to be synchronized + y_quantity: the y-component quantity to be synchronized + """ + req = self.start_synchronize_vector_interfaces(x_quantity, y_quantity) + req.wait() + + def start_synchronize_vector_interfaces( + self, x_quantity: Quantity, y_quantity: Quantity + ) -> HaloUpdateRequest: + """ + Synchronize shared points at the edges of a vector interface variable. + + Sends the values on the south and west edges to overwrite the values on adjacent + subtiles. Vector must be defined on the Arakawa C grid. + + For interface variables, the edges of the tile are computed on both ranks + bordering that edge. This routine copies values across those shared edges + so that both ranks have the same value for that edge. It also handles any + rotation of vector quantities needed to move data across the edge. + + Args: + x_quantity: the x-component quantity to be synchronized + y_quantity: the y-component quantity to be synchronized + + Returns: + request: an asynchronous request object with a .wait() method + """ + halo_updater = VectorInterfaceHaloUpdater( + comm=self.comm, + boundaries=self.boundaries, + force_cpu=self._force_cpu, + timer=self.timer, + ) + req = halo_updater.start_synchronize_vector_interfaces(x_quantity, y_quantity) + return req + + def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): + if len(specifications) == 0: + raise RuntimeError("Cannot create updater with specifications list") + if specifications[0].n_points == 0: + raise ValueError("cannot perform a halo update on zero halo points") + return HaloUpdater.from_scalar_specifications( + self, + self._maybe_force_cpu(specifications[0].numpy_module), + specifications, + self.boundaries.values(), + self._get_halo_tag(), + self.timer, + ) + + def get_vector_halo_updater( + self, + specifications_x: List[QuantityHaloSpec], + specifications_y: List[QuantityHaloSpec], + ): + if len(specifications_x) == 0 and len(specifications_y) == 0: + raise RuntimeError("Cannot create updater with empty specifications list") + if specifications_x[0].n_points == 0 and specifications_y[0].n_points == 0: + raise ValueError("Cannot perform a halo update on zero halo points") + return HaloUpdater.from_vector_specifications( + self, + self._maybe_force_cpu(specifications_x[0].numpy_module), + specifications_x, + specifications_y, + self.boundaries.values(), + self._get_halo_tag(), + self.timer, + ) + + def _get_halo_tag(self) -> int: + self._last_halo_tag += 1 + return self._last_halo_tag + + @property + def boundaries(self) -> Mapping[int, Boundary]: + """boundaries of this tile with neighboring tiles""" + if self._boundaries is None: + self._boundaries = {} + for boundary_type in constants.BOUNDARY_TYPES: + boundary = self.partitioner.boundary(boundary_type, self.rank) + if boundary is not None: + self._boundaries[boundary_type] = boundary + return self._boundaries + +class Partitioner(abc.ABC): + @abc.abstractmethod + def __init__(self): + self.tile = None + self.layout = None + + @abc.abstractmethod + def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: + ... + + @abc.abstractmethod + def tile_index(self, rank: int): + pass + + @abc.abstractmethod + def global_extent(self, rank_metadata: QuantityMetadata) -> Tuple[int, ...]: + """Return the shape of a full tile representation for the given dimensions. + + Args: + metadata: quantity metadata + + Returns: + extent: shape of full tile representation + """ + pass + + @abc.abstractmethod + def subtile_slice( + self, + rank: int, + global_dims: Sequence[str], + global_extent: Sequence[int], + overlap: bool = False, + ) -> Tuple[Union[int, slice], ...]: + """Return the subtile slice of a given rank on an array. + + Global refers to the domain being partitioned. For example, for a partitioning + of a tile, the tile would be the "global" domain. + + Args: + rank: the rank of the process + global_dims: dimensions of the global quantity being partitioned + global_extent: extent of the global quantity being partitioned + overlap (optional): if True, for interface variables include the part + of the array shared by adjacent ranks in both ranks. If False, ensure + only one of those ranks (the greater rank) is assigned the overlapping + section. Default is False. + + Returns: + subtile_slice: the slice of the global compute domain corresponding + to the subtile compute domain + """ + pass + + @abc.abstractmethod + def subtile_extent( + self, + global_metadata: QuantityMetadata, + rank: int, + ) -> Tuple[int, ...]: + """Return the shape of a single rank representation for the given dimensions. + + Args: + global_metadata: quantity metadata. + rank: rank of the process. + + Returns: + extent: shape of a single rank representation for the given dimensions. + """ + pass + + @property + @abc.abstractmethod + def total_ranks(self) -> int: + pass \ No newline at end of file diff --git a/tests/checkpointer/test_snapshot.py b/tests/checkpointer/test_snapshot.py index a8dd538..89d368e 100644 --- a/tests/checkpointer/test_snapshot.py +++ b/tests/checkpointer/test_snapshot.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ndsl import SnapshotCheckpointer +from ndsl.checkpointer import SnapshotCheckpointer from ndsl.optional_imports import xarray as xr diff --git a/tests/checkpointer/test_thresholds.py b/tests/checkpointer/test_thresholds.py index 90d1f8f..8bf70b0 100644 --- a/tests/checkpointer/test_thresholds.py +++ b/tests/checkpointer/test_thresholds.py @@ -1,7 +1,11 @@ import numpy as np import pytest -from ndsl import InsufficientTrialsError, Threshold, ThresholdCalibrationCheckpointer +from ndsl.checkpointer import ( + InsufficientTrialsError, + Threshold, + ThresholdCalibrationCheckpointer, +) def test_thresholds_no_trials(): diff --git a/tests/checkpointer/test_validation.py b/tests/checkpointer/test_validation.py index 091bb7c..0c08d52 100644 --- a/tests/checkpointer/test_validation.py +++ b/tests/checkpointer/test_validation.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from ndsl import SavepointThresholds, Threshold, ValidationCheckpointer +from ndsl.checkpointer import SavepointThresholds, Threshold, ValidationCheckpointer from ndsl.checkpointer.validation import _clip_pace_array_to_target from ndsl.optional_imports import xarray as xr diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index def0d34..79e71cf 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -2,7 +2,7 @@ import pytest from ndsl import ConcurrencyError, DummyComm -from ndsl.comm.communicator import recv_buffer +from ndsl.buffer import recv_buffer from tests.mpi.mpi_comm import MPI diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index b28eba1..5674bfc 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -5,8 +5,6 @@ import numpy as np from ndsl import ( - CachingCommReader, - CachingCommWriter, CubedSphereCommunicator, CubedSpherePartitioner, LocalComm, @@ -14,6 +12,7 @@ Quantity, TilePartitioner, ) +from ndsl.comm import CachingCommReader, CachingCommWriter from ndsl.constants import X_DIM, Y_DIM From a59736e18a9fa90f65eaae40e5379bab430ba4e5 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Tue, 5 Mar 2024 16:33:07 -0500 Subject: [PATCH 08/14] Linting --- ndsl/typing.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ndsl/typing.py b/ndsl/typing.py index 2f815d0..aa80c38 100644 --- a/ndsl/typing.py +++ b/ndsl/typing.py @@ -5,19 +5,21 @@ import ndsl.constants as constants from ndsl.buffer import array_buffer, recv_buffer, send_buffer +from ndsl.comm import boundary as bd from ndsl.comm.boundary import Boundary from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from ndsl.performance.timer import NullTimer, Timer from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata from ndsl.types import NumpyModule from ndsl.utils import device_synchronize -from ndsl.comm import boundary as bd + try: import cupy except ImportError: cupy = None + def to_numpy(array, dtype=None) -> np.ndarray: """ Input array can be a numpy array or a cupy array. Returns numpy array. @@ -40,11 +42,13 @@ def to_numpy(array, dtype=None) -> np.ndarray: output = output.astype(dtype=dtype) return output + class Checkpointer(abc.ABC): @abc.abstractmethod def __call__(self, savepoint_name, **kwargs): ... + class Communicator(abc.ABC): def __init__( self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None @@ -570,7 +574,8 @@ def boundaries(self) -> Mapping[int, Boundary]: if boundary is not None: self._boundaries[boundary_type] = boundary return self._boundaries - + + class Partitioner(abc.ABC): @abc.abstractmethod def __init__(self): @@ -645,4 +650,4 @@ def subtile_extent( @property @abc.abstractmethod def total_ranks(self) -> int: - pass \ No newline at end of file + pass From 4c4daac0eab3c4cb783313700b9e9115d1f9030d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 6 Mar 2024 12:34:52 -0500 Subject: [PATCH 09/14] Update to main to grab `GlobalTable` feature --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index d6dfd6f..66f8447 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit d6dfd6ff46cc1d50b0fb6d05fb0b6271e4a1f5cc +Subproject commit 66f8447398762127ba51c7a335d0da7ada369219 From 9f0477a6a8c5f752e076c833685d37a6f0c944dc Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 6 Mar 2024 15:20:38 -0500 Subject: [PATCH 10/14] Changes as of comments from 6 Mar 2024, from PR 14 --- ndsl/__init__.py | 38 +- ndsl/checkpointer/base.py | 7 + ndsl/checkpointer/null.py | 2 +- ndsl/checkpointer/snapshots.py | 2 +- ndsl/checkpointer/thresholds.py | 2 +- ndsl/checkpointer/validation.py | 2 +- ndsl/comm/communicator.py | 567 ++++++++++++++++++++- ndsl/comm/partitioner.py | 79 ++- ndsl/dsl/caches/cache_location.py | 2 +- ndsl/dsl/dace/dace_config.py | 3 +- ndsl/dsl/dace/wrapped_halo_exchange.py | 2 +- ndsl/dsl/stencil.py | 2 +- ndsl/dsl/stencil_config.py | 3 +- ndsl/grid/__init__.py | 2 +- ndsl/grid/generation.py | 2 +- ndsl/halo/__init__.py | 5 + ndsl/halo/updater.py | 2 +- ndsl/monitor/netcdf_monitor.py | 2 +- ndsl/monitor/zarr_monitor.py | 3 +- ndsl/restart/_legacy_restart.py | 2 +- ndsl/stencils/__init__.py | 16 - ndsl/stencils/c2l_ord.py | 2 +- ndsl/stencils/testing/__init__.py | 16 + ndsl/stencils/testing/conftest.py | 7 +- ndsl/typing.py | 662 +------------------------ tests/dsl/test_stencil_factory.py | 3 +- tests/mpi/test_mpi_mock.py | 3 +- tests/test_halo_data_transformer.py | 11 +- tests/test_halo_update.py | 2 +- 29 files changed, 714 insertions(+), 737 deletions(-) create mode 100644 ndsl/checkpointer/base.py diff --git a/ndsl/__init__.py b/ndsl/__init__.py index a507302..f8c6461 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,9 +1,7 @@ -from .buffer import Buffer -from .comm.boundary import Boundary, SimpleBoundary from .comm.communicator import CubedSphereCommunicator, TileCommunicator -from .comm.local_comm import AsyncResult, ConcurrencyError, LocalComm +from .comm.local_comm import LocalComm from .comm.mpi import MPIComm -from .comm.null_comm import NullAsyncResult, NullComm +from .comm.null_comm import NullComm from .comm.partitioner import CubedSpherePartitioner, TilePartitioner from .constants import ConstantVersions from .dsl.caches.codepath import FV3CodePath @@ -16,46 +14,24 @@ StorageReport, ) from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater -from .dsl.stencil import ( - CompareToNumpyStencil, - FrozenStencil, - GridIndexing, - StencilFactory, - TimingCollector, -) +from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from .exceptions import OutOfBoundsError -from .halo.data_transformer import ( - HaloDataTransformer, - HaloDataTransformerCPU, - HaloDataTransformerGPU, - HaloExchangeSpec, -) +from .halo.data_transformer import HaloExchangeSpec from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater -from .initialization.allocator import QuantityFactory, StorageNumpy +from .initialization.allocator import QuantityFactory from .initialization.sizer import GridSizer, SubtileGridSizer from .logging import ndsl_log from .monitor.netcdf_monitor import NetCDFMonitor from .monitor.protocol import Protocol from .monitor.zarr_monitor import ZarrMonitor from .namelist import Namelist -from .optional_imports import RaiseWhenAccessed -from .performance.collector import ( - AbstractPerformanceCollector, - NullPerformanceCollector, - PerformanceCollector, -) +from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.config import PerformanceConfig from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport from .performance.timer import NullTimer, Timer -from .quantity import ( - BoundaryArrayView, - BoundedArrayView, - Quantity, - QuantityHaloSpec, - QuantityMetadata, -) +from .quantity import Quantity from .testing.dummy_comm import DummyComm from .types import Allocator, AsyncRequest, NumpyModule from .units import UnitsError diff --git a/ndsl/checkpointer/base.py b/ndsl/checkpointer/base.py new file mode 100644 index 0000000..8218bbf --- /dev/null +++ b/ndsl/checkpointer/base.py @@ -0,0 +1,7 @@ +import abc + + +class Checkpointer(abc.ABC): + @abc.abstractmethod + def __call__(self, savepoint_name, **kwargs): + ... diff --git a/ndsl/checkpointer/null.py b/ndsl/checkpointer/null.py index 448b3a6..fbc7875 100644 --- a/ndsl/checkpointer/null.py +++ b/ndsl/checkpointer/null.py @@ -1,4 +1,4 @@ -from ndsl.typing import Checkpointer +from ndsl.checkpointer.base import Checkpointer class NullCheckpointer(Checkpointer): diff --git a/ndsl/checkpointer/snapshots.py b/ndsl/checkpointer/snapshots.py index 573701a..aa806b2 100644 --- a/ndsl/checkpointer/snapshots.py +++ b/ndsl/checkpointer/snapshots.py @@ -2,9 +2,9 @@ import numpy as np +from ndsl.checkpointer.base import Checkpointer from ndsl.optional_imports import cupy as cp from ndsl.optional_imports import xarray as xr -from ndsl.typing import Checkpointer def make_dims(savepoint_dim, label, data_list): diff --git a/ndsl/checkpointer/thresholds.py b/ndsl/checkpointer/thresholds.py index 2f1af55..ded73b3 100644 --- a/ndsl/checkpointer/thresholds.py +++ b/ndsl/checkpointer/thresholds.py @@ -5,8 +5,8 @@ import numpy as np +from ndsl.checkpointer.base import Checkpointer from ndsl.quantity import Quantity -from ndsl.typing import Checkpointer try: diff --git a/ndsl/checkpointer/validation.py b/ndsl/checkpointer/validation.py index 12146a5..8af1131 100644 --- a/ndsl/checkpointer/validation.py +++ b/ndsl/checkpointer/validation.py @@ -5,6 +5,7 @@ import numpy as np +from ndsl.checkpointer.base import Checkpointer from ndsl.checkpointer.thresholds import ( ArrayLike, SavepointName, @@ -12,7 +13,6 @@ cast_to_ndarray, ) from ndsl.optional_imports import xarray as xr -from ndsl.typing import Checkpointer def _clip_pace_array_to_target( diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 3f21ee2..f1b97c8 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -1,11 +1,16 @@ -from typing import List, Optional, Sequence, Tuple, Union, cast +import abc +from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast + +import numpy as np import ndsl.constants as constants -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner -from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity, QuantityMetadata -from ndsl.typing import Communicator +from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer +from ndsl.comm.boundary import Boundary +from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner +from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater +from ndsl.performance.timer import NullTimer, Timer +from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata +from ndsl.types import NumpyModule try: @@ -14,6 +19,556 @@ cupy = None +def to_numpy(array, dtype=None) -> np.ndarray: + """ + Input array can be a numpy array or a cupy array. Returns numpy array. + """ + try: + output = np.asarray(array) + except ValueError as err: + if err.args[0] == "object __array__ method not producing an array": + output = cupy.asnumpy(array) + else: + raise err + except TypeError as err: + if err.args[0].startswith( + "Implicit conversion to a NumPy array is not allowed." + ): + output = cupy.asnumpy(array) + else: + raise err + if dtype: + output = output.astype(dtype=dtype) + return output + + +class Communicator(abc.ABC): + def __init__( + self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None + ): + self.comm = comm + self.partitioner: Partitioner = partitioner + self._force_cpu = force_cpu + self._boundaries: Optional[Mapping[int, Boundary]] = None + self._last_halo_tag = 0 + self.timer: Timer = timer if timer is not None else NullTimer() + + @abc.abstractproperty + def tile(self): + pass + + @classmethod + @abc.abstractmethod + def from_layout( + cls, + comm, + layout: Tuple[int, int], + force_cpu: bool = False, + timer: Optional[Timer] = None, + ): + pass + + @property + def rank(self) -> int: + """rank of the current process within this communicator""" + return self.comm.Get_rank() + + @property + def size(self) -> int: + """Total number of ranks in this communicator""" + return self.comm.Get_size() + + def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: + """ + Get a numpy-like module depending on configuration and + Quantity original allocator. + """ + if self._force_cpu: + return np + return module + + @staticmethod + def _device_synchronize(): + """Wait for all work that could be in-flight to finish.""" + # this is a method so we can profile it separately from other device syncs + device_synchronize() + + def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): + with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( + numpy_module.zeros, recvbuf + ) as recv: + self.comm.Scatter(send, recv, **kwargs) + + def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): + with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( + numpy_module.zeros, recvbuf + ) as recv: + self.comm.Gather(send, recv, **kwargs) + + def scatter( + self, + send_quantity: Optional[Quantity] = None, + recv_quantity: Optional[Quantity] = None, + ) -> Quantity: + """Transfer subtile regions of a full-tile quantity + from the tile root rank to all subtiles. + + Args: + send_quantity: quantity to send, only required/used on the tile root rank + recv_quantity: if provided, assign received data into this Quantity. + Returns: + recv_quantity + """ + if self.rank == constants.ROOT_RANK and send_quantity is None: + raise TypeError("send_quantity is a required argument on the root rank") + if self.rank == constants.ROOT_RANK: + send_quantity = cast(Quantity, send_quantity) + metadata = self.comm.bcast(send_quantity.metadata, root=constants.ROOT_RANK) + else: + metadata = self.comm.bcast(None, root=constants.ROOT_RANK) + shape = self.partitioner.subtile_extent(metadata, self.rank) + if recv_quantity is None: + recv_quantity = self._get_scatter_recv_quantity(shape, metadata) + if self.rank == constants.ROOT_RANK: + send_quantity = cast(Quantity, send_quantity) + with array_buffer( + self._maybe_force_cpu(metadata.np).zeros, + (self.partitioner.total_ranks,) + shape, + dtype=metadata.dtype, + ) as sendbuf: + for rank in range(0, self.partitioner.total_ranks): + subtile_slice = self.partitioner.subtile_slice( + rank=rank, + global_dims=metadata.dims, + global_extent=metadata.extent, + overlap=True, + ) + sendbuf.assign_from( + send_quantity.view[subtile_slice], + buffer_slice=np.index_exp[rank, :], + ) + self._Scatter( + metadata.np, + sendbuf.array, + recv_quantity.view[:], + root=constants.ROOT_RANK, + ) + else: + self._Scatter( + metadata.np, + None, + recv_quantity.view[:], + root=constants.ROOT_RANK, + ) + return recv_quantity + + def _get_gather_recv_quantity( + self, global_extent: Sequence[int], send_metadata: QuantityMetadata + ) -> Quantity: + """Initialize a Quantity for use when receiving global data during gather""" + recv_quantity = Quantity( + send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype), + dims=send_metadata.dims, + units=send_metadata.units, + origin=tuple([0 for dim in send_metadata.dims]), + extent=global_extent, + gt4py_backend=send_metadata.gt4py_backend, + allow_mismatch_float_precision=True, + ) + return recv_quantity + + def _get_scatter_recv_quantity( + self, shape: Sequence[int], send_metadata: QuantityMetadata + ) -> Quantity: + """Initialize a Quantity for use when receiving subtile data during scatter""" + recv_quantity = Quantity( + send_metadata.np.zeros(shape, dtype=send_metadata.dtype), + dims=send_metadata.dims, + units=send_metadata.units, + gt4py_backend=send_metadata.gt4py_backend, + allow_mismatch_float_precision=True, + ) + return recv_quantity + + def gather( + self, send_quantity: Quantity, recv_quantity: Quantity = None + ) -> Optional[Quantity]: + """Transfer subtile regions of a full-tile quantity + from each rank to the tile root rank. + + Args: + send_quantity: quantity to send + recv_quantity: if provided, assign received data into this Quantity (only + used on the tile root rank) + Returns: + recv_quantity: quantity if on root rank, otherwise None + """ + result: Optional[Quantity] + if self.rank == constants.ROOT_RANK: + with array_buffer( + send_quantity.np.zeros, + (self.partitioner.total_ranks,) + tuple(send_quantity.extent), + dtype=send_quantity.data.dtype, + ) as recvbuf: + self._Gather( + send_quantity.np, + send_quantity.view[:], + recvbuf.array, + root=constants.ROOT_RANK, + ) + if recv_quantity is None: + global_extent = self.partitioner.global_extent( + send_quantity.metadata + ) + recv_quantity = self._get_gather_recv_quantity( + global_extent, send_quantity.metadata + ) + for rank in range(self.partitioner.total_ranks): + to_slice = self.partitioner.subtile_slice( + rank=rank, + global_dims=recv_quantity.dims, + global_extent=recv_quantity.extent, + overlap=True, + ) + recvbuf.assign_to( + recv_quantity.view[to_slice], buffer_slice=np.index_exp[rank, :] + ) + result = recv_quantity + else: + self._Gather( + send_quantity.np, + send_quantity.view[:], + None, + root=constants.ROOT_RANK, + ) + result = None + return result + + def gather_state(self, send_state=None, recv_state=None, transfer_type=None): + """Transfer a state dictionary from subtile ranks to the tile root rank. + + 'time' is assumed to be the same on all ranks, and its value will be set + to the value from the root rank. + + Args: + send_state: the model state to be sent containing the subtile data + recv_state: the pre-allocated state in which to recieve the full tile + state. Only variables which are scattered will be written to. + Returns: + recv_state: on the root rank, the state containing the entire tile + """ + if self.rank == constants.ROOT_RANK and recv_state is None: + recv_state = {} + for name, quantity in send_state.items(): + if name == "time": + if self.rank == constants.ROOT_RANK: + recv_state["time"] = send_state["time"] + else: + gather_value = to_numpy(quantity.view[:], dtype=transfer_type) + gather_quantity = Quantity( + data=gather_value, + dims=quantity.dims, + units=quantity.units, + allow_mismatch_float_precision=True, + ) + if recv_state is not None and name in recv_state: + tile_quantity = self.gather( + gather_quantity, recv_quantity=recv_state[name] + ) + else: + tile_quantity = self.gather(gather_quantity) + if self.rank == constants.ROOT_RANK: + recv_state[name] = tile_quantity + del gather_quantity + return recv_state + + def scatter_state(self, send_state=None, recv_state=None): + """Transfer a state dictionary from the tile root rank to all subtiles. + + Args: + send_state: the model state to be sent containing the entire tile, + required only from the root rank + recv_state: the pre-allocated state in which to recieve the scattered + state. Only variables which are scattered will be written to. + Returns: + rank_state: the state corresponding to this rank's subdomain + """ + + def scatter_root(): + if send_state is None: + raise TypeError("send_state is a required argument on the root rank") + name_list = list(send_state.keys()) + while "time" in name_list: + name_list.remove("time") + name_list = self.comm.bcast(name_list, root=constants.ROOT_RANK) + array_list = [send_state[name] for name in name_list] + for name, array in zip(name_list, array_list): + if name in recv_state: + self.scatter(send_quantity=array, recv_quantity=recv_state[name]) + else: + recv_state[name] = self.scatter(send_quantity=array) + recv_state["time"] = self.comm.bcast( + send_state.get("time", None), root=constants.ROOT_RANK + ) + + def scatter_client(): + name_list = self.comm.bcast(None, root=constants.ROOT_RANK) + for name in name_list: + if name in recv_state: + self.scatter(recv_quantity=recv_state[name]) + else: + recv_state[name] = self.scatter() + recv_state["time"] = self.comm.bcast(None, root=constants.ROOT_RANK) + + if recv_state is None: + recv_state = {} + if self.rank == constants.ROOT_RANK: + scatter_root() + else: + scatter_client() + if recv_state["time"] is None: + recv_state.pop("time") + return recv_state + + def halo_update(self, quantity: Union[Quantity, List[Quantity]], n_points: int): + """Perform a halo update on a quantity or quantities + + Args: + quantity: the quantity to be updated + n_points: how many halo points to update, starting from the interior + """ + if isinstance(quantity, Quantity): + quantities = [quantity] + else: + quantities = quantity + + halo_updater = self.start_halo_update(quantities, n_points) + halo_updater.wait() + + def start_halo_update( + self, quantity: Union[Quantity, List[Quantity]], n_points: int + ) -> HaloUpdater: + """Start an asynchronous halo update on a quantity. + + Args: + quantity: the quantity to be updated + n_points: how many halo points to update, starting from the interior + + Returns: + request: an asynchronous request object with a .wait() method + """ + if isinstance(quantity, Quantity): + quantities = [quantity] + else: + quantities = quantity + + specifications = [] + for quantity in quantities: + specification = QuantityHaloSpec( + n_points=n_points, + shape=quantity.data.shape, + strides=quantity.data.strides, + itemsize=quantity.data.itemsize, + origin=quantity.origin, + extent=quantity.extent, + dims=quantity.dims, + numpy_module=self._maybe_force_cpu(quantity.np), + dtype=quantity.metadata.dtype, + ) + specifications.append(specification) + + halo_updater = self.get_scalar_halo_updater(specifications) + halo_updater.force_finalize_on_wait() + halo_updater.start(quantities) + return halo_updater + + def vector_halo_update( + self, + x_quantity: Union[Quantity, List[Quantity]], + y_quantity: Union[Quantity, List[Quantity]], + n_points: int, + ): + """Perform a halo update of a horizontal vector quantity or quantities. + + Assumes the x and y dimension indices are the same between the two quantities. + + Args: + x_quantity: the x-component quantity to be halo updated + y_quantity: the y-component quantity to be halo updated + n_points: how many halo points to update, starting at the interior + """ + if isinstance(x_quantity, Quantity): + x_quantities = [x_quantity] + else: + x_quantities = x_quantity + if isinstance(y_quantity, Quantity): + y_quantities = [y_quantity] + else: + y_quantities = y_quantity + + halo_updater = self.start_vector_halo_update( + x_quantities, y_quantities, n_points + ) + halo_updater.wait() + + def start_vector_halo_update( + self, + x_quantity: Union[Quantity, List[Quantity]], + y_quantity: Union[Quantity, List[Quantity]], + n_points: int, + ) -> HaloUpdater: + """Start an asynchronous halo update of a horizontal vector quantity. + + Assumes the x and y dimension indices are the same between the two quantities. + + Args: + x_quantity: the x-component quantity to be halo updated + y_quantity: the y-component quantity to be halo updated + n_points: how many halo points to update, starting at the interior + + Returns: + request: an asynchronous request object with a .wait() method + """ + if isinstance(x_quantity, Quantity): + x_quantities = [x_quantity] + else: + x_quantities = x_quantity + if isinstance(y_quantity, Quantity): + y_quantities = [y_quantity] + else: + y_quantities = y_quantity + + x_specifications = [] + y_specifications = [] + for x_quantity, y_quantity in zip(x_quantities, y_quantities): + x_specification = QuantityHaloSpec( + n_points=n_points, + shape=x_quantity.data.shape, + strides=x_quantity.data.strides, + itemsize=x_quantity.data.itemsize, + origin=x_quantity.metadata.origin, + extent=x_quantity.metadata.extent, + dims=x_quantity.metadata.dims, + numpy_module=self._maybe_force_cpu(x_quantity.np), + dtype=x_quantity.metadata.dtype, + ) + x_specifications.append(x_specification) + y_specification = QuantityHaloSpec( + n_points=n_points, + shape=y_quantity.data.shape, + strides=y_quantity.data.strides, + itemsize=y_quantity.data.itemsize, + origin=y_quantity.metadata.origin, + extent=y_quantity.metadata.extent, + dims=y_quantity.metadata.dims, + numpy_module=self._maybe_force_cpu(y_quantity.np), + dtype=y_quantity.metadata.dtype, + ) + y_specifications.append(y_specification) + + halo_updater = self.get_vector_halo_updater(x_specifications, y_specifications) + halo_updater.force_finalize_on_wait() + halo_updater.start(x_quantities, y_quantities) + return halo_updater + + def synchronize_vector_interfaces(self, x_quantity: Quantity, y_quantity: Quantity): + """ + Synchronize shared points at the edges of a vector interface variable. + + Sends the values on the south and west edges to overwrite the values on adjacent + subtiles. Vector must be defined on the Arakawa C grid. + + For interface variables, the edges of the tile are computed on both ranks + bordering that edge. This routine copies values across those shared edges + so that both ranks have the same value for that edge. It also handles any + rotation of vector quantities needed to move data across the edge. + + Args: + x_quantity: the x-component quantity to be synchronized + y_quantity: the y-component quantity to be synchronized + """ + req = self.start_synchronize_vector_interfaces(x_quantity, y_quantity) + req.wait() + + def start_synchronize_vector_interfaces( + self, x_quantity: Quantity, y_quantity: Quantity + ) -> HaloUpdateRequest: + """ + Synchronize shared points at the edges of a vector interface variable. + + Sends the values on the south and west edges to overwrite the values on adjacent + subtiles. Vector must be defined on the Arakawa C grid. + + For interface variables, the edges of the tile are computed on both ranks + bordering that edge. This routine copies values across those shared edges + so that both ranks have the same value for that edge. It also handles any + rotation of vector quantities needed to move data across the edge. + + Args: + x_quantity: the x-component quantity to be synchronized + y_quantity: the y-component quantity to be synchronized + + Returns: + request: an asynchronous request object with a .wait() method + """ + halo_updater = VectorInterfaceHaloUpdater( + comm=self.comm, + boundaries=self.boundaries, + force_cpu=self._force_cpu, + timer=self.timer, + ) + req = halo_updater.start_synchronize_vector_interfaces(x_quantity, y_quantity) + return req + + def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): + if len(specifications) == 0: + raise RuntimeError("Cannot create updater with specifications list") + if specifications[0].n_points == 0: + raise ValueError("cannot perform a halo update on zero halo points") + return HaloUpdater.from_scalar_specifications( + self, + self._maybe_force_cpu(specifications[0].numpy_module), + specifications, + self.boundaries.values(), + self._get_halo_tag(), + self.timer, + ) + + def get_vector_halo_updater( + self, + specifications_x: List[QuantityHaloSpec], + specifications_y: List[QuantityHaloSpec], + ): + if len(specifications_x) == 0 and len(specifications_y) == 0: + raise RuntimeError("Cannot create updater with empty specifications list") + if specifications_x[0].n_points == 0 and specifications_y[0].n_points == 0: + raise ValueError("Cannot perform a halo update on zero halo points") + return HaloUpdater.from_vector_specifications( + self, + self._maybe_force_cpu(specifications_x[0].numpy_module), + specifications_x, + specifications_y, + self.boundaries.values(), + self._get_halo_tag(), + self.timer, + ) + + def _get_halo_tag(self) -> int: + self._last_halo_tag += 1 + return self._last_halo_tag + + @property + def boundaries(self) -> Mapping[int, Boundary]: + """boundaries of this tile with neighboring tiles""" + if self._boundaries is None: + self._boundaries = {} + for boundary_type in constants.BOUNDARY_TYPES: + boundary = self.partitioner.boundary(boundary_type, self.rank) + if boundary is not None: + self._boundaries[boundary_type] = boundary + return self._boundaries + + def bcast_metadata_list(comm, quantity_list): is_root = comm.Get_rank() == constants.ROOT_RANK if is_root: diff --git a/ndsl/comm/partitioner.py b/ndsl/comm/partitioner.py index e3b2e02..6b8750a 100644 --- a/ndsl/comm/partitioner.py +++ b/ndsl/comm/partitioner.py @@ -1,3 +1,4 @@ +import abc import copy import functools from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union, cast @@ -17,7 +18,6 @@ WEST, ) from ndsl.quantity import Quantity, QuantityMetadata -from ndsl.typing import Partitioner from ndsl.utils import list_by_dims @@ -54,6 +54,83 @@ def get_tile_number(tile_rank: int, total_ranks: int) -> int: return tile_rank // ranks_per_tile + 1 +class Partitioner(abc.ABC): + @abc.abstractmethod + def __init__(self): + self.tile = None + self.layout = None + + @abc.abstractmethod + def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: + ... + + @abc.abstractmethod + def tile_index(self, rank: int): + pass + + @abc.abstractmethod + def global_extent(self, rank_metadata: QuantityMetadata) -> Tuple[int, ...]: + """Return the shape of a full tile representation for the given dimensions. + + Args: + metadata: quantity metadata + + Returns: + extent: shape of full tile representation + """ + pass + + @abc.abstractmethod + def subtile_slice( + self, + rank: int, + global_dims: Sequence[str], + global_extent: Sequence[int], + overlap: bool = False, + ) -> Tuple[Union[int, slice], ...]: + """Return the subtile slice of a given rank on an array. + + Global refers to the domain being partitioned. For example, for a partitioning + of a tile, the tile would be the "global" domain. + + Args: + rank: the rank of the process + global_dims: dimensions of the global quantity being partitioned + global_extent: extent of the global quantity being partitioned + overlap (optional): if True, for interface variables include the part + of the array shared by adjacent ranks in both ranks. If False, ensure + only one of those ranks (the greater rank) is assigned the overlapping + section. Default is False. + + Returns: + subtile_slice: the slice of the global compute domain corresponding + to the subtile compute domain + """ + pass + + @abc.abstractmethod + def subtile_extent( + self, + global_metadata: QuantityMetadata, + rank: int, + ) -> Tuple[int, ...]: + """Return the shape of a single rank representation for the given dimensions. + + Args: + global_metadata: quantity metadata. + rank: rank of the process. + + Returns: + extent: shape of a single rank representation for the given dimensions. + """ + pass + + @property + @abc.abstractmethod + def total_ranks(self) -> int: + pass + + class TilePartitioner(Partitioner): def __init__( self, diff --git a/ndsl/dsl/caches/cache_location.py b/ndsl/dsl/caches/cache_location.py index 2d973f7..edf563b 100644 --- a/ndsl/dsl/caches/cache_location.py +++ b/ndsl/dsl/caches/cache_location.py @@ -1,5 +1,5 @@ +from ndsl.comm.partitioner import Partitioner from ndsl.dsl.caches.codepath import FV3CodePath -from ndsl.typing import Partitioner def identify_code_path( diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index f93d2ba..7f1c147 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -6,12 +6,13 @@ from dace.codegen.compiled_sdfg import CompiledSDFG from dace.frontend.python.parser import DaceProgram +from ndsl.comm.communicator import Communicator +from ndsl.comm.partitioner import Partitioner from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath from ndsl.dsl.gt4py_utils import is_gpu_backend from ndsl.dsl.typing import floating_point_precision from ndsl.optional_imports import cupy as cp -from ndsl.typing import Communicator, Partitioner # This can be turned on to revert compilation for orchestration diff --git a/ndsl/dsl/dace/wrapped_halo_exchange.py b/ndsl/dsl/dace/wrapped_halo_exchange.py index ca36f3a..78a68fa 100644 --- a/ndsl/dsl/dace/wrapped_halo_exchange.py +++ b/ndsl/dsl/dace/wrapped_halo_exchange.py @@ -1,9 +1,9 @@ import dataclasses from typing import Any, List, Optional +from ndsl.comm.communicator import Communicator from ndsl.dsl.dace.orchestration import dace_inhibitor from ndsl.halo.updater import HaloUpdater -from ndsl.typing import Communicator class WrappedHaloUpdater: diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index f57c139..b831672 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -23,6 +23,7 @@ from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline from ndsl.comm.comm_abc import Comm +from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles from ndsl.comm.mpi import MPI from ndsl.constants import X_DIM, X_DIMS, Y_DIM, Y_DIMS, Z_DIM, Z_DIMS @@ -34,7 +35,6 @@ # from ndsl import testing from ndsl.testing import comparison -from ndsl.typing import Communicator try: diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index e1e233b..6b8f75e 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -5,10 +5,11 @@ from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline +from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import determine_rank_is_compiling, set_distributed_caches +from ndsl.comm.partitioner import Partitioner from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration from ndsl.dsl.gt4py_utils import is_gpu_backend -from ndsl.typing import Communicator, Partitioner class RunMode(enum.Enum): diff --git a/ndsl/grid/__init__.py b/ndsl/grid/__init__.py index 49eccf0..fabe72b 100644 --- a/ndsl/grid/__init__.py +++ b/ndsl/grid/__init__.py @@ -1,5 +1,5 @@ from .eta import HybridPressureCoefficients -from .generation import GridDefinition, GridDefinitions, MetricTerms +from .generation import GridDefinitions, MetricTerms from .helper import ( AngleGridData, ContravariantGridData, diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 2d6450a..12275d7 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -5,6 +5,7 @@ import numpy as np +from ndsl.comm.communicator import Communicator from ndsl.constants import ( N_HALO_DEFAULT, PI, @@ -58,7 +59,6 @@ fill_corners_cgrid, fill_corners_dgrid, ) -from ndsl.typing import Communicator # TODO: when every environment in python3.8, remove diff --git a/ndsl/halo/__init__.py b/ndsl/halo/__init__.py index e69de29..e16177d 100644 --- a/ndsl/halo/__init__.py +++ b/ndsl/halo/__init__.py @@ -0,0 +1,5 @@ +from .data_transformer import ( + HaloDataTransformer, + HaloDataTransformerCPU, + HaloDataTransformerGPU, +) diff --git a/ndsl/halo/updater.py b/ndsl/halo/updater.py index 7684c56..665d0b9 100644 --- a/ndsl/halo/updater.py +++ b/ndsl/halo/updater.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: - from ndsl.typing import Communicator + from ndsl.comm.communicator import Communicator _HaloSendTuple = Tuple[AsyncRequest, Buffer] _HaloRecvTuple = Tuple[AsyncRequest, Buffer, np.ndarray] diff --git a/ndsl/monitor/netcdf_monitor.py b/ndsl/monitor/netcdf_monitor.py index 3073109..8a0b96f 100644 --- a/ndsl/monitor/netcdf_monitor.py +++ b/ndsl/monitor/netcdf_monitor.py @@ -5,12 +5,12 @@ import fsspec import numpy as np +from ndsl.comm.communicator import Communicator from ndsl.filesystem import get_fs from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy from ndsl.optional_imports import xarray as xr from ndsl.quantity import Quantity -from ndsl.typing import Communicator class _TimeChunkedVariable: diff --git a/ndsl/monitor/zarr_monitor.py b/ndsl/monitor/zarr_monitor.py index 85e3722..214171b 100644 --- a/ndsl/monitor/zarr_monitor.py +++ b/ndsl/monitor/zarr_monitor.py @@ -4,13 +4,12 @@ import cftime import ndsl.constants as constants -from ndsl.comm.partitioner import subtile_slice +from ndsl.comm.partitioner import Partitioner, subtile_slice from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy from ndsl.optional_imports import cupy from ndsl.optional_imports import xarray as xr from ndsl.optional_imports import zarr -from ndsl.typing import Partitioner from ndsl.utils import list_by_dims diff --git a/ndsl/restart/_legacy_restart.py b/ndsl/restart/_legacy_restart.py index afa4d52..01f9bdb 100644 --- a/ndsl/restart/_legacy_restart.py +++ b/ndsl/restart/_legacy_restart.py @@ -5,11 +5,11 @@ import ndsl.constants as constants import ndsl.filesystem as filesystem import ndsl.io as io +from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import get_tile_index from ndsl.optional_imports import xarray as xr from ndsl.quantity import Quantity from ndsl.restart._properties import RESTART_PROPERTIES, RestartProperties -from ndsl.typing import Communicator __all__ = ["open_restart"] diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index 0fe1672..641e032 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1,21 +1,5 @@ from .c2l_ord import CubedToLatLon from .corners import CopyCorners, CopyCornersXY, FillCornersBGrid -from .testing.grid import Grid # type: ignore -from .testing.parallel_translate import ( - ParallelTranslate, - ParallelTranslate2Py, - ParallelTranslate2PyState, - ParallelTranslateBaseSlicing, - ParallelTranslateGrid, -) -from .testing.savepoint import SavepointCase, Translate, dataset_to_dict -from .testing.temporaries import assert_same_temporaries, copy_temporaries -from .testing.translate import ( - TranslateFortranData2Py, - TranslateGrid, - pad_field_in_j, - read_serialized_data, -) __version__ = "0.2.0" diff --git a/ndsl/stencils/c2l_ord.py b/ndsl/stencils/c2l_ord.py index 87d59f2..67f2b5a 100644 --- a/ndsl/stencils/c2l_ord.py +++ b/ndsl/stencils/c2l_ord.py @@ -8,13 +8,13 @@ ) import ndsl.dsl.gt4py_utils as utils +from ndsl.comm.communicator import Communicator from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM from ndsl.dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater from ndsl.dsl.stencil import StencilFactory from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ from ndsl.grid.helper import GridData from ndsl.initialization.allocator import QuantityFactory -from ndsl.typing import Communicator A1 = 0.5625 diff --git a/ndsl/stencils/testing/__init__.py b/ndsl/stencils/testing/__init__.py index e69de29..4be2c60 100644 --- a/ndsl/stencils/testing/__init__.py +++ b/ndsl/stencils/testing/__init__.py @@ -0,0 +1,16 @@ +from .grid import Grid # type: ignore +from .parallel_translate import ( + ParallelTranslate, + ParallelTranslate2Py, + ParallelTranslate2PyState, + ParallelTranslateBaseSlicing, + ParallelTranslateGrid, +) +from .savepoint import SavepointCase, Translate, dataset_to_dict +from .temporaries import assert_same_temporaries, copy_temporaries +from .translate import ( + TranslateFortranData2Py, + TranslateGrid, + pad_field_in_j, + read_serialized_data, +) diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index b3a3a7e..d000e1f 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -8,7 +8,11 @@ import yaml import ndsl.dsl -from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator +from ndsl.comm.communicator import ( + Communicator, + CubedSphereCommunicator, + TileCommunicator, +) from ndsl.comm.mpi import MPI from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig @@ -16,7 +20,6 @@ from ndsl.stencils.testing.parallel_translate import ParallelTranslate from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict from ndsl.stencils.testing.translate import TranslateGrid -from ndsl.typing import Communicator @pytest.fixture() diff --git a/ndsl/typing.py b/ndsl/typing.py index aa80c38..03f9624 100644 --- a/ndsl/typing.py +++ b/ndsl/typing.py @@ -1,653 +1,9 @@ -import abc -from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast - -import numpy as np - -import ndsl.constants as constants -from ndsl.buffer import array_buffer, recv_buffer, send_buffer -from ndsl.comm import boundary as bd -from ndsl.comm.boundary import Boundary -from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater -from ndsl.performance.timer import NullTimer, Timer -from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata -from ndsl.types import NumpyModule -from ndsl.utils import device_synchronize - - -try: - import cupy -except ImportError: - cupy = None - - -def to_numpy(array, dtype=None) -> np.ndarray: - """ - Input array can be a numpy array or a cupy array. Returns numpy array. - """ - try: - output = np.asarray(array) - except ValueError as err: - if err.args[0] == "object __array__ method not producing an array": - output = cupy.asnumpy(array) - else: - raise err - except TypeError as err: - if err.args[0].startswith( - "Implicit conversion to a NumPy array is not allowed." - ): - output = cupy.asnumpy(array) - else: - raise err - if dtype: - output = output.astype(dtype=dtype) - return output - - -class Checkpointer(abc.ABC): - @abc.abstractmethod - def __call__(self, savepoint_name, **kwargs): - ... - - -class Communicator(abc.ABC): - def __init__( - self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None - ): - self.comm = comm - self.partitioner: Partitioner = partitioner - self._force_cpu = force_cpu - self._boundaries: Optional[Mapping[int, Boundary]] = None - self._last_halo_tag = 0 - self.timer: Timer = timer if timer is not None else NullTimer() - - @abc.abstractproperty - def tile(self): - pass - - @classmethod - @abc.abstractmethod - def from_layout( - cls, - comm, - layout: Tuple[int, int], - force_cpu: bool = False, - timer: Optional[Timer] = None, - ): - pass - - @property - def rank(self) -> int: - """rank of the current process within this communicator""" - return self.comm.Get_rank() - - @property - def size(self) -> int: - """Total number of ranks in this communicator""" - return self.comm.Get_size() - - def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: - """ - Get a numpy-like module depending on configuration and - Quantity original allocator. - """ - if self._force_cpu: - return np - return module - - @staticmethod - def _device_synchronize(): - """Wait for all work that could be in-flight to finish.""" - # this is a method so we can profile it separately from other device syncs - device_synchronize() - - def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: - self.comm.Scatter(send, recv, **kwargs) - - def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: - self.comm.Gather(send, recv, **kwargs) - - def scatter( - self, - send_quantity: Optional[Quantity] = None, - recv_quantity: Optional[Quantity] = None, - ) -> Quantity: - """Transfer subtile regions of a full-tile quantity - from the tile root rank to all subtiles. - - Args: - send_quantity: quantity to send, only required/used on the tile root rank - recv_quantity: if provided, assign received data into this Quantity. - Returns: - recv_quantity - """ - if self.rank == constants.ROOT_RANK and send_quantity is None: - raise TypeError("send_quantity is a required argument on the root rank") - if self.rank == constants.ROOT_RANK: - send_quantity = cast(Quantity, send_quantity) - metadata = self.comm.bcast(send_quantity.metadata, root=constants.ROOT_RANK) - else: - metadata = self.comm.bcast(None, root=constants.ROOT_RANK) - shape = self.partitioner.subtile_extent(metadata, self.rank) - if recv_quantity is None: - recv_quantity = self._get_scatter_recv_quantity(shape, metadata) - if self.rank == constants.ROOT_RANK: - send_quantity = cast(Quantity, send_quantity) - with array_buffer( - self._maybe_force_cpu(metadata.np).zeros, - (self.partitioner.total_ranks,) + shape, - dtype=metadata.dtype, - ) as sendbuf: - for rank in range(0, self.partitioner.total_ranks): - subtile_slice = self.partitioner.subtile_slice( - rank=rank, - global_dims=metadata.dims, - global_extent=metadata.extent, - overlap=True, - ) - sendbuf.assign_from( - send_quantity.view[subtile_slice], - buffer_slice=np.index_exp[rank, :], - ) - self._Scatter( - metadata.np, - sendbuf.array, - recv_quantity.view[:], - root=constants.ROOT_RANK, - ) - else: - self._Scatter( - metadata.np, - None, - recv_quantity.view[:], - root=constants.ROOT_RANK, - ) - return recv_quantity - - def _get_gather_recv_quantity( - self, global_extent: Sequence[int], send_metadata: QuantityMetadata - ) -> Quantity: - """Initialize a Quantity for use when receiving global data during gather""" - recv_quantity = Quantity( - send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype), - dims=send_metadata.dims, - units=send_metadata.units, - origin=tuple([0 for dim in send_metadata.dims]), - extent=global_extent, - gt4py_backend=send_metadata.gt4py_backend, - allow_mismatch_float_precision=True, - ) - return recv_quantity - - def _get_scatter_recv_quantity( - self, shape: Sequence[int], send_metadata: QuantityMetadata - ) -> Quantity: - """Initialize a Quantity for use when receiving subtile data during scatter""" - recv_quantity = Quantity( - send_metadata.np.zeros(shape, dtype=send_metadata.dtype), - dims=send_metadata.dims, - units=send_metadata.units, - gt4py_backend=send_metadata.gt4py_backend, - allow_mismatch_float_precision=True, - ) - return recv_quantity - - def gather( - self, send_quantity: Quantity, recv_quantity: Quantity = None - ) -> Optional[Quantity]: - """Transfer subtile regions of a full-tile quantity - from each rank to the tile root rank. - - Args: - send_quantity: quantity to send - recv_quantity: if provided, assign received data into this Quantity (only - used on the tile root rank) - Returns: - recv_quantity: quantity if on root rank, otherwise None - """ - result: Optional[Quantity] - if self.rank == constants.ROOT_RANK: - with array_buffer( - send_quantity.np.zeros, - (self.partitioner.total_ranks,) + tuple(send_quantity.extent), - dtype=send_quantity.data.dtype, - ) as recvbuf: - self._Gather( - send_quantity.np, - send_quantity.view[:], - recvbuf.array, - root=constants.ROOT_RANK, - ) - if recv_quantity is None: - global_extent = self.partitioner.global_extent( - send_quantity.metadata - ) - recv_quantity = self._get_gather_recv_quantity( - global_extent, send_quantity.metadata - ) - for rank in range(self.partitioner.total_ranks): - to_slice = self.partitioner.subtile_slice( - rank=rank, - global_dims=recv_quantity.dims, - global_extent=recv_quantity.extent, - overlap=True, - ) - recvbuf.assign_to( - recv_quantity.view[to_slice], buffer_slice=np.index_exp[rank, :] - ) - result = recv_quantity - else: - self._Gather( - send_quantity.np, - send_quantity.view[:], - None, - root=constants.ROOT_RANK, - ) - result = None - return result - - def gather_state(self, send_state=None, recv_state=None, transfer_type=None): - """Transfer a state dictionary from subtile ranks to the tile root rank. - - 'time' is assumed to be the same on all ranks, and its value will be set - to the value from the root rank. - - Args: - send_state: the model state to be sent containing the subtile data - recv_state: the pre-allocated state in which to recieve the full tile - state. Only variables which are scattered will be written to. - Returns: - recv_state: on the root rank, the state containing the entire tile - """ - if self.rank == constants.ROOT_RANK and recv_state is None: - recv_state = {} - for name, quantity in send_state.items(): - if name == "time": - if self.rank == constants.ROOT_RANK: - recv_state["time"] = send_state["time"] - else: - gather_value = to_numpy(quantity.view[:], dtype=transfer_type) - gather_quantity = Quantity( - data=gather_value, - dims=quantity.dims, - units=quantity.units, - allow_mismatch_float_precision=True, - ) - if recv_state is not None and name in recv_state: - tile_quantity = self.gather( - gather_quantity, recv_quantity=recv_state[name] - ) - else: - tile_quantity = self.gather(gather_quantity) - if self.rank == constants.ROOT_RANK: - recv_state[name] = tile_quantity - del gather_quantity - return recv_state - - def scatter_state(self, send_state=None, recv_state=None): - """Transfer a state dictionary from the tile root rank to all subtiles. - - Args: - send_state: the model state to be sent containing the entire tile, - required only from the root rank - recv_state: the pre-allocated state in which to recieve the scattered - state. Only variables which are scattered will be written to. - Returns: - rank_state: the state corresponding to this rank's subdomain - """ - - def scatter_root(): - if send_state is None: - raise TypeError("send_state is a required argument on the root rank") - name_list = list(send_state.keys()) - while "time" in name_list: - name_list.remove("time") - name_list = self.comm.bcast(name_list, root=constants.ROOT_RANK) - array_list = [send_state[name] for name in name_list] - for name, array in zip(name_list, array_list): - if name in recv_state: - self.scatter(send_quantity=array, recv_quantity=recv_state[name]) - else: - recv_state[name] = self.scatter(send_quantity=array) - recv_state["time"] = self.comm.bcast( - send_state.get("time", None), root=constants.ROOT_RANK - ) - - def scatter_client(): - name_list = self.comm.bcast(None, root=constants.ROOT_RANK) - for name in name_list: - if name in recv_state: - self.scatter(recv_quantity=recv_state[name]) - else: - recv_state[name] = self.scatter() - recv_state["time"] = self.comm.bcast(None, root=constants.ROOT_RANK) - - if recv_state is None: - recv_state = {} - if self.rank == constants.ROOT_RANK: - scatter_root() - else: - scatter_client() - if recv_state["time"] is None: - recv_state.pop("time") - return recv_state - - def halo_update(self, quantity: Union[Quantity, List[Quantity]], n_points: int): - """Perform a halo update on a quantity or quantities - - Args: - quantity: the quantity to be updated - n_points: how many halo points to update, starting from the interior - """ - if isinstance(quantity, Quantity): - quantities = [quantity] - else: - quantities = quantity - - halo_updater = self.start_halo_update(quantities, n_points) - halo_updater.wait() - - def start_halo_update( - self, quantity: Union[Quantity, List[Quantity]], n_points: int - ) -> HaloUpdater: - """Start an asynchronous halo update on a quantity. - - Args: - quantity: the quantity to be updated - n_points: how many halo points to update, starting from the interior - - Returns: - request: an asynchronous request object with a .wait() method - """ - if isinstance(quantity, Quantity): - quantities = [quantity] - else: - quantities = quantity - - specifications = [] - for quantity in quantities: - specification = QuantityHaloSpec( - n_points=n_points, - shape=quantity.data.shape, - strides=quantity.data.strides, - itemsize=quantity.data.itemsize, - origin=quantity.origin, - extent=quantity.extent, - dims=quantity.dims, - numpy_module=self._maybe_force_cpu(quantity.np), - dtype=quantity.metadata.dtype, - ) - specifications.append(specification) - - halo_updater = self.get_scalar_halo_updater(specifications) - halo_updater.force_finalize_on_wait() - halo_updater.start(quantities) - return halo_updater - - def vector_halo_update( - self, - x_quantity: Union[Quantity, List[Quantity]], - y_quantity: Union[Quantity, List[Quantity]], - n_points: int, - ): - """Perform a halo update of a horizontal vector quantity or quantities. - - Assumes the x and y dimension indices are the same between the two quantities. - - Args: - x_quantity: the x-component quantity to be halo updated - y_quantity: the y-component quantity to be halo updated - n_points: how many halo points to update, starting at the interior - """ - if isinstance(x_quantity, Quantity): - x_quantities = [x_quantity] - else: - x_quantities = x_quantity - if isinstance(y_quantity, Quantity): - y_quantities = [y_quantity] - else: - y_quantities = y_quantity - - halo_updater = self.start_vector_halo_update( - x_quantities, y_quantities, n_points - ) - halo_updater.wait() - - def start_vector_halo_update( - self, - x_quantity: Union[Quantity, List[Quantity]], - y_quantity: Union[Quantity, List[Quantity]], - n_points: int, - ) -> HaloUpdater: - """Start an asynchronous halo update of a horizontal vector quantity. - - Assumes the x and y dimension indices are the same between the two quantities. - - Args: - x_quantity: the x-component quantity to be halo updated - y_quantity: the y-component quantity to be halo updated - n_points: how many halo points to update, starting at the interior - - Returns: - request: an asynchronous request object with a .wait() method - """ - if isinstance(x_quantity, Quantity): - x_quantities = [x_quantity] - else: - x_quantities = x_quantity - if isinstance(y_quantity, Quantity): - y_quantities = [y_quantity] - else: - y_quantities = y_quantity - - x_specifications = [] - y_specifications = [] - for x_quantity, y_quantity in zip(x_quantities, y_quantities): - x_specification = QuantityHaloSpec( - n_points=n_points, - shape=x_quantity.data.shape, - strides=x_quantity.data.strides, - itemsize=x_quantity.data.itemsize, - origin=x_quantity.metadata.origin, - extent=x_quantity.metadata.extent, - dims=x_quantity.metadata.dims, - numpy_module=self._maybe_force_cpu(x_quantity.np), - dtype=x_quantity.metadata.dtype, - ) - x_specifications.append(x_specification) - y_specification = QuantityHaloSpec( - n_points=n_points, - shape=y_quantity.data.shape, - strides=y_quantity.data.strides, - itemsize=y_quantity.data.itemsize, - origin=y_quantity.metadata.origin, - extent=y_quantity.metadata.extent, - dims=y_quantity.metadata.dims, - numpy_module=self._maybe_force_cpu(y_quantity.np), - dtype=y_quantity.metadata.dtype, - ) - y_specifications.append(y_specification) - - halo_updater = self.get_vector_halo_updater(x_specifications, y_specifications) - halo_updater.force_finalize_on_wait() - halo_updater.start(x_quantities, y_quantities) - return halo_updater - - def synchronize_vector_interfaces(self, x_quantity: Quantity, y_quantity: Quantity): - """ - Synchronize shared points at the edges of a vector interface variable. - - Sends the values on the south and west edges to overwrite the values on adjacent - subtiles. Vector must be defined on the Arakawa C grid. - - For interface variables, the edges of the tile are computed on both ranks - bordering that edge. This routine copies values across those shared edges - so that both ranks have the same value for that edge. It also handles any - rotation of vector quantities needed to move data across the edge. - - Args: - x_quantity: the x-component quantity to be synchronized - y_quantity: the y-component quantity to be synchronized - """ - req = self.start_synchronize_vector_interfaces(x_quantity, y_quantity) - req.wait() - - def start_synchronize_vector_interfaces( - self, x_quantity: Quantity, y_quantity: Quantity - ) -> HaloUpdateRequest: - """ - Synchronize shared points at the edges of a vector interface variable. - - Sends the values on the south and west edges to overwrite the values on adjacent - subtiles. Vector must be defined on the Arakawa C grid. - - For interface variables, the edges of the tile are computed on both ranks - bordering that edge. This routine copies values across those shared edges - so that both ranks have the same value for that edge. It also handles any - rotation of vector quantities needed to move data across the edge. - - Args: - x_quantity: the x-component quantity to be synchronized - y_quantity: the y-component quantity to be synchronized - - Returns: - request: an asynchronous request object with a .wait() method - """ - halo_updater = VectorInterfaceHaloUpdater( - comm=self.comm, - boundaries=self.boundaries, - force_cpu=self._force_cpu, - timer=self.timer, - ) - req = halo_updater.start_synchronize_vector_interfaces(x_quantity, y_quantity) - return req - - def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): - if len(specifications) == 0: - raise RuntimeError("Cannot create updater with specifications list") - if specifications[0].n_points == 0: - raise ValueError("cannot perform a halo update on zero halo points") - return HaloUpdater.from_scalar_specifications( - self, - self._maybe_force_cpu(specifications[0].numpy_module), - specifications, - self.boundaries.values(), - self._get_halo_tag(), - self.timer, - ) - - def get_vector_halo_updater( - self, - specifications_x: List[QuantityHaloSpec], - specifications_y: List[QuantityHaloSpec], - ): - if len(specifications_x) == 0 and len(specifications_y) == 0: - raise RuntimeError("Cannot create updater with empty specifications list") - if specifications_x[0].n_points == 0 and specifications_y[0].n_points == 0: - raise ValueError("Cannot perform a halo update on zero halo points") - return HaloUpdater.from_vector_specifications( - self, - self._maybe_force_cpu(specifications_x[0].numpy_module), - specifications_x, - specifications_y, - self.boundaries.values(), - self._get_halo_tag(), - self.timer, - ) - - def _get_halo_tag(self) -> int: - self._last_halo_tag += 1 - return self._last_halo_tag - - @property - def boundaries(self) -> Mapping[int, Boundary]: - """boundaries of this tile with neighboring tiles""" - if self._boundaries is None: - self._boundaries = {} - for boundary_type in constants.BOUNDARY_TYPES: - boundary = self.partitioner.boundary(boundary_type, self.rank) - if boundary is not None: - self._boundaries[boundary_type] = boundary - return self._boundaries - - -class Partitioner(abc.ABC): - @abc.abstractmethod - def __init__(self): - self.tile = None - self.layout = None - - @abc.abstractmethod - def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: - ... - - @abc.abstractmethod - def tile_index(self, rank: int): - pass - - @abc.abstractmethod - def global_extent(self, rank_metadata: QuantityMetadata) -> Tuple[int, ...]: - """Return the shape of a full tile representation for the given dimensions. - - Args: - metadata: quantity metadata - - Returns: - extent: shape of full tile representation - """ - pass - - @abc.abstractmethod - def subtile_slice( - self, - rank: int, - global_dims: Sequence[str], - global_extent: Sequence[int], - overlap: bool = False, - ) -> Tuple[Union[int, slice], ...]: - """Return the subtile slice of a given rank on an array. - - Global refers to the domain being partitioned. For example, for a partitioning - of a tile, the tile would be the "global" domain. - - Args: - rank: the rank of the process - global_dims: dimensions of the global quantity being partitioned - global_extent: extent of the global quantity being partitioned - overlap (optional): if True, for interface variables include the part - of the array shared by adjacent ranks in both ranks. If False, ensure - only one of those ranks (the greater rank) is assigned the overlapping - section. Default is False. - - Returns: - subtile_slice: the slice of the global compute domain corresponding - to the subtile compute domain - """ - pass - - @abc.abstractmethod - def subtile_extent( - self, - global_metadata: QuantityMetadata, - rank: int, - ) -> Tuple[int, ...]: - """Return the shape of a single rank representation for the given dimensions. - - Args: - global_metadata: quantity metadata. - rank: rank of the process. - - Returns: - extent: shape of a single rank representation for the given dimensions. - """ - pass - - @property - @abc.abstractmethod - def total_ranks(self) -> int: - pass +# flake8: noqa +from ndsl.checkpointer.base import Checkpointer +from ndsl.comm.communicator import Communicator +from ndsl.comm.local_comm import AsyncResult, ConcurrencyError +from ndsl.comm.null_comm import NullAsyncResult +from ndsl.comm.partitioner import Partitioner +from ndsl.performance.collector import AbstractPerformanceCollector +from ndsl.types import AsyncRequest, NumpyModule +from ndsl.units import UnitsError diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index 756de95..ac189ad 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -3,7 +3,6 @@ from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region from ndsl import ( - CompareToNumpyStencil, CompilationConfig, DaceConfig, FrozenStencil, @@ -13,7 +12,7 @@ ) from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py_utils import make_storage_from_shape -from ndsl.dsl.stencil import get_stencils_with_varied_bounds +from ndsl.dsl.stencil import CompareToNumpyStencil, get_stencils_with_varied_bounds from ndsl.dsl.typing import FloatField diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 79e71cf..42fdcbe 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -1,8 +1,9 @@ import numpy as np import pytest -from ndsl import ConcurrencyError, DummyComm +from ndsl import DummyComm from ndsl.buffer import recv_buffer +from ndsl.typing import ConcurrencyError from tests.mpi.mpi_comm import MPI diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index ec986f8..e3f6d85 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -4,13 +4,8 @@ import numpy as np import pytest -from ndsl import ( - Buffer, - HaloDataTransformer, - HaloExchangeSpec, - Quantity, - QuantityHaloSpec, -) +from ndsl import HaloExchangeSpec, Quantity +from ndsl.buffer import Buffer from ndsl.comm import _boundary_utils from ndsl.constants import ( EAST, @@ -28,7 +23,9 @@ Z_DIM, Z_INTERFACE_DIM, ) +from ndsl.halo import HaloDataTransformer from ndsl.halo.rotate import rotate_scalar_data, rotate_vector_data +from ndsl.quantity import QuantityHaloSpec @pytest.fixture diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index d0536b2..afff2bd 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -10,7 +10,6 @@ HaloUpdater, OutOfBoundsError, Quantity, - QuantityHaloSpec, TileCommunicator, TilePartitioner, Timer, @@ -34,6 +33,7 @@ Z_DIM, Z_INTERFACE_DIM, ) +from ndsl.quantity import QuantityHaloSpec @pytest.fixture From a40a026d1743cca1005f02f46d2da58a5997ccda Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Thu, 7 Mar 2024 10:39:48 -0500 Subject: [PATCH 11/14] Changes requested 1000 7 Mar 2024, PR 14 from Florian --- ndsl/checkpointer/__init__.py | 2 +- ndsl/comm/__init__.py | 1 - ndsl/comm/communicator.py | 2 +- ndsl/dsl/stencil.py | 2 -- tests/checkpointer/__init__.py | 0 tests/dsl/__init__.py | 0 tests/dsl/test_stencil_wrapper.py | 4 ++-- tests/quantity/__init__.py | 0 8 files changed, 4 insertions(+), 7 deletions(-) delete mode 100644 tests/checkpointer/__init__.py delete mode 100644 tests/dsl/__init__.py delete mode 100644 tests/quantity/__init__.py diff --git a/ndsl/checkpointer/__init__.py b/ndsl/checkpointer/__init__.py index 6486d96..8fee4dc 100644 --- a/ndsl/checkpointer/__init__.py +++ b/ndsl/checkpointer/__init__.py @@ -1,5 +1,5 @@ from .null import NullCheckpointer -from .snapshots import SnapshotCheckpointer, _Snapshots +from .snapshots import SnapshotCheckpointer from .thresholds import ( InsufficientTrialsError, SavepointThresholds, diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index 289e641..c4a5865 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -4,6 +4,5 @@ CachingCommWriter, CachingRequestReader, CachingRequestWriter, - NullRequest, ) from .comm_abc import Comm, Request diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index f1b97c8..ff270df 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -54,7 +54,7 @@ def __init__( self.timer: Timer = timer if timer is not None else NullTimer() @abc.abstractproperty - def tile(self): + def tile(self) -> "TileCommunicator": pass @classmethod diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index b831672..75ef28e 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -32,8 +32,6 @@ from ndsl.dsl.typing import Float, Index3D, cast_to_index3d from ndsl.initialization.sizer import GridSizer, SubtileGridSizer from ndsl.quantity import Quantity - -# from ndsl import testing from ndsl.testing import comparison diff --git a/tests/checkpointer/__init__.py b/tests/checkpointer/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/dsl/__init__.py b/tests/dsl/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index cfe56de..986883d 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -285,14 +285,14 @@ def test_backend_options( "backend": "numpy", "rebuild": True, "format_source": False, - "name": "tests.dsl.test_stencil_wrapper.copy_stencil", + "name": "test_stencil_wrapper.copy_stencil", }, "cuda": { "backend": "cuda", "rebuild": True, "device_sync": False, "format_source": False, - "name": "tests.dsl.test_stencil_wrapper.copy_stencil", + "name": "test_stencil_wrapper.copy_stencil", }, } diff --git a/tests/quantity/__init__.py b/tests/quantity/__init__.py deleted file mode 100644 index e69de29..0000000 From 13f4f1ab86160765bc430ec7522348cafe404815 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Thu, 7 Mar 2024 14:15:28 -0500 Subject: [PATCH 12/14] Changes from comments (pokes) --- ndsl/__init__.py | 8 +------- ndsl/initialization/__init__.py | 1 + ndsl/monitor/__init__.py | 3 +++ ndsl/performance/__init__.py | 2 ++ tests/test_cube_scatter_gather.py | 2 +- tests/test_g2g_communication.py | 2 +- tests/test_halo_update.py | 2 +- tests/test_halo_update_ranks.py | 2 +- tests/test_netcdf_monitor.py | 2 +- tests/test_sync_shared_boundary.py | 2 +- tests/test_timer.py | 2 +- tests/test_zarr_monitor.py | 10 ++-------- 12 files changed, 16 insertions(+), 22 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index f8c6461..5ceae72 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -22,17 +22,11 @@ from .initialization.allocator import QuantityFactory from .initialization.sizer import GridSizer, SubtileGridSizer from .logging import ndsl_log -from .monitor.netcdf_monitor import NetCDFMonitor -from .monitor.protocol import Protocol -from .monitor.zarr_monitor import ZarrMonitor from .namelist import Namelist from .performance.collector import NullPerformanceCollector, PerformanceCollector -from .performance.config import PerformanceConfig from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport -from .performance.timer import NullTimer, Timer from .quantity import Quantity from .testing.dummy_comm import DummyComm -from .types import Allocator, AsyncRequest, NumpyModule -from .units import UnitsError +from .types import Allocator from .utils import MetaEnumStr diff --git a/ndsl/initialization/__init__.py b/ndsl/initialization/__init__.py index e69de29..8f40c7a 100644 --- a/ndsl/initialization/__init__.py +++ b/ndsl/initialization/__init__.py @@ -0,0 +1 @@ +from .sizer import GridSizer diff --git a/ndsl/monitor/__init__.py b/ndsl/monitor/__init__.py index e69de29..a0c7e03 100644 --- a/ndsl/monitor/__init__.py +++ b/ndsl/monitor/__init__.py @@ -0,0 +1,3 @@ +from .netcdf_monitor import NetCDFMonitor +from .protocol import Monitor +from .zarr_monitor import ZarrMonitor diff --git a/ndsl/performance/__init__.py b/ndsl/performance/__init__.py index e69de29..28e03bc 100644 --- a/ndsl/performance/__init__.py +++ b/ndsl/performance/__init__.py @@ -0,0 +1,2 @@ +from .config import PerformanceConfig +from .timer import NullTimer, Timer diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index 7966142..aee2253 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -9,7 +9,6 @@ DummyComm, Quantity, TilePartitioner, - Timer, ) from ndsl.constants import ( HORIZONTAL_DIMS, @@ -21,6 +20,7 @@ Z_DIM, Z_INTERFACE_DIM, ) +from ndsl.performance import Timer try: diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index 28f1af7..17a5878 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -14,9 +14,9 @@ DummyComm, Quantity, TilePartitioner, - Timer, ) from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.performance import Timer try: diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index afff2bd..3d3bf50 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -12,7 +12,6 @@ Quantity, TileCommunicator, TilePartitioner, - Timer, ) from ndsl.buffer import BUFFER_CACHE from ndsl.comm._boundary_utils import get_boundary_slice @@ -33,6 +32,7 @@ Z_DIM, Z_INTERFACE_DIM, ) +from ndsl.performance import Timer from ndsl.quantity import QuantityHaloSpec diff --git a/tests/test_halo_update_ranks.py b/tests/test_halo_update_ranks.py index 8ec77cc..6ceb488 100644 --- a/tests/test_halo_update_ranks.py +++ b/tests/test_halo_update_ranks.py @@ -6,7 +6,6 @@ DummyComm, Quantity, TilePartitioner, - Timer, ) from ndsl.constants import ( X_DIM, @@ -16,6 +15,7 @@ Z_DIM, Z_INTERFACE_DIM, ) +from ndsl.performance import Timer @pytest.fixture diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index 7a21dd7..326739b 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -10,10 +10,10 @@ CubedSphereCommunicator, CubedSpherePartitioner, DummyComm, - NetCDFMonitor, Quantity, TilePartitioner, ) +from ndsl.monitor import NetCDFMonitor from ndsl.optional_imports import xarray as xr diff --git a/tests/test_sync_shared_boundary.py b/tests/test_sync_shared_boundary.py index 3e0930a..7db5a62 100644 --- a/tests/test_sync_shared_boundary.py +++ b/tests/test_sync_shared_boundary.py @@ -6,9 +6,9 @@ DummyComm, Quantity, TilePartitioner, - Timer, ) from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM +from ndsl.performance import Timer @pytest.fixture diff --git a/tests/test_timer.py b/tests/test_timer.py index 213a487..bb8ec3a 100644 --- a/tests/test_timer.py +++ b/tests/test_timer.py @@ -2,7 +2,7 @@ import pytest -from ndsl import NullTimer, Timer +from ndsl.performance import NullTimer, Timer @pytest.fixture diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index b608ec0..e40d521 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -12,13 +12,7 @@ import cftime import pytest -from ndsl import ( - CubedSpherePartitioner, - DummyComm, - Quantity, - TilePartitioner, - ZarrMonitor, -) +from ndsl import CubedSpherePartitioner, DummyComm, Quantity, TilePartitioner from ndsl.constants import ( X_DIM, X_DIMS, @@ -28,7 +22,7 @@ Y_INTERFACE_DIM, Z_DIM, ) -from ndsl.monitor.zarr_monitor import array_chunks, get_calendar +from ndsl.monitor.zarr_monitor import ZarrMonitor, array_chunks, get_calendar from ndsl.optional_imports import xarray as xr From 14dc28b2f7452e46aad02dccdc54428ecd96bd30 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Fri, 8 Mar 2024 11:26:40 -0500 Subject: [PATCH 13/14] Poke changes --- ndsl/__init__.py | 1 + ndsl/monitor/__init__.py | 1 - tests/test_netcdf_monitor.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 5ceae72..a2f771c 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -22,6 +22,7 @@ from .initialization.allocator import QuantityFactory from .initialization.sizer import GridSizer, SubtileGridSizer from .logging import ndsl_log +from .monitor.netcdf_monitor import NetCDFMonitor from .namelist import Namelist from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler diff --git a/ndsl/monitor/__init__.py b/ndsl/monitor/__init__.py index a0c7e03..5d73231 100644 --- a/ndsl/monitor/__init__.py +++ b/ndsl/monitor/__init__.py @@ -1,3 +1,2 @@ -from .netcdf_monitor import NetCDFMonitor from .protocol import Monitor from .zarr_monitor import ZarrMonitor diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index 326739b..7a21dd7 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -10,10 +10,10 @@ CubedSphereCommunicator, CubedSpherePartitioner, DummyComm, + NetCDFMonitor, Quantity, TilePartitioner, ) -from ndsl.monitor import NetCDFMonitor from ndsl.optional_imports import xarray as xr From f813fb162335186b3d1e4c85750f3c3416a48403 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Mon, 11 Mar 2024 11:01:59 -0400 Subject: [PATCH 14/14] Imported UnitsError and ConcurrencyError to exceptions, moved AsyncResult and NullAsyncResult out of typing --- ndsl/exceptions.py | 5 +++++ ndsl/typing.py | 3 --- tests/mpi/test_mpi_mock.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ndsl/exceptions.py b/ndsl/exceptions.py index 329b1b5..fa5d118 100644 --- a/ndsl/exceptions.py +++ b/ndsl/exceptions.py @@ -1,2 +1,7 @@ +# flake8: noqa +from ndsl.comm.local_comm import ConcurrencyError +from ndsl.units import UnitsError + + class OutOfBoundsError(ValueError): pass diff --git a/ndsl/typing.py b/ndsl/typing.py index 03f9624..ddbf168 100644 --- a/ndsl/typing.py +++ b/ndsl/typing.py @@ -1,9 +1,6 @@ # flake8: noqa from ndsl.checkpointer.base import Checkpointer from ndsl.comm.communicator import Communicator -from ndsl.comm.local_comm import AsyncResult, ConcurrencyError -from ndsl.comm.null_comm import NullAsyncResult from ndsl.comm.partitioner import Partitioner from ndsl.performance.collector import AbstractPerformanceCollector from ndsl.types import AsyncRequest, NumpyModule -from ndsl.units import UnitsError diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 42fdcbe..b820299 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -3,7 +3,7 @@ from ndsl import DummyComm from ndsl.buffer import recv_buffer -from ndsl.typing import ConcurrencyError +from ndsl.exceptions import ConcurrencyError from tests.mpi.mpi_comm import MPI