Skip to content

Commit

Permalink
Fixes issue #8193
Browse files Browse the repository at this point in the history
  • Loading branch information
Leostayner authored and Marco Gorelli committed Jan 16, 2020
1 parent 5d49730 commit 8f7ee75
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,7 @@ Plotting
- Bug in color validation incorrectly raising for non-color styles (:issue:`29122`).
- Allow :meth:`DataFrame.plot.scatter` to plot ``objects`` and ``datetime`` type data (:issue:`18755`, :issue:`30391`)
- Bug in :meth:`DataFrame.hist`, ``xrot=0`` does not work with ``by`` and subplots (:issue:`30288`).
- :func:`.plot` for line/bar now accepts color by dictonary (:issue:`8193`).

Groupby/resample/rolling
^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
8 changes: 6 additions & 2 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,10 @@ def _apply_style_colors(self, colors, kwds, col_num, label):
has_color = "color" in kwds or self.colormap is not None
nocolor_style = style is None or re.match("[a-z]+", style) is None
if (has_color or self.subplots) and nocolor_style:
kwds["color"] = colors[col_num % len(colors)]
if isinstance(colors, dict):
kwds["color"] = colors[label]
else:
kwds["color"] = colors[col_num % len(colors)]
return style, kwds

def _get_colors(self, num_colors=None, color_kwds="color"):
Expand Down Expand Up @@ -1341,12 +1344,13 @@ def _make_plot(self):

pos_prior = neg_prior = np.zeros(len(self.data))
K = self.nseries

for i, (label, y) in enumerate(self._iter_data(fillna=0)):
ax = self._get_ax(i)
kwds = self.kwds.copy()
if self._is_series:
kwds["color"] = colors
elif isinstance(colors, dict):
kwds["color"] = colors[label]
else:
kwds["color"] = colors[i % ncolors]

Expand Down
6 changes: 5 additions & 1 deletion pandas/plotting/_matplotlib/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def _get_standard_colors(
warnings.warn(
"'color' and 'colormap' cannot be used simultaneously. Using 'color'"
)
colors = list(color) if is_list_like(color) else color
colors = (
list(color)
if is_list_like(color) and not isinstance(color, dict)
else color
)
else:
if color_type == "default":
# need to call list() on the result to copy so we don't
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/plotting/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,24 @@ def test_get_standard_colors_no_appending(self):
color_list = cm.gnuplot(np.linspace(0, 1, 16))
p = df.A.plot.bar(figsize=(16, 7), color=color_list)
assert p.patches[1].get_facecolor() == p.patches[17].get_facecolor()

@pytest.mark.slow
def test_dictionary_color(self):
# issue-8193
# Test plot color dictionary format
data_files = ["a", "b"]

expected = [(0.5, 0.24, 0.6), (0.3, 0.7, 0.7)]

df1 = DataFrame(np.random.rand(2, 2), columns=data_files)
dic_color = {"b": (0.3, 0.7, 0.7), "a": (0.5, 0.24, 0.6)}

# Bar color test
ax = df1.plot(kind="bar", color=dic_color)
colors = [rect.get_facecolor()[0:-1] for rect in ax.get_children()[0:3:2]]
assert all(color == expected[index] for index, color in enumerate(colors))

# Line color test
ax = df1.plot(kind="line", color=dic_color)
colors = [rect.get_color() for rect in ax.get_lines()[0:2]]
assert all(color == expected[index] for index, color in enumerate(colors))

0 comments on commit 8f7ee75

Please sign in to comment.