Skip to content

Commit

Permalink
feat(gpjax/kernels/computations/basis_functions.py): add diagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
stephen-huan committed Apr 2, 2024
1 parent 3a0bac8 commit 93ed7d2
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions gpjax/kernels/computations/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cola import PSD
from cola.ops import (
Dense,
Diagonal,
LinearOperator,
)

Expand Down Expand Up @@ -58,6 +59,20 @@ def gram(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> LinearOperator:
z1 = self.compute_features(kernel, inputs)
return PSD(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))

def diagonal(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> Diagonal:
r"""For a given kernel, compute the elementwise diagonal of the
NxN gram matrix on an input matrix of shape NxD.
Args:
kernel (AbstractKernel): the kernel function.
inputs (Float[Array, "N D"]): The input matrix.
Returns
-------
Diagonal: The computed diagonal variance entries.
"""
return super().diagonal(kernel.base_kernel, inputs)

def compute_features(
self, kernel: Kernel, x: Float[Array, "N D"]
) -> Float[Array, "N L"]:
Expand Down

0 comments on commit 93ed7d2

Please sign in to comment.