diff --git a/gpjax/kernels/computations/basis_functions.py b/gpjax/kernels/computations/basis_functions.py index e0693f12..d62d144f 100644 --- a/gpjax/kernels/computations/basis_functions.py +++ b/gpjax/kernels/computations/basis_functions.py @@ -12,6 +12,7 @@ from cola import PSD from cola.ops import ( Dense, + Diagonal, LinearOperator, ) @@ -58,6 +59,20 @@ def gram(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> LinearOperator: z1 = self.compute_features(kernel, inputs) return PSD(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T))) + def diagonal(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> Diagonal: + r"""For a given kernel, compute the elementwise diagonal of the + NxN gram matrix on an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): the kernel function. + inputs (Float[Array, "N D"]): The input matrix. + + Returns + ------- + Diagonal: The computed diagonal variance entries. + """ + return super().diagonal(kernel.base_kernel, inputs) + def compute_features( self, kernel: Kernel, x: Float[Array, "N D"] ) -> Float[Array, "N L"]: