-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring kernels #206
Refactoring kernels #206
Conversation
Codecov Report
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more @@ Coverage Diff @@
## refactor_kernels #206 +/- ##
====================================================
+ Coverage 28.29% 28.63% +0.33%
====================================================
Files 70 70
Lines 3262 3332 +70
====================================================
+ Hits 923 954 +31
- Misses 2339 2378 +39
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
gpjax/kernels/stationary/matern52.py
Outdated
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] | ||
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to replace all kernel inputs of Float[Array, "1 D"]
with Float[Array, "D"]
for the x
's and y
's!
gpjax/kernels/stationary/white.py
Outdated
class White(AbstractKernel): | ||
def __init__( | ||
self, | ||
compute_engine: AbstractKernelComputation = ConstantDiagonalKernelComputation, | ||
active_dims: Optional[List[int]] = None, | ||
name: Optional[str] = "White Noise Kernel", | ||
) -> None: | ||
super().__init__(compute_engine, active_dims, spectral_density=None, name=name) | ||
self._stationary = True | ||
|
||
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the White kernel, the default computation should be ConstantDiagonalKernelComputation
, so that
from simple_pytree import static_field
@dataclass
class White(AbstractKernel):
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)
compute_engine: AbstractKernelComputation = static_field(ConstantDiagonalKernelComputation) # <- set the default
def test_gram(self, dim: int, n: int) -> None: | ||
kernel: AbstractKernel = self.kernel() | ||
kernel.gram | ||
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) | ||
Kxx = kernel.gram(x) | ||
assert isinstance(Kxx, LinearOperator) | ||
assert Kxx.shape == (n, n) | ||
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense()) > 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you test positive definiteness, I would recommend adding a small jitter, to the diagonal of the gram matrix Kxx, I found that one of the tests failed for me locally on my machine.
I would also give a quick comment to say what the line:
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense()) > 0.0)
is doing to make it clear for others!
Just contributing to #199.
Pull request type
Please check the type of change your PR introduces: