-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add example with categorical family #457
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
@@ Coverage Diff @@
## main #457 +/- ##
=======================================
Coverage 89.24% 89.24%
=======================================
Files 31 31
Lines 2491 2491
=======================================
Hits 2223 2223
Misses 268 268 Continue to review full report at Codecov.
|
View / edit / reply to this conversation on ReviewNB tomicapretto commented on 2022-02-22T02:17:03Z I think we never mention this in the documentation, thanks for adding it here! |
View / edit / reply to this conversation on ReviewNB tomicapretto commented on 2022-02-22T02:17:04Z I would mention the posterior for the probability contains information about the all tjburch commented on 2022-03-02T02:19:03Z Added to the discussion below,
Of note, the posterior means predicted by Bambi contain information about all n categories (despite having only n−1 coefficients), so we can directly construct this plot, rather than manually calculating 1−(p1+p2) for the third class.
|
Sounds good. Once #458 merges, I'll add a call of the posterior predictive sampling. Thanks for the good examples in the original PR. |
You can start from the following code. I think it would be nice to add a legend like the default in https://arviz-devs.github.io/arviz/api/generated/arviz.plot_ppc.html. model.predict(idata, kind="pps")
draws = np.prod(idata.posterior_predictive["choice"].shape[:2])
choices = ["Other", "Invertebrates", "Fish"]
pps = idata.posterior_predictive["choice"].values.reshape(draws, 63)
y = np.zeros((draws, len(choices)))
for i in range(draws):
for j, value in enumerate(choices):
y[i, j] = np.sum(pps[i] == value)
y = y / y.sum(axis=1)[:, None]
fig, ax = plt.subplots(figsize=(8, 5))
x = np.arange(len(choices))
for i in range(draws):
ax.hlines(y[i], xmin=x - 0.5, xmax=x + 0.5, alpha=0.05)
ax.hlines(y.mean(0), xmin=x - 0.5, xmax=x + 0.5, color="C0", lw=2, ls="--")
true_counts = data["choice"].value_counts().sort_index()
true_counts = true_counts / true_counts.sum()
ax.hlines(true_counts, xmin=x - 0.5, xmax=x + 0.5, color="black", lw=2)
ax.set_xticks(x)
ax.set_xticklabels(choices)
ax.set_xlabel("Choice")
ax.set_ylabel("Probability"); @aloctavodia do you know why |
View / edit / reply to this conversation on ReviewNB aloctavodia commented on 2022-02-23T08:56:17Z This will fail if for a number of chains other than 2. Also we can be more explicit
x_new = np.linspace(-5, 5, num=200) model.predict(idata, data=pd.DataFrame({"x": x_new})) p = idata.posterior["y_mean"].sel(draw=slice(0, None, 10))
for j, g in enumerate("ABC"): plt.plot(x_new, p.sel({"y_mean_coord":g}).stack(samples=("chain", "draw")), color=f"C{j}", alpha=0.2) plt.xlabel("x") plt.ylabel("y"); tjburch commented on 2022-03-02T02:17:46Z Applied these changes.
|
View / edit / reply to this conversation on ReviewNB aloctavodia commented on 2022-02-23T08:56:18Z You can use xarray features to write code that is a little bit more explicit.
new_length = np.linspace(1, 4) new_data = pd.DataFrame({"length": np.tile(new_length, 2), "sex": ["Male"] * 50 + ["Female"] * 50}) model.predict(idata, data=new_data) p = idata.posterior["choice_mean"]
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
choices = ["Other", "Invertebrates", "Fish"] for j, choice in enumerate(choices): males = p.sel({"choice_mean_coord":choice, "choice_obs":slice(0, 49)}) females = p.sel({"choice_mean_coord":choice, "choice_obs":slice(50, 100)})
axes[0].plot(new_length, males.mean(("chain", "draw")), color=f"C{j}", lw=2) axes[1].plot(new_length, females.mean(("chain", "draw")), color=f"C{j}", lw=2)
az.plot_hdi(new_length, males, color=f"C{j}", ax=axes[0]) az.plot_hdi(new_length, females, color=f"C{j}", ax=axes[1])
axes[0].set_title("Male") axes[1].set_title("Female")
handles = [Line2D([], [], color=f"C{j}", label=choice) for j, choice in enumerate(choices)]
fig.subplots_adjust(left=0.05, right=0.975, bottom=0.075, top=0.85)
fig.legend( handles, choices, loc="center right", ncol=3, bbox_to_anchor=(0.99, 0.95), bbox_transform=fig.transFigure ); tjburch commented on 2022-03-02T02:19:32Z Nice - applied. |
Not sure. I will check. I recently used |
Applied these changes.
View entire conversation on ReviewNB |
Added to the discussion below,
Of note, the posterior means predicted by Bambi contain information about all n categories (despite having only n−1 coefficients), so we can directly construct this plot, rather than manually calculating 1−(p1+p2) for the third class.
View entire conversation on ReviewNB |
Nice - applied. View entire conversation on ReviewNB |
Sorry about the delay on getting these changes in, ran into some issues and been a busy week. I think it should be up-to-date with everything mentioned. If there's anything else you'd like me to add/edit, let me know! |
View / edit / reply to this conversation on ReviewNB aloctavodia commented on 2022-03-02T06:40:46Z The examples in this notebook were... |
View / edit / reply to this conversation on ReviewNB aloctavodia commented on 2022-03-02T06:40:47Z If you use Bambi from main you can then get a similar figure with this code:
ax = az.plot_ppc(idata) ax.set_xticks([0.5, 1.5, 2.5]) ax.set_xticklabels(model.response.levels) ax.set_xlabel("Choice");
|
Thanks, I just added two small comments. After addressing those I think this is ready to be merged! |
Last couple of comments should be taken care of in the latest commit! |
Thanks @tjburch for working on another example! |
@tomicapretto you did all the hard work, I just put it together and added some fluff :) Happy to do it. |
Added a notebook with examples using the categorical family to address #436. I noticed on the original merge (#426) that @tomicapretto made several very nice examples, so I aggregated several into one notebook and added some comments to tie it together.
Note that I omitted two of the examples, the satisfaction survey and the inhaler example. I made that choice because I think those two fall more in the domain of ordinal regression, rather than categorical. I'm happy to add them if you'd prefer to include them though.