-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement default_transform
and transform
argument for distributions
#7207
Implement default_transform
and transform
argument for distributions
#7207
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, left some small user-quality of life suggestions
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7207 +/- ##
==========================================
- Coverage 92.29% 92.28% -0.01%
==========================================
Files 101 100 -1
Lines 16892 16906 +14
==========================================
+ Hits 15590 15602 +12
- Misses 1302 1304 +2
|
Looks like I fixed all failed test cases. Going add some new tests and changes to documentation as a next steps. |
@mkusnetsov took the documentation initiative in #7232 so we should be good to go just with docstrings and tests |
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
1225764
to
7d6ecf9
Compare
@ricardoV94 I added some test cases to check warning it |
pymc/distributions/distribution.py
Outdated
@@ -397,6 +398,15 @@ def __new__( | |||
if not isinstance(name, string_types): | |||
raise TypeError(f"Name needs to be a string but got: {name}") | |||
|
|||
if transform is None and default_transform is UNSET: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This warning should be in the relevant section in pm.Model
instead of Distribution
tests/distributions/test_mixture.py
Outdated
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, transform=None) | ||
Mixture("mix4", w=[0.5, 0.5], comp_dists=comp_dists, default_transform=None) | ||
|
||
with pytest.warns( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This functionality shouldn't be tested here, since it's not specific to Mixture. Probably in model/test_core.py there should be related stuff already? I see you already have it there, so is this test needed?
tests/logprob/test_utils.py
Outdated
x = pm.Uniform("x", lower=0, upper=1, transform=transform, default_transform=None) | ||
# Operation between the variables provides a regression test for #7054 | ||
y = pm.Uniform("y", lower=0, upper=pt.exp(x), transform=transform) | ||
z = pm.Uniform("z", lower=0, upper=y, transform=transform) | ||
w = pm.Uniform("w", lower=0, upper=pt.square(z), transform=transform) | ||
y = pm.Uniform( | ||
"y", lower=0, upper=pt.exp(x), transform=transform, default_transform=None | ||
) | ||
z = pm.Uniform("z", lower=0, upper=y, transform=transform, default_transform=None) | ||
w = pm.Uniform( | ||
"w", lower=0, upper=pt.square(z), transform=transform, default_transform=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass transform to default_transform
tests/model/test_core.py
Outdated
|
||
with pm.Model() as model: | ||
x = pm.Normal("x", transform=DummyTransform(2), default_transform=DummyTransform(1)) | ||
assert transform_order == [1, 2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use regular transforms, and simply assert the obtained transform is a Chain, which as a property like transform_list
that includes the transforms, and you can assert those are also the expected ones. The transform is available in models.rvs_to_transforms[x]
Also, would be nice to include a numerical example that would have led to nan
or -inf
probability before the change, like an ordered mixture of LogNormals
evaluated at -1
default_transform
and transform
argument for distributions
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides making the test_transform_order
simpler, just need some tweaks to the docstrings. They don't describe the parameter. We can also give better type hint than Any
. If mypy complaints just revert to Any
total_size=None, | ||
dims=None, | ||
transform=UNSET, | ||
default_transform=UNSET, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing description in the docstrings
pymc/model/core.py
Outdated
@@ -1288,6 +1299,7 @@ def make_obs_var( | |||
data: np.ndarray, | |||
dims, | |||
transform: Any | None, | |||
default_transform: Any | None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing description in the docstrings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also type hint isn't great, should be RVTransform | None
(or something like that, don't remember the exact class now)
pymc/model/core.py
Outdated
rv_var: TensorVariable, | ||
*, | ||
transform: Any, | ||
default_transform: Any, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update type hints and docstring
@ricardoV94 Is there any way to get with pm.Model() as model:
x1 = pm.LogNormal("x1", 0.0, 1.0)
x2 = pm.LogNormal("x2", 0.0, 1.0, default_transform=LogTransform())
assert pm.logp(x1, 1.0).eval() != pm.logp(x2, 1.0).eval() and got assertion error |
You want to use |
@ricardoV94 I added new test case with numerical example for transform args. Please let me know what do you thunk about it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @aerubanov
I left some minor comments, I think this is the last round!
tests/model/test_core.py
Outdated
with pm.Model() as model1: | ||
x1 = pm.LogNormal("x1", 0, 1, transform=Interval(-2, 2), default_transform=None) | ||
with pm.Model() as model3: | ||
x2 = pm.LogNormal("x2", 0, 1, transform=Interval(-2, 2), default_transform=log) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine but is not very realistic. What about ordered transform + log which is the example that motivated this PR?
x2 = pm.LogNormal("x2", 0, 1, transform=Interval(-2, 2), default_transform=log) | |
x2 = pm.LogNormal("x2", 0, 1, transform=Interval(-2, 2), default_transform=log) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rest of the logic of the test is spot-on!
@@ -1230,7 +1239,9 @@ def register_rv( | |||
dims : tuple | |||
Dimension names for the variable. | |||
transform | |||
A transform for the random variable in log-likelihood space. | |||
Additianal transform which may be applied after default transform. | |||
default_transform |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: show default_transform before transform (also in the signature)?
pymc/model/core.py
Outdated
transform | ||
Additianal transform which may be applied after default transform. | ||
default_transform |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
pymc/model/core.py
Outdated
transform: Transform | ||
Additianal transform which may be applied after default transform. | ||
|
||
default_transform: Transform | ||
A transform for the random variable in log-likelihood space. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
@ricardoV94 please review again |
tests/model/test_core.py
Outdated
def test_transform_order(self): | ||
with pm.Model() as model: | ||
x = pm.Normal("x", transform=Interval(0, 1), default_transform=log) | ||
assert isinstance(model.rvs_to_transforms[x], ChainedTransform) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick (feel free to ignore): Save the transform in a separate variable so you don't need to write 3 times model.rvs_to_transforms[x]
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks great!
We should follow up with an issue to provide informative warnings when doing prior/posterior predictive sampling of variables with custom non-default transforms, like we do with Potentials |
Description
Related Issue
default_transform
andtransform
argument for distributions #5674Type of change
📚 Documentation preview 📚: https://pymc--7207.org.readthedocs.build/en/7207/