Skip to content

Commit

Permalink
Merge pull request #164 from JaxGaussianProcesses/jax_bump
Browse files Browse the repository at this point in the history
Address Jax/Jaxlib `v0.4.x` compatibility, incorporate CircleCI testing workflows, incorporate versioneer.
  • Loading branch information
daniel-dodd authored Dec 22, 2022
2 parents 5cf7143 + b27a74d commit beff24e
Show file tree
Hide file tree
Showing 18 changed files with 3,124 additions and 131 deletions.
117 changes: 117 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
version: 2.1

orbs:
python: circleci/python@2.1.1
codecov: codecov/codecov@3.2.2

jobs:
build-and-test:
docker:
- image: cimg/python:3.8.0
steps:
- checkout
- run:
name: Update pip
command: pip install --upgrade pip
- python/install-packages:
pkg-manager: pip-dist
path-args: .[dev]
- run:
name: Run tests
command: pytest --cov=./ --cov-report=xml
- run:
name: Upload tests to Codecov
command: |
curl -Os https://uploader.codecov.io/v0.1.0_4653/linux/codecov
chmod +x codecov
./codecov -t ${CODECOV_TOKEN}
- codecov/upload:
file: coverage.xml

publish:
docker:
- image: cimg/python:3.9.0
steps:
- checkout
- run:
name: init .pypirc
command: |
echo -e "[distutils]" >> ~/.pypirc
echo -e "index-servers = " >> ~/.pypirc
echo -e " pypi" >> ~/.pypirc
echo -e " gpjax" >> ~/.pypirc
echo -e "" >> ~/.pypirc
echo -e "[pypi]" >> ~/.pypirc
echo -e " username = thomaspinder" >> ~/.pypirc
echo -e " password = $PYPI_TOKEN" >> ~/.pypirc
echo -e "" >> ~/.pypirc
echo -e "[gpjax]" >> ~/.pypirc
echo -e " repository = https://upload.pypi.org/legacy/" >> ~/.pypirc
echo -e " username = __token__" >> ~/.pypirc
echo -e " password = $GPJAX_PYPI" >> ~/.pypirc
- run:
name: Build package
command: |
pip install -U twine
python setup.py sdist bdist_wheel
- run:
name: Upload to PyPI
command: twine upload dist/* -r gpjax --verbose

publish-nightly:
docker:
- image: cimg/python:3.9.0
steps:
- checkout
- run:
name: init .pypirc
command: |
echo -e "[distutils]" >> ~/.pypirc
echo -e "index-servers = " >> ~/.pypirc
echo -e " pypi" >> ~/.pypirc
echo -e " gpjax-nightly" >> ~/.pypirc
echo -e "" >> ~/.pypirc
echo -e "[pypi]" >> ~/.pypirc
echo -e " username = thomaspinder" >> ~/.pypirc
echo -e " password = $PYPI_TOKEN" >> ~/.pypirc
echo -e "" >> ~/.pypirc
echo -e "[gpjax-nightly]" >> ~/.pypirc
echo -e " repository = https://upload.pypi.org/legacy/" >> ~/.pypirc
echo -e " username = __token__" >> ~/.pypirc
echo -e " password = $GPJAX_NIGHTLY_PYPI" >> ~/.pypirc
- run:
name: Build package
command: |
pip install -U twine
python setup.py sdist bdist_wheel
environment:
BUILD_GPJAX_NIGHTLY: 'nightly'
- run:
name: Upload to PyPI
command: twine upload dist/* -r gpjax-nightly --verbose

workflows:
main:
jobs:
- build-and-test:
filters: # required since `deploy` has tag filters AND requires `build`
tags:
only: /.*/
- publish:
requires:
- build-and-test
filters:
tags:
only: /^v.*/ # Only run on tags starting with v
branches:
ignore: /.*/
nightly:
triggers:
- schedule:
cron: "0 0 * * *"
filters:
branches:
only:
- main
jobs:
- publish-nightly
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gpjax/_version.py export-subst
47 changes: 0 additions & 47 deletions .github/workflows/publish.yml

This file was deleted.

30 changes: 0 additions & 30 deletions .github/workflows/workflow-master.yml

This file was deleted.

1 change: 1 addition & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ ignore:
- "*/tests/.*"
- "__init__.py"
- "tests/*.py"
- "*/_version.py"
8 changes: 4 additions & 4 deletions examples/barycentres.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
# We'll now independently learn Gaussian process posterior distributions for each dataset. We won't spend any time here discussing how GP hyperparameters are optimised. For advice on achieving this, see the [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) for advice on optimisation and the [Kernels notebook](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html) for advice on selecting an appropriate kernel.

# %%
def fit_gp(x: jnp.DeviceArray, y: jnp.DeviceArray) -> dx.MultivariateNormalTri:
def fit_gp(x: jax.Array, y: jax.Array) -> dx.MultivariateNormalTri:
if y.ndim == 1:
y = y.reshape(-1, 1)
D = Dataset(X=x, y=y)
Expand Down Expand Up @@ -130,18 +130,18 @@ def fit_gp(x: jnp.DeviceArray, y: jnp.DeviceArray) -> dx.MultivariateNormalTri:
# In GPJax, the predictive distribution of a GP is given by a [Distrax](https://github.com/deepmind/distrax) distribution, making it straightforward to extract the mean vector and covariance matrix of each GP for learning a barycentre. We implement the fixed point scheme given in (3) in the following cell by utilising Jax's `vmap` operator to speed up large matrix operations using broadcasting in `tensordot`.

# %%
def sqrtm(A: jnp.DeviceArray):
def sqrtm(A: jax.Array):
return jnp.real(jsl.sqrtm(A))


def wasserstein_barycentres(
distributions: tp.List[dx.MultivariateNormalTri], weights: jnp.DeviceArray
distributions: tp.List[dx.MultivariateNormalTri], weights: jax.Array
):
covariances = [d.covariance() for d in distributions]
cov_stack = jnp.stack(covariances)
stack_sqrt = jax.vmap(sqrtm)(cov_stack)

def step(covariance_candidate: jnp.DeviceArray, i: jnp.DeviceArray):
def step(covariance_candidate: jax.Array):
inner_term = jax.vmap(sqrtm)(
jnp.matmul(jnp.matmul(stack_sqrt, covariance_candidate), stack_sqrt)
)
Expand Down
1 change: 0 additions & 1 deletion examples/classification.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
# %%
import blackjax
import distrax as dx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
Expand Down
4 changes: 2 additions & 2 deletions examples/intro_to_gps.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@

d1 = dx.MultivariateNormalFullCovariance(jnp.zeros(2), jnp.eye(2))
d2 = dx.MultivariateNormalTri(
jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, 0.9], [0.9, 1.0]]))
jnp.zeros(2), jnp.linalg.cholesky(jax.Array([[1.0, 0.9], [0.9, 1.0]]))
)
d3 = dx.MultivariateNormalTri(
jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, -0.5], [-0.5, 1.0]]))
jnp.zeros(2), jnp.linalg.cholesky(jax.Array([[1.0, -0.5], [-0.5, 1.0]]))
)

dists = [d1, d2, d3]
Expand Down
3 changes: 2 additions & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@
WhitenedVariationalGaussian,
)
from .variational_inference import CollapsedVI, StochasticVI
from . import _version

__version__ = _version.get_versions()["version"]
__license__ = "MIT"
__description__ = "Didactic Gaussian processes in JAX"
__url__ = "https://github.com/thomaspinder/GPJax"
__contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
__version__ = "0.5.5"


__all__ = [
Expand Down
Loading

0 comments on commit beff24e

Please sign in to comment.