diff --git a/gpjax/kernels/computations/basis_functions.py b/gpjax/kernels/computations/basis_functions.py index fef2fe76..36bc3132 100644 --- a/gpjax/kernels/computations/basis_functions.py +++ b/gpjax/kernels/computations/basis_functions.py @@ -30,8 +30,7 @@ def cross_covariance( """ z1 = self.compute_features(x) z2 = self.compute_features(y) - z1 /= self.kernel.num_basis_fns - return self.kernel.base_kernel.variance * jnp.matmul(z1, z2.T) + return self.scaling * jnp.matmul(z1, z2.T) def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator: r"""Compute an approximate Gram matrix. @@ -47,9 +46,7 @@ def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator: $`N \times N`$ Gram matrix. """ z1 = self.compute_features(inputs) - matrix = jnp.matmul(z1, z1.T) # shape: (n_samples, n_samples) - matrix /= self.kernel.num_basis_fns - return DenseLinearOperator(self.kernel.base_kernel.variance * matrix) + return DenseLinearOperator(self.scaling * jnp.matmul(z1, z1.T)) def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]: r"""Compute the features for the inputs. @@ -66,3 +63,7 @@ def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]: z = jnp.matmul(x, (frequencies / scaling_factor).T) z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1) return z + + @property + def scaling(self): + return self.kernel.base_kernel.variance / self.kernel.num_basis_fns