Skip to content

Commit

Permalink
typing and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Jhsmit committed Jan 11, 2024
1 parent e86bf4d commit 3dafc2c
Showing 1 changed file with 91 additions and 28 deletions.
119 changes: 91 additions & 28 deletions pyhdx/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@

def peptide_coverage_figure(
data: pd.DataFrame,
wrap: int = None,
wrap: Optional[int] = None,
cmap: Union[pplt.Colormap, mpl.colors.Colormap, str, tuple, dict] = "turbo",
norm: Type[mpl.colors.Normalize] = None,
color_field: str = "rfu",
Expand All @@ -70,7 +70,9 @@ def peptide_coverage_figure(
**figure_kwargs,
) -> tuple:
subplot_values = data[subplot_field].unique()
sub_dfs = {value: data.query(f"`{subplot_field}` == {value}") for value in subplot_values}
sub_dfs = {
value: data.query(f"`{subplot_field}` == {value}") for value in subplot_values
}

n_subplots = len(subplot_values)

Expand All @@ -86,11 +88,18 @@ def peptide_coverage_figure(
start_field, end_field = rect_fields
if wrap is None:
wrap = max(
[autowrap(sub_df[start_field], sub_df[end_field]) for sub_df in sub_dfs.values()]
[
autowrap(sub_df[start_field], sub_df[end_field])
for sub_df in sub_dfs.values()
]
)

fig, axes = pplt.subplots(
ncols=ncols, nrows=nrows, width=figure_width, refaspect=refaspect, **figure_kwargs
ncols=ncols,
nrows=nrows,
width=figure_width,
refaspect=refaspect,
**figure_kwargs,
)
rect_kwargs = rect_kwargs or {}
axes_iter = iter(axes)
Expand Down Expand Up @@ -159,13 +168,17 @@ def peptide_coverage(
color = cmap(norm(elem[color_field]))

width = elem[end_field] - elem[start_field]
rect = Rectangle((elem[start_field] - 0.5, i), width, 1, facecolor=color, **rect_kwargs)
rect = Rectangle(
(elem[start_field] - 0.5, i), width, 1, facecolor=color, **rect_kwargs
)
ax.add_patch(rect)
if labels:
rx, ry = rect.get_xy()
cy = ry
cx = rx
ax.annotate(str(p_num), (cx, cy), color="k", fontsize=6, va="bottom", ha="right")
ax.annotate(
str(p_num), (cx, cy), color="k", fontsize=6, va="bottom", ha="right"
)
i -= 1

ax.set_ylim(-wrap, 0)
Expand Down Expand Up @@ -244,7 +257,9 @@ def residue_time_scatter_figure(
return fig, axes, cbars


def residue_time_scatter(ax, hdx_tp, field="rfu", cmap="turbo", norm=None, cbar=True, **kwargs):
def residue_time_scatter(
ax, hdx_tp, field="rfu", cmap="turbo", norm=None, cbar=True, **kwargs
):
# update cmap, norm defaults
cmap = pplt.Colormap(cmap) # todo allow None as cmap
norm = norm or pplt.Norm("linear", vmin=0, vmax=1)
Expand Down Expand Up @@ -285,13 +300,19 @@ def residue_scatter_figure(
tps = np.unique(np.concatenate([hdxm.timepoints for hdxm in hdxm_set]))

fig, axes = pplt.subplots(
ncols=ncols, nrows=nrows, width=figure_width, refaspect=refaspect, **figure_kwargs
ncols=ncols,
nrows=nrows,
width=figure_width,
refaspect=refaspect,
**figure_kwargs,
)
axes_iter = iter(axes)
scatter_kwargs = scatter_kwargs or {}
for hdxm in hdxm_set:
ax = next(axes_iter)
residue_scatter(ax, hdxm, cmap=cmap, norm=norm, field=field, cbar=False, **scatter_kwargs)
residue_scatter(
ax, hdxm, cmap=cmap, norm=norm, field=field, cbar=False, **scatter_kwargs
)
ax.format(title=f"{hdxm.name}")

for ax in axes_iter:
Expand All @@ -310,7 +331,9 @@ def residue_scatter_figure(


# todo allow colorbar_scatter to take rfus
def residue_scatter(ax, hdxm, field="rfu", cmap="viridis", norm=None, cbar=True, **kwargs):
def residue_scatter(
ax, hdxm, field="rfu", cmap="viridis", norm=None, cbar=True, **kwargs
):
cmap = pplt.Colormap(cmap)
tps = hdxm.timepoints[np.nonzero(hdxm.timepoints)]
norm = norm or pplt.Norm("log", tps.min(), tps.max())
Expand Down Expand Up @@ -380,7 +403,9 @@ def dG_scatter_figure(

# Set global ylims
ylims = [lim for ax in axes if ax.axison for lim in ax.get_ylim()]
axes.format(ylim=(np.max(ylims), np.min(ylims)), yticklabelloc="none", ytickloc="none")
axes.format(
ylim=(np.max(ylims), np.min(ylims)), yticklabelloc="none", ytickloc="none"
)

cbar_kwargs = cbar_kwargs or {}
cbars = []
Expand Down Expand Up @@ -417,7 +442,9 @@ def ddG_scatter_figure(
dG_test = data.xs("dG", axis=1, level=1).drop(reference_state, axis=1)
dG_ref = data[reference_state, "dG"]
ddG = dG_test.subtract(dG_ref, axis=0)
ddG.columns = pd.MultiIndex.from_product([ddG.columns, ["ddG"]], names=["State", "quantity"])
ddG.columns = pd.MultiIndex.from_product(
[ddG.columns, ["ddG"]], names=["State", "quantity"]
)

cov_test = data.xs("covariance", axis=1, level=1).drop(reference_state, axis=1) ** 2
cov_ref = data[reference_state, "covariance"] ** 2
Expand Down Expand Up @@ -490,7 +517,9 @@ def ddG_scatter_figure(
return fig, axes, cbars


def peptide_mse_figure(peptide_mse, cmap=None, norm=None, rect_kwargs=None, **figure_kwargs):
def peptide_mse_figure(
peptide_mse, cmap=None, norm=None, rect_kwargs=None, **figure_kwargs
):
n_subplots = len(peptide_mse.columns.unique(level=0))
ncols = figure_kwargs.pop("ncols", min(cfg.plotting.ncols, n_subplots))
nrows = figure_kwargs.pop("nrows", int(np.ceil(n_subplots / ncols)))
Expand All @@ -500,7 +529,11 @@ def peptide_mse_figure(peptide_mse, cmap=None, norm=None, rect_kwargs=None, **fi
cmap = cmap or CMAP_NORM_DEFAULTS["mse"][0]

fig, axes = pplt.subplots(
ncols=ncols, nrows=nrows, width=figure_width, refaspect=refaspect, **figure_kwargs
ncols=ncols,
nrows=nrows,
width=figure_width,
refaspect=refaspect,
**figure_kwargs,
)
axes_iter = iter(axes)
cbars = []
Expand Down Expand Up @@ -535,7 +568,11 @@ def loss_figure(fit_result, **figure_kwargs):
) # todo loss aspect also in config?

fig, ax = pplt.subplots(
ncols=ncols, nrows=nrows, width=figure_width, refaspect=refaspect, **figure_kwargs
ncols=ncols,
nrows=nrows,
width=figure_width,
refaspect=refaspect,
**figure_kwargs,
)
fit_result.losses.plot(ax=ax)
# ax.plot(fit_result.losses, legend='t') # altnernative proplot plotting
Expand Down Expand Up @@ -839,7 +876,9 @@ def rainbowclouds(

strip_kwargs = _strip_kwargs.update(strip_kwargs) if strip_kwargs else _strip_kwargs
kde_kwargs = _kde_kwargs.update(strip_kwargs) if kde_kwargs else _kde_kwargs
boxplot_kwargs = _boxplot_kwargs.update(strip_kwargs) if boxplot_kwargs else _boxplot_kwargs
boxplot_kwargs = (
_boxplot_kwargs.update(strip_kwargs) if boxplot_kwargs else _boxplot_kwargs
)

stripplot(f_data, ax=ax, **strip_kwargs)
kdeplot(f_data, ax=ax, **kde_kwargs)
Expand All @@ -853,7 +892,9 @@ def rainbowclouds(
ytickloc="left",
ylim=ylim,
)
format_kwargs = _format_kwargs.update(format_kwargs) if format_kwargs else _format_kwargs
format_kwargs = (
_format_kwargs.update(format_kwargs) if format_kwargs else _format_kwargs
)

ax.format(**format_kwargs)

Expand Down Expand Up @@ -1055,7 +1096,9 @@ def add_mse_panels(

if cbar:
if fig is None:
raise ValueError("Must pass 'fig' keyword argument to add a global colorbar")
raise ValueError(
"Must pass 'fig' keyword argument to add a global colorbar"
)
cbar_kwargs = cbar_kwargs or {}
cbar_kwargs = {
"width": CBAR_KWARGS["width"],
Expand Down Expand Up @@ -1135,13 +1178,17 @@ def __init__(self):
}

colors = ["#6EA72A", "#DAD853", "#FFA842", "#A22D46", "#5D0496"][::-1]
cmap_redundancy = pplt.Colormap(colors, discrete=True, N=len(colors), listmode="discrete")
cmap_redundancy = pplt.Colormap(
colors, discrete=True, N=len(colors), listmode="discrete"
)
cmap_redundancy.set_over("#0E4A21")
cmap_redundancy.set_bad(NO_COVERAGE)
self.cmaps["redundancy"] = cmap_redundancy

colors = ["#008832", "#72D100", "#FFFF04", "#FFB917", "#FF8923"]
cmap_redundancy = pplt.Colormap(colors, discrete=True, N=len(colors), listmode="discrete")
cmap_redundancy = pplt.Colormap(
colors, discrete=True, N=len(colors), listmode="discrete"
)
cmap_redundancy.set_over("#FE2B2E")
cmap_redundancy.set_bad(NO_COVERAGE)
self.cmaps["resolution"] = cmap_redundancy
Expand Down Expand Up @@ -1216,7 +1263,9 @@ def pymol_figures(
values = values.reindex(pd.RangeIndex(rmin, rmax + 1, name="r_number"))
colors = apply_cmap(values, cmap, norm)
name = (
f"pymol_ddG_{state}_ref_{reference_state}" if reference_state else f"pymol_dG_{state}"
f"pymol_ddG_{state}_ref_{reference_state}"
if reference_state
else f"pymol_dG_{state}"
)
name += name_suffix
pymol_render(
Expand Down Expand Up @@ -1341,7 +1390,9 @@ def stripplot(

for i, (d, color) in enumerate(zip(data, color_list)):
jitter_offsets = (np.random.rand(d.size) - 0.5) * jitter
cat_var = i * np.ones_like(d) + jitter_offsets + offset # categorical axis variable
cat_var = (
i * np.ones_like(d) + jitter_offsets + offset
) # categorical axis variable
if orientation == "vertical":
ax.scatter(cat_var, d, color=color, **scatter_kwargs)
elif orientation == "horizontal":
Expand Down Expand Up @@ -1443,7 +1494,9 @@ def kdeplot(
color=color,
)
elif orientation == "vertical":
ax.fill_betweenx(kde_x, len(data) - cat_var, len(data) - cat_var_zero, color=color)
ax.fill_betweenx(
kde_x, len(data) - cat_var, len(data) - cat_var_zero, color=color
)

if fill_cmap:
fill_norm = fill_norm or pplt.Norm("linear")
Expand Down Expand Up @@ -1539,7 +1592,9 @@ def _make_figure(self, figure_name, **kwargs):
# return dictionary
# keys: either protein state name (hdxm.name) or 'All states'

figures_dict = {name: function(arg, **kwargs) for name, arg in args_dict.items()}
figures_dict = {
name: function(arg, **kwargs) for name, arg in args_dict.items()
}
return figures_dict

def make_figure(self, figure_name, **kwargs):
Expand All @@ -1550,7 +1605,9 @@ def make_figure(self, figure_name, **kwargs):
return figures_dict

def get_fit_timepoints(self):
all_timepoints = np.concatenate([hdxm.timepoints for hdxm in self.fit_result.hdxm_set])
all_timepoints = np.concatenate(
[hdxm.timepoints for hdxm in self.fit_result.hdxm_set]
)

# x_axis_type = self.settings.get('fit_time_axis', 'Log')
x_axis_type = "Log" # todo configureable
Expand Down Expand Up @@ -1622,7 +1679,9 @@ def save_figure(self, fig_name, ext=".png", **kwargs):
figures_dict = self._make_figure(fig_name, **kwargs)

if self.output_path is None:
raise ValueError(f"No output path given when `FitResultPlot` object as initialized")
raise ValueError(
f"No output path given when `FitResultPlot` object as initialized"
)
for name, fig_tup in figures_dict.items():
fig = fig_tup if isinstance(fig_tup, plt.Figure) else fig_tup[0]

Expand Down Expand Up @@ -1670,7 +1729,9 @@ def plot_fitresults(
"""

raise DeprecationWarning("This function is deprecated, use FitResultPlot.plot_all instead")
raise DeprecationWarning(
"This function is deprecated, use FitResultPlot.plot_all instead"
)
# batch results only
history_path = fitresult_path / "model_history.csv"
output_path = output_path or fitresult_path
Expand Down Expand Up @@ -1759,7 +1820,9 @@ def plot_fitresults(
plt.close(fig)

if "dG_scatter" in plots:
fig, axes, cbars = dG_scatter_figure(fitresult.output.df, cmap=dG_cmap, norm=dG_norm)
fig, axes, cbars = dG_scatter_figure(
fitresult.output.df, cmap=dG_cmap, norm=dG_norm
)
for ext in output_type:
f_out = output_path / (f"dG_scatter" + ext)
plt.savefig(f_out)
Expand Down

0 comments on commit 3dafc2c

Please sign in to comment.