Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: Zero mean function doesn't necessarily return zero after optimising a GP posterior #330

Closed
Thomas-Christie opened this issue Jul 4, 2023 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@Thomas-Christie
Copy link
Contributor

Thomas-Christie commented Jul 4, 2023

Bug Report

GPJax version:

0.6.7

Current behaviour:

Currently, the Zero mean function doesn't necessarily stay as zero once the posterior model has been optimised. This is because, under the hood, it is modelled using the Constant mean function, with the constant field set to 0. However, the constant field is trainable, and so upon inspecting a GP posterior which has had its hyperparameters optimised, the Zero mean function can in fact be non-zero.

Expected behaviour:

I believe that if a user explicitly chooses the Zero mean function then it should always return zero (so, in effect, the mean is not trainable). This is also consistent with existing GP modelling libraries such as GPFlow (https://github.com/GPflow/GPflow/blob/develop/gpflow/functions.py#L187).

Steps to reproduce:

Create a GP model with a Zero mean function prior, and generate an optimised posterior on some data with non-zero mean. Inspecting the posterior model's "prior" mean will reveal that the constant field of the Zero mean function will have changed from 0. This can be seen in the regression notebook.

Feel free to let me know if this behaviour is in fact intended!

@Thomas-Christie Thomas-Christie added the bug Something isn't working label Jul 4, 2023
@Thomas-Christie Thomas-Christie changed the title bug: Zero mean function doesn't necessarily remain zero r bug: Zero mean function doesn't necessarily remain zero Jul 4, 2023
@Thomas-Christie Thomas-Christie changed the title bug: Zero mean function doesn't necessarily remain zero bug: Zero mean function doesn't necessarily return zero after optimising a GP posterior Jul 4, 2023
@daniel-dodd
Copy link
Member

Thanks for spotting this @Thomas-Christie.

Changing param_field's trainable=False, won't fix this issue in general. The constant field really needs to be a static_field as it should never change value under any operations, and we should have an init=False to ensure it does not change value upon initialisation e.g.,

from gpjax.mean_functions import Constant
from gpjax.base import static_field
from dataclasses import dataclass
from jaxtyping import Float, Array
import jax.numpy as jnp

@dataclass
class Zero(Constant):
    constant: Float[Array, "1"] = static_field(jnp.array([0.0]), init=False)

Other options available here too.

@Thomas-Christie
Copy link
Contributor Author

Thanks for the advice @daniel-dodd , I will work on this.

@Thomas-Christie Thomas-Christie self-assigned this Jul 6, 2023
@trsav
Copy link
Contributor

trsav commented Jul 30, 2023

#339 feels slightly similar to this having just put in the request. Parameter (c) defined as a function of a static field which should remain constant, however is being 'trained' (in this case resulting in an invalid Polar GP).

@Thomas-Christie
Copy link
Contributor Author

Fixed in #358

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants