Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(gpjax/kernels/base.py): add diagonal #429

Merged

Conversation

stephen-huan
Copy link
Contributor

@stephen-huan stephen-huan commented Dec 28, 2023

Type of changes

  • Bug fix
  • New feature
  • Documentation / docstrings
  • Tests
  • Other

Checklist

  • I've formatted the new code by running poetry run pre-commit run --all-files --show-diff-on-failure before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

AbstractKernel has wrappers for cross_covariance and gram but not for diagonal, when all three are defined by AbstractKernelComputation (diagonal is actually concretely implemented by AbstractKernelComputation).

This PR adds diagonal to AbstractKernel.

Although the change is trivial, I believe it's important for the following reasons.

Currently, the way to get the diagonal of a kernel matrix efficiently would be code like

import jax.numpy as jnp
from gpjax.kernels import DenseKernelComputation, Matern52

kernel = Matern52()
diagonal = DenseKernelComputation().diagonal
# diag([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.])
print(diagonal(kernel, jnp.zeros((10, 3))))

Other than being needlessly verbose, this obscures the fact that most of the *Computation classes (DiagonalKernelComputation, etc.) could be used in place of DenseKernelComputation without changing the underlying implementation from AbstractKernelComputation. Instead, code like

kernel.diagonal(jnp.zeros((10, 3)))

is cleaner and would allow specific kernels to provide a more efficient, specialized implementation (for example, the Matern family has a unit diagonal no matter the input data).

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening your first PR into GPJax!

If you have not heard from us in a while, please feel free to ping @gpjax/developers or anyone who has commented on the PR. Most of our reviewers are volunteers and sometimes things fall through the cracks.

You can also join us on Slack for real-time discussion.

For details on testing, writing docs, and our review process, please see the developer guide

We strive to be a welcoming and open project. Please follow our Code of Conduct.

@thomaspinder
Copy link
Collaborator

Hi @stephen-huan - thanks for adding this. Would you be able to include a unit test for this function please?

@stephen-huan
Copy link
Contributor Author

@thomaspinder I've added tests, sorry for the long delay!

@stephen-huan
Copy link
Contributor Author

@thomaspinder any progress on this PR?

I added consistency tests that are tangentially related to this PR: now kernel.gram(x) is checked by kernel.cross_covariance(x, x), kernel.diagonal(x) is checked by diagonal(kernel.gram(x)), and kernel.cross_covariance(a, b) is checked by kernel.gram(c) where c is vstack((a, b)).

I also noticed a check of max(a - b) < tol, when it should probably be max(abs(a - b)) < tol.

@thomaspinder
Copy link
Collaborator

Thanks @stephen-huan - nice to see this included in GPJax!

@thomaspinder thomaspinder merged commit 7ae0adf into JaxGaussianProcesses:main Aug 9, 2024
11 checks passed
@stephen-huan stephen-huan deleted the kernel-diagonal branch August 9, 2024 07:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants