-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Categorical kernel #345
Categorical kernel #345
Conversation
First implementation and demo notebook
Merge in main
This reverts commit a4e877f.
jnp.eye(2), bijector=tfb.CorrelationCholesky() | ||
) | ||
inspace_vals: list = static_field(None) | ||
name: str = "Dictionary Kernel" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the name is a Dictionary Kernel, yet the object is called CatKernel
. Could you shed some light on this incongruity please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dictionary Kernel or DictKernel
was the old name. References to this old name will be fixed.
L = self.sdev.reshape(-1, 1) * self.cholesky_lower | ||
return L @ L.T | ||
|
||
def __call__( # TODO not consistent with general kernel interface |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is True of the GraphKernel too. Not helpful, I know, but maybe there's an alternative abstraction that is more appropriate for non-Euclidean kernels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The this is because AbstractKernel.__call__(x, y)
requires float array for x and y. Insteadm maybe the right signature for this baseclass would be
def __call__(self, x: Num[Array, " D"], y: Num[Array, " D"]) -> ScalarFloat:
because then the categorical kernel could specialize to ScalarInt if I'm not mistaken.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I agree this would be the more general signature.
@property | ||
def explicit_gram(self): | ||
L = self.sdev.reshape(-1, 1) * self.cholesky_lower | ||
return L @ L.T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this differ from the regular gram
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually a property. Implemented this because it can be handy to use this rather than the internal parametrization. Adding a doc string.
ValueError: If the number of diagonal variance parameters does not match the number of input space values. | ||
""" | ||
|
||
sdev: Float[Array, " N"] = param_field(jnp.ones((2,)), bijector=tfb.Softplus()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: elsewhere in the package we use stddev
for standard deviation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed name.
return num_inspace_vals * (num_inspace_vals - 1) // 2 | ||
|
||
@classmethod | ||
def gram_to_sdev_cholesky_lower(cls, gram: Float[Array, "N N"]) -> CatKernelParams: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't need to be a class method since cls
is not used. Can be a static method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, changed.
@thomaspinder The |
If I get your thumbs up I'll merge. |
cholesky_lower: Float[Array, "N N"] = param_field( | ||
jnp.eye(2), bijector=tfb.CorrelationCholesky() | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not know this bijector exists. This makes the code neater. Its worth thinking about a slightly messier formulation though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice to be able to control the flexibility, kinda like when you specify the rank of W in the decomp K = W * W^T + kappa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much work would this be @ingmarschuster? If it's simple, then maybe let's add it to this PR. Otherwise, if you feel it's a good idea, then let's open an issue for it.
Left two comments @ingmarschuster. They don't need resolving or actioning, and the PR can now be merged. |
Type of changes
Checklist
poetry run pre-commit run --all-files --show-diff-on-failure
before committing.Description
This implements a kernel with explicit gram values for categorical input (such as a string of characters). The parametrization is working very well for gradient descent.
Not sure about formatting with the poetry command - the command didn't work for me. I'm using the black formatter.