From 34d0c3272826deaa69b70b3c14dc6bfd339c84e9 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 6 Nov 2022 12:02:33 +0000 Subject: [PATCH 1/8] Update docs --- docs/README.md | 86 ++++++++++ docs/_api.rst | 28 ++- docs/_static/css/gpjax_theme.css | 3 + docs/_static/custom.css | 9 - docs/_templates/autosummary/gpjax_module.rst | 0 docs/conf.py | 42 +++-- docs/conf_sphinx_patch.py | 169 +++++++++++++++++++ docs/design.md | 6 +- 8 files changed, 308 insertions(+), 35 deletions(-) create mode 100644 docs/README.md create mode 100644 docs/_static/css/gpjax_theme.css delete mode 100644 docs/_static/custom.css create mode 100644 docs/_templates/autosummary/gpjax_module.rst create mode 100644 docs/conf_sphinx_patch.py diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..336452f4 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,86 @@ +# Where to find the docs + +The FLAX documentation can be found here: +https://gpjax.readthedocs.io/en/latest/ + +# How to build the docs + +1. Install the requirements using `pip install -r docs/requirements.txt` +2. Make sure `pandoc` is installed +3. Run the make script `make html` + +The corresponding HTML files can then be found in `docs/_build/html/`. + +# How to write code documentation + +Our documentation it is written in ReStructuredText for Sphinx. This is a +meta-language that is compiled into online documentation. For more details see +[Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html). +As a result, our docstrings adhere to a specific syntax that has to be kept in +mind. Below we provide some guidelines. + +## How much information to put in a docstring + +A docstring should be informative. If in doubt, then it is best to add more +information to a docstring than less. Many users will skim documentation, so +please ensure the opening sentence or two of a docstring contains the core +information. Adding examples and mathematical descriptions to documentation is +highly desirable. + +We are making an active effort within GPJax to improve our documentation. If you +spot any areas where there is missing information within the existing +documentation, then please either raise an issue or +[create a pull request](https://gpjax.readthedocs.io/en/latest/contributing.html). + +## An example docstring + +An example docstring that adheres the principles of GPJax is given below. +The docstring contains a simple, snappy introduction with links to auxillary +components. More detail is then provided in the form of a mathematical +description and a code example. The docstring is concluded with a description +of the objects attributes with corresponding types. + +```python +@dataclass +class Prior(AbstractPrior): + """A Gaussian process prior object. The GP is parameterised by a + `mean `_ + and `kernel `_ function. + + A Gaussian process prior parameterised by a mean function :math:`m(\\cdot)` and a kernel + function :math:`k(\\cdot, \\cdot)` is given by + + .. math:: + + p(f(\\cdot)) = \mathcal{GP}(m(\\cdot), k(\\cdot, \\cdot)). + + To invoke a ``Prior`` distribution, only a kernel function is required. By default, + the mean function will be set to zero. In general, this assumption will be reasonable + assuming the data being modelled has been centred. + + Example: + >>> import gpjax as gpx + >>> + >>> kernel = gpx.kernels.RBF() + >>> prior = gpx.Prior(kernel = kernel) + + Attributes: + kernel (Kernel): The kernel function used to parameterise the prior. + mean_function (MeanFunction): The mean function used to parameterise the prior. Defaults to zero. + name (str): The name of the GP prior. Defaults to "GP prior". + """ + + kernel: Kernel + mean_function: Optional[AbstractMeanFunction] = Zero() + name: Optional[str] = "GP prior" +``` + +### Documentation syntax + +A helpful cheatsheet for writing restructured text can be found +[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). In addition to that, we adopt the following convention when documenting +`dataclass` objects. + +* Class attributes should be specified using the `Attributes:` tag. +* Method argument should be specified using the `Args:` tags. +* All attributes and arguments should have types. diff --git a/docs/_api.rst b/docs/_api.rst index b6404ea9..31dda5e6 100644 --- a/docs/_api.rst +++ b/docs/_api.rst @@ -8,18 +8,34 @@ Package Reference Gaussian Processes ################################# +The Gaussian process abstractions in GPJax can be segmented into two distinct +types: prior and posterior objects. This makes for a clean separation of both +code and mathematical concepts. Throught the multiplication of a +`Prior `_ +and `Likelihood `_, +GPJax will then return the appropriate `Posterior `_. + + .. automodule:: gpjax.gps .. currentmodule:: gpjax.gps - Abstract GPs ********************************* -.. autoclass:: AbstractGP +To ensure a consistent API, we subclass all Gaussian process objects from either +``AbstractPrior`` or ``AbstractPosterior``. These classes are not intended to be +used directly, but instead to provide a common interface for all downstream Gaussian +process objects. + +.. autoclass:: AbstractPrior :members: + :special-members: __call__ + :private-members: _initialise_params + :exclude-members: from_tuple, replace, to_tuple .. autoclass:: AbstractPosterior :members: + :special-members: __call__ Gaussian Process Priors @@ -27,15 +43,23 @@ Gaussian Process Priors .. autoclass:: Prior :members: + :special-members: __call__, __mul__ Gaussian Process Posteriors ********************************* .. autoclass:: ConjugatePosterior :members: + :special-members: __call__ .. autoclass:: NonConjugatePosterior :members: + :special-members: __call__ + +Posterior Constructors +********************************* + +.. autofunction:: construct_posterior Kernels diff --git a/docs/_static/css/gpjax_theme.css b/docs/_static/css/gpjax_theme.css new file mode 100644 index 00000000..47e6a841 --- /dev/null +++ b/docs/_static/css/gpjax_theme.css @@ -0,0 +1,3 @@ +nav .bd-links a:hover{ + color: #B5121B +} \ No newline at end of file diff --git a/docs/_static/custom.css b/docs/_static/custom.css deleted file mode 100644 index 170ebcfc..00000000 --- a/docs/_static/custom.css +++ /dev/null @@ -1,9 +0,0 @@ -@import "sphinx-book-theme.css"; - -a { - color: #B5121B -} - -.content-container a { - color: #B5121B -} \ No newline at end of file diff --git a/docs/_templates/autosummary/gpjax_module.rst b/docs/_templates/autosummary/gpjax_module.rst new file mode 100644 index 00000000..e69de29b diff --git a/docs/conf.py b/docs/conf.py index 2a4d5f8f..24829663 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,9 +16,11 @@ import re import sys +sys.path.insert(0, os.path.abspath("..")) + from importlib_metadata import version -# sys.path.insert(0, os.path.abspath("../..")) +import docs.conf_sphinx_patch def read(*names, **kwargs): @@ -72,7 +74,6 @@ def find_version(*file_paths): "sphinx.ext.autosummary", "sphinx.ext.napoleon", "sphinx.ext.viewcode", - "sphinx.ext.autosectionlabel", "sphinx_copybutton", "sphinxcontrib.bibtex", "sphinxext.opengraph", @@ -80,6 +81,9 @@ def find_version(*file_paths): "sphinx_tabs.tabs", ] +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + # MyST Config myst_enable_extensions = [ "amsmath", @@ -165,6 +169,15 @@ def find_version(*file_paths): ] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] +master_doc = "index" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +# html_static_path = ["_static"] +html_static_path = ["_static"] +html_css_files = ["css/gpjax_theme.css"] + # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -173,14 +186,14 @@ def find_version(*file_paths): autosummary_generate = True +autodoc_typehints = "none" napoleon_use_rtype = False autodoc_default_options = { "member-order": "bysource", "special-members": "__init__, __call__", - "undoc-members": True, - "exclude-members": "__weakref__,_abc_impl", + "exclude-members": "__weakref__,_abc_impl,from_tuple,replace,to_tuple", + "autodoc-typehints": "none", } - # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for @@ -202,22 +215,5 @@ def find_version(*file_paths): "use_repository_button": True, "use_sidenotes": True, # Turns footnotes into sidenotes - https://sphinx-book-theme.readthedocs.io/en/stable/content-blocks.html } -# html_title = f"v{version}" -# html_theme_options = { -# "light_css_variables": { -# "color-brand-primary": "#B5121B", -# "color-brand-content": "#CC3333", -# "color-admonition-background": "orange", -# "source_repository": "https://github.com/thomaspinder/GPJax/", -# "source_branch": "master", -# "source_directory": "docs/", -# }, -# } - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ["_static"] -html_static_path = ["_static"] -html_css_files = ["custom.css"] +always_document_param_types = True diff --git a/docs/conf_sphinx_patch.py b/docs/conf_sphinx_patch.py new file mode 100644 index 00000000..e2a51975 --- /dev/null +++ b/docs/conf_sphinx_patch.py @@ -0,0 +1,169 @@ +# This file is credited to the Flax authors. + +from typing import Any, Dict, List, Set, Tuple +import sphinx.ext.autosummary.generate as ag +import sphinx.ext.autodoc + + +def generate_autosummary_content( + name: str, + obj: Any, + parent: Any, + template: ag.AutosummaryRenderer, + template_name: str, + imported_members: bool, + app: Any, + recursive: bool, + context: Dict, + modname: str = None, + qualname: str = None, +) -> str: + doc = ag.get_documenter(app, obj, parent) + + def skip_member(obj: Any, name: str, objtype: str) -> bool: + try: + return app.emit_firstresult( + "autodoc-skip-member", objtype, name, obj, False, {} + ) + except Exception as exc: + ag.logger.warning( + __( + "autosummary: failed to determine %r to be documented, " + "the following exception was raised:\n%s" + ), + name, + exc, + type="autosummary", + ) + return False + + def get_class_members(obj: Any) -> Dict[str, Any]: + members = sphinx.ext.autodoc.get_class_members(obj, [qualname], ag.safe_getattr) + return {name: member.object for name, member in members.items()} + + def get_module_members(obj: Any) -> Dict[str, Any]: + members = {} + for name in ag.members_of(obj, app.config): + try: + members[name] = ag.safe_getattr(obj, name) + except AttributeError: + continue + return members + + def get_all_members(obj: Any) -> Dict[str, Any]: + if doc.objtype == "module": + return get_module_members(obj) + elif doc.objtype == "class": + return get_class_members(obj) + return {} + + def get_members( + obj: Any, types: Set[str], include_public: List[str] = [], imported: bool = True + ) -> Tuple[List[str], List[str]]: + items: List[str] = [] + public: List[str] = [] + + all_members = get_all_members(obj) + for name, value in all_members.items(): + documenter = ag.get_documenter(app, value, obj) + if documenter.objtype in types: + # skip imported members if expected + if imported or getattr(value, "__module__", None) == obj.__name__: + skipped = skip_member(value, name, documenter.objtype) + if skipped is True: + pass + elif skipped is False: + # show the member forcedly + items.append(name) + public.append(name) + else: + items.append(name) + if name in include_public or not name.startswith("_"): + # considers member as public + public.append(name) + return public, items + + def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]: + """Find module attributes with docstrings.""" + attrs, public = [], [] + try: + analyzer = ag.ModuleAnalyzer.for_module(name) + attr_docs = analyzer.find_attr_docs() + for namespace, attr_name in attr_docs: + if namespace == "" and attr_name in members: + attrs.append(attr_name) + if not attr_name.startswith("_"): + public.append(attr_name) + except ag.PycodeError: + pass # give up if ModuleAnalyzer fails to parse code + return public, attrs + + def get_modules(obj: Any) -> Tuple[List[str], List[str]]: + items: List[str] = [] + for _, modname, _ispkg in ag.pkgutil.iter_modules(obj.__path__): + fullname = name + "." + modname + try: + module = ag.import_module(fullname) + if module and hasattr(module, "__sphinx_mock__"): + continue + except ImportError: + pass + + items.append(fullname) + public = [x for x in items if not x.split(".")[-1].startswith("_")] + return public, items + + ns: Dict[str, Any] = {} + ns.update(context) + + if doc.objtype == "module": + scanner = ag.ModuleScanner(app, obj) + ns["members"] = scanner.scan(imported_members) + ns["functions"], ns["all_functions"] = get_members( + obj, {"function"}, imported=imported_members + ) + ns["classes"], ns["all_classes"] = get_members( + obj, {"class"}, imported=imported_members + ) + ns["exceptions"], ns["all_exceptions"] = get_members( + obj, {"exception"}, imported=imported_members + ) + ns["attributes"], ns["all_attributes"] = get_module_attrs(ns["members"]) + ispackage = hasattr(obj, "__path__") + if ispackage and recursive: + ns["modules"], ns["all_modules"] = get_modules(obj) + elif doc.objtype == "class": + ns["members"] = dir(obj) + ns["inherited_members"] = set(dir(obj)) - set(obj.__dict__.keys()) + ns["methods"], ns["all_methods"] = get_members(obj, {"method"}, ["__init__"]) + ns["attributes"], ns["all_attributes"] = get_members( + obj, {"attribute", "property"} + ) + ns["annotations"] = list(getattr(obj, "__annotations__", {}).keys()) + + if modname is None or qualname is None: + modname, qualname = ag.split_full_qualified_name(name) + + if doc.objtype in ("method", "attribute", "property"): + ns["class"] = qualname.rsplit(".", 1)[0] + + if doc.objtype in ("class",): + shortname = qualname + else: + shortname = qualname.rsplit(".", 1)[-1] + + ns["fullname"] = name + ns["module"] = modname + ns["objname"] = qualname + ns["name"] = shortname + + ns["objtype"] = doc.objtype + ns["underline"] = len(name) * "=" + + if template_name: + return template.render(template_name, ns) + else: + return template.render(doc.objtype, ns) + + +ag.generate_autosummary_content = generate_autosummary_content diff --git a/docs/design.md b/docs/design.md index db17e1fb..d265b0af 100644 --- a/docs/design.md +++ b/docs/design.md @@ -2,7 +2,7 @@ `GPJax` is designed to be a Gaussian process package that provides an accurate representation of the underlying maths. Variable names are chosen to closely match the notation in {cite}`rasmussen2006gaussian`. -We here list the notation used in `GPJax` with its corresponding mathematical quantity. +We here list the notation used in `GPJax` with its corresponding mathematical quantity. ## Gaussian process notation @@ -26,3 +26,7 @@ We here list the notation used in `GPJax` with its corresponding mathematical qu | $m$ | m | Number of inducing inputs | | $\boldsymbol{z} = (z_1,\dotsc,z_{m})$ | z | Inducing inputs | | $\boldsymbol{u} = (u_1,\dotsc,u_{m})$ | u | Inducing outputs | + +## Package style + +Prior to building GPJax, the developers of GPJax have benefited greatly from the [GPFlow](https://github.com/GPflow/GPflow) and [GPyTorch](https://github.com/cornellius-gp/gpytorch) packages. As such, many of the design principles in GPJax are inspired by the excellent precursory pacakges. Documentation designs have been greatly inspired by the exceptional [Flax docs](https://flax.readthedocs.io/en/latest/index.html). \ No newline at end of file From a1d1a84af06920a701db80f1677d5acbf6b56fe2 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 6 Nov 2022 12:07:52 +0000 Subject: [PATCH 2/8] Fix ref --- docs/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/README.md b/docs/README.md index 336452f4..986a3132 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,6 +1,6 @@ # Where to find the docs -The FLAX documentation can be found here: +The GPJax documentation can be found here: https://gpjax.readthedocs.io/en/latest/ # How to build the docs From 54d3a13752013c07900b9c745aac5d2ce41d0623 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 6 Nov 2022 12:08:22 +0000 Subject: [PATCH 3/8] Rename AbstractGP to AbstractPrior --- gpjax/gps.py | 231 ++++++++++++++++++++++++++++++++++++---------- tests/test_gps.py | 14 +-- 2 files changed, 189 insertions(+), 56 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 27b6ad1f..03398c92 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -32,15 +32,27 @@ from .types import Dataset, PRNGKeyType from .utils import concat_dictionaries -DEFAULT_JITTER = get_defaults()["jitter"] - @dataclass -class AbstractGP: - """Abstract Gaussian process object.""" +class AbstractPrior: + """Abstract Gaussian process prior. + + All Gaussian processes priors should inherit from this class. + + All GPJax Modules are `Chex dataclasses `_. Since + dataclasses take over ``__init__``, the ``__post_init__`` method can be used to + initialise the GP's parameters. + """ def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: - """Evaluate the Gaussian process at the given points. + """Evaluate the Gaussian process at the given points. The output of this function + is a `Distrax distribution `_ from which the + the latent function's mean and covariance can be evaluated and the distribution + can be sampled. + + Under the hood, ``__call__`` is calling the objects ``predict`` method. For this + reasons, classes inheriting the ``AbstractPrior`` class, should not overwrite the + ``__call__`` method and should instead define a ``predict`` method. Args: *args (Any): The arguments to pass to the GP's `predict` method. @@ -53,7 +65,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: @abstractmethod def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: - """Compute the latent function's multivariate normal distribution. + """Compute the latent function's multivariate normal distribution for a + given set of parameters. For any class inheriting the ``AbstractPrior`` class, + this method must be implemented. Args: *args (Any): Arguments to the predict method. @@ -64,8 +78,13 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: """ raise NotImplementedError + @abstractmethod def _initialise_params(self, key: PRNGKeyType) -> Dict: - """Initialise the GP's parameter set. + """An initialisation method for the GP's parameters. This method should + be implemented for all classes that inherit the ``AbstractPrior`` class. + Whilst not always necessary, the method accepts a PRNG key to allow + for stochastic initialisation. The method should is most often invoked + through the ``initialise`` function given in GPJax. Args: key (PRNGKeyType): The PRNG key. @@ -79,23 +98,66 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: ####################### # GP Priors ####################### -@dataclass(repr=False) -class Prior(AbstractGP): - """A Gaussian process prior object. The GP is parameterised by a mean and kernel function.""" +@dataclass +class Prior(AbstractPrior): + """A Gaussian process prior object. The GP is parameterised by a + `mean `_ + and `kernel `_ function. + + A Gaussian process prior parameterised by a mean function :math:`m(\\cdot)` and a kernel + function :math:`k(\\cdot, \\cdot)` is given by + + .. math:: + + p(f(\\cdot)) = \mathcal{GP}(m(\\cdot), k(\\cdot, \\cdot)). + + To invoke a ``Prior`` distribution, only a kernel function is required. By default, + the mean function will be set to zero. In general, this assumption will be reasonable + assuming the data being modelled has been centred. + + Example: + >>> import gpjax as gpx + >>> + >>> kernel = gpx.kernels.RBF() + >>> prior = gpx.Prior(kernel = kernel) + + Attributes: + kernel (Kernel): The kernel function used to parameterise the prior. + mean_function (MeanFunction): The mean function used to parameterise the prior. Defaults to zero. + name (str): The name of the GP prior. Defaults to "GP prior". + """ kernel: Kernel mean_function: Optional[AbstractMeanFunction] = Zero() name: Optional[str] = "GP prior" - jitter: Optional[float] = DEFAULT_JITTER def __mul__(self, other: AbstractLikelihood): - """The product of a prior and likelihood is proportional to the posterior distribution. By computing the product of a GP prior and a likelihood object, a posterior GP object will be returned. + """The product of a prior and likelihood is proportional to the + posterior distribution. By computing the product of a GP prior and a + likelihood object, a posterior GP object will be returned. Mathetically, + this can be described by: + .. math:: - Args: - other (Likelihood): The likelihood distribution of the observed dataset. + p(f(\\cdot) | y) \\propto p(y | f(\\cdot)) p(f(\\cdot)). - Returns: - Posterior: The relevant GP posterior for the given prior and likelihood. Special cases are accounted for where the model is conjugate. + where :math:`p(y | f(\\cdot))` is the likelihood and :math:`p(f(\\cdot))` + is the prior. + + + Example: + >>> import gpjax as gpx + >>> + >>> kernel = gpx.kernels.RBF() + >>> prior = gpx.Prior(kernel = kernel) + >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) + >>> + >>> prior * likelihood + + Args: + other (Likelihood): The likelihood distribution of the observed dataset. + + Returns: + Posterior: The relevant GP posterior for the given prior and likelihood. Special cases are accounted for where the model is conjugate. """ return construct_posterior(prior=self, likelihood=other) @@ -113,22 +175,43 @@ def __rmul__(self, other: AbstractLikelihood): def predict( self, params: Dict ) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: - """Compute the GP's prior mean and variance. + """Compute the predictive prior distribution for a given set of + parameters. The output of this function is a function that computes + a distrx distribution for a given set of inputs. + + In the following example, we compute the predictive prior distribution + and then evaluate it on the interval :math:`[0, 1]`: + + Example: + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> + >>> kernel = gpx.kernels.RBF() + >>> prior = gpx.Prior(kernel = kernel) + >>> + >>> parameter_state = gpx.initialise(prior) + >>> prior_predictive = prior.predict(parameter_state.params) + >>> prior_predictive(jnp.linspace(0, 1, 100)) Args: - params (Dict): The specific set of parameters for which the mean function should be defined for. + params (Dict): The specific set of parameters for which the mean + function should be defined for. Returns: - Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned. + Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A mean + function that accepts an input array for where the mean function + should be evaluated at. The mean function's value at these points is + then returned. """ gram = self.kernel.gram + jitter = get_defaults()["jitter"] def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: t = test_inputs n_test = t.shape[0] μt = self.mean_function(t, params["mean_function"]) Ktt = gram(self.kernel, t, params["kernel"]) - Ktt += self.jitter * I(n_test) + Ktt += jitter * I(n_test) Lt = Ktt.triangular_lower() return dx.MultivariateNormalTri(jnp.atleast_1d(μt.squeeze()), Lt) @@ -154,24 +237,37 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: # GP Posteriors ####################### @dataclass -class AbstractPosterior(AbstractGP): - """The base GP posterior object conditioned on an observed dataset.""" +class AbstractPosterior(AbstractPrior): + """The base GP posterior object conditioned on an observed dataset. All + posterior objects should inherit from this class. + + All GPJax Modules are `Chex dataclasses `_. Since + dataclasses take over ``__init__``, the ``__post_init__`` method can be used to + initialise the GP's parameters. + + Attributes: + prior (Prior): The prior distribution of the GP. + likelihood (AbstractLikelihood): The likelihood distribution of the observed dataset. + name (str): The name of the GP posterior. Defaults to "GP posterior". + """ prior: Prior likelihood: AbstractLikelihood name: Optional[str] = "GP posterior" - jitter: Optional[float] = DEFAULT_JITTER @abstractmethod def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: - """Predict the GP's output given the input. + """Compute the predictive posterior distribution of the latent function + for a given set of parameters. For any class inheriting the + ``AbstractPosterior`` class, this method must be implemented. Args: - *args (Any): Arguments to the predict method. - **kwargs (Any): Keyword arguments to the predict method. + *args (Any): Arguments to the predict method. **kwargs (Any): + Keyword arguments to the predict method. Returns: - dx.Distribution: A multivariate normal random variable representation of the Gaussian process. + dx.Distribution: A multivariate normal random variable + representation of the Gaussian process. """ raise NotImplementedError @@ -192,12 +288,46 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: @dataclass class ConjugatePosterior(AbstractPosterior): - """Gaussian process posterior object for models where the likelihood is Gaussian.""" + """A Gaussian process posterior distribution when the constituent likelihood + function is a Gaussian distribution. In such cases, the latent function values + :math:`f` can be analytically integrated out of the posterior distribution. + As such, many computational operations can be simplified; something we make use + of in this object. + + For a Gaussian process prior :math:`p(\mathbf{f})` and a Gaussian likelihood + :math:`p(y | \\mathbf{f}) = \\mathcal{N}(y\\mid \mathbf{f}, \\sigma^2))` where + :math:`\mathbf{f} = f(\\mathbf{x})`, the predictive posterior distribution at + a set of inputs :math:`\\mathbf{x}` is given by + + .. math:: + + p(\\mathbf{f}^{\\star}\mid \mathbf{y}) & = \\int p(\\mathbf{f}^{\\star} \\mathbf{f} \\mid \\mathbf{y})\\\\ + & =\\mathcal{N}(\\mathbf{f}^{\\star} \\boldsymbol{\mu}_{\mid \mathbf{y}}, \\boldsymbol{\Sigma}_{\mid \mathbf{y}} + where + + .. math:: + + \\boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left(k(\\mathbf{x}, \\mathbf{x})+\\sigma^2\\mathbf{I}_n\\right)^{-1}\\mathbf{y} \\\\ + \\boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\\mathbf{x}^{\\star}, \\mathbf{x}^{\\star}) -k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left( k(\\mathbf{x}, \\mathbf{x}) + \\sigma^2\\mathbf{I}_n \\right)^{-1}k(\\mathbf{x}, \\mathbf{x}^{\\star}). + + Example: + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> + >>> prior = gpx.Prior(kernel = gpx.kernels.RBF()) + >>> likelihood = gpx.likelihoods.Gaussian() + >>> + >>> posterior = prior * likelihood + + Attributes: + prior (Prior): The prior distribution of the GP. + likelihood (Gaussian): The Gaussian likelihood distribution of the observed dataset. + name (str): The name of the GP posterior. Defaults to "Conjugate posterior". + """ prior: Prior likelihood: Gaussian name: Optional[str] = "Conjugate posterior" - jitter: Optional[float] = DEFAULT_JITTER def predict( self, train_data: Dataset, params: Dict @@ -211,6 +341,8 @@ def predict( Returns: Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts an input array and returns the predictive distribution as a `dx.MultivariateNormalTri`. """ + jitter = get_defaults()["jitter"] + x, y, n = train_data.X, train_data.y, train_data.n gram, cross_covariance = ( self.prior.kernel.gram, @@ -223,7 +355,7 @@ def predict( # Precompute covariance matrices Kxx = gram(self.prior.kernel, x, params["kernel"]) - Kxx += I(n) * self.jitter + Kxx += I(n) * jitter # Σ = Kxx + Iσ² Sigma = Kxx + I(n) * obs_noise @@ -246,7 +378,7 @@ def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution: # Ktt - Ktx (Kxx + Iσ²)⁻¹ Kxt covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) - covariance += I(n_test) * self.jitter + covariance += I(n_test) * jitter return dx.MultivariateNormalFullCovariance( jnp.atleast_1d(mean.squeeze()), covariance.to_dense() @@ -270,6 +402,7 @@ def marginal_log_likelihood( Returns: Callable[[Dict], Float[Array, "1"]]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ + jitter = get_defaults()["jitter"] x, y, n = train_data.X, train_data.y, train_data.n gram, cross_covariance = ( self.prior.kernel.gram, @@ -283,7 +416,7 @@ def mll( obs_noise = params["likelihood"]["obs_noise"] μx = self.prior.mean_function(x, params["mean_function"]) Kxx = gram(self.prior.kernel, x, params["kernel"]) - Kxx += I(n) * self.jitter + Kxx += I(n) * jitter # TODO: This implementation does not take advantage of the covariance operator structure. # Future work concerns implementation of a custom Gaussian distribution / measure object that accepts a covariance operator. @@ -317,7 +450,6 @@ class NonConjugatePosterior(AbstractPosterior): prior: Prior likelihood: AbstractLikelihood name: Optional[str] = "Non-conjugate posterior" - jitter: Optional[float] = DEFAULT_JITTER def _initialise_params(self, key: PRNGKeyType) -> Dict: """Initialise the parameter set of a non-conjugate GP posterior.""" @@ -333,14 +465,14 @@ def predict( ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Note, to gain predictions on the scale of the original data, the returned distribution will need to be transformed through the likelihood function's inverse link function. - Args: - train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. - params (Dict): A dictionary of parameters that should be used to compute the posterior. + Args: + train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. + params (Dict): A dictionary of parameters that should be used to compute the posterior. - Returns: - <<<<<<< HEAD - tp.Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `numpyro.distributions.MultivariateNormal`. + Returns: + tp.Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `dx.Distribution`. """ + jitter = get_defaults()["jitter"] x, n = train_data.X, train_data.n gram, cross_covariance = ( self.prior.kernel.gram, @@ -348,13 +480,13 @@ def predict( ) Kxx = gram(self.prior.kernel, x, params["kernel"]) - Kxx += I(n) * self.jitter + Kxx += I(n) * jitter def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: t = test_inputs n_test = t.shape[0] Ktx = cross_covariance(self.prior.kernel, t, x, params["kernel"]) - Ktt = gram(self.prior.kernel, t, params["kernel"]) + I(n_test) * self.jitter + Ktt = gram(self.prior.kernel, t, params["kernel"]) + I(n_test) * jitter μt = self.prior.mean_function(t, params["mean_function"]) Lx = Kxx.triangular_lower() @@ -367,7 +499,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: # Ktt - Ktx Kxx⁻¹ Kxt covariance = Ktt - covariance += I(n_test) * self.jitter + covariance += I(n_test) * jitter covariance = covariance.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) return dx.MultivariateNormalFullCovariance( @@ -392,6 +524,7 @@ def marginal_log_likelihood( Returns: Callable[[Dict], Float[Array, "1"]]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ + jitter = get_defaults()["jitter"] x, y, n = train_data.X, train_data.y, train_data.n gram = self.prior.kernel.gram if not priors: @@ -400,7 +533,7 @@ def marginal_log_likelihood( def mll(params: Dict): Kxx = gram(self.prior.kernel, x, params["kernel"]) - Kxx += I(n) * self.jitter + Kxx += I(n) * jitter Lx = Kxx.triangular_lower() μx = self.prior.mean_function(x, params["mean_function"]) @@ -450,10 +583,10 @@ def euclidean_distance( __all__ = [ - AbstractGP, - Prior, - AbstractPosterior, - ConjugatePosterior, - NonConjugatePosterior, - construct_posterior, + "AbstractPrior", + "Prior", + "AbstractPosterior", + "ConjugatePosterior", + "NonConjugatePosterior", + "construct_posterior", ] diff --git a/tests/test_gps.py b/tests/test_gps.py index dad6c157..9f23cd74 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -23,7 +23,7 @@ from gpjax import Dataset, initialise from gpjax.gps import ( - AbstractGP, + AbstractPrior, ConjugatePosterior, NonConjugatePosterior, Prior, @@ -44,7 +44,7 @@ def test_prior(num_datapoints): parameter_state = initialise(p, jr.PRNGKey(123)) params, _, _ = parameter_state.unpack() assert isinstance(p, Prior) - assert isinstance(p, AbstractGP) + assert isinstance(p, AbstractPrior) prior_rv_fn = p(params) assert isinstance(prior_rv_fn, tp.Callable) @@ -71,12 +71,12 @@ def test_conjugate_posterior(num_datapoints): lik = Gaussian(num_datapoints=num_datapoints) post = p * lik assert isinstance(post, ConjugatePosterior) - assert isinstance(post, AbstractGP) - assert isinstance(p, AbstractGP) + assert isinstance(post, AbstractPrior) + assert isinstance(p, AbstractPrior) post2 = lik * p assert isinstance(post2, ConjugatePosterior) - assert isinstance(post2, AbstractGP) + assert isinstance(post2, AbstractPrior) parameter_state = initialise(post, key) params, _, bijectors = parameter_state.unpack() @@ -116,8 +116,8 @@ def test_nonconjugate_posterior(num_datapoints, likel): lik = likel(num_datapoints=num_datapoints) post = p * lik assert isinstance(post, NonConjugatePosterior) - assert isinstance(post, AbstractGP) - assert isinstance(p, AbstractGP) + assert isinstance(post, AbstractPrior) + assert isinstance(p, AbstractPrior) parameter_state = initialise(post, key) params, _, _ = parameter_state.unpack() From 4971063b9a83c68f595dad461162b08df9231ac8 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 6 Nov 2022 12:10:00 +0000 Subject: [PATCH 4/8] Add template --- docs/_templates/autosummary/gpjax_module.rst | 23 ++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/_templates/autosummary/gpjax_module.rst b/docs/_templates/autosummary/gpjax_module.rst index e69de29b..5f51933c 100644 --- a/docs/_templates/autosummary/gpjax_module.rst +++ b/docs/_templates/autosummary/gpjax_module.rst @@ -0,0 +1,23 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :exclude-members: + + {% block methods %} + + .. automethod:: __call__ + + {% if methods %} + .. rubric:: Methods + + .. autosummary:: + + {% for item in methods %} + {%- if item not in inherited_members and item not in annotations and not item in ['__init__'] %} + ~{{ name }}.{{ item }} + {%- endif %} + {%- endfor %} + {% endif %} + {% endblock %} \ No newline at end of file From 9b1d89279fcd3e72108dfc4d6a2a72e7709caa76 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 6 Nov 2022 15:04:25 +0000 Subject: [PATCH 5/8] Predict docstring --- gpjax/gps.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 03398c92..f0f0c310 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -332,7 +332,27 @@ class ConjugatePosterior(AbstractPosterior): def predict( self, train_data: Dataset, params: Dict ) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: - """Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. + """Conditional on a training data set, compute the GP's posterior + predictive distribution for a given set of parameters. The returned + function can be evaluated at a set of test inputs to compute the + corresponding predictive density. + + The conditioning set is a GPJax ``Dataset`` object, whilst predictions + are made on a regular Jax array. + + £xample: + For a ``posterior`` distribution, the following code snippet will + evaluate the predictive distribution. + + >>> import gpjax as gpx + >>> + >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1) + >>> ytrain = jnp.sin(xtrain) + >>> xtest = jnp.linspace(0, 1).reshape(-1, 1) + >>> + >>> params = gpx.initialise(posterior) + >>> predictive_dist = posterior.predict(gpx.Dataset(X=xtrain, y=ytrain), params) + >>> predictive_dist(xtest) Args: train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. @@ -555,6 +575,17 @@ def mll(params: Dict): def construct_posterior( prior: Prior, likelihood: AbstractLikelihood ) -> AbstractPosterior: + """Utility function for constructing a posterior object from a prior and + likelihood. The function will automatically select the correct posterior + object based on the likelihood. + + Args: + prior (Prior): The Prior distribution. + likelihood (AbstractLikelihood): The likelihood that represents our beliefs around the distribution of the data. + + Returns: + AbstractPosterior: A posterior distribution. If the likelihood is Gaussian, then a ``ConjugatePosterior`` will be returned. Otherwise, a ``NonConjugatePosterior`` will be returned. + """ if isinstance(likelihood, Conjugate): PosteriorGP = ConjugatePosterior From 3db54e4bc3b8533922936734e393c41af974db05 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 6 Nov 2022 15:33:17 +0000 Subject: [PATCH 6/8] Finish conjugate GP --- docs/_api.rst | 5 ++++ gpjax/gps.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/docs/_api.rst b/docs/_api.rst index 31dda5e6..84ba73ed 100644 --- a/docs/_api.rst +++ b/docs/_api.rst @@ -48,6 +48,11 @@ Gaussian Process Priors Gaussian Process Posteriors ********************************* +There are two main classes of posterior Gaussian process objects within GPJax. +The ``ConjugatePosterior`` class is used when the likelihood distribution is +Gaussian whilst the ``NonConjugatePosterior`` class is used when the likelihood +distribution is non-Gaussian. + .. autoclass:: ConjugatePosterior :members: :special-members: __call__ diff --git a/gpjax/gps.py b/gpjax/gps.py index f0f0c310..e0b0a7fb 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -307,8 +307,8 @@ class ConjugatePosterior(AbstractPosterior): .. math:: - \\boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left(k(\\mathbf{x}, \\mathbf{x})+\\sigma^2\\mathbf{I}_n\\right)^{-1}\\mathbf{y} \\\\ - \\boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\\mathbf{x}^{\\star}, \\mathbf{x}^{\\star}) -k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left( k(\\mathbf{x}, \\mathbf{x}) + \\sigma^2\\mathbf{I}_n \\right)^{-1}k(\\mathbf{x}, \\mathbf{x}^{\\star}). + \\boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left(k(\\mathbf{x}, \\mathbf{x}')+\\sigma^2\\mathbf{I}_n\\right)^{-1}\\mathbf{y} \\\\ + \\boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\\mathbf{x}^{\\star}, \\mathbf{x}^{\\star\\prime}) -k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left( k(\\mathbf{x}, \\mathbf{x}') + \\sigma^2\\mathbf{I}_n \\right)^{-1}k(\\mathbf{x}, \\mathbf{x}^{\\star}). Example: >>> import gpjax as gpx @@ -337,6 +337,18 @@ def predict( function can be evaluated at a set of test inputs to compute the corresponding predictive density. + The predictive distribution of a conjugate GP is given by + .. math:: + + p(\\mathbf{f}^{\\star}\mid \mathbf{y}) & = \\int p(\\mathbf{f}^{\\star} \\mathbf{f} \\mid \\mathbf{y})\\\\ + & =\\mathcal{N}(\\mathbf{f}^{\\star} \\boldsymbol{\mu}_{\mid \mathbf{y}}, \\boldsymbol{\Sigma}_{\mid \mathbf{y}} + where + + .. math:: + + \\boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left(k(\\mathbf{x}, \\mathbf{x}')+\\sigma^2\\mathbf{I}_n\\right)^{-1}\\mathbf{y} \\\\ + \\boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\\mathbf{x}^{\\star}, \\mathbf{x}^{\\star\\prime}) -k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left( k(\\mathbf{x}, \\mathbf{x}') + \\sigma^2\\mathbf{I}_n \\right)^{-1}k(\\mathbf{x}, \\mathbf{x}^{\\star}). + The conditioning set is a GPJax ``Dataset`` object, whilst predictions are made on a regular Jax array. @@ -412,7 +424,52 @@ def marginal_log_likelihood( priors: Dict = None, negative: bool = False, ) -> Callable[[Dict], Float[Array, "1"]]: - """Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here enables exact estimation of the Gaussian process' latent function values. + """Compute the marginal log-likelihood function of the Gaussian process. + The returned function can then be used for gradient based optimisation + of the model's parameters or for model comparison. The implementation + given here enables exact estimation of the Gaussian process' latent + function values. + + For a training dataset :math:`\\{x_n, y_n\\}_{n=1}^N`, set of test + inputs :math:`\\mathbf{x}^{\\star}` the corresponding latent function + evaluations are given by :math:`\\mathbf{f} = f(\\mathbf{x})` + and :math:`\\mathbf{f}^{\\star} = f(\\mathbf{x}^{\\star})`, the marginal + log-likelihood is given by: + + .. math:: + + \\log p(\\mathbf{y}) & = \\int p(\\mathbf{y}\\mid\\mathbf{f})p(\\mathbf{f}, \\mathbf{f}^{\\star}\\mathrm{d}\\mathbf{f}^{\\star}\\\\ + &=0.5\\left(-\\mathbf{y}^{\\top}\\left(k(\\mathbf{x}, \\mathbf{x}') +\\sigma^2\\mathbf{I}_N \\right)^{-1}\\mathbf{y}-\\log\\lvert k(\\mathbf{x}, \\mathbf{x}') + \\sigma^2\\mathbf{I}_N\\rvert - n\\log 2\\pi \\right). + + Example: + + For a given ``ConjugatePosterior`` object, the following code snippet shows + how the marginal log-likelihood can be evaluated. + + >>> import gpjax as gpx + >>> + >>> xtrain = jnp.linspace(0, 1).reshape(-1, 1) + >>> ytrain = jnp.sin(xtrain) + >>> D = gpx.Dataset(X=xtrain, y=ytrain) + >>> + >>> params = gpx.initialise(posterior) + >>> mll = posterior.marginal_log_likelihood(train_data = D) + >>> mll(params) + + Our goal is to maximise the marginal log-likelihood. Therefore, when + optimising the model's parameters with respect to the parameters, we + use the negative marginal log-likelihood. This can be realised through + + >>> mll = posterior.marginal_log_likelihood(train_data = D, negative=True) + + Further, prior distributions can be passed into the marginal log-likelihood + + >>> mll = posterior.marginal_log_likelihood(train_data = D, priors=priors) + + For optimal performance, the marginal log-likelihood should be ``jax.jit`` + compiled. + + >>> mll = jit(posterior.marginal_log_likelihood(train_data = D)) Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. From 894dac71f05e11c6a6d05ca9bbec5d13c3784770 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 6 Nov 2022 16:00:13 +0000 Subject: [PATCH 7/8] GPs finished --- gpjax/gps.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index e0b0a7fb..7f09ce91 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -522,9 +522,23 @@ def mll( @dataclass class NonConjugatePosterior(AbstractPosterior): - """Generic Gaussian process posterior object for models where the likelihood is non-Gaussian.""" + """ + A Gaussian process posterior object for models where the likelihood is + non-Gaussian. Unlike the ``ConjugatePosterior`` object, the + ``NonConjugatePosterior`` object does not provide an exact marginal + log-likelihood function. Instead, the ``NonConjugatePosterior`` object + represents the posterior distributions as a function of the model's + hyperparameters and the latent function. Markov chain Monte Carlo, + variational inference, or Laplace approximations can then be used to sample + from, or optimise an approximation to, the posterior distribution. - prior: Prior + Attributes: + prior (AbstractPrior): The Gaussian process prior distribution. + likelihood (AbstractLikelihood): The likelihood function that represents the data. + name (str): The name of the posterior object. Defaults to "Non-conjugate posterior". + """ + + prior: AbstractPrior likelihood: AbstractLikelihood name: Optional[str] = "Non-conjugate posterior" @@ -540,7 +554,13 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: def predict( self, train_data: Dataset, params: Dict ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: - """Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Note, to gain predictions on the scale of the original data, the returned distribution will need to be transformed through the likelihood function's inverse link function. + """ + Conditional on a set of training data, compute the GP's posterior + predictive distribution for a given set of parameters. The returned + function can be evaluated at a set of test inputs to compute the + corresponding predictive density. Note, to gain predictions on the scale + of the original data, the returned distribution will need to be + transformed through the likelihood function's inverse link function. Args: train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. @@ -593,6 +613,8 @@ def marginal_log_likelihood( ) -> Callable[[Dict], Float[Array, "1"]]: """Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here is general and will work for any likelihood support by GPJax. + Unlike the marginal_log_likelihood function of the ConjugatePosterior object, the marginal_log_likelihood function of the NonConjugatePosterior object does not provide an exact marginal log-likelihood function. Instead, the NonConjugatePosterior object represents the posterior distributions as a function of the model's hyperparameters and the latent function. Markov chain Monte Carlo, variational inference, or Laplace approximations can then be used to sample from, or optimise an approximation to, the posterior distribution. + Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. priors (Dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None. From ddfbfabd6121093aa96dd38fd0bc8c0d9acfff72 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 6 Nov 2022 16:25:43 +0000 Subject: [PATCH 8/8] Rename kernel --- docs/_api.rst | 2 +- examples/haiku.pct.py | 8 ++--- examples/kernels.pct.py | 4 +-- gpjax/gps.py | 6 ++-- gpjax/kernels.py | 70 ++++++++++++++++++++++------------------- tests/test_kernels.py | 34 +++++++++++--------- 6 files changed, 66 insertions(+), 58 deletions(-) diff --git a/docs/_api.rst b/docs/_api.rst index 84ba73ed..175d297b 100644 --- a/docs/_api.rst +++ b/docs/_api.rst @@ -83,7 +83,7 @@ Kernel Functions Abstract Kernels ********************************* -.. autoclass:: Kernel +.. autoclass:: AbstractKernel :members: .. autoclass:: CombinationKernel diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index 10af3987..8428d144 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -33,7 +33,7 @@ from scipy.signal import sawtooth import gpjax as gpx -from gpjax.kernels import DenseKernelComputation, Kernel +from gpjax.kernels import DenseKernelComputation, AbstractKernel # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -72,17 +72,17 @@ # # ### Implementation # -# Although deep kernels are not currently supported natively in GPJax, defining one is straightforward as we now demonstrate. Using the base `Kernel` object given in GPJax, we provide a mixin class named `_DeepKernelFunction` to facilitate the user supplying the neural network and base kernel of their choice. Kernel matrices are then computed using the regular `gram` and `cross_covariance` functions. +# Although deep kernels are not currently supported natively in GPJax, defining one is straightforward as we now demonstrate. Using the base `AbstractKernel` object given in GPJax, we provide a mixin class named `_DeepKernelFunction` to facilitate the user supplying the neural network and base kernel of their choice. Kernel matrices are then computed using the regular `gram` and `cross_covariance` functions. # %% @dataclass class _DeepKernelFunction: network: hk.Module - base_kernel: Kernel + base_kernel: AbstractKernel @dataclass -class DeepKernelFunction(Kernel, DenseKernelComputation, _DeepKernelFunction): +class DeepKernelFunction(AbstractKernel, DenseKernelComputation, _DeepKernelFunction): def __call__( self, x: jnp.DeviceArray, y: jnp.DeviceArray, params: dict ) -> jnp.ndarray: diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index d3ad76f3..6b553cda 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -199,7 +199,7 @@ def angular_distance(x, y, c): @dataclass -class Polar(gpx.kernels.Kernel, gpx.kernels.DenseKernelComputation): +class Polar(gpx.kernels.AbstractKernel, gpx.kernels.DenseKernelComputation): period: float = 2 * jnp.pi def __post_init__(self): @@ -233,7 +233,7 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict: # example, without a `@dataclass` decorator, the instantiation of the above # `Polar` kernel would be done through # ```python -# class Polar(gpx.kernels.Kernel): +# class Polar(gpx.kernels.AbstractKernel): # def __init__(self, period: float = 2*jnp.pi): # super().__init__() # self.period = period diff --git a/gpjax/gps.py b/gpjax/gps.py index 7f09ce91..c6ca7f39 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -25,7 +25,7 @@ from .config import get_defaults from .covariance_operator import I -from .kernels import Kernel +from .kernels import AbstractKernel from .likelihoods import AbstractLikelihood, Conjugate, Gaussian, NonConjugate from .mean_functions import AbstractMeanFunction, Zero from .parameters import copy_dict_structure, evaluate_priors @@ -122,12 +122,12 @@ class Prior(AbstractPrior): >>> prior = gpx.Prior(kernel = kernel) Attributes: - kernel (Kernel): The kernel function used to parameterise the prior. + kernel (AbstractKernel): The kernel function used to parameterise the prior. mean_function (MeanFunction): The mean function used to parameterise the prior. Defaults to zero. name (str): The name of the GP prior. Defaults to "GP prior". """ - kernel: Kernel + kernel: AbstractKernel mean_function: Optional[AbstractMeanFunction] = Zero() name: Optional[str] = "GP prior" diff --git a/gpjax/kernels.py b/gpjax/kernels.py index cb32a407..349d434c 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -35,13 +35,14 @@ # Abtract classes ########################################## @dataclass(repr=False) -class Kernel: - """Base kernel class""" +class AbstractKernel: + """ + Base kernel class""" active_dims: Optional[List[int]] = None stationary: Optional[bool] = False spectral: Optional[bool] = False - name: Optional[str] = "Kernel" + name: Optional[str] = "AbstractKernel" def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) @@ -72,24 +73,24 @@ def slice_input(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: """ return x[..., self.active_dims] - def __add__(self, other: "Kernel") -> "Kernel": + def __add__(self, other: "AbstractKernel") -> "AbstractKernel": """Add two kernels together. Args: - other (Kernel): The kernel to be added to the current kernel. + other (AbstractKernel): The kernel to be added to the current kernel. Returns: - Kernel: A new kernel that is the sum of the two kernels. + AbstractKernel: A new kernel that is the sum of the two kernels. """ return SumKernel(kernel_set=[self, other]) - def __mul__(self, other: "Kernel") -> "Kernel": + def __mul__(self, other: "AbstractKernel") -> "AbstractKernel": """Multiply two kernels together. Args: - other (Kernel): The kernel to be multiplied with the current kernel. + other (AbstractKernel): The kernel to be multiplied with the current kernel. Returns: - Kernel: A new kernel that is the product of the two kernels. + AbstractKernel: A new kernel that is the product of the two kernels. """ return ProductKernel(kernel_set=[self, other]) @@ -122,13 +123,13 @@ class AbstractKernelComputation: @staticmethod @abc.abstractmethod def gram( - kernel: Kernel, inputs: Float[Array, "N D"], params: Dict + kernel: AbstractKernel, inputs: Float[Array, "N D"], params: Dict ) -> CovarianceOperator: """Compute Gram covariance operator of the kernel function. Args: - kernel (Kernel): The kernel function to be evaluated. + kernel (AbstractKernel): The kernel function to be evaluated. inputs (Float[Array, "N N"]): The inputs to the kernel function. params (Dict): The parameters of the kernel function. @@ -140,12 +141,15 @@ def gram( @staticmethod def cross_covariance( - kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"], params: Dict + kernel: AbstractKernel, + x: Float[Array, "N D"], + y: Float[Array, "M D"], + params: Dict, ) -> Float[Array, "N M"]: """For a given kernel, compute the NxM gram matrix on an a pair of input matrices with shape NxD and MxD. Args: - kernel (Kernel): The kernel for which the cross-covariance matrix should be computed for. + kernel (AbstractKernel): The kernel for which the cross-covariance matrix should be computed for. x (Float[Array,"N D"]): The first input matrix. y (Float[Array,"M D"]): The second input matrix. params (Dict): The kernel's parameter set. @@ -160,11 +164,11 @@ def cross_covariance( @staticmethod def diagonal( - kernel: Kernel, inputs: Float[Array, "N D"], params: Dict + kernel: AbstractKernel, inputs: Float[Array, "N D"], params: Dict ) -> CovarianceOperator: """For a given kernel, compute the elementwise diagonal of the NxN gram matrix on an input matrix of shape NxD. Args: - kernel (Kernel): The kernel for which the variance vector should be computed for. + kernel (AbstractKernel): The kernel for which the variance vector should be computed for. inputs (Float[Array, "N D"]): The input matrix. params (Dict): The kernel's parameter set. Returns: @@ -181,12 +185,12 @@ class DenseKernelComputation(AbstractKernelComputation): @staticmethod def gram( - kernel: Kernel, inputs: Float[Array, "N D"], params: Dict + kernel: AbstractKernel, inputs: Float[Array, "N D"], params: Dict ) -> CovarianceOperator: """For a given kernel, compute the NxN gram matrix on an input matrix of shape NxD. Args: - kernel (Kernel): The kernel for which the Gram matrix should be computed for. + kernel (AbstractKernel): The kernel for which the Gram matrix should be computed for. inputs (Float[Array,"N D"]): The input matrix. params (Dict): The kernel's parameter set. @@ -202,12 +206,12 @@ def gram( class DiagonalKernelComputation(AbstractKernelComputation): @staticmethod def gram( - kernel: Kernel, inputs: Float[Array, "N D"], params: Dict + kernel: AbstractKernel, inputs: Float[Array, "N D"], params: Dict ) -> CovarianceOperator: """For a kernel with diagonal structure, compute the NxN gram matrix on an input matrix of shape NxD. Args: - kernel (Kernel): The kernel for which the Gram matrix should be computed for. + kernel (AbstractKernel): The kernel for which the Gram matrix should be computed for. inputs (Float[Array, "N D"]): The input matrix. params (Dict): The kernel's parameter set. @@ -222,11 +226,11 @@ def gram( @dataclass class _KernelSet: - kernel_set: List[Kernel] + kernel_set: List[AbstractKernel] @dataclass -class CombinationKernel(Kernel, _KernelSet, DenseKernelComputation): +class CombinationKernel(AbstractKernel, _KernelSet, DenseKernelComputation): """A base class for products or sums of kernels.""" name: Optional[str] = "Combination kernel" @@ -236,16 +240,16 @@ def __post_init__(self): """Set the kernel set to the list of kernels passed to the constructor.""" kernels = self.kernel_set - if not all(isinstance(k, Kernel) for k in kernels): + if not all(isinstance(k, AbstractKernel) for k in kernels): raise TypeError("can only combine Kernel instances") # pragma: no cover - self.kernel_set: List[Kernel] = [] + self.kernel_set: List[AbstractKernel] = [] self._set_kernels(kernels) - def _set_kernels(self, kernels: Sequence[Kernel]) -> None: + def _set_kernels(self, kernels: Sequence[AbstractKernel]) -> None: """Combine multiple kernels. Based on GPFlow's Combination kernel.""" # add kernels to a list, flattening out instances of this class therein - kernels_list: List[Kernel] = [] + kernels_list: List[AbstractKernel] = [] for k in kernels: if isinstance(k, self.__class__): kernels_list.extend(k.kernel_set) @@ -285,7 +289,7 @@ class ProductKernel(CombinationKernel): # Euclidean kernels ########################################## @dataclass(repr=False) -class RBF(Kernel, DenseKernelComputation): +class RBF(AbstractKernel, DenseKernelComputation): """The Radial Basis Function (RBF) kernel.""" name: Optional[str] = "Radial basis function kernel" @@ -327,7 +331,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: @dataclass(repr=False) -class Matern12(Kernel, DenseKernelComputation): +class Matern12(AbstractKernel, DenseKernelComputation): """The Matérn kernel with smoothness parameter fixed at 0.5.""" name: Optional[str] = "Matern 1/2" @@ -363,7 +367,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: @dataclass(repr=False) -class Matern32(Kernel, DenseKernelComputation): +class Matern32(AbstractKernel, DenseKernelComputation): """The Matérn kernel with smoothness parameter fixed at 1.5.""" name: Optional[str] = "Matern 3/2" @@ -405,7 +409,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: @dataclass(repr=False) -class Matern52(Kernel, DenseKernelComputation): +class Matern52(AbstractKernel, DenseKernelComputation): """The Matérn kernel with smoothness parameter fixed at 2.5.""" name: Optional[str] = "Matern 5/2" @@ -447,7 +451,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: @dataclass(repr=False) -class Polynomial(Kernel, DenseKernelComputation): +class Polynomial(AbstractKernel, DenseKernelComputation): """The Polynomial kernel with variable degree.""" name: Optional[str] = "Polynomial" @@ -486,7 +490,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: @dataclass(repr=False) -class White(Kernel, DiagonalKernelComputation): +class White(AbstractKernel, DiagonalKernelComputation): def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) @@ -530,7 +534,7 @@ class _EigenKernel: @dataclass -class GraphKernel(Kernel, _EigenKernel, DenseKernelComputation): +class GraphKernel(AbstractKernel, _EigenKernel, DenseKernelComputation): name: Optional[str] = "Graph kernel" def __post_init__(self): @@ -605,7 +609,7 @@ def euclidean_distance( __all__ = [ - "Kernel", + "AbstractKernel", "CombinationKernel", "SumKernel", "ProductKernel", diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 00d60d2e..f1e890ab 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -35,7 +35,7 @@ RBF, CombinationKernel, GraphKernel, - Kernel, + AbstractKernel, Matern12, Matern32, Matern52, @@ -58,10 +58,10 @@ def test_abstract_kernel(): # Test initialising abstract kernel raises TypeError with unimplemented __call__ and _init_params methods: with pytest.raises(TypeError): - Kernel() + AbstractKernel() # Create a dummy kernel class with __call__ and _init_params methods implemented: - class DummyKernel(Kernel): + class DummyKernel(AbstractKernel): def __call__( self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict ) -> Float[Array, "1"]: @@ -99,7 +99,7 @@ def test_euclidean_distance( @pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) @pytest.mark.parametrize("dim", [1, 2, 5]) @pytest.mark.parametrize("n", [1, 2, 10]) -def test_gram(kernel: Kernel, dim: int, n: int) -> None: +def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: # Gram constructor static method: gram = kernel.gram @@ -120,7 +120,9 @@ def test_gram(kernel: Kernel, dim: int, n: int) -> None: @pytest.mark.parametrize("num_a", [1, 2, 5]) @pytest.mark.parametrize("num_b", [1, 2, 5]) @pytest.mark.parametrize("dim", [1, 2, 5]) -def test_cross_covariance(kernel: Kernel, num_a: int, num_b: int, dim: int) -> None: +def test_cross_covariance( + kernel: AbstractKernel, num_a: int, num_b: int, dim: int +) -> None: # Cross covariance constructor static method: cross_cov = kernel.cross_covariance @@ -140,7 +142,7 @@ def test_cross_covariance(kernel: Kernel, num_a: int, num_b: int, dim: int) -> N @pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) @pytest.mark.parametrize("dim", [1, 2, 5]) -def test_call(kernel: Kernel, dim: int) -> None: +def test_call(kernel: AbstractKernel, dim: int) -> None: # Datapoint x and datapoint y: x = jnp.array([[1.0] * dim]) @@ -160,7 +162,9 @@ def test_call(kernel: Kernel, dim: int) -> None: @pytest.mark.parametrize("dim", [1, 2, 5]) @pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) @pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def(kern: Kernel, dim: int, ell: float, sigma: float, n: int) -> None: +def test_pos_def( + kern: AbstractKernel, dim: int, ell: float, sigma: float, n: int +) -> None: # Gram constructor static method: gram = kern.gram @@ -178,7 +182,7 @@ def test_pos_def(kern: Kernel, dim: int, ell: float, sigma: float, n: int) -> No @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) @pytest.mark.parametrize("dim", [None, 1, 2, 5, 10]) -def test_initialisation(kernel: Kernel, dim: int) -> None: +def test_initialisation(kernel: AbstractKernel, dim: int) -> None: if dim is None: kern = kernel() @@ -199,7 +203,7 @@ def test_initialisation(kernel: Kernel, dim: int) -> None: @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -def test_dtype(kernel: Kernel) -> None: +def test_dtype(kernel: AbstractKernel) -> None: parameter_state = initialise(kernel(), _initialise_key) params, *_ = parameter_state.unpack() @@ -238,7 +242,7 @@ def test_polynomial( @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -def test_active_dim(kernel: Kernel) -> None: +def test_active_dim(kernel: AbstractKernel) -> None: dim_list = [0, 1, 2, 3] perm_length = 2 dim_pairs = list(permutations(dim_list, r=perm_length)) @@ -263,7 +267,7 @@ def test_active_dim(kernel: Kernel) -> None: @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52, Polynomial]) @pytest.mark.parametrize("n_kerns", [2, 3, 4]) def test_combination_kernel( - combination_type: CombinationKernel, kernel: Kernel, n_kerns: int + combination_type: CombinationKernel, kernel: AbstractKernel, n_kerns: int ) -> None: n = 20 @@ -272,7 +276,7 @@ def test_combination_kernel( assert len(c_kernel.kernel_set) == n_kerns assert len(c_kernel._initialise_params(_initialise_key)) == n_kerns assert isinstance(c_kernel.kernel_set, list) - assert isinstance(c_kernel.kernel_set[0], Kernel) + assert isinstance(c_kernel.kernel_set[0], AbstractKernel) assert isinstance(c_kernel._initialise_params(_initialise_key)[0], dict) x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) Kff = c_kernel.gram(c_kernel, x, c_kernel._initialise_params(_initialise_key)) @@ -286,7 +290,7 @@ def test_combination_kernel( @pytest.mark.parametrize( "k2", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] ) -def test_sum_kern_value(k1: Kernel, k2: Kernel) -> None: +def test_sum_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: n = 10 sum_kernel = SumKernel(kernel_set=[k1, k2]) x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) @@ -306,7 +310,7 @@ def test_sum_kern_value(k1: Kernel, k2: Kernel) -> None: @pytest.mark.parametrize( "k2", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] ) -def test_prod_kern_value(k1: Kernel, k2: Kernel) -> None: +def test_prod_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: n = 10 prod_kernel = ProductKernel(kernel_set=[k1, k2]) x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) @@ -348,7 +352,7 @@ def test_graph_kernel(): @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52, Polynomial]) -def test_combination_kernel_type(kernel: Kernel) -> None: +def test_combination_kernel_type(kernel: AbstractKernel) -> None: prod_kern = kernel() * kernel() assert isinstance(prod_kern, ProductKernel) assert isinstance(prod_kern, CombinationKernel)