diff --git a/gpjax/kernels/stationary/powered_exponential.py b/gpjax/kernels/stationary/powered_exponential.py index 71cfa926..e0963195 100644 --- a/gpjax/kernels/stationary/powered_exponential.py +++ b/gpjax/kernels/stationary/powered_exponential.py @@ -31,9 +31,11 @@ @dataclass class PoweredExponential(AbstractKernel): - r"""The powered exponential family of kernels. + r"""The powered exponential family of kernels. This also equivalent to the symmetric generalized normal distribution. - Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics". + See Diggle and Ribeiro (2007) - "Model-based Geostatistics". + and + https://en.wikipedia.org/wiki/Generalized_normal_distribution#Symmetric_version """ @@ -41,7 +43,7 @@ class PoweredExponential(AbstractKernel): jnp.array(1.0), bijector=tfb.Softplus() ) variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) - power: ScalarFloat = param_field(jnp.array(1.0)) + power: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Sigmoid()) name: str = "Powered Exponential" def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat: diff --git a/mkdocs.yml b/mkdocs.yml index 2c5fa3f8..0f0c4eaa 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -89,7 +89,7 @@ plugins: execute: true allow_errors: false include: ["examples/*.py"] - ignore: ["examples/utils.py"] + ignore: ["examples/utils.py", "_statch/*.py", "scripts/*.py"] # binder: true # binder_service_name: "gh" # binder_branch: "main" diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index 7091e891..2a48cd4d 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -108,7 +108,7 @@ def test_initialization(self, fields: dict, dim: int) -> None: if field in ["variance", "lengthscale", "period", "alpha"]: assert isinstance(meta[field]["bijector"], tfb.Softplus) if field in ["power"]: - assert isinstance(meta[field]["bijector"], tfb.Identity) + assert isinstance(meta[field]["bijector"], tfb.Sigmoid) # Trainability state assert meta[field]["trainable"] is True @@ -225,7 +225,7 @@ class TestPeriodic(BaseTestKernel): class TestPoweredExponential(BaseTestKernel): kernel = PoweredExponential fields = prod( - {"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0], "power": [0.1, 2.0]} + {"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0], "power": [0.1, 0.9]} ) params = {"test_initialization": fields} default_compute_engine = DenseKernelComputation