-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for missing values using masked dataset #340
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly needed is a test for masked log_prob. The := operator is just a question of style
gpjax/dataset.py
Outdated
from simple_pytree import Pytree | ||
|
||
from gpjax.typing import Array | ||
|
||
|
||
class _Missing: | ||
"""Sentinel class for not-yet-computed mask""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you expand a little on this description please? To me, it doesn't fully explain the class' purpose.
Also, what is a sentinel class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On reflection, why do we even need this class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Dataset
class holds either a boolean array or None
, however at initialization jnp.isnan
will be computed every time to ensure that there are no missing values in the data. This can be a lot of overhead especially when constructing batches (see get_batch
in fit.py
): if the original Dataset has mask=None
there is no reason to do the computation for the subset. So the logic here is:
- if nothing is specified the default is this sentinel class, we compute the mask if needed
- if
None
is passed explicitly, we skip the computation and the mask is justNone
In short, we need this class to provide a clear differentiation between an intentionally provided value of None
and a default "unset" value for the mask attribute.
In these situations the "sentinel object/value" pattern is useful, see https://python-patterns.guide/python/sentinel-object/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks better to me than the solution I've always used (introduce an additional boolean variable use_mask
). Only downside is you have to introduce a class only for this.
Maybe yet another solution would be to have a "Dataset" and a "MaskedDataset" class. I would be ok with the sentinel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another solution would be to use string constants, which is what some python libs do and you don't have to introduce a new class for it.
@dataclass
class Dataset:
mask: Optional[Union[Bool[Array, "N Q"], str]] = "No Mask"
one hint: Optional[...]
is a shorthand notation for Union[..., None]
so your original type hint (Optional[Union[Bool[Array, "N Q"], None]]
) is doubling things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sentinel would be fine to me. But see the comment about this above for other solution options I see
gpjax/dataset.py
Outdated
from simple_pytree import Pytree | ||
|
||
from gpjax.typing import Array | ||
|
||
|
||
class _Missing: | ||
"""Sentinel class for not-yet-computed mask""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks better to me than the solution I've always used (introduce an additional boolean variable use_mask
). Only downside is you have to introduce a class only for this.
Maybe yet another solution would be to have a "Dataset" and a "MaskedDataset" class. I would be ok with the sentinel.
I changed the sentinel class |
Following discussion with Ingmar on slack
Type of changes
Checklist
poetry run pre-commit run --all-files --show-diff-on-failure
before committing.Description
This PR introduces support for missing data points (missing
y
s) in a way that is compatible with JAX jit-compilation (by keeping shapes static). This could be particularly useful in meta-learning settings (where each task has a variable number of available training points), or in multi-output GPs where only some of the outputs are missing (@ingmarschuster).mask
added to theDataset
objectlog_prob
and to posterior predictions, such that masked training inputs have no influence on the resultsFor now I only tested this on a simple exact GP with gaussian likelihood.