Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: accept a dictionary in plot colors #31071

Merged
merged 4 commits into from
Jan 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ I/O
Plotting
^^^^^^^^

-
- :func:`.plot` for line/bar now accepts color by dictonary (:issue:`8193`).
-

Groupby/resample/rolling
Expand Down
173 changes: 103 additions & 70 deletions pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,45 @@ def hist_frame(
"""


_bar_or_line_doc = """
Parameters
----------
x : label or position, optional
Allows plotting of one column versus another. If not specified,
the index of the DataFrame is used.
y : label or position, optional
Allows plotting of one column versus another. If not specified,
all numerical columns are used.
color : str, array_like, or dict, optional
The color for each of the DataFrame's columns. Possible values are:

- A single color string referred to by name, RGB or RGBA code,
for instance 'red' or '#a98d19'.

- A sequence of color strings referred to by name, RGB or RGBA
code, which will be used for each column recursively. For
instance ['green','yellow'] each column's %(kind)s will be filled in
green or yellow, alternatively.

- A dict of the form {column name : color}, so that each column will be
colored accordingly. For example, if your columns are called `a` and
`b`, then passing {'a': 'green', 'b': 'red'} will color %(kind)ss for
column `a` in green and %(kind)ss for column `b` in red.

.. versionadded:: 1.1.0

**kwargs
Additional keyword arguments are documented in
:meth:`DataFrame.plot`.

Returns
-------
matplotlib.axes.Axes or np.ndarray of them
An ndarray is returned with one :class:`matplotlib.axes.Axes`
per column when ``subplots=True``.
"""


@Substitution(backend="")
@Appender(_boxplot_doc)
def boxplot(
Expand Down Expand Up @@ -848,31 +887,8 @@ def __call__(self, *args, **kwargs):

__call__.__doc__ = __doc__

def line(self, x=None, y=None, **kwargs):
@Appender(
"""
Plot Series or DataFrame as lines.

This function is useful to plot lines using DataFrame's values
as coordinates.

Parameters
----------
x : int or str, optional
Columns to use for the horizontal axis.
Either the location or the label of the columns to be used.
By default, it will use the DataFrame indices.
y : int, str, or list of them, optional
The values to be plotted.
Either the location or the label of the columns to be used.
By default, it will use the remaining DataFrame numeric columns.
**kwargs
Keyword arguments to pass on to :meth:`DataFrame.plot`.

Returns
-------
:class:`matplotlib.axes.Axes` or :class:`numpy.ndarray`
Return an ndarray when ``subplots=True``.

See Also
--------
matplotlib.pyplot.plot : Plot y versus x as lines and/or markers.
Expand Down Expand Up @@ -907,6 +923,16 @@ def line(self, x=None, y=None, **kwargs):
>>> type(axes)
<class 'numpy.ndarray'>

.. plot::
:context: close-figs

Let's repeat the same example, but specifying colors for
each column (in this case, for each animal).

>>> axes = df.plot.line(
... subplots=True, color={"pig": "pink", "horse": "#742802"}
... )

.. plot::
:context: close-figs

Expand All @@ -915,36 +941,20 @@ def line(self, x=None, y=None, **kwargs):

>>> lines = df.plot.line(x='pig', y='horse')
"""
return self(kind="line", x=x, y=y, **kwargs)

def bar(self, x=None, y=None, **kwargs):
)
@Substitution(kind="line")
@Appender(_bar_or_line_doc)
def line(self, x=None, y=None, **kwargs):
"""
Vertical bar plot.

A bar plot is a plot that presents categorical data with
rectangular bars with lengths proportional to the values that they
represent. A bar plot shows comparisons among discrete categories. One
axis of the plot shows the specific categories being compared, and the
other axis represents a measured value.

Parameters
----------
x : label or position, optional
Allows plotting of one column versus another. If not specified,
the index of the DataFrame is used.
y : label or position, optional
Allows plotting of one column versus another. If not specified,
all numerical columns are used.
**kwargs
Additional keyword arguments are documented in
:meth:`DataFrame.plot`.
Plot Series or DataFrame as lines.

Returns
-------
matplotlib.axes.Axes or np.ndarray of them
An ndarray is returned with one :class:`matplotlib.axes.Axes`
per column when ``subplots=True``.
This function is useful to plot lines using DataFrame's values
as coordinates.
"""
return self(kind="line", x=x, y=y, **kwargs)

@Appender(
"""
See Also
--------
DataFrame.plot.barh : Horizontal bar plot.
Expand Down Expand Up @@ -986,6 +996,17 @@ def bar(self, x=None, y=None, **kwargs):
>>> axes = df.plot.bar(rot=0, subplots=True)
>>> axes[1].legend(loc=2) # doctest: +SKIP

If you don't like the default colours, you can specify how you'd
like each column to be colored.

.. plot::
:context: close-figs

>>> axes = df.plot.bar(
... rot=0, subplots=True, color={"speed": "red", "lifespan": "green"}
... )
>>> axes[1].legend(loc=2) # doctest: +SKIP

Plot a single column.

.. plot::
Expand All @@ -999,32 +1020,24 @@ def bar(self, x=None, y=None, **kwargs):
:context: close-figs

>>> ax = df.plot.bar(x='lifespan', rot=0)
"""
)
@Substitution(kind="bar")
@Appender(_bar_or_line_doc)
def bar(self, x=None, y=None, **kwargs):
"""
return self(kind="bar", x=x, y=y, **kwargs)

def barh(self, x=None, y=None, **kwargs):
"""
Make a horizontal bar plot.
Vertical bar plot.

A horizontal bar plot is a plot that presents quantitative data with
A bar plot is a plot that presents categorical data with
rectangular bars with lengths proportional to the values that they
represent. A bar plot shows comparisons among discrete categories. One
axis of the plot shows the specific categories being compared, and the
other axis represents a measured value.
"""
return self(kind="bar", x=x, y=y, **kwargs)

Parameters
----------
x : label or position, default DataFrame.index
Column to be used for categories.
y : label or position, default All numeric columns in dataframe
Columns to be plotted from the DataFrame.
**kwargs
Keyword arguments to pass on to :meth:`DataFrame.plot`.

Returns
-------
:class:`matplotlib.axes.Axes` or numpy.ndarray of them

@Appender(
"""
See Also
--------
DataFrame.plot.bar: Vertical bar plot.
Expand Down Expand Up @@ -1054,6 +1067,13 @@ def barh(self, x=None, y=None, **kwargs):
... 'lifespan': lifespan}, index=index)
>>> ax = df.plot.barh()

We can specify colors for each column

.. plot::
:context: close-figs

>>> ax = df.plot.barh(color={"speed": "red", "lifespan": "green"})

Plot a column of the DataFrame to a horizontal bar plot

.. plot::
Expand All @@ -1079,6 +1099,19 @@ def barh(self, x=None, y=None, **kwargs):
>>> df = pd.DataFrame({'speed': speed,
... 'lifespan': lifespan}, index=index)
>>> ax = df.plot.barh(x='lifespan')
"""
)
@Substitution(kind="bar")
@Appender(_bar_or_line_doc)
def barh(self, x=None, y=None, **kwargs):
"""
Make a horizontal bar plot.

A horizontal bar plot is a plot that presents quantitative data with
rectangular bars with lengths proportional to the values that they
represent. A bar plot shows comparisons among discrete categories. One
axis of the plot shows the specific categories being compared, and the
other axis represents a measured value.
"""
return self(kind="barh", x=x, y=y, **kwargs)

Expand Down
7 changes: 6 additions & 1 deletion pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,10 @@ def _apply_style_colors(self, colors, kwds, col_num, label):
has_color = "color" in kwds or self.colormap is not None
nocolor_style = style is None or re.match("[a-z]+", style) is None
if (has_color or self.subplots) and nocolor_style:
kwds["color"] = colors[col_num % len(colors)]
if isinstance(colors, dict):
kwds["color"] = colors[label]
else:
kwds["color"] = colors[col_num % len(colors)]
return style, kwds

def _get_colors(self, num_colors=None, color_kwds="color"):
Expand Down Expand Up @@ -1347,6 +1350,8 @@ def _make_plot(self):
kwds = self.kwds.copy()
if self._is_series:
kwds["color"] = colors
elif isinstance(colors, dict):
kwds["color"] = colors[label]
else:
kwds["color"] = colors[i % ncolors]

Expand Down
6 changes: 5 additions & 1 deletion pandas/plotting/_matplotlib/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def _get_standard_colors(
warnings.warn(
"'color' and 'colormap' cannot be used simultaneously. Using 'color'"
)
colors = list(color) if is_list_like(color) else color
colors = (
list(color)
if is_list_like(color) and not isinstance(color, dict)
else color
)
else:
if color_type == "default":
# need to call list() on the result to copy so we don't
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/plotting/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,24 @@ def test_get_standard_colors_no_appending(self):
color_list = cm.gnuplot(np.linspace(0, 1, 16))
p = df.A.plot.bar(figsize=(16, 7), color=color_list)
assert p.patches[1].get_facecolor() == p.patches[17].get_facecolor()

@pytest.mark.slow
def test_dictionary_color(self):
# issue-8193
# Test plot color dictionary format
data_files = ["a", "b"]

expected = [(0.5, 0.24, 0.6), (0.3, 0.7, 0.7)]

df1 = DataFrame(np.random.rand(2, 2), columns=data_files)
dic_color = {"b": (0.3, 0.7, 0.7), "a": (0.5, 0.24, 0.6)}

# Bar color test
ax = df1.plot(kind="bar", color=dic_color)
colors = [rect.get_facecolor()[0:-1] for rect in ax.get_children()[0:3:2]]
assert all(color == expected[index] for index, color in enumerate(colors))

# Line color test
ax = df1.plot(kind="line", color=dic_color)
colors = [rect.get_color() for rect in ax.get_lines()[0:2]]
assert all(color == expected[index] for index, color in enumerate(colors))