Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define plot methods in class definitions #16913

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5591,9 +5591,8 @@ def isin(self, values):
2 True True
"""
if isinstance(values, dict):
from collections import defaultdict
from pandas.core.reshape.concat import concat
values = defaultdict(list, values)
values = collections.defaultdict(list, values)
return concat((self.iloc[:, [i]].isin(values[col])
for i, col in enumerate(self.columns)), axis=1)
elif isinstance(values, Series):
Expand All @@ -5617,6 +5616,24 @@ def isin(self, values):
values).reshape(self.shape), self.index,
self.columns)

# ----------------------------------------------------------------------
# Add plotting methods to DataFrame
plot = base.AccessorProperty(gfx.FramePlotMethods, gfx.FramePlotMethods)
hist = gfx.hist_frame

@Appender(_shared_docs['boxplot'] % _shared_doc_kwargs)
def boxplot(self, column=None, by=None, ax=None, fontsize=None, rot=0,
grid=True, figsize=None, layout=None, return_type=None, **kwds):
from pandas.plotting._core import boxplot
import matplotlib.pyplot as plt
ax = boxplot(self, column=column, by=by, ax=ax, fontsize=fontsize,
grid=grid, rot=rot, figsize=figsize, layout=layout,
return_type=return_type, **kwds)
plt.draw_if_interactive()
return ax




DataFrame._setup_axes(['index', 'columns'], info_axis=1, stat_axis=0,
axes_are_reversed=True, aliases={'rows': 0})
Expand Down Expand Up @@ -5970,26 +5987,6 @@ def _put_str(s, space):
return ('%s' % s)[:space].ljust(space)


# ----------------------------------------------------------------------
# Add plotting methods to DataFrame
DataFrame.plot = base.AccessorProperty(gfx.FramePlotMethods,
gfx.FramePlotMethods)
DataFrame.hist = gfx.hist_frame


@Appender(_shared_docs['boxplot'] % _shared_doc_kwargs)
def boxplot(self, column=None, by=None, ax=None, fontsize=None, rot=0,
grid=True, figsize=None, layout=None, return_type=None, **kwds):
from pandas.plotting._core import boxplot
import matplotlib.pyplot as plt
ax = boxplot(self, column=column, by=by, ax=ax, fontsize=fontsize,
grid=grid, rot=rot, figsize=figsize, layout=layout,
return_type=return_type, **kwds)
plt.draw_if_interactive()
return ax


DataFrame.boxplot = boxplot

ops.add_flex_arithmetic_methods(DataFrame, **ops.frame_flex_funcs)
ops.add_special_arithmetic_methods(DataFrame, **ops.frame_special_funcs)
16 changes: 8 additions & 8 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
from pandas._libs import index as libindex, tslib as libts, lib, iNaT
from pandas.core.config import get_option

import pandas.plotting._core as gfx # noqa

__all__ = ['Series']

_shared_doc_kwargs = dict(
Expand Down Expand Up @@ -2877,6 +2879,12 @@ def _dir_additions(self):
pass
return rv

# ----------------------------------------------------------------------
# Add plotting methods to Series

plot = base.AccessorProperty(gfx.SeriesPlotMethods, gfx.SeriesPlotMethods)
hist = gfx.hist_series


Series._setup_axes(['index'], info_axis=0, stat_axis=0, aliases={'rows': 0})
Series._add_numeric_operations()
Expand Down Expand Up @@ -3064,14 +3072,6 @@ def create_from_value(value, index, dtype):
return subarr


# ----------------------------------------------------------------------
# Add plotting methods to Series

import pandas.plotting._core as _gfx # noqa

Series.plot = base.AccessorProperty(_gfx.SeriesPlotMethods,
_gfx.SeriesPlotMethods)
Series.hist = _gfx.hist_series

# Add arithmetic!
ops.add_flex_arithmetic_methods(Series, **ops.series_flex_funcs)
Expand Down
54 changes: 28 additions & 26 deletions pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

from pandas.util._decorators import cache_readonly
from pandas.core.base import PandasObject

from pandas.core.dtypes.generic import (ABCSeries, ABCDataFrame, ABCIndex,
ABCMultiIndex, ABCPeriodIndex)
from pandas.core.dtypes.missing import notnull
from pandas.core.dtypes.common import (
is_list_like,
Expand All @@ -20,9 +23,8 @@
is_iterator)
from pandas.core.common import AbstractMethodError, isnull, _try_sort
from pandas.core.generic import _shared_docs, _shared_doc_kwargs
from pandas.core.index import Index, MultiIndex
from pandas.core.series import Series, remove_na
from pandas.core.indexes.period import PeriodIndex


from pandas.compat import range, lrange, map, zip, string_types
import pandas.compat as compat
from pandas.io.formats.printing import pprint_thing
Expand Down Expand Up @@ -156,7 +158,7 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=None,
for kw, err in zip(['xerr', 'yerr'], [xerr, yerr]):
self.errors[kw] = self._parse_errorbars(kw, err)

if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, Index)):
if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, ABCIndex)):
secondary_y = [secondary_y]
self.secondary_y = secondary_y

Expand Down Expand Up @@ -334,7 +336,7 @@ def result(self):
def _compute_plot_data(self):
data = self.data

if isinstance(data, Series):
if isinstance(data, ABCSeries):
label = self.label
if label is None and data.name is None:
label = 'None'
Expand Down Expand Up @@ -451,7 +453,7 @@ def _apply_axis_properties(self, axis, rot=None, fontsize=None):

@property
def legend_title(self):
if not isinstance(self.data.columns, MultiIndex):
if not isinstance(self.data.columns, ABCMultiIndex):
name = self.data.columns.name
if name is not None:
name = pprint_thing(name)
Expand Down Expand Up @@ -533,7 +535,7 @@ def _get_xticks(self, convert_period=False):
'datetime64', 'time')

if self.use_index:
if convert_period and isinstance(index, PeriodIndex):
if convert_period and isinstance(index, ABCPeriodIndex):
self.data = self.data.reindex(index=index.sort_values())
x = self.data.index.to_timestamp()._mpl_repr()
elif index.is_numeric():
Expand Down Expand Up @@ -563,7 +565,7 @@ def _plot(cls, ax, x, y, style=None, is_errorbar=False, **kwds):
y = np.ma.array(y)
y = np.ma.masked_where(mask, y)

if isinstance(x, Index):
if isinstance(x, ABCIndex):
x = x._mpl_repr()

if is_errorbar:
Expand All @@ -582,7 +584,7 @@ def _plot(cls, ax, x, y, style=None, is_errorbar=False, **kwds):
return ax.plot(*args, **kwds)

def _get_index_name(self):
if isinstance(self.data.index, MultiIndex):
if isinstance(self.data.index, ABCMultiIndex):
name = self.data.index.names
if any(x is not None for x in name):
name = ','.join([pprint_thing(x) for x in name])
Expand Down Expand Up @@ -620,7 +622,7 @@ def on_right(self, i):
if isinstance(self.secondary_y, bool):
return self.secondary_y

if isinstance(self.secondary_y, (tuple, list, np.ndarray, Index)):
if isinstance(self.secondary_y, (tuple, list, np.ndarray, ABCIndex)):
return self.data.columns[i] in self.secondary_y

def _apply_style_colors(self, colors, kwds, col_num, label):
Expand Down Expand Up @@ -671,22 +673,21 @@ def _parse_errorbars(self, label, err):
if err is None:
return None

from pandas import DataFrame, Series

def match_labels(data, e):
e = e.reindex_axis(data.index)
return e

# key-matched DataFrame
if isinstance(err, DataFrame):
if isinstance(err, ABCDataFrame):

err = match_labels(self.data, err)
# key-matched dict
elif isinstance(err, dict):
pass

# Series of error values
elif isinstance(err, Series):
elif isinstance(err, ABCSeries):
# broadcast error series across data
err = match_labels(self.data, err)
err = np.atleast_2d(err)
Expand Down Expand Up @@ -732,14 +733,13 @@ def match_labels(data, e):
return err

def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
from pandas import DataFrame
errors = {}

for kw, flag in zip(['xerr', 'yerr'], [xerr, yerr]):
if flag:
err = self.errors[kw]
# user provided label-matched dataframe of errors
if isinstance(err, (DataFrame, dict)):
if isinstance(err, (ABCDataFrame, dict)):
if label is not None and label in err.keys():
err = err[label]
else:
Expand Down Expand Up @@ -1376,6 +1376,7 @@ def _plot(cls, ax, y, style=None, bw_method=None, ind=None,
from scipy.stats import gaussian_kde
from scipy import __version__ as spv

from pandas.core.series import remove_na
y = remove_na(y)

if LooseVersion(spv) >= '0.11.0':
Expand Down Expand Up @@ -1494,6 +1495,7 @@ def _args_adjust(self):

@classmethod
def _plot(cls, ax, y, column_num=None, return_type='axes', **kwds):
from pandas.core.series import remove_na
if y.ndim == 2:
y = [remove_na(v) for v in y]
# Boxplot fails with empty arrays, so need to add a NaN
Expand Down Expand Up @@ -1566,6 +1568,7 @@ def maybe_color_bp(self, bp):

def _make_plot(self):
if self.subplots:
from pandas.core.series import Series
self._return_obj = Series()

for i, (label, y) in enumerate(self._iter_data()):
Expand Down Expand Up @@ -1647,17 +1650,16 @@ def _plot(data, x=None, y=None, subplots=False,
else:
raise ValueError("%r is not a valid plot kind" % kind)

from pandas import DataFrame
if kind in _dataframe_kinds:
if isinstance(data, DataFrame):
if isinstance(data, ABCDataFrame):
plot_obj = klass(data, x=x, y=y, subplots=subplots, ax=ax,
kind=kind, **kwds)
else:
raise ValueError("plot kind %r can only be used for data frames"
% kind)

elif kind in _series_kinds:
if isinstance(data, DataFrame):
if isinstance(data, ABCDataFrame):
if y is None and subplots is False:
msg = "{0} requires either y column or 'subplots=True'"
raise ValueError(msg.format(kind))
Expand All @@ -1669,7 +1671,7 @@ def _plot(data, x=None, y=None, subplots=False,
data.index.name = y
plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
else:
if isinstance(data, DataFrame):
if isinstance(data, ABCDataFrame):
if x is not None:
if is_integer(x) and not data.columns.holds_integer():
x = data.columns[x]
Expand Down Expand Up @@ -1952,9 +1954,8 @@ def boxplot(data, column=None, by=None, ax=None, fontsize=None,
if return_type not in BoxPlot._valid_return_types:
raise ValueError("return_type must be {'axes', 'dict', 'both'}")

from pandas import Series, DataFrame
if isinstance(data, Series):
data = DataFrame({'x': data})
if isinstance(data, ABCSeries):
data = data.to_frame(name='x')
column = 'x'

def _get_colors():
Expand All @@ -1968,6 +1969,7 @@ def maybe_color_bp(bp):
setp(bp['medians'], color=colors[2], alpha=1)

def plot_group(keys, values, ax):
from pandas.core.series import remove_na
keys = [pprint_thing(x) for x in keys]
values = [remove_na(v) for v in values]
bp = ax.boxplot(values, **kwds)
Expand Down Expand Up @@ -2123,7 +2125,7 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
return axes

if column is not None:
if not isinstance(column, (list, np.ndarray, Index)):
if not isinstance(column, (list, np.ndarray, ABCIndex)):
column = [column]
data = data[column]
data = data._get_numeric_data()
Expand Down Expand Up @@ -2317,6 +2319,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
figsize=figsize, layout=layout)
axes = _flatten(axes)

from pandas.core.series import Series
ret = Series()
for (key, group), ax in zip(grouped, axes):
d = group.boxplot(ax=ax, column=column, fontsize=fontsize,
Expand Down Expand Up @@ -2344,7 +2347,6 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
figsize=None, sharex=True, sharey=True, layout=None,
rot=0, ax=None, **kwargs):
from pandas import DataFrame

if figsize == 'default':
# allowed to specify mpl default with 'default'
Expand All @@ -2365,7 +2367,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,

for i, (key, group) in enumerate(grouped):
ax = _axes[i]
if numeric_only and isinstance(group, DataFrame):
if numeric_only and isinstance(group, ABCDataFrame):
group = group._get_numeric_data()
plotf(group, ax, **kwargs)
ax.set_title(pprint_thing(key))
Expand All @@ -2388,7 +2390,6 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,

_axes = _flatten(axes)

result = Series()
ax_values = []

for i, col in enumerate(columns):
Expand All @@ -2401,6 +2402,7 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
ax_values.append(re_plotf)
ax.grid(grid)

from pandas.core.series import Series
result = Series(ax_values, index=columns)

# Return axes in multiplot case, maybe revisit later # 985
Expand Down
13 changes: 6 additions & 7 deletions pandas/plotting/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import numpy as np

from pandas.core.dtypes.generic import ABCSeries, ABCDataFrame, ABCIndex

from pandas.core.dtypes.common import is_list_like
from pandas.core.index import Index
from pandas.core.series import Series
from pandas.compat import range


Expand Down Expand Up @@ -44,10 +44,9 @@ def table(ax, data, rowLabels=None, colLabels=None,
-------
matplotlib table object
"""
from pandas import DataFrame
if isinstance(data, Series):
data = DataFrame(data, columns=[data.name])
elif isinstance(data, DataFrame):
if isinstance(data, ABCSeries):
data = data.to_frame()
elif isinstance(data, ABCDataFrame):
pass
else:
raise ValueError('Input data must be DataFrame or Series')
Expand Down Expand Up @@ -341,7 +340,7 @@ def _handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey):
def _flatten(axes):
if not is_list_like(axes):
return np.array([axes])
elif isinstance(axes, (np.ndarray, Index)):
elif isinstance(axes, (np.ndarray, ABCIndex)):
return axes.ravel()
return np.array(axes)

Expand Down