Skip to content

Commit

Permalink
move pytest imports to be local
Browse files Browse the repository at this point in the history
  • Loading branch information
bomtall authored and twiecki committed Jun 15, 2024
1 parent 33b24cc commit 5e82394
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest

from numpy import random as nr
from numpy import testing as npt
Expand Down Expand Up @@ -342,6 +341,8 @@ def check_logp(
scipy_args : Dictionary with extra arguments needed to call scipy logp method
Usually the same as extra_args
"""
import pytest

if decimal is None:
decimal = select_by_precision(float64=6, float32=3)

Expand Down Expand Up @@ -388,6 +389,7 @@ def scipy_logp_with_scipy_args(**args):
point[invalid_param] = np.asarray(
invalid_edge, dtype=paramdomains[invalid_param].dtype
)

with pytest.raises(ParameterValueError):
pymc_logp(**point)
pytest.fail(f"test_params={point}")
Expand Down Expand Up @@ -459,6 +461,8 @@ def check_logcdf(
returns -inf for invalid parameter values outside the supported domain edge
"""
import pytest

if decimal is None:
decimal = select_by_precision(float64=6, float32=3)

Expand Down Expand Up @@ -498,6 +502,7 @@ def check_logcdf(

point = valid_params.copy()
point[invalid_param] = invalid_edge

with pytest.raises(ParameterValueError):
pymc_logcdf(**point)
pytest.fail(f"test_params={point}")
Expand Down Expand Up @@ -563,6 +568,8 @@ def check_icdf(
returns nan for invalid parameter values outside the supported domain edge
"""
import pytest

if decimal is None:
decimal = select_by_precision(float64=6, float32=3)

Expand Down Expand Up @@ -601,6 +608,7 @@ def check_icdf(

point = valid_params.copy()
point[invalid_param] = invalid_edge

with pytest.raises(ParameterValueError):
pymc_icdf(**point)
pytest.fail(f"test_params={point}")
Expand Down Expand Up @@ -860,6 +868,8 @@ class BaseTestDistributionRandom:
random_state = None

def test_distribution(self):
import pytest

self.validate_tests_list()
if self.pymc_dist == pm.Wishart:
with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):
Expand Down

0 comments on commit 5e82394

Please sign in to comment.