Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up missing._get_interpolator #4776

Merged
merged 3 commits into from
Jan 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,16 @@ def bfill(arr, dim=None, limit=None):
).transpose(*arr.dims)


def _import_interpolant(interpolant, method):
"""Import interpolant from scipy.interpolate."""
try:
from scipy import interpolate

return getattr(interpolate, interpolant)
except ImportError as e:
raise ImportError(f"Interpolation with method {method} requires scipy.") from e


def _get_interpolator(method, vectorizeable_only=False, **kwargs):
"""helper function to select the appropriate interpolator class

Expand All @@ -459,12 +469,6 @@ def _get_interpolator(method, vectorizeable_only=False, **kwargs):
"akima",
]

has_scipy = True
try:
from scipy import interpolate
except ImportError:
has_scipy = False

# prioritize scipy.interpolate
if (
method == "linear"
Expand All @@ -475,32 +479,29 @@ def _get_interpolator(method, vectorizeable_only=False, **kwargs):
interp_class = NumpyInterpolator

elif method in valid_methods:
if not has_scipy:
raise ImportError("Interpolation with method `%s` requires scipy" % method)

if method in interp1d_methods:
kwargs.update(method=method)
interp_class = ScipyInterpolator
elif vectorizeable_only:
raise ValueError(
"{} is not a vectorizeable interpolator. "
"Available methods are {}".format(method, interp1d_methods)
f"{method} is not a vectorizeable interpolator. "
f"Available methods are {interp1d_methods}"
)
elif method == "barycentric":
interp_class = interpolate.BarycentricInterpolator
interp_class = _import_interpolant("BarycentricInterpolator", method)
elif method == "krog":
interp_class = interpolate.KroghInterpolator
interp_class = _import_interpolant("KroghInterpolator", method)
elif method == "pchip":
interp_class = interpolate.PchipInterpolator
interp_class = _import_interpolant("PchipInterpolator", method)
elif method == "spline":
kwargs.update(method=method)
interp_class = SplineInterpolator
elif method == "akima":
interp_class = interpolate.Akima1DInterpolator
interp_class = _import_interpolant("Akima1DInterpolator", method)
else:
raise ValueError("%s is not a valid scipy interpolator" % method)
raise ValueError(f"{method} is not a valid scipy interpolator")
else:
raise ValueError("%s is not a valid interpolator" % method)
raise ValueError(f"{method} is not a valid interpolator")

return interp_class, kwargs

Expand All @@ -512,18 +513,13 @@ def _get_interpolator_nd(method, **kwargs):
"""
valid_methods = ["linear", "nearest"]

try:
from scipy import interpolate
except ImportError:
raise ImportError("Interpolation with method `%s` requires scipy" % method)

if method in valid_methods:
kwargs.update(method=method)
interp_class = interpolate.interpn
interp_class = _import_interpolant("interpn", method)
else:
raise ValueError(
"%s is not a valid interpolator for interpolating "
"over multiple dimensions." % method
f"{method} is not a valid interpolator for interpolating "
"over multiple dimensions."
)

return interp_class, kwargs
Expand Down