Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Multinomial robust against batches #4169

Merged
merged 4 commits into from
Oct 14, 2020

Conversation

lucianopaz
Copy link
Contributor

At the moment, our Multinomial distribution mangles the n parameter's shape. This makes it very difficult to work with batches that have more than 2 dimensions. This PR addresses the problem and adds a test for it.

Thank your for opening a PR!

Depending on what your PR does, here are a few things you might want to address in the description:

  • what are the (breaking) changes that this PR makes?
  • important background, or details about the implementation
  • are the changes—especially new features—covered by tests and docstrings?
  • consider adding/updating relevant example notebooks
  • right before it's ready to merge, mention the PR in the RELEASE-NOTES.md

@twiecki
Copy link
Member

twiecki commented Oct 13, 2020

Seems like there are legit test errors.

@lucianopaz
Copy link
Contributor Author

Yeah, I'll go through those tomorrow.

@@ -597,14 +597,10 @@ def __init__(self, n, p, *args, **kwargs):
super().__init__(*args, **kwargs)

p = p / tt.sum(p, axis=-1, keepdims=True)
n = np.squeeze(n) # works also if n is a tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty simple fix, just remove some code 👍

@twiecki
Copy link
Member

twiecki commented Oct 14, 2020

Also needs a line in the release-notes.

@codecov
Copy link

codecov bot commented Oct 14, 2020

Codecov Report

Merging #4169 into master will increase coverage by 0.01%.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4169      +/-   ##
==========================================
+ Coverage   88.76%   88.77%   +0.01%     
==========================================
  Files          89       89              
  Lines       14083    14079       -4     
==========================================
- Hits        12501    12499       -2     
+ Misses       1582     1580       -2     
Impacted Files Coverage Δ
pymc3/distributions/multivariate.py 81.10% <ø> (+0.17%) ⬆️

@twiecki twiecki merged commit d8bfe93 into pymc-devs:master Oct 14, 2020
@twiecki
Copy link
Member

twiecki commented Oct 14, 2020

Thanks @lucianopaz!

@lucianopaz lucianopaz deleted the batch_multivariate branch October 14, 2020 08:57
Copy link
Contributor

@AlexAndorra AlexAndorra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lucianopaz ! Funny to see that removing code fixed the issue 😅

@bsmith89 bsmith89 mentioned this pull request Jan 3, 2021
15 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants