From 93ed7d2d5d6b4e7acf2b5834735ba2de96bbb775 Mon Sep 17 00:00:00 2001 From: Stephen Huan Date: Tue, 2 Apr 2024 03:54:46 -0400 Subject: [PATCH] feat(gpjax/kernels/computations/basis_functions.py): add diagonal --- gpjax/kernels/computations/basis_functions.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/gpjax/kernels/computations/basis_functions.py b/gpjax/kernels/computations/basis_functions.py index e0693f12..d62d144f 100644 --- a/gpjax/kernels/computations/basis_functions.py +++ b/gpjax/kernels/computations/basis_functions.py @@ -12,6 +12,7 @@ from cola import PSD from cola.ops import ( Dense, + Diagonal, LinearOperator, ) @@ -58,6 +59,20 @@ def gram(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> LinearOperator: z1 = self.compute_features(kernel, inputs) return PSD(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T))) + def diagonal(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> Diagonal: + r"""For a given kernel, compute the elementwise diagonal of the + NxN gram matrix on an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): the kernel function. + inputs (Float[Array, "N D"]): The input matrix. + + Returns + ------- + Diagonal: The computed diagonal variance entries. + """ + return super().diagonal(kernel.base_kernel, inputs) + def compute_features( self, kernel: Kernel, x: Float[Array, "N D"] ) -> Float[Array, "N L"]: