Skip to content

Commit

Permalink
Merge pull request #279 from ingmarschuster/patch-1
Browse files Browse the repository at this point in the history
Bugfix powered_exponential.py
  • Loading branch information
thomaspinder authored May 30, 2023
2 parents 134d80d + 2b72201 commit 558e797
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
8 changes: 5 additions & 3 deletions gpjax/kernels/stationary/powered_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@

@dataclass
class PoweredExponential(AbstractKernel):
r"""The powered exponential family of kernels.
r"""The powered exponential family of kernels. This also equivalent to the symmetric generalized normal distribution.
Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics".
See Diggle and Ribeiro (2007) - "Model-based Geostatistics".
and
https://en.wikipedia.org/wiki/Generalized_normal_distribution#Symmetric_version
"""

lengthscale: Union[ScalarFloat, Float[Array, " D"]] = param_field(
jnp.array(1.0), bijector=tfb.Softplus()
)
variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
power: ScalarFloat = param_field(jnp.array(1.0))
power: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Sigmoid())
name: str = "Powered Exponential"

def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat:
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ plugins:
execute: true
allow_errors: false
include: ["examples/*.py"]
ignore: ["examples/utils.py"]
ignore: ["examples/utils.py", "_statch/*.py", "scripts/*.py"]
# binder: true
# binder_service_name: "gh"
# binder_branch: "main"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_kernels/test_stationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_initialization(self, fields: dict, dim: int) -> None:
if field in ["variance", "lengthscale", "period", "alpha"]:
assert isinstance(meta[field]["bijector"], tfb.Softplus)
if field in ["power"]:
assert isinstance(meta[field]["bijector"], tfb.Identity)
assert isinstance(meta[field]["bijector"], tfb.Sigmoid)

# Trainability state
assert meta[field]["trainable"] is True
Expand Down Expand Up @@ -225,7 +225,7 @@ class TestPeriodic(BaseTestKernel):
class TestPoweredExponential(BaseTestKernel):
kernel = PoweredExponential
fields = prod(
{"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0], "power": [0.1, 2.0]}
{"lengthscale": [0.1, 1.0], "variance": [0.1, 1.0], "power": [0.1, 0.9]}
)
params = {"test_initialization": fields}
default_compute_engine = DenseKernelComputation
Expand Down

0 comments on commit 558e797

Please sign in to comment.