Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for missing values using masked dataset #340

Merged
merged 15 commits into from
Aug 25, 2023
Merged

Support for missing values using masked dataset #340

merged 15 commits into from
Aug 25, 2023

Conversation

frazane
Copy link
Contributor

@frazane frazane commented Jul 31, 2023

Type of changes

  • Bug fix
  • New feature
  • Documentation / docstrings
  • Tests
  • Other

Checklist

  • I've formatted the new code by running poetry run pre-commit run --all-files --show-diff-on-failure before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

This PR introduces support for missing data points (missing ys) in a way that is compatible with JAX jit-compilation (by keeping shapes static). This could be particularly useful in meta-learning settings (where each task has a variable number of available training points), or in multi-output GPs where only some of the outputs are missing (@ingmarschuster).

  • new attribute mask added to the Datasetobject
  • added code to computations of log_prob and to posterior predictions, such that masked training inputs have no influence on the results

For now I only tested this on a simple exact GP with gaussian likelihood.

@frazane frazane added the enhancement New feature or request label Jul 31, 2023
Copy link
Contributor

@ingmarschuster ingmarschuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly needed is a test for masked log_prob. The := operator is just a question of style

gpjax/dataset.py Outdated Show resolved Hide resolved
gpjax/gaussian_distribution.py Show resolved Hide resolved
gpjax/dataset.py Outdated Show resolved Hide resolved
gpjax/gaussian_distribution.py Show resolved Hide resolved
gpjax/dataset.py Outdated
from simple_pytree import Pytree

from gpjax.typing import Array


class _Missing:
"""Sentinel class for not-yet-computed mask"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you expand a little on this description please? To me, it doesn't fully explain the class' purpose.

Also, what is a sentinel class?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On reflection, why do we even need this class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Dataset class holds either a boolean array or None, however at initialization jnp.isnan will be computed every time to ensure that there are no missing values in the data. This can be a lot of overhead especially when constructing batches (see get_batch in fit.py): if the original Dataset has mask=None there is no reason to do the computation for the subset. So the logic here is:

  • if nothing is specified the default is this sentinel class, we compute the mask if needed
  • if None is passed explicitly, we skip the computation and the mask is just None

In short, we need this class to provide a clear differentiation between an intentionally provided value of None and a default "unset" value for the mask attribute.

In these situations the "sentinel object/value" pattern is useful, see https://python-patterns.guide/python/sentinel-object/

Copy link
Contributor

@ingmarschuster ingmarschuster Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks better to me than the solution I've always used (introduce an additional boolean variable use_mask). Only downside is you have to introduce a class only for this.
Maybe yet another solution would be to have a "Dataset" and a "MaskedDataset" class. I would be ok with the sentinel.

Copy link
Contributor

@ingmarschuster ingmarschuster Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another solution would be to use string constants, which is what some python libs do and you don't have to introduce a new class for it.

@dataclass
class Dataset:
    mask: Optional[Union[Bool[Array, "N Q"], str]] = "No Mask"

one hint: Optional[...] is a shorthand notation for Union[..., None] so your original type hint (Optional[Union[Bool[Array, "N Q"], None]]) is doubling things.

gpjax/dataset.py Outdated Show resolved Hide resolved
@thomaspinder thomaspinder added this to the v1.0.0 milestone Aug 8, 2023
@frazane frazane marked this pull request as ready for review August 8, 2023 11:54
Copy link
Contributor

@ingmarschuster ingmarschuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sentinel would be fine to me. But see the comment about this above for other solution options I see

gpjax/dataset.py Outdated
from simple_pytree import Pytree

from gpjax.typing import Array


class _Missing:
"""Sentinel class for not-yet-computed mask"""
Copy link
Contributor

@ingmarschuster ingmarschuster Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks better to me than the solution I've always used (introduce an additional boolean variable use_mask). Only downside is you have to introduce a class only for this.
Maybe yet another solution would be to have a "Dataset" and a "MaskedDataset" class. I would be ok with the sentinel.

@ingmarschuster
Copy link
Contributor

I changed the sentinel class _Missing to a string literal alternative. I believe this serves the same purpose and might be more GPJax-style simple.
Accidentally pushed to the upstream branch. @frazane and @thomaspinder, what do you think? Shall we merge this version?

@frazane frazane dismissed thomaspinder’s stale review August 25, 2023 14:45

Following discussion with Ingmar on slack

@ingmarschuster ingmarschuster removed the request for review from thomaspinder August 25, 2023 15:22
@frazane frazane merged commit 7c85891 into main Aug 25, 2023
21 checks passed
@frazane frazane deleted the mask-missing branch December 23, 2023 15:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants