Skip to content

Commit

Permalink
Merge pull request #176 from JaxGaussianProcesses/remove_chex
Browse files Browse the repository at this point in the history
Remove chex
  • Loading branch information
thomaspinder authored Jan 8, 2023
2 parents 8d52d7b + bb63fea commit 3bbc8cb
Show file tree
Hide file tree
Showing 34 changed files with 743 additions and 369 deletions.
20 changes: 6 additions & 14 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: 2.1

orbs:
python: circleci/python@2.1.1
# python: circleci/python@2.1.1
codecov: codecov/codecov@3.2.2

commands:
Expand Down Expand Up @@ -51,7 +51,6 @@ commands:
- run:
name: Upload to PyPI
command: twine upload dist/* -r << parameters.pkgname >> --verbose

install_pandoc:
description: "Install pandoc"
parameters:
Expand Down Expand Up @@ -83,15 +82,12 @@ jobs:
resource_class: large
steps:
- checkout
- restore_cache:
keys:
- pip-cache
- run:
name: Update pip
command: pip install --upgrade pip
- python/install-packages:
pkg-manager: pip-dist
path-args: .[dev]
name: Install dependencies
command: |
pip install --upgrade pip
pip install -r requirements/dev.txt
pip install -e .
- run:
name: Run tests
command: |
Expand All @@ -103,10 +99,6 @@ jobs:
curl -Os https://uploader.codecov.io/v0.1.0_4653/linux/codecov
chmod +x codecov
./codecov -t ${CODECOV_TOKEN}
- save_cache:
key: pip-cache
paths:
- ~/.cache/pip
- store_test_results:
path: test-results
- store_artifacts:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ parameter_state = gpx.initialise(posterior, key=key)
Finally, we run an optimisation loop using the Adam optimiser via the `fit` callable.

```python
inference_state = gpx.fit(mll, parameter_state, opt, n_iters=500)
inference_state = gpx.fit(mll, parameter_state, opt, num_iters=500)
```

## 3. Making predictions
Expand Down
8 changes: 1 addition & 7 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ description and a code example. The docstring is concluded with a description
of the objects attributes with corresponding types.

```python
@dataclass
class Prior(AbstractPrior):
"""A Gaussian process prior object. The GP is parameterised by a
`mean <https://gpjax.readthedocs.io/en/latest/api.html#module-gpjax.mean_functions>`_
Expand Down Expand Up @@ -78,9 +77,4 @@ class Prior(AbstractPrior):
### Documentation syntax

A helpful cheatsheet for writing restructured text can be found
[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). In addition to that, we adopt the following convention when documenting
`dataclass` objects.

* Class attributes should be specified using the `Attributes:` tag.
* Method argument should be specified using the `Args:` tags.
* All attributes and arguments should have types.
[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst).
85 changes: 85 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Where to find the docs

The GPJax documentation can be found here:
https://gpjax.readthedocs.io/en/latest/

# How to build the docs

1. Install the requirements using `pip install -r docs/requirements.txt`
2. Make sure `pandoc` is installed
3. Run the make script `make html`

The corresponding HTML files can then be found in `docs/_build/html/`.

# How to write code documentation

Our documentation it is written in ReStructuredText for Sphinx. This is a
meta-language that is compiled into online documentation. For more details see
[Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html).
As a result, our docstrings adhere to a specific syntax that has to be kept in
mind. Below we provide some guidelines.

## How much information to put in a docstring

A docstring should be informative. If in doubt, then it is best to add more
information to a docstring than less. Many users will skim documentation, so
please ensure the opening sentence or two of a docstring contains the core
information. Adding examples and mathematical descriptions to documentation is
highly desirable.

We are making an active effort within GPJax to improve our documentation. If you
spot any areas where there is missing information within the existing
documentation, then please either raise an issue or
[create a pull request](https://gpjax.readthedocs.io/en/latest/contributing.html).

## An example docstring

An example docstring that adheres the principles of GPJax is given below.
The docstring contains a simple, snappy introduction with links to auxillary
components. More detail is then provided in the form of a mathematical
description and a code example. The docstring is concluded with a description
of the objects attributes with corresponding types.

```python
class Prior(AbstractPrior):
"""A Gaussian process prior object. The GP is parameterised by a
`mean <https://gpjax.readthedocs.io/en/latest/api.html#module-gpjax.mean_functions>`_
and `kernel <https://gpjax.readthedocs.io/en/latest/api.html#module-gpjax.kernels>`_ function.
A Gaussian process prior parameterised by a mean function :math:`m(\\cdot)` and a kernel
function :math:`k(\\cdot, \\cdot)` is given by
.. math::
p(f(\\cdot)) = \mathcal{GP}(m(\\cdot), k(\\cdot, \\cdot)).
To invoke a ``Prior`` distribution, only a kernel function is required. By default,
the mean function will be set to zero. In general, this assumption will be reasonable
assuming the data being modelled has been centred.
Example:
>>> import gpjax as gpx
>>>
>>> kernel = gpx.kernels.RBF()
>>> prior = gpx.Prior(kernel = kernel)
Attributes:
kernel (Kernel): The kernel function used to parameterise the prior.
mean_function (MeanFunction): The mean function used to parameterise the prior. Defaults to zero.
name (str): The name of the GP prior. Defaults to "GP prior".
"""

kernel: Kernel
mean_function: Optional[AbstractMeanFunction] = Zero()
name: Optional[str] = "GP prior"
```

### Documentation syntax

A helpful cheatsheet for writing restructured text can be found
[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). In addition to that, we adopt the following convention when documenting
`` objects.

* Class attributes should be specified using the `Attributes:` tag.
* Method argument should be specified using the `Args:` tags.
* All attributes and arguments should have types.
2 changes: 1 addition & 1 deletion examples/barycentres.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def fit_gp(x: jax.Array, y: jax.Array) -> dx.MultivariateNormalTri:
objective=negative_mll,
parameter_state=parameter_state,
optax_optim=optimiser,
n_iters=1000,
num_iters=1000,
)

learned_params, training_history = inference_state.unpack()
Expand Down
2 changes: 1 addition & 1 deletion examples/classification.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
objective=negative_mll,
parameter_state=parameter_state,
optax_optim=optimiser,
n_iters=1000,
num_iters=1000,
)

map_estimate, training_history = inference_state.unpack()
Expand Down
2 changes: 1 addition & 1 deletion examples/collapsed_vi.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
objective=negative_elbo,
parameter_state=parameter_state,
optax_optim=optimiser,
n_iters=2000,
num_iters=2000,
)

learned_params, training_history = inference_state.unpack()
Expand Down
2 changes: 1 addition & 1 deletion examples/graph_kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
objective=negative_mll,
parameter_state=parameter_state,
optax_optim=optimiser,
n_iters=1000,
num_iters=1000,
)

learned_params, training_history = inference_state.unpack()
Expand Down
2 changes: 1 addition & 1 deletion examples/haiku.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def forward(x):
objective=negative_mll,
parameter_state=parameter_state,
optax_optim=optimiser,
n_iters=2500,
num_iters=2500,
)

learned_params, training_history = inference_state.unpack()
Expand Down
25 changes: 2 additions & 23 deletions examples/kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,28 +228,7 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict:
# domain is a circle, this is $2\pi$. Next we define the kernel's `__call__`
# function which is a direct implementation of Equation (1). Finally, we define
# the Kernel's parameter property which contains just one value $\tau$ that we
# initialise to 4 in the kernel's `__post_init__`.
#
# #### Aside on dataclasses
#
# One can see in the above definition of a `Polar` kernel that we decorated the
# class with a `@dataclass` command. Dataclasses are simply regular classs
# objects in Python, however, much of the boilerplate code has been removed. For
# example, without a `@dataclass` decorator, the instantiation of the above
# `Polar` kernel would be done through
# ```python
# class Polar(jk.kernels.AbstractKernel):
# def __init__(self, period: float = 2*jnp.pi):
# super().__init__()
# self.period = period
# ```
# As objects become increasingly large and complex, the conciseness of a
# dataclass becomes increasingly attractive. To ensure full compatability with
# Jax, it is crucial that the dataclass decorator is imported from Chex, not
# base Python's `dataclass` module. Functionally, the two objects are identical.
# However, unlike regular Python dataclasses, it is possilbe to apply operations
# such as `jit`, `vmap` and `grad` to the dataclasses given by Chex as they are
# registrered PyTrees.
# initialise to 4 in the kernel's `__init__`.
#
#
# ### Custom Parameter Bijection
Expand Down Expand Up @@ -312,7 +291,7 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict:
objective=negative_mll,
parameter_state=parameter_state,
optax_optim=optimiser,
n_iters=1000,
num_iters=1000,
)

learned_params, training_history = inference_state.unpack()
Expand Down
2 changes: 1 addition & 1 deletion examples/natgrads.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
natural_svgp,
parameter_state=parameter_state,
train_data=D,
n_iters=5000,
num_iters=5000,
batch_size=256,
key=jr.PRNGKey(42),
moment_optim=ox.sgd(0.01),
Expand Down
6 changes: 3 additions & 3 deletions examples/regression.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
#
# ## Parameter state
#
# So far, all of the objects that we've defined have been stateless. To give our model state, we can use the `initialise` function provided in GPJax. Upon calling this, a `ParameterState` dataclass is returned that contains four dictionaries:
# So far, all of the objects that we've defined have been stateless. To give our model state, we can use the `initialise` function provided in GPJax. Upon calling this, a `ParameterState` class is returned that contains four dictionaries:
#
# | Dictionary | Description |
# |---|---|
Expand Down Expand Up @@ -185,11 +185,11 @@
objective=negative_mll,
parameter_state=parameter_state,
optax_optim=optimiser,
n_iters=500,
num_iters=500,
)

# %% [markdown]
# Similar to the `ParameterState` object above, the returned variable from the `fit` function is a dataclass, namely an `InferenceState` object that contains the parameters' final values and a tracked array of the evaluation of our objective function throughout optimisation.
# Similar to the `ParameterState` object above, the returned variable from the `fit` function is a class, namely an `InferenceState` object that contains the parameters' final values and a tracked array of the evaluation of our objective function throughout optimisation.

# %%
learned_params, training_history = inference_state.unpack()
Expand Down
4 changes: 2 additions & 2 deletions examples/uncollapsed_vi.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
parameter_state=parameter_state,
train_data=D,
optax_optim=optimiser,
n_iters=3000,
num_iters=3000,
key=jr.PRNGKey(42),
batch_size=128,
)
Expand Down Expand Up @@ -225,7 +225,7 @@
parameter_state=parameter_state,
train_data=D,
optax_optim=optimiser,
n_iters=3000,
num_iters=3000,
key=jr.PRNGKey(42),
batch_size=128,
)
Expand Down
2 changes: 1 addition & 1 deletion examples/yacht.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
objective=negative_mll,
parameter_state=parameter_state,
optax_optim=optimiser,
n_iters=1000,
num_iters=1000,
log_rate=50,
)

Expand Down
2 changes: 1 addition & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
from .likelihoods import Bernoulli, Gaussian
from .mean_functions import Constant, Zero
from .parameters import constrain, copy_dict_structure, initialise, unconstrain
from .types import Dataset
from .variational_families import (
CollapsedVariationalGaussian,
ExpectationVariationalGaussian,
NaturalVariationalGaussian,
VariationalGaussian,
WhitenedVariationalGaussian,
)
from .types import Dataset
from .variational_inference import CollapsedVI, StochasticVI
from . import _version

Expand Down
Loading

0 comments on commit 3bbc8cb

Please sign in to comment.