Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Categorical kernel #345

Merged
merged 9 commits into from
Aug 8, 2023
Merged

Categorical kernel #345

merged 9 commits into from
Aug 8, 2023

Conversation

ingmarschuster
Copy link
Contributor

@ingmarschuster ingmarschuster commented Aug 5, 2023

Type of changes

  • Bug fix
  • New feature
  • Documentation / docstrings
  • Tests
  • Other

Checklist

  • I've formatted the new code by running poetry run pre-commit run --all-files --show-diff-on-failure before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

This implements a kernel with explicit gram values for categorical input (such as a string of characters). The parametrization is working very well for gradient descent.

Not sure about formatting with the poetry command - the command didn't work for me. I'm using the black formatter.

@ingmarschuster ingmarschuster added enhancement New feature or request testing Testing labels Aug 5, 2023
gpjax/kernels/non_euclidean/categorical.py Show resolved Hide resolved
jnp.eye(2), bijector=tfb.CorrelationCholesky()
)
inspace_vals: list = static_field(None)
name: str = "Dictionary Kernel"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the name is a Dictionary Kernel, yet the object is called CatKernel. Could you shed some light on this incongruity please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dictionary Kernel or DictKernel was the old name. References to this old name will be fixed.

L = self.sdev.reshape(-1, 1) * self.cholesky_lower
return L @ L.T

def __call__( # TODO not consistent with general kernel interface
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is True of the GraphKernel too. Not helpful, I know, but maybe there's an alternative abstraction that is more appropriate for non-Euclidean kernels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The this is because AbstractKernel.__call__(x, y) requires float array for x and y. Insteadm maybe the right signature for this baseclass would be
def __call__(self, x: Num[Array, " D"], y: Num[Array, " D"]) -> ScalarFloat:
because then the categorical kernel could specialize to ScalarInt if I'm not mistaken.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I agree this would be the more general signature.

Comment on lines 72 to 75
@property
def explicit_gram(self):
L = self.sdev.reshape(-1, 1) * self.cholesky_lower
return L @ L.T
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this differ from the regular gram method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually a property. Implemented this because it can be handy to use this rather than the internal parametrization. Adding a doc string.

ValueError: If the number of diagonal variance parameters does not match the number of input space values.
"""

sdev: Float[Array, " N"] = param_field(jnp.ones((2,)), bijector=tfb.Softplus())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: elsewhere in the package we use stddev for standard deviation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed name.

return num_inspace_vals * (num_inspace_vals - 1) // 2

@classmethod
def gram_to_sdev_cholesky_lower(cls, gram: Float[Array, "N N"]) -> CatKernelParams:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to be a class method since cls is not used. Can be a static method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, changed.

@ingmarschuster
Copy link
Contributor Author

@thomaspinder The input_1hot field is now documented. Thanks for the catch. All other problems should be fixed.

@ingmarschuster
Copy link
Contributor Author

If I get your thumbs up I'll merge.

Comment on lines +61 to +63
cholesky_lower: Float[Array, "N N"] = param_field(
jnp.eye(2), bijector=tfb.CorrelationCholesky()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know this bijector exists. This makes the code neater. Its worth thinking about a slightly messier formulation though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to be able to control the flexibility, kinda like when you specify the rank of W in the decomp K = W * W^T + kappa

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much work would this be @ingmarschuster? If it's simple, then maybe let's add it to this PR. Otherwise, if you feel it's a good idea, then let's open an issue for it.

@thomaspinder
Copy link
Collaborator

If I get your thumbs up I'll merge.

Left two comments @ingmarschuster. They don't need resolving or actioning, and the PR can now be merged.

@thomaspinder thomaspinder added this to the v1.0.0 milestone Aug 8, 2023
@ingmarschuster ingmarschuster merged commit 77aa81c into main Aug 8, 2023
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request testing Testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants