Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Refactor on-demand-import mechanism #3270

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/developing/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ and contributing code!
creating_frontend
external_analysis
deprecating_features
managing_optional_dependencies
38 changes: 38 additions & 0 deletions doc/source/developing/managing_optional_dependencies.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
How to manage soft (optional) dependencies
------------------------------------------

We might sometimes rely on heavy external libraries to support some features
outside of yt's core. Typically, many frontends require HDF5 support, provided
by the ``h5py`` library, but many users do not need it and shouldn't have to
install such a large library to get yt.

A mechanism to support soft-dependencies is implemented in the
``yt/utilities/on_demand_imports.py`` module. Existing soft dependencies are
listed in a ``optional_dependencies`` dictionary using package names for keys.
The dictionary is then processed to generate drop-in replacement for the actual
packages, which behave as normal instances of the packages they mirror:

.. code-block:: python

from yt.utilities.on_demand_imports import h5py


def process(file):
with h5py.File(file, mode="r") as fh:
# perform interesting operations
...

In case the package is missing, an ``ImportError`` will be raised at runtime if
and only if the code using it is run. In the example above, this means that as
long as the ``process`` function is not run, no error will be raised. An
implication is that soft dependencies cannot be used inconditionally in the
global scope of a module, because it is run at yt's importtime, but should
rather be restricted to lower level scopes, inside of functions or properly
guarded conditional blocks.

Note that, mocking an external package using the ``OnDemandImport`` class also
delays any actual import to the first time it is used (this is how we can afford
to fail only at the conditions we have to). A direct implication is that objects
defined as such can safely be imported at the top level of a module, where the
bulk of import statements live, without affecting import durations in any
significant way.
2 changes: 1 addition & 1 deletion yt/data_objects/construction_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2595,7 +2595,7 @@ def export_sketchfab(

@parallel_root_only
def _upload_to_sketchfab(self, data, files):
from yt.utilities.on_demand_imports import _requests as requests
from yt.utilities.on_demand_imports import requests

SKETCHFAB_DOMAIN = "sketchfab.com"
SKETCHFAB_API_URL = f"https://api.{SKETCHFAB_DOMAIN}/v2/models"
Expand Down
14 changes: 7 additions & 7 deletions yt/data_objects/data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
YTSpatialFieldUnitError,
)
from yt.utilities.object_registries import data_object_registry
from yt.utilities.on_demand_imports import _firefly as firefly
from yt.utilities.on_demand_imports import Firefly
from yt.utilities.parameter_file_storage import ParameterFileStore


Expand Down Expand Up @@ -518,7 +518,7 @@ def to_dataframe(self, fields):
>>> dd = ds.all_data()
>>> df = dd.to_dataframe([("gas", "density"), ("gas", "temperature")])
"""
from yt.utilities.on_demand_imports import _pandas as pd
from yt.utilities.on_demand_imports import pandas as pd

data = {}
fields = self._determine_fields(fields)
Expand Down Expand Up @@ -794,8 +794,8 @@ def create_firefly_object(
## for safety, in case someone passes a float just cast it
default_decimation_factor = int(default_decimation_factor)

## initialize a firefly reader instance
reader = firefly.data_reader.Reader(
## initialize a Firefly reader instance
reader = Firefly.data_reader.Reader(
JSONdir=JSONdir, clean_JSONdir=True, **kwargs
)

Expand Down Expand Up @@ -853,8 +853,8 @@ def create_firefly_object(
tracked_filter_flags = np.ones(len(tracked_names))
tracked_colormap_flags = np.ones(len(tracked_names))

## create a firefly ParticleGroup for this particle type
pg = firefly.data_reader.ParticleGroup(
## create a Firefly ParticleGroup for this particle type
pg = Firefly.data_reader.ParticleGroup(
UIname=ptype,
coordinates=self[ptype, "relative_particle_position"].in_units(
coordinate_units
Expand All @@ -866,7 +866,7 @@ def create_firefly_object(
decimation_factor=default_decimation_factor,
)

## bind this particle group to the firefly reader object
## bind this particle group to the Firefly reader object
reader.addParticleGroup(pg)

return reader
Expand Down
2 changes: 1 addition & 1 deletion yt/data_objects/particle_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from yt.units.yt_array import array_like_field
from yt.utilities.exceptions import YTIllDefinedParticleData
from yt.utilities.lib.particle_mesh_operations import CICSample_3
from yt.utilities.on_demand_imports import _h5py as h5py
from yt.utilities.on_demand_imports import h5py
from yt.utilities.parallel_tools.parallel_analysis_interface import parallel_root_only


Expand Down
2 changes: 1 addition & 1 deletion yt/data_objects/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def to_dataframe(self, fields=None, only_used=False):
>>> df1 = p.to_dataframe()
>>> df2 = p.to_dataframe(fields=("gas", "density"), only_used=True)
"""
from yt.utilities.on_demand_imports import _pandas as pd
from yt.utilities.on_demand_imports import pandas as pd

idxs, masked, fields = self._export_prep(fields, only_used)
pdata = {self.x_field[-1]: self.x[idxs]}
Expand Down
6 changes: 3 additions & 3 deletions yt/data_objects/selection_objects/cut_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from yt.funcs import iter_fields, validate_object, validate_sequence
from yt.geometry.selection_routines import points_in_cells
from yt.utilities.exceptions import YTIllDefinedCutRegion
from yt.utilities.on_demand_imports import _scipy
from yt.utilities.on_demand_imports import scipy


class YTCutRegion(YTSelectionContainer3D):
Expand Down Expand Up @@ -196,7 +196,7 @@ def _part_ind_KDTree(self, ptype):
dx_loc = dx[lvl_mask]
pos_loc = pos[lvl_mask]

grid_tree = _scipy.spatial.cKDTree(pos_loc, boxsize=1)
grid_tree = scipy.spatial.cKDTree(pos_loc, boxsize=1)

# Compute closest cell for all remaining particles
dist, icell = grid_tree.query(
Expand Down Expand Up @@ -238,7 +238,7 @@ def _part_ind(self, ptype):
# implementation. Else, fall back onto the direct
# brute-force algorithm.
try:
_scipy.spatial.KDTree
scipy.spatial.KDTree
return self._part_ind_KDTree(ptype)
except ImportError:
return self._part_ind_brute_force(ptype)
Expand Down
4 changes: 2 additions & 2 deletions yt/data_objects/selection_objects/spheroids.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from yt.utilities.exceptions import YTEllipsoidOrdering, YTException, YTSphereTooSmall
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.math_utils import get_rotation_matrix
from yt.utilities.on_demand_imports import _miniball
from yt.utilities.on_demand_imports import miniball


class YTSphere(YTSelectionContainer3D):
Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(self, points, ds=None, field_parameters=None, data_source=None):
f"Not enough points. Expected at least 2, got {len(points)}"
)
mylog.debug("Building minimal sphere around points.")
mb = _miniball.Miniball(points)
mb = miniball.Miniball(points)
if not mb.is_valid():
raise YTException("Could not build valid sphere around points.")

Expand Down
2 changes: 1 addition & 1 deletion yt/fields/xray_emission_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
BilinearFieldInterpolator,
UnilinearFieldInterpolator,
)
from yt.utilities.on_demand_imports import _h5py as h5py
from yt.utilities.on_demand_imports import h5py

data_version = {"cloudy": 2, "apec": 3}

Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/amrvac/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from yt.geometry.selection_routines import GridSelector
from yt.utilities.io_handler import BaseIOHandler
from yt.utilities.on_demand_imports import _f90nml as f90nml
from yt.utilities.on_demand_imports import f90nml


def read_amrvac_namelist(parfiles):
Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/amrvac/tests/test_read_amrvac_namelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from yt.frontends.amrvac.api import read_amrvac_namelist
from yt.testing import requires_module
from yt.utilities.on_demand_imports import _f90nml as f90nml
from yt.utilities.on_demand_imports import f90nml

test_dir = os.path.dirname(os.path.abspath(__file__))
blast_wave_parfile = os.path.join(test_dir, "sample_parfiles", "bw_3d.par")
Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/arepo/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from yt.frontends.gadget.api import GadgetHDF5Dataset
from yt.funcs import mylog
from yt.utilities.on_demand_imports import _h5py as h5py
from yt.utilities.on_demand_imports import h5py

from .fields import ArepoFieldInfo

Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/arepo/io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from yt.frontends.gadget.api import IOHandlerGadgetHDF5
from yt.utilities.on_demand_imports import _h5py as h5py
from yt.utilities.on_demand_imports import h5py


class IOHandlerArepoHDF5(IOHandlerGadgetHDF5):
Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/chombo/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from yt.geometry.grid_geometry_handler import GridIndex
from yt.utilities.file_handler import HDF5FileHandler, warn_h5py
from yt.utilities.lib.misc_utilities import get_box_grids_level
from yt.utilities.on_demand_imports import _h5py as h5py
from yt.utilities.on_demand_imports import h5py
from yt.utilities.parallel_tools.parallel_analysis_interface import parallel_root_only

from .fields import (
Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/eagle/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from yt.fields.field_info_container import FieldInfoContainer
from yt.frontends.gadget.data_structures import GadgetHDF5Dataset
from yt.frontends.owls.fields import OWLSFieldInfo
from yt.utilities.on_demand_imports import _h5py as h5py
from yt.utilities.on_demand_imports import h5py

from .fields import EagleNetworkFieldInfo

Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/enzo/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from yt.geometry.geometry_handler import YTDataChunk
from yt.geometry.grid_geometry_handler import GridIndex
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.on_demand_imports import _h5py as h5py, _libconf as libconf
from yt.utilities.on_demand_imports import h5py, libconf

from .fields import EnzoFieldInfo

Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/enzo/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from yt.geometry.selection_routines import GridSelector
from yt.utilities.io_handler import BaseIOHandler
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.on_demand_imports import _h5py as h5py
from yt.utilities.on_demand_imports import h5py

_convert_mass = ("particle_mass", "mass")

Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/enzo_e/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from yt.geometry.grid_geometry_handler import GridIndex
from yt.utilities.cosmology import Cosmology
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.on_demand_imports import _h5py as h5py, _libconf as libconf
from yt.utilities.on_demand_imports import h5py, libconf


class EnzoEGrid(AMRGridPatch):
Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/enzo_e/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from yt.utilities.exceptions import YTException
from yt.utilities.io_handler import BaseIOHandler
from yt.utilities.on_demand_imports import _h5py as h5py
from yt.utilities.on_demand_imports import h5py


class EnzoEIOHandler(BaseIOHandler):
Expand Down
2 changes: 1 addition & 1 deletion yt/frontends/enzo_e/tests/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
data_dir_load,
requires_ds,
)
from yt.utilities.on_demand_imports import _h5py as h5py
from yt.utilities.on_demand_imports import h5py

_fields = (
("gas", "density"),
Expand Down
30 changes: 13 additions & 17 deletions yt/frontends/fits/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from yt.utilities.decompose import decompose_array, get_psize
from yt.utilities.file_handler import FITSFileHandler
from yt.utilities.io_handler import io_registry
from yt.utilities.on_demand_imports import NotAModule, _astropy
from yt.utilities.on_demand_imports import astropy

from .fields import FITSFieldInfo, WCSFITSFieldInfo, YTFITSFieldInfo

Expand Down Expand Up @@ -88,7 +88,7 @@ def _determine_image_units(self, bunit):
try:
try:
# First let AstroPy attempt to figure the unit out
u = 1.0 * _astropy.units.Unit(bunit, format="fits")
u = 1.0 * astropy.units.Unit(bunit, format="fits")
u = YTQuantity.from_astropy(u).units
except ValueError:
try:
Expand Down Expand Up @@ -133,7 +133,7 @@ def _detect_output_fields(self):
for i, fits_file in enumerate(self.dataset._handle._fits_files):
for j, hdu in enumerate(fits_file):
if (
isinstance(hdu, _astropy.pyfits.BinTableHDU)
isinstance(hdu, astropy.io.fits.BinTableHDU)
or hdu.header["naxis"] == 0
):
continue
Expand Down Expand Up @@ -274,14 +274,10 @@ def check_fits_valid(filename):
ext = filename.rsplit(".", 1)[0].rsplit(".", 1)[-1]
if ext.upper() not in ("FITS", "FTS"):
return None
elif isinstance(_astropy.pyfits, NotAModule):
raise RuntimeError(
"This appears to be a FITS file, but AstroPy is not installed."
)
try:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, append=True)
fileh = _astropy.pyfits.open(filename)
fileh = astropy.io.fits.open(filename)
header, _ = find_primary_header(fileh)
if header["naxis"] >= 2:
return fileh
Expand Down Expand Up @@ -355,24 +351,24 @@ def __init__(
self.nan_mask = nan_mask
self._handle = FITSFileHandler(self.filenames[0])
if isinstance(
self.filenames[0], _astropy.pyfits.hdu.image._ImageBaseHDU
) or isinstance(self.filenames[0], _astropy.pyfits.HDUList):
self.filenames[0], astropy.io.fits.hdu.image._ImageBaseHDU
) or isinstance(self.filenames[0], astropy.io.fits.HDUList):
fn = f"InMemoryFITSFile_{uuid.uuid4().hex}"
else:
fn = self.filenames[0]
self._handle._fits_files.append(self._handle)
if self.num_files > 1:
for fits_file in auxiliary_files:
if isinstance(fits_file, _astropy.pyfits.hdu.image._ImageBaseHDU):
f = _astropy.pyfits.HDUList([fits_file])
elif isinstance(fits_file, _astropy.pyfits.HDUList):
if isinstance(fits_file, astropy.io.fits.hdu.image._ImageBaseHDU):
f = astropy.io.fits.HDUList([fits_file])
elif isinstance(fits_file, astropy.io.fits.HDUList):
f = fits_file
else:
if os.path.exists(fits_file):
fn = fits_file
else:
fn = os.path.join(ytcfg.get("yt", "test_data_dir"), fits_file)
f = _astropy.pyfits.open(
f = astropy.io.fits.open(
fn, memmap=True, do_not_scale_image_data=True, ignore_blank=True
)
self._handle._fits_files.append(f)
Expand Down Expand Up @@ -502,9 +498,9 @@ def _determine_structure(self):
]

def _determine_wcs(self):
wcs = _astropy.pywcs.WCS(header=self.primary_header)
wcs = astropy.wcs.WCS(header=self.primary_header)
if self.naxis == 4:
self.wcs = _astropy.pywcs.WCS(naxis=3)
self.wcs = astropy.wcs.WCS(naxis=3)
self.wcs.wcs.crpix = wcs.wcs.crpix[:3]
self.wcs.wcs.cdelt = wcs.wcs.cdelt[:3]
self.wcs.wcs.crval = wcs.wcs.crval[:3]
Expand Down Expand Up @@ -873,7 +869,7 @@ def _determine_structure(self):
self.naxis = 2

def _determine_wcs(self):
self.wcs = _astropy.pywcs.WCS(naxis=2)
self.wcs = astropy.wcs.WCS(naxis=2)
self.events_info = {}
for k, v in self.primary_header.items():
if k.startswith("TTYP"):
Expand Down
8 changes: 4 additions & 4 deletions yt/frontends/fits/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from yt.fields.derived_field import ValidateSpatial
from yt.units.yt_array import YTArray, YTQuantity
from yt.utilities.logger import ytLogger as mylog
from yt.utilities.on_demand_imports import _astropy
from yt.utilities.on_demand_imports import astropy


def _make_counts(emin, emax):
Expand All @@ -24,8 +24,8 @@ def _counts(field, data):
else:
sigma = None
if sigma is not None and sigma > 0.0:
kern = _astropy.conv.Gaussian2DKernel(x_stddev=sigma)
img[:, :, 0] = _astropy.conv.convolve(img[:, :, 0], kern)
kern = astropy.convolution.Gaussian2DKernel(x_stddev=sigma)
img[:, :, 0] = astropy.convolution.convolve(img[:, :, 0], kern)
return data.ds.arr(img, "counts/pixel")

return _counts
Expand Down Expand Up @@ -209,7 +209,7 @@ class PlotWindowWCS:
"""

def __init__(self, pw):
WCSAxes = _astropy.wcsaxes.WCSAxes
WCSAxes = astropy.visualization.wcsaxes.WCSAxes

if pw.oblique:
raise NotImplementedError("WCS axes are not implemented for oblique plots.")
Expand Down
Loading