Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix .hist and .plot.hist when passing existing figure (#37278) #37467

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ Other enhancements
- :class:`Index` with object dtype supports division and multiplication (:issue:`34160`)
- :meth:`DataFrame.explode` and :meth:`Series.explode` now support exploding of sets (:issue:`35614`)
- :meth:`DataFrame.hist` now supports time series (datetime) data (:issue:`32590`)
- :meth:`DataFrame.hist` and :meth:`DataFrame.plot.hist` can now be called with an existing matplotlib ``Figure`` object via added ``figure`` argument (:issue:`37278`)
- ``Styler`` now allows direct CSS class name addition to individual data cells (:issue:`36159`)
- :meth:`Rolling.mean()` and :meth:`Rolling.sum()` use Kahan summation to calculate the mean to avoid numerical problems (:issue:`10319`, :issue:`11645`, :issue:`13254`, :issue:`32761`, :issue:`36031`)
- :meth:`DatetimeIndex.searchsorted`, :meth:`TimedeltaIndex.searchsorted`, :meth:`PeriodIndex.searchsorted`, and :meth:`Series.searchsorted` with datetimelike dtypes will now try to cast string arguments (listlike and scalar) to the matching datetimelike type (:issue:`36346`)
Expand Down
6 changes: 6 additions & 0 deletions pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from pandas.core.base import PandasObject

if TYPE_CHECKING:
from matplotlib.figure import Figure

from pandas import DataFrame


Expand Down Expand Up @@ -107,6 +109,7 @@ def hist_frame(
xrot: Optional[float] = None,
ylabelsize: Optional[int] = None,
yrot: Optional[float] = None,
figure: Optional["Figure"] = None,
ax=None,
sharex: bool = False,
sharey: bool = False,
Expand Down Expand Up @@ -146,6 +149,8 @@ def hist_frame(
yrot : float, default None
Rotation of y axis labels. For example, a value of 90 displays the
y labels rotated 90 degrees clockwise.
figure : Matplotlib Figure object, default None
The figure to plot the histogram on.
ax : Matplotlib axes object, default None
The axes to plot the histogram on.
sharex : bool, default True if ax is None else False
Expand Down Expand Up @@ -217,6 +222,7 @@ def hist_frame(
xrot=xrot,
ylabelsize=ylabelsize,
yrot=yrot,
figure=figure,
ax=ax,
sharex=sharex,
sharey=sharey,
Expand Down
4 changes: 4 additions & 0 deletions pandas/plotting/_matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def plot(data, kind, **kwargs):
# work)
import matplotlib.pyplot as plt

if kwargs.get("figure"):
kwargs["fig"] = kwargs.get("figure")
kwargs["ax"] = kwargs["figure"].gca()
kwargs.pop("reuse_plot", None)
if kwargs.pop("reuse_plot", False):
ax = kwargs.get("ax")
if ax is None and len(plt.get_fignums()) > 0:
Expand Down
6 changes: 5 additions & 1 deletion pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,16 @@ def _setup_subplots(self):
sharex=self.sharex,
sharey=self.sharey,
figsize=self.figsize,
figure=self.fig,
ax=self.ax,
layout=self.layout,
layout_type=self._layout_type,
)
else:
if self.ax is None:
if self.fig is not None:
fig = self.fig
axes = fig.add_subplot(111)
elif self.ax is None:
fig = self.plt.figure(figsize=self.figsize)
axes = fig.add_subplot(111)
else:
Expand Down
18 changes: 16 additions & 2 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import numpy as np

Expand All @@ -16,6 +16,7 @@

if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure


class HistPlot(LinePlot):
Expand Down Expand Up @@ -181,6 +182,7 @@ def _grouped_plot(
column=None,
by=None,
numeric_only=True,
figure: Optional["Figure"] = None,
figsize=None,
sharex=True,
sharey=True,
Expand All @@ -203,7 +205,13 @@ def _grouped_plot(

naxes = len(grouped)
fig, axes = create_subplots(
naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
naxes=naxes,
figure=figure,
figsize=figsize,
sharex=sharex,
sharey=sharey,
ax=ax,
layout=layout,
)

_axes = flatten_axes(axes)
Expand All @@ -222,6 +230,7 @@ def _grouped_hist(
data,
column=None,
by=None,
figure=None,
ax=None,
bins=50,
figsize=None,
Expand All @@ -245,6 +254,7 @@ def _grouped_hist(
data : Series/DataFrame
column : object, optional
by : object, optional
figure: figure, optional
ax : axes, optional
bins : int, default 50
figsize : tuple, optional
Expand Down Expand Up @@ -282,6 +292,7 @@ def plot_group(group, ax):
data,
column=column,
by=by,
figure=figure,
sharex=sharex,
sharey=sharey,
ax=ax,
Expand Down Expand Up @@ -381,6 +392,7 @@ def hist_frame(
xrot=None,
ylabelsize=None,
yrot=None,
figure=None,
ax=None,
sharex=False,
sharey=False,
Expand All @@ -397,6 +409,7 @@ def hist_frame(
data,
column=column,
by=by,
figure=figure,
ax=ax,
grid=grid,
figsize=figsize,
Expand Down Expand Up @@ -430,6 +443,7 @@ def hist_frame(

fig, axes = create_subplots(
naxes=naxes,
figure=figure,
ax=ax,
squeeze=False,
sharex=sharex,
Expand Down
11 changes: 9 additions & 2 deletions pandas/plotting/_matplotlib/tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# being a bit too dynamic
from math import ceil
from typing import TYPE_CHECKING, Iterable, List, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple, Union
import warnings

import matplotlib.table
Expand All @@ -17,6 +17,7 @@
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.axis import Axis
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from matplotlib.table import Table

Expand Down Expand Up @@ -106,6 +107,7 @@ def create_subplots(
sharey: bool = False,
squeeze: bool = True,
subplot_kw=None,
figure: Optional["Figure"] = None,
ax=None,
layout=None,
layout_type: str = "box",
Expand Down Expand Up @@ -145,6 +147,9 @@ def create_subplots(
Dict with keywords passed to the add_subplot() call used to create each
subplots.

figure : Matplotlib figure object, optional
Existing figure to be used for plotting.

ax : Matplotlib axis object, optional

layout : tuple
Expand Down Expand Up @@ -190,7 +195,9 @@ def create_subplots(
if subplot_kw is None:
subplot_kw = {}

if ax is None:
if figure is not None:
fig = figure
elif ax is None:
fig = plt.figure(**fig_kw)
else:
if is_list_like(ax):
Expand Down
14 changes: 14 additions & 0 deletions pandas/tests/plotting/test_hist_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ def test_hist_with_legend_raises(self, by):
with pytest.raises(ValueError, match="Cannot use both legend and label"):
s.hist(legend=True, by=by, label="c")

def test_hist_with_figure_argument(self):
# GH37278
index = 15 * ["1"] + 15 * ["2"]
s = Series(np.random.randn(30), index=index, name="a")
_check_plot_works(s.hist, figure=self.plt.figure())
_check_plot_works(s.plot.hist, figure=self.plt.figure())


@td.skip_if_no_mpl
class TestDataFramePlots(TestPlotBase):
Expand Down Expand Up @@ -395,6 +402,13 @@ def test_hist_with_legend_raises(self, by, column):
with pytest.raises(ValueError, match="Cannot use both legend and label"):
df.hist(legend=True, by=by, column=column, label="d")

def test_hist_with_figure_argument(self):
# GH37278
index = Index(15 * ["1"] + 15 * ["2"], name="c")
df = DataFrame(np.random.randn(30, 2), index=index, columns=["a", "b"])
_check_plot_works(df.hist, figure=self.plt.figure())
_check_plot_works(df.plot.hist, figure=self.plt.figure())


@td.skip_if_no_mpl
class TestDataFrameGroupByPlots(TestPlotBase):
Expand Down