Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example with categorical family #457

Merged
merged 5 commits into from
Mar 4, 2022
Merged

Add example with categorical family #457

merged 5 commits into from
Mar 4, 2022

Conversation

tjburch
Copy link
Contributor

@tjburch tjburch commented Feb 21, 2022

Added a notebook with examples using the categorical family to address #436. I noticed on the original merge (#426) that @tomicapretto made several very nice examples, so I aggregated several into one notebook and added some comments to tie it together.

Note that I omitted two of the examples, the satisfaction survey and the inhaler example. I made that choice because I think those two fall more in the domain of ordinal regression, rather than categorical. I'm happy to add them if you'd prefer to include them though.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov-commenter
Copy link

codecov-commenter commented Feb 21, 2022

Codecov Report

Merging #457 (bffa24d) into main (05ece83) will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #457   +/-   ##
=======================================
  Coverage   89.24%   89.24%           
=======================================
  Files          31       31           
  Lines        2491     2491           
=======================================
  Hits         2223     2223           
  Misses        268      268           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 05ece83...bffa24d. Read the comment docs.

@review-notebook-app
Copy link

review-notebook-app bot commented Feb 22, 2022

View / edit / reply to this conversation on ReviewNB

tomicapretto commented on 2022-02-22T02:17:03Z
----------------------------------------------------------------

I think we never mention this in the documentation, thanks for adding it here!


@review-notebook-app
Copy link

review-notebook-app bot commented Feb 22, 2022

View / edit / reply to this conversation on ReviewNB

tomicapretto commented on 2022-02-22T02:17:04Z
----------------------------------------------------------------

I would mention the posterior for the probability contains information about the all $n$ categories, and not only about $n - 1$ as in the posterior for the coefficients. This is what makes it straightforward to create these type of charts without having to do 1 - (p_1 + p_2)


tjburch commented on 2022-03-02T02:19:03Z
----------------------------------------------------------------

Added to the discussion below,

Of note, the posterior means predicted by Bambi contain information about all n categories (despite having only n−1 coefficients), so we can directly construct this plot, rather than manually calculating 1−(p1+p2) for the third class. 

@tomicapretto
Copy link
Collaborator

@tjburch thanks a lot for taking the initiative to work on this example notebook. Not many comments from my side as I think it already looks quite clean. On top of that, I think the last example is a good candidate to show how to use the posterior predictive sampling that I'm incorporating on #458.

@tjburch
Copy link
Contributor Author

tjburch commented Feb 22, 2022

Sounds good. Once #458 merges, I'll add a call of the posterior predictive sampling. Thanks for the good examples in the original PR.

@tomicapretto
Copy link
Collaborator

You can start from the following code. I think it would be nice to add a legend like the default in https://arviz-devs.github.io/arviz/api/generated/arviz.plot_ppc.html.

model.predict(idata, kind="pps")

draws = np.prod(idata.posterior_predictive["choice"].shape[:2])
choices = ["Other", "Invertebrates", "Fish"]

pps = idata.posterior_predictive["choice"].values.reshape(draws, 63)

y = np.zeros((draws, len(choices)))
for i in range(draws):
    for j, value in enumerate(choices):
        y[i, j] = np.sum(pps[i] == value)

y = y / y.sum(axis=1)[:, None]

fig, ax = plt.subplots(figsize=(8, 5))
x = np.arange(len(choices))

for i in range(draws):
    ax.hlines(y[i], xmin=x - 0.5, xmax=x + 0.5, alpha=0.05)
ax.hlines(y.mean(0), xmin=x - 0.5, xmax=x + 0.5, color="C0", lw=2, ls="--")

true_counts = data["choice"].value_counts().sort_index()
true_counts = true_counts / true_counts.sum()
ax.hlines(true_counts, xmin=x - 0.5, xmax=x + 0.5, color="black", lw=2)

ax.set_xticks(x)
ax.set_xticklabels(choices)
ax.set_xlabel("Choice")
ax.set_ylabel("Probability");

image

@aloctavodia do you know why az.plot_ppc does not work with discrete values? I get an error similar to the one here arviz-devs/arviz#1882

@review-notebook-app
Copy link

review-notebook-app bot commented Feb 23, 2022

View / edit / reply to this conversation on ReviewNB

aloctavodia commented on 2022-02-23T08:56:17Z
----------------------------------------------------------------

This will fail if for a number of chains other than 2. Also we can be more explicit

x_new = np.linspace(-5, 5, num=200)

model.predict(idata, data=pd.DataFrame({"x": x_new}))

p = idata.posterior["y_mean"].sel(draw=slice(0, None, 10))

for j, g in enumerate("ABC"):

   plt.plot(x_new, p.sel({"y_mean_coord":g}).stack(samples=("chain", "draw")), color=f"C{j}", alpha=0.2)

plt.xlabel("x")

plt.ylabel("y");


tjburch commented on 2022-03-02T02:17:46Z
----------------------------------------------------------------

Applied these changes.

@review-notebook-app
Copy link

review-notebook-app bot commented Feb 23, 2022

View / edit / reply to this conversation on ReviewNB

aloctavodia commented on 2022-02-23T08:56:18Z
----------------------------------------------------------------

You can use xarray features to write code that is a little bit more explicit.

new_length = np.linspace(1, 4)

new_data = pd.DataFrame({"length": np.tile(new_length, 2), "sex": ["Male"] * 50 + ["Female"] * 50})

model.predict(idata, data=new_data)

p = idata.posterior["choice_mean"]

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

choices = ["Other", "Invertebrates", "Fish"]

for j, choice in enumerate(choices):

   males = p.sel({"choice_mean_coord":choice, "choice_obs":slice(0, 49)})

   females = p.sel({"choice_mean_coord":choice, "choice_obs":slice(50, 100)})

   axes[0].plot(new_length, males.mean(("chain", "draw")), color=f"C{j}", lw=2)

   axes[1].plot(new_length, females.mean(("chain", "draw")), color=f"C{j}", lw=2)

   az.plot_hdi(new_length, males, color=f"C{j}", ax=axes[0])

   az.plot_hdi(new_length, females, color=f"C{j}", ax=axes[1])

axes[0].set_title("Male")

axes[1].set_title("Female")

handles = [Line2D([], [], color=f"C{j}", label=choice) for j, choice in enumerate(choices)]

fig.subplots_adjust(left=0.05, right=0.975, bottom=0.075, top=0.85)

fig.legend(

   handles,

   choices,

   loc="center right",

   ncol=3,

   bbox_to_anchor=(0.99, 0.95),

   bbox_transform=fig.transFigure

);


tjburch commented on 2022-03-02T02:19:32Z
----------------------------------------------------------------

Nice - applied.

@aloctavodia
Copy link
Collaborator

You can start from the following code. I think it would be nice to add a legend like the default in https://arviz-devs.github.io/arviz/api/generated/arviz.plot_ppc.html.

model.predict(idata, kind="pps")

draws = np.prod(idata.posterior_predictive["choice"].shape[:2])
choices = ["Other", "Invertebrates", "Fish"]

pps = idata.posterior_predictive["choice"].values.reshape(draws, 63)

y = np.zeros((draws, len(choices)))
for i in range(draws):
    for j, value in enumerate(choices):
        y[i, j] = np.sum(pps[i] == value)

y = y / y.sum(axis=1)[:, None]

fig, ax = plt.subplots(figsize=(8, 5))
x = np.arange(len(choices))

for i in range(draws):
    ax.hlines(y[i], xmin=x - 0.5, xmax=x + 0.5, alpha=0.05)
ax.hlines(y.mean(0), xmin=x - 0.5, xmax=x + 0.5, color="C0", lw=2, ls="--")

true_counts = data["choice"].value_counts().sort_index()
true_counts = true_counts / true_counts.sum()
ax.hlines(true_counts, xmin=x - 0.5, xmax=x + 0.5, color="black", lw=2)

ax.set_xticks(x)
ax.set_xticklabels(choices)
ax.set_xlabel("Choice")
ax.set_ylabel("Probability");

image

@aloctavodia do you know why az.plot_ppc does not work with discrete values? I get an error similar to the one here arviz-devs/arviz#1882

Not sure. I will check. I recently used az.plot_ppc with discrete data and it worked

Copy link
Contributor Author

tjburch commented Mar 2, 2022

Applied these changes.


View entire conversation on ReviewNB

Copy link
Contributor Author

tjburch commented Mar 2, 2022

Added to the discussion below,

Of note, the posterior means predicted by Bambi contain information about all n categories (despite having only n−1 coefficients), so we can directly construct this plot, rather than manually calculating 1−(p1+p2) for the third class. 


View entire conversation on ReviewNB

Copy link
Contributor Author

tjburch commented Mar 2, 2022

Nice - applied.


View entire conversation on ReviewNB

@tjburch
Copy link
Contributor Author

tjburch commented Mar 2, 2022

Sorry about the delay on getting these changes in, ran into some issues and been a busy week. I think it should be up-to-date with everything mentioned. If there's anything else you'd like me to add/edit, let me know!

@review-notebook-app
Copy link

review-notebook-app bot commented Mar 2, 2022

View / edit / reply to this conversation on ReviewNB

aloctavodia commented on 2022-03-02T06:40:46Z
----------------------------------------------------------------

The examples in this notebook were...


@review-notebook-app
Copy link

View / edit / reply to this conversation on ReviewNB

aloctavodia commented on 2022-03-02T06:40:47Z
----------------------------------------------------------------

If you use Bambi from main you can then get a similar figure with this code:

ax = az.plot_ppc(idata)

ax.set_xticks([0.5, 1.5, 2.5])

ax.set_xticklabels(model.response.levels)

ax.set_xlabel("Choice");


@aloctavodia
Copy link
Collaborator

Thanks, I just added two small comments. After addressing those I think this is ready to be merged!

@tjburch
Copy link
Contributor Author

tjburch commented Mar 4, 2022

Last couple of comments should be taken care of in the latest commit!

@tomicapretto
Copy link
Collaborator

Thanks @tjburch for working on another example!

@tomicapretto tomicapretto merged commit 59ac71c into bambinos:main Mar 4, 2022
@tjburch
Copy link
Contributor Author

tjburch commented Mar 4, 2022

@tomicapretto you did all the hard work, I just put it together and added some fluff :)

Happy to do it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants