From 5e823945b508c36fde8a4ee28da195726524a12d Mon Sep 17 00:00:00 2001 From: bomtall Date: Sat, 15 Jun 2024 15:28:59 +0100 Subject: [PATCH] move pytest imports to be local --- pymc/testing.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pymc/testing.py b/pymc/testing.py index ddc9683db0..b359e0ea3b 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -21,7 +21,6 @@ import numpy as np import pytensor import pytensor.tensor as pt -import pytest from numpy import random as nr from numpy import testing as npt @@ -342,6 +341,8 @@ def check_logp( scipy_args : Dictionary with extra arguments needed to call scipy logp method Usually the same as extra_args """ + import pytest + if decimal is None: decimal = select_by_precision(float64=6, float32=3) @@ -388,6 +389,7 @@ def scipy_logp_with_scipy_args(**args): point[invalid_param] = np.asarray( invalid_edge, dtype=paramdomains[invalid_param].dtype ) + with pytest.raises(ParameterValueError): pymc_logp(**point) pytest.fail(f"test_params={point}") @@ -459,6 +461,8 @@ def check_logcdf( returns -inf for invalid parameter values outside the supported domain edge """ + import pytest + if decimal is None: decimal = select_by_precision(float64=6, float32=3) @@ -498,6 +502,7 @@ def check_logcdf( point = valid_params.copy() point[invalid_param] = invalid_edge + with pytest.raises(ParameterValueError): pymc_logcdf(**point) pytest.fail(f"test_params={point}") @@ -563,6 +568,8 @@ def check_icdf( returns nan for invalid parameter values outside the supported domain edge """ + import pytest + if decimal is None: decimal = select_by_precision(float64=6, float32=3) @@ -601,6 +608,7 @@ def check_icdf( point = valid_params.copy() point[invalid_param] = invalid_edge + with pytest.raises(ParameterValueError): pymc_icdf(**point) pytest.fail(f"test_params={point}") @@ -860,6 +868,8 @@ class BaseTestDistributionRandom: random_state = None def test_distribution(self): + import pytest + self.validate_tests_list() if self.pymc_dist == pm.Wishart: with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):