From 15144f66635d46d895b9967ed80a65709561d3c3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 14 Sep 2021 20:42:07 +0200 Subject: [PATCH 1/6] Import matplotlib once per module --- xarray/plot/dataset_plot.py | 17 +++++++++-------- xarray/plot/facetgrid.py | 13 ++++++------- xarray/plot/plot.py | 11 +++++------ xarray/plot/utils.py | 34 +++++++++++++++------------------- 4 files changed, 35 insertions(+), 40 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index c1aedd570bc..2bf9657a514 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -11,9 +11,15 @@ _is_numeric, _process_cmap_cbar_kwargs, get_axis, + import_matplotlib_pyplot, label_from_attrs, ) +try: + plt = import_matplotlib_pyplot() +except ImportError: + plt = None + # copied from seaborn _MARKERSIZE_RANGE = np.array([18.0, 72.0]) @@ -134,8 +140,7 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None) # copied from seaborn def _parse_size(data, norm): - - import matplotlib as mpl + mpl = plt.matplotlib if data is None: return None @@ -544,8 +549,6 @@ def quiver(ds, x, y, ax, u, v, **kwargs): Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`. """ - import matplotlib as mpl - if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for quiver plots.") @@ -560,7 +563,7 @@ def quiver(ds, x, y, ax, u, v, **kwargs): # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params if not cmap_params["norm"]: - cmap_params["norm"] = mpl.colors.Normalize( + cmap_params["norm"] = plt.Normalize( cmap_params.pop("vmin"), cmap_params.pop("vmax") ) @@ -576,8 +579,6 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`. """ - import matplotlib as mpl - if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for streamplot plots.") @@ -613,7 +614,7 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params if not cmap_params["norm"]: - cmap_params["norm"] = mpl.colors.Normalize( + cmap_params["norm"] = plt.Normalize( cmap_params.pop("vmin"), cmap_params.pop("vmax") ) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 28dd82e76f5..88baa7bb230 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -13,6 +13,11 @@ label_from_attrs, ) +try: + plt = import_matplotlib_pyplot() +except ImportError: + plt = None + # Overrides axes.labelsize, xtick.major.size, ytick.major.size # from mpl.rcParams _FONTSIZE = "small" @@ -116,8 +121,6 @@ def __init__( """ - plt = import_matplotlib_pyplot() - # Handle corner case of nonunique coordinates rep_col = col is not None and not data[col].to_index().is_unique rep_row = row is not None and not data[row].to_index().is_unique @@ -519,10 +522,8 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar self: FacetGrid object """ - import matplotlib as mpl - if size is None: - size = mpl.rcParams["axes.labelsize"] + size = plt.rcParams["axes.labelsize"] nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template) @@ -619,8 +620,6 @@ def map(self, func, *args, **kwargs): self : FacetGrid object """ - plt = import_matplotlib_pyplot() - for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): if namedict is not None: data = self.data.loc[namedict] diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index e20b6568e79..4a4a4d2252f 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -34,6 +34,11 @@ legend_elements, ) +try: + plt = import_matplotlib_pyplot() +except ImportError: + plt = None + # copied from seaborn _MARKERSIZE_RANGE = np.array([18.0, 72.0]) @@ -83,8 +88,6 @@ def _parse_size(data, norm, width): If the data is categorical, normalize it to numbers. """ - plt = import_matplotlib_pyplot() - if data is None: return None @@ -682,8 +685,6 @@ def scatter( **kwargs : optional Additional keyword arguments to matplotlib """ - plt = import_matplotlib_pyplot() - # Handle facetgrids first if row or col: allargs = locals().copy() @@ -1111,8 +1112,6 @@ def newplotfunc( allargs["plotfunc"] = globals()[plotfunc.__name__] return _easy_facetgrid(darray, kind="dataarray", **allargs) - plt = import_matplotlib_pyplot() - if ( plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index af5859c1f14..71928005a00 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -41,6 +41,12 @@ def import_matplotlib_pyplot(): return plt +try: + plt = import_matplotlib_pyplot() +except ImportError: + plt = None + + def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax @@ -58,7 +64,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled): """ Build a discrete colormap and normalization of the data. """ - import matplotlib as mpl + mpl = plt.matplotlib if len(levels) == 1: levels = [levels[0], levels[0]] @@ -109,8 +115,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled): def _color_palette(cmap, n_colors): - import matplotlib.pyplot as plt - from matplotlib.colors import ListedColormap + ListedColormap = plt.matplotlib.colors.ListedColormap colors_i = np.linspace(0, 1.0, n_colors) if isinstance(cmap, (list, tuple)): @@ -171,7 +176,7 @@ def _determine_cmap_params( cmap_params : dict Use depends on the type of the plotting function """ - import matplotlib as mpl + mpl = plt.matplotlib if isinstance(levels, Iterable): levels = sorted(levels) @@ -279,13 +284,13 @@ def _determine_cmap_params( levels = np.asarray([(vmin + vmax) / 2]) else: # N in MaxNLocator refers to bins, not ticks - ticker = mpl.ticker.MaxNLocator(levels - 1) + ticker = plt.MaxNLocator(levels - 1) levels = ticker.tick_values(vmin, vmax) vmin, vmax = levels[0], levels[-1] # GH3734 if vmin == vmax: - vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax) + vmin, vmax = plt.LinearLocator(2).tick_values(vmin, vmax) if extend is None: extend = _determine_extend(calc_data, vmin, vmax) @@ -415,10 +420,7 @@ def _assert_valid_xy(darray, xy, name): def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): - try: - import matplotlib as mpl - import matplotlib.pyplot as plt - except ImportError: + if plt is None: raise ImportError("matplotlib is required for plot.utils.get_axis") if figsize is not None: @@ -431,7 +433,7 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): if ax is not None: raise ValueError("cannot provide both `size` and `ax` arguments") if aspect is None: - width, height = mpl.rcParams["figure.figsize"] + width, height = plt.rcParams["figure.figsize"] aspect = width / height figsize = (size * aspect, size) _, ax = plt.subplots(figsize=figsize) @@ -448,9 +450,6 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): def _maybe_gca(**kwargs): - - import matplotlib.pyplot as plt - # can call gcf unconditionally: either it exists or would be created by plt.axes f = plt.gcf() @@ -908,9 +907,7 @@ def _process_cmap_cbar_kwargs( def _get_nice_quiver_magnitude(u, v): - import matplotlib as mpl - - ticker = mpl.ticker.MaxNLocator(3) + ticker = plt.MaxNLocator(3) mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy())) magnitude = ticker.tick_values(0, mean)[-2] return magnitude @@ -985,7 +982,7 @@ def legend_elements( """ import warnings - import matplotlib as mpl + mpl = plt.matplotlib mlines = mpl.lines @@ -1122,7 +1119,6 @@ def _legend_add_subtitle(handles, labels, text, func): def _adjust_legend_subtitles(legend): """Make invisible-handle "subtitles" entries look more like titles.""" - plt = import_matplotlib_pyplot() # Legend title not in rcParams until 3.0 font_size = plt.rcParams.get("legend.title_fontsize", None) From 5c6fff27e327d4416688652e70f58f7fccece8a2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 14 Sep 2021 20:49:40 +0200 Subject: [PATCH 2/6] already imported in utils --- xarray/plot/dataset_plot.py | 7 +------ xarray/plot/facetgrid.py | 7 +------ xarray/plot/plot.py | 7 +------ 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 2bf9657a514..7288a368e47 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -11,15 +11,10 @@ _is_numeric, _process_cmap_cbar_kwargs, get_axis, - import_matplotlib_pyplot, label_from_attrs, + plt, ) -try: - plt = import_matplotlib_pyplot() -except ImportError: - plt = None - # copied from seaborn _MARKERSIZE_RANGE = np.array([18.0, 72.0]) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 88baa7bb230..b384dea0571 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -9,15 +9,10 @@ _get_nice_quiver_magnitude, _infer_xy_labels, _process_cmap_cbar_kwargs, - import_matplotlib_pyplot, label_from_attrs, + plt, ) -try: - plt = import_matplotlib_pyplot() -except ImportError: - plt = None - # Overrides axes.labelsize, xtick.major.size, ytick.major.size # from mpl.rcParams _FONTSIZE = "small" diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 4a4a4d2252f..cd797ebefd2 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -29,16 +29,11 @@ _resolve_intervals_2dplot, _update_axes, get_axis, - import_matplotlib_pyplot, label_from_attrs, legend_elements, + plt, ) -try: - plt = import_matplotlib_pyplot() -except ImportError: - plt = None - # copied from seaborn _MARKERSIZE_RANGE = np.array([18.0, 72.0]) From 2d07c9bfb59c7ea76267bed30bf461dc77bdcfd6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 14 Sep 2021 21:24:29 +0200 Subject: [PATCH 3/6] Add benchmark --- asv_bench/benchmarks/import_xarray.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 asv_bench/benchmarks/import_xarray.py diff --git a/asv_bench/benchmarks/import_xarray.py b/asv_bench/benchmarks/import_xarray.py new file mode 100644 index 00000000000..d1e472f6337 --- /dev/null +++ b/asv_bench/benchmarks/import_xarray.py @@ -0,0 +1,9 @@ +class ImportXarray: + def setup(self, *args, **kwargs): + def import_xr(): + import xarray + + self._import_xr = import_xr + + def import_xarray(self): + self._import_xr() From 7f402faf679ad739c57f6e1f72efffe08378ee54 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 14 Sep 2021 21:28:55 +0200 Subject: [PATCH 4/6] Update import_xarray.py --- asv_bench/benchmarks/import_xarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asv_bench/benchmarks/import_xarray.py b/asv_bench/benchmarks/import_xarray.py index d1e472f6337..45ca126aa31 100644 --- a/asv_bench/benchmarks/import_xarray.py +++ b/asv_bench/benchmarks/import_xarray.py @@ -1,7 +1,7 @@ class ImportXarray: def setup(self, *args, **kwargs): def import_xr(): - import xarray + import xarray # noqa: f401 self._import_xr = import_xr From d7f92df352b887343f13d4a069a261418bd4bba1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 14 Sep 2021 21:32:32 +0200 Subject: [PATCH 5/6] Update import_xarray.py --- asv_bench/benchmarks/import_xarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asv_bench/benchmarks/import_xarray.py b/asv_bench/benchmarks/import_xarray.py index 45ca126aa31..971d90baac8 100644 --- a/asv_bench/benchmarks/import_xarray.py +++ b/asv_bench/benchmarks/import_xarray.py @@ -1,7 +1,7 @@ class ImportXarray: def setup(self, *args, **kwargs): def import_xr(): - import xarray # noqa: f401 + import xarray # noqa: F401 self._import_xr = import_xr From 8a41948a0332b937bf20293849e271a25b09d35e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 14 Sep 2021 21:53:01 +0200 Subject: [PATCH 6/6] Update import_xarray.py --- asv_bench/benchmarks/import_xarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asv_bench/benchmarks/import_xarray.py b/asv_bench/benchmarks/import_xarray.py index 971d90baac8..94652e3b82a 100644 --- a/asv_bench/benchmarks/import_xarray.py +++ b/asv_bench/benchmarks/import_xarray.py @@ -5,5 +5,5 @@ def import_xr(): self._import_xr = import_xr - def import_xarray(self): + def time_import_xarray(self): self._import_xr()