Skip to content

Commit

Permalink
Rename Plot.configure -> Plot.layout, add Plot.share, and modify Plot…
Browse files Browse the repository at this point in the history
….layout parameters (#2954)

* Rename Plot.configure -> Plot.layout, add algo parameter

* Document and test layout(algo=)

* Rename layout(figsize=) -> layout(size=)

* Add Plot.share and remove sharex/sharey from Plot.layout

* Cover constrained layout in tests too

* Update nextgen docs

* Handle test backwards compatability

* Update nextgen api docs
  • Loading branch information
mwaskom authored Aug 13, 2022
1 parent b1db0f7 commit b5c4c35
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 42 deletions.
5 changes: 3 additions & 2 deletions doc/nextgen/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ Plot interface

Plot
Plot.add
Plot.scale
Plot.facet
Plot.pair
Plot.configure
Plot.layout
Plot.on
Plot.plot
Plot.save
Plot.scale
Plot.share
Plot.show

Marks
Expand Down
4 changes: 2 additions & 2 deletions doc/nextgen/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@
" .facet(col=\"day\")\n",
" .add(so.Dots(color=\".75\"), col=None)\n",
" .add(so.Dots(), color=\"day\")\n",
" .configure(figsize=(7, 3))\n",
" .layout(size=(7, 3))\n",
")"
]
},
Expand Down Expand Up @@ -805,7 +805,7 @@
"(\n",
" so.Plot(tips)\n",
" .pair(x=tips.columns, wrap=3)\n",
" .configure(sharey=False)\n",
" .share(y=False)\n",
" .add(so.Bar(), so.Hist())\n",
")"
]
Expand Down
3 changes: 1 addition & 2 deletions doc/nextgen/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"outputs": [],
"source": [
"import seaborn as sns\n",
"sns.set_theme()\n",
"tips = sns.load_dataset(\"tips\")\n",
"\n",
"import seaborn.objects as so\n",
Expand All @@ -31,7 +30,7 @@
" )\n",
" .facet(\"time\")\n",
" .add(so.Dots())\n",
" .configure(figsize=(7, 4))\n",
" .layout(size=(7, 4))\n",
")"
]
},
Expand Down
72 changes: 47 additions & 25 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class Plot:
_layers: list[Layer]

_scales: dict[str, Scale]
_shares: dict[str, bool | str]
_limits: dict[str, tuple[Any, Any]]
_labels: dict[str, str | Callable[[str], str]]
_theme: dict[str, Any]
Expand All @@ -159,6 +160,7 @@ class Plot:

_figure_spec: dict[str, Any]
_subplot_spec: dict[str, Any]
_layout_spec: dict[str, Any]

def __init__(
self,
Expand All @@ -180,6 +182,7 @@ def __init__(
self._layers = []

self._scales = {}
self._shares = {}
self._limits = {}
self._labels = {}
self._theme = {}
Expand All @@ -189,6 +192,7 @@ def __init__(

self._figure_spec = {}
self._subplot_spec = {}
self._layout_spec = {}

self._target = None

Expand Down Expand Up @@ -250,6 +254,7 @@ def _clone(self) -> Plot:
new._layers.extend(self._layers)

new._scales.update(self._scales)
new._shares.update(self._shares)
new._limits.update(self._limits)
new._labels.update(self._labels)
new._theme.update(self._theme)
Expand All @@ -259,6 +264,7 @@ def _clone(self) -> Plot:

new._figure_spec.update(self._figure_spec)
new._subplot_spec.update(self._subplot_spec)
new._layout_spec.update(self._layout_spec)

new._target = self._target

Expand All @@ -272,7 +278,7 @@ def _theme_with_defaults(self) -> dict[str, Any]:
"xaxis", "xtick", "yaxis", "ytick",
]
base = {
k: v for k, v in mpl.rcParamsDefault.items()
k: mpl.rcParamsDefault[k] for k in mpl.rcParams
if any(k.startswith(p) for p in style_groups)
}
theme = {
Expand Down Expand Up @@ -597,6 +603,21 @@ def scale(self, **scales: Scale) -> Plot:
new._scales.update(scales)
return new

def share(self, **shares: bool | str) -> Plot:
"""
Control sharing of axis limits and ticks across subplots.
Keywords correspond to variables defined in the plot, and values can be
boolean (to share across all subplots), or one of "row" or "col" (to share
more selectively across one dimension of a grid).
Behavior for non-coordinate variables is currently undefined.
"""
new = self._clone()
new._shares.update(shares)
return new

def limit(self, **limits: tuple[Any, Any]) -> Plot:
"""
Control the range of visible data.
Expand Down Expand Up @@ -637,23 +658,22 @@ def label(self, *, title=None, **variables: str | Callable[[str], str]) -> Plot:
new._labels.update(variables)
return new

def configure(
def layout(
self,
figsize: tuple[float, float] | None = None,
sharex: bool | str | None = None,
sharey: bool | str | None = None,
*,
size: tuple[float, float] | None = None,
algo: str | None = "tight", # TODO document
) -> Plot:
"""
Control the figure size and layout.
Parameters
----------
figsize: (width, height)
Size of the resulting figure, in inches.
sharex, sharey : bool, "row", or "col"
Whether axis limits should be shared across subplots. Boolean values apply
across the entire grid, whereas `"row"` or `"col"` have a smaller scope.
Shared axes will have tick labels disabled.
size : (width, height)
Size of the resulting figure, in inches. Size is inclusive of legend when
using pyplot, but not otherwise.
algo : {{"tight", "constrained", None}}
Name of algorithm for automatically adjusting the layout to remove overlap.
"""
# TODO add an "auto" mode for figsize that roughly scales with the rcParams
Expand All @@ -663,12 +683,8 @@ def configure(

new = self._clone()

new._figure_spec["figsize"] = figsize

if sharex is not None:
new._subplot_spec["sharex"] = sharex
if sharey is not None:
new._subplot_spec["sharey"] = sharey
new._figure_spec["figsize"] = size
new._layout_spec["algo"] = algo

return new

Expand Down Expand Up @@ -894,6 +910,10 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
facet_spec = p._facet_spec.copy()
pair_spec = p._pair_spec.copy()

for axis in "xy":
if axis in p._shares:
subplot_spec[f"share{axis}"] = p._shares[axis]

for dim in ["col", "row"]:
if dim in common.frame and dim not in facet_spec["structure"]:
order = categorical_order(common.frame[dim])
Expand Down Expand Up @@ -928,7 +948,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:

# ~~ Decoration visibility

# TODO there should be some override (in Plot.configure?) so that
# TODO there should be some override (in Plot.layout?) so that
# tick labels can be shown on interior shared axes
axis_obj = getattr(ax, f"{axis}axis")
visible_side = {"x": "bottom", "y": "left"}.get(axis)
Expand All @@ -948,10 +968,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
for t in getattr(axis_obj, f"get_{group}ticklabels")():
t.set_visible(show_tick_labels)

# TODO title template should be configurable
# ---- Also we want right-side titles for row facets in most cases?
# ---- Or wrapped? That can get annoying too.
# TODO should configure() accept a title= kwarg (for single subplot plots)?
# TODO we want right-side titles for row facets in most cases?
# Let's have what we currently call "margin titles" but properly using the
# ax.set_title interface (see my gist)
title_parts = []
Expand Down Expand Up @@ -1521,6 +1538,9 @@ def _make_legend(self, p: Plot) -> None:
else:
merged_contents[key] = artists.copy(), labels

# TODO explain
loc = "center right" if self._pyplot else "center left"

base_legend = None
for (name, _), (handles, labels) in merged_contents.items():

Expand All @@ -1529,7 +1549,7 @@ def _make_legend(self, p: Plot) -> None:
handles,
labels,
title=name,
loc="center left",
loc=loc,
bbox_to_anchor=(.98, .55),
)

Expand Down Expand Up @@ -1563,9 +1583,11 @@ def _finalize_figure(self, p: Plot) -> None:
hi = cast(float, hi) + 0.5
ax.set(**{f"{axis}lim": (lo, hi)})

# TODO this should be configurable
if not self._figure.get_constrained_layout():
layout_algo = p._layout_spec.get("algo", "tight")
if layout_algo == "tight":
self._figure.set_tight_layout(True)
elif layout_algo == "constrained":
self._figure.set_constrained_layout(True)


@contextmanager
Expand Down
44 changes: 33 additions & 11 deletions tests/_core/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def test_facet_categories_unshared(self):
p = (
Plot(x=["a", "b", "a", "c"])
.facet(col=["x", "x", "y", "y"])
.configure(sharex=False)
.share(x=False)
.add(m)
.plot()
)
Expand All @@ -539,7 +539,7 @@ def test_facet_categories_single_dim_shared(self):
Plot(df, x="x")
.facet(row="row", col="col")
.add(m)
.configure(sharex="row")
.share(x="row")
.plot()
)

Expand Down Expand Up @@ -574,7 +574,7 @@ def test_pair_categories_shared(self):
data = [("a", "a"), ("b", "c")]
df = pd.DataFrame(data, columns=["x1", "x2"]).assign(y=1)
m = MockMark()
p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).configure(sharex=True).plot()
p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).share(x=True).plot()

for ax in p._figure.axes:
assert ax.get_xticks() == [0, 1, 2]
Expand Down Expand Up @@ -1040,6 +1040,12 @@ def test_save(self):
tag = xml.etree.ElementTree.fromstring(buf.getvalue()).tag
assert tag == "{http://www.w3.org/2000/svg}svg"

def test_layout_size(self):

size = (4, 2)
p = Plot().layout(size=size).plot()
assert tuple(p._figure.get_size_inches()) == size

def test_on_axes(self):

ax = mpl.figure.Figure().subplots()
Expand Down Expand Up @@ -1285,11 +1291,27 @@ def test_2d_with_order(self, long_df, reorder):
p = Plot(long_df).facet(**variables, order=order)
self.check_facet_results_2d(p, long_df, variables, order)

def test_figsize(self):
@pytest.mark.parametrize("algo", ["tight", "constrained"])
def test_layout_algo(self, algo):

if algo == "constrained" and Version(mpl.__version__) < Version("3.3.0"):
pytest.skip("constrained_layout requires matplotlib>=3.3")

p = Plot().facet(["a", "b"]).limit(x=(.1, .9))

p1 = p.layout(algo=algo).plot()
p2 = p.layout(algo=None).plot()

# Force a draw (we probably need a method for this)
p1.save(io.BytesIO())
p2.save(io.BytesIO())

bb11, bb12 = [ax.get_position() for ax in p1._figure.axes]
bb21, bb22 = [ax.get_position() for ax in p2._figure.axes]

figsize = (4, 2)
p = Plot().configure(figsize=figsize).plot()
assert tuple(p._figure.get_size_inches()) == figsize
sep1 = bb12.corners()[0, 0] - bb11.corners()[2, 0]
sep2 = bb22.corners()[0, 0] - bb21.corners()[2, 0]
assert sep1 < sep2

def test_axis_sharing(self, long_df):

Expand All @@ -1303,13 +1325,13 @@ def test_axis_sharing(self, long_df):
shareset = getattr(root, f"get_shared_{axis}_axes")()
assert all(shareset.joined(root, ax) for ax in other)

p2 = p.configure(sharex=False, sharey=False).plot()
p2 = p.share(x=False, y=False).plot()
root, *other = p2._figure.axes
for axis in "xy":
shareset = getattr(root, f"get_shared_{axis}_axes")()
assert not any(shareset.joined(root, ax) for ax in other)

p3 = p.configure(sharex="col", sharey="row").plot()
p3 = p.share(x="col", y="row").plot()
shape = (
len(categorical_order(long_df[variables["row"]])),
len(categorical_order(long_df[variables["col"]])),
Expand Down Expand Up @@ -1494,7 +1516,7 @@ def test_axis_sharing(self, long_df):
y_shareset = getattr(root, "get_shared_y_axes")()
assert not any(y_shareset.joined(root, ax) for ax in other)

p2 = p.configure(sharex=False, sharey=False).plot()
p2 = p.share(x=False, y=False).plot()
root, *other = p2._figure.axes
for axis in "xy":
shareset = getattr(root, f"get_shared_{axis}_axes")()
Expand Down Expand Up @@ -1758,7 +1780,7 @@ def test_2d_unshared(self):
p = (
Plot()
.facet(col=["a", "b"], row=["x", "y"])
.configure(sharex=False, sharey=False)
.share(x=False, y=False)
.plot()
)
subplots = list(p._subplots)
Expand Down

0 comments on commit b5c4c35

Please sign in to comment.