Skip to content

Commit

Permalink
Ignore finite upper limit in Nat domains.
Browse files Browse the repository at this point in the history
Move new checks to `check_logcdf`.
  • Loading branch information
ricardoV94 committed Jan 2, 2021
1 parent 7c3fb12 commit bd28eb9
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,28 +575,6 @@ def check_logcdf(
err_msg=str(pt),
)

def check_selfconsistency_discrete_logcdf(
self, distribution, domain, paramdomains, decimal=None, n_samples=100
):
"""
Check that logcdf of discrete distributions matches sum of logps up to value
"""
domains = paramdomains.copy()
domains["value"] = domain
if decimal is None:
decimal = select_by_precision(float64=6, float32=3)
for pt in product(domains, n_samples=n_samples):
params = dict(pt)
value = params.pop("value")
values = np.arange(domain.lower, value + 1)
dist = distribution.dist(**params)
assert_almost_equal(
dist.logcdf(value).tag.test_value,
logsumexp(dist.logp(values), keepdims=False).tag.test_value,
decimal=decimal,
err_msg=str(pt),
)

# Test that values below domain evaluate to -np.inf
if np.isfinite(domain.lower):
below_domain = domain.lower - 1
Expand All @@ -607,7 +585,9 @@ def check_selfconsistency_discrete_logcdf(
)

# Test that values above domain evaluate to 0
if np.isfinite(domain.upper):
# Natural domains do not have inf as the upper edge, but should also be ignored
not_nat_domain = domain not in (NatSmall, Nat, NatBig, PosNat)
if not_nat_domain and np.isfinite(domain.upper):
above_domain = domain.upper + 1
assert_equal(
dist.logcdf(above_domain).tag.test_value,
Expand All @@ -619,11 +599,31 @@ def check_selfconsistency_discrete_logcdf(
try:
dist.logcdf(np.array([value, value])).tag.test_value
except TypeError as err:
if not str(err).endswith(
".logcdf expects a scalar value but received a 1-dimensional object."
):
if not str(err).endswith(".logcdf expects a scalar value but received a 1-dimensional object."):
raise

def check_selfconsistency_discrete_logcdf(
self, distribution, domain, paramdomains, decimal=None, n_samples=100
):
"""
Check that logcdf of discrete distributions matches sum of logps up to value
"""
domains = paramdomains.copy()
domains["value"] = domain
if decimal is None:
decimal = select_by_precision(float64=6, float32=3)
for pt in product(domains, n_samples=n_samples):
params = dict(pt)
value = params.pop("value")
values = np.arange(domain.lower, value + 1)
dist = distribution.dist(**params)
assert_almost_equal(
dist.logcdf(value).tag.test_value,
logsumexp(dist.logp(values), keepdims=False).tag.test_value,
decimal=decimal,
err_msg=str(pt),
)

def check_int_to_1(self, model, value, domain, paramdomains):
pdf = model.fastfn(exp(model.logpt))
for pt in product(paramdomains, n_samples=10):
Expand Down

0 comments on commit bd28eb9

Please sign in to comment.