From 0c34b9bd5ae5cc450dea152525d6fc90ee43586c Mon Sep 17 00:00:00 2001 From: krsnik93 Date: Wed, 28 Oct 2020 10:59:56 +0000 Subject: [PATCH] BUG: Fix .hist and .plot.hist when passing existing figure (#37278) --- doc/source/whatsnew/v1.2.0.rst | 1 + pandas/plotting/_core.py | 6 ++++++ pandas/plotting/_matplotlib/__init__.py | 4 ++++ pandas/plotting/_matplotlib/core.py | 6 +++++- pandas/plotting/_matplotlib/hist.py | 18 ++++++++++++++++-- pandas/plotting/_matplotlib/tools.py | 11 +++++++++-- pandas/tests/plotting/test_hist_method.py | 14 ++++++++++++++ 7 files changed, 55 insertions(+), 5 deletions(-) diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 812af544ed9d8..7524d2a4aac8d 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -214,6 +214,7 @@ Other enhancements - :class:`Index` with object dtype supports division and multiplication (:issue:`34160`) - :meth:`DataFrame.explode` and :meth:`Series.explode` now support exploding of sets (:issue:`35614`) - :meth:`DataFrame.hist` now supports time series (datetime) data (:issue:`32590`) +- :meth:`DataFrame.hist` and :meth:`DataFrame.plot.hist` can now be called with an existing matplotlib ``Figure`` object via added ``figure`` argument (:issue:`37278`) - ``Styler`` now allows direct CSS class name addition to individual data cells (:issue:`36159`) - :meth:`Rolling.mean()` and :meth:`Rolling.sum()` use Kahan summation to calculate the mean to avoid numerical problems (:issue:`10319`, :issue:`11645`, :issue:`13254`, :issue:`32761`, :issue:`36031`) - :meth:`DatetimeIndex.searchsorted`, :meth:`TimedeltaIndex.searchsorted`, :meth:`PeriodIndex.searchsorted`, and :meth:`Series.searchsorted` with datetimelike dtypes will now try to cast string arguments (listlike and scalar) to the matching datetimelike type (:issue:`36346`) diff --git a/pandas/plotting/_core.py b/pandas/plotting/_core.py index e0e35e31d22ac..0f54e8d473e6a 100644 --- a/pandas/plotting/_core.py +++ b/pandas/plotting/_core.py @@ -12,6 +12,8 @@ from pandas.core.base import PandasObject if TYPE_CHECKING: + from matplotlib.figure import Figure + from pandas import DataFrame @@ -107,6 +109,7 @@ def hist_frame( xrot: Optional[float] = None, ylabelsize: Optional[int] = None, yrot: Optional[float] = None, + figure: Optional["Figure"] = None, ax=None, sharex: bool = False, sharey: bool = False, @@ -146,6 +149,8 @@ def hist_frame( yrot : float, default None Rotation of y axis labels. For example, a value of 90 displays the y labels rotated 90 degrees clockwise. + figure : Matplotlib Figure object, default None + The figure to plot the histogram on. ax : Matplotlib axes object, default None The axes to plot the histogram on. sharex : bool, default True if ax is None else False @@ -217,6 +222,7 @@ def hist_frame( xrot=xrot, ylabelsize=ylabelsize, yrot=yrot, + figure=figure, ax=ax, sharex=sharex, sharey=sharey, diff --git a/pandas/plotting/_matplotlib/__init__.py b/pandas/plotting/_matplotlib/__init__.py index 33011e6a66cac..133c16cfaa025 100644 --- a/pandas/plotting/_matplotlib/__init__.py +++ b/pandas/plotting/_matplotlib/__init__.py @@ -51,6 +51,10 @@ def plot(data, kind, **kwargs): # work) import matplotlib.pyplot as plt + if kwargs.get("figure"): + kwargs["fig"] = kwargs.get("figure") + kwargs["ax"] = kwargs["figure"].gca() + kwargs.pop("reuse_plot", None) if kwargs.pop("reuse_plot", False): ax = kwargs.get("ax") if ax is None and len(plt.get_fignums()) > 0: diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 6c9924e0ada79..5c7e88964b8b8 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -325,12 +325,16 @@ def _setup_subplots(self): sharex=self.sharex, sharey=self.sharey, figsize=self.figsize, + figure=self.fig, ax=self.ax, layout=self.layout, layout_type=self._layout_type, ) else: - if self.ax is None: + if self.fig is not None: + fig = self.fig + axes = fig.add_subplot(111) + elif self.ax is None: fig = self.plt.figure(figsize=self.figsize) axes = fig.add_subplot(111) else: diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index 6d22d2ffe4a51..4e0824f36302c 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import numpy as np @@ -16,6 +16,7 @@ if TYPE_CHECKING: from matplotlib.axes import Axes + from matplotlib.figure import Figure class HistPlot(LinePlot): @@ -181,6 +182,7 @@ def _grouped_plot( column=None, by=None, numeric_only=True, + figure: Optional["Figure"] = None, figsize=None, sharex=True, sharey=True, @@ -203,7 +205,13 @@ def _grouped_plot( naxes = len(grouped) fig, axes = create_subplots( - naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout + naxes=naxes, + figure=figure, + figsize=figsize, + sharex=sharex, + sharey=sharey, + ax=ax, + layout=layout, ) _axes = flatten_axes(axes) @@ -222,6 +230,7 @@ def _grouped_hist( data, column=None, by=None, + figure=None, ax=None, bins=50, figsize=None, @@ -245,6 +254,7 @@ def _grouped_hist( data : Series/DataFrame column : object, optional by : object, optional + figure: figure, optional ax : axes, optional bins : int, default 50 figsize : tuple, optional @@ -282,6 +292,7 @@ def plot_group(group, ax): data, column=column, by=by, + figure=figure, sharex=sharex, sharey=sharey, ax=ax, @@ -381,6 +392,7 @@ def hist_frame( xrot=None, ylabelsize=None, yrot=None, + figure=None, ax=None, sharex=False, sharey=False, @@ -397,6 +409,7 @@ def hist_frame( data, column=column, by=by, + figure=figure, ax=ax, grid=grid, figsize=figsize, @@ -430,6 +443,7 @@ def hist_frame( fig, axes = create_subplots( naxes=naxes, + figure=figure, ax=ax, squeeze=False, sharex=sharex, diff --git a/pandas/plotting/_matplotlib/tools.py b/pandas/plotting/_matplotlib/tools.py index bec1f48f5e64a..30dcb83226cc4 100644 --- a/pandas/plotting/_matplotlib/tools.py +++ b/pandas/plotting/_matplotlib/tools.py @@ -1,6 +1,6 @@ # being a bit too dynamic from math import ceil -from typing import TYPE_CHECKING, Iterable, List, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple, Union import warnings import matplotlib.table @@ -17,6 +17,7 @@ if TYPE_CHECKING: from matplotlib.axes import Axes from matplotlib.axis import Axis + from matplotlib.figure import Figure from matplotlib.lines import Line2D from matplotlib.table import Table @@ -106,6 +107,7 @@ def create_subplots( sharey: bool = False, squeeze: bool = True, subplot_kw=None, + figure: Optional["Figure"] = None, ax=None, layout=None, layout_type: str = "box", @@ -145,6 +147,9 @@ def create_subplots( Dict with keywords passed to the add_subplot() call used to create each subplots. + figure : Matplotlib figure object, optional + Existing figure to be used for plotting. + ax : Matplotlib axis object, optional layout : tuple @@ -190,7 +195,9 @@ def create_subplots( if subplot_kw is None: subplot_kw = {} - if ax is None: + if figure is not None: + fig = figure + elif ax is None: fig = plt.figure(**fig_kw) else: if is_list_like(ax): diff --git a/pandas/tests/plotting/test_hist_method.py b/pandas/tests/plotting/test_hist_method.py index d9a58e808661b..5c23ea7b45e3f 100644 --- a/pandas/tests/plotting/test_hist_method.py +++ b/pandas/tests/plotting/test_hist_method.py @@ -152,6 +152,13 @@ def test_hist_with_legend_raises(self, by): with pytest.raises(ValueError, match="Cannot use both legend and label"): s.hist(legend=True, by=by, label="c") + def test_hist_with_figure_argument(self): + # GH37278 + index = 15 * ["1"] + 15 * ["2"] + s = Series(np.random.randn(30), index=index, name="a") + _check_plot_works(s.hist, figure=self.plt.figure()) + _check_plot_works(s.plot.hist, figure=self.plt.figure()) + @td.skip_if_no_mpl class TestDataFramePlots(TestPlotBase): @@ -395,6 +402,13 @@ def test_hist_with_legend_raises(self, by, column): with pytest.raises(ValueError, match="Cannot use both legend and label"): df.hist(legend=True, by=by, column=column, label="d") + def test_hist_with_figure_argument(self): + # GH37278 + index = Index(15 * ["1"] + 15 * ["2"], name="c") + df = DataFrame(np.random.randn(30, 2), index=index, columns=["a", "b"]) + _check_plot_works(df.hist, figure=self.plt.figure()) + _check_plot_works(df.plot.hist, figure=self.plt.figure()) + @td.skip_if_no_mpl class TestDataFrameGroupByPlots(TestPlotBase):