diff --git a/.circleci/config.yml b/.circleci/config.yml index d34be04a..fa484290 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -4,12 +4,88 @@ orbs: python: circleci/python@2.1.1 codecov: codecov/codecov@3.2.2 +commands: + create_pypirc: + description: "Create .pypirc file" + steps: + - run: + name: init .pypirc + command: | + echo -e "[distutils]" >> ~/.pypirc + echo -e "index-servers = " >> ~/.pypirc + echo -e " pypi" >> ~/.pypirc + echo -e " gpjax-nightly" >> ~/.pypirc + echo -e "" >> ~/.pypirc + echo -e "[pypi]" >> ~/.pypirc + echo -e " username = thomaspinder" >> ~/.pypirc + echo -e " password = $PYPI_TOKEN" >> ~/.pypirc + echo -e "" >> ~/.pypirc + echo -e "[gpjax]" >> ~/.pypirc + echo -e " repository = https://upload.pypi.org/legacy/" >> ~/.pypirc + echo -e " username = __token__" >> ~/.pypirc + echo -e " password = $GPJAX_PYPI" >> ~/.pypirc + echo -e "" >> ~/.pypirc + echo -e "[gpjax-nightly]" >> ~/.pypirc + echo -e " repository = https://upload.pypi.org/legacy/" >> ~/.pypirc + echo -e " username = __token__" >> ~/.pypirc + echo -e " password = $GPJAX_NIGHTLY_PYPI" >> ~/.pypirc + publish_to_pypi: + description: "Publish a package to PyPI" + parameters: + pkgname: + description: Package name + type: string + default: gpjax + nightly: + description: Perform a nightly installation + type: string + default: None + steps: + - run: + name: Build package + command: | + pip install -U twine + python setup.py sdist bdist_wheel + environment: + BUILD_GPJAX_NIGHTLY: << parameters.nightly >> + - run: + name: Upload to PyPI + command: twine upload dist/* -r << parameters.pkgname >> --verbose + + install_pandoc: + description: "Install pandoc" + parameters: + pandoc_url: + type: string + pandoc_dest: + type: string + steps: + - restore_cache: + keys: + - pandoc-download + - run: + name: Install pandoc + command: | + if [ ! -f "~/pandoc.tar.gz" ]; then + wget << parameters.pandoc_url >> -O ~/pandoc.tar.gz + fi + sudo tar xvzf ~/pandoc.tar.gz --strip-components 1 -C << parameters.pandoc_dest >> + - save_cache: + key: pandoc-download + paths: + - ~/pandoc.tar.gz + jobs: build-and-test: docker: - image: cimg/python:3.8.0 + parallelism: 4 + resource_class: large steps: - checkout + - restore_cache: + keys: + - pip-cache - run: name: Update pip command: pip install --upgrade pip @@ -18,77 +94,45 @@ jobs: path-args: .[dev] - run: name: Run tests - command: pytest --cov=./ --cov-report=xml + command: | + TEST_FILES=$(circleci tests glob "tests/test_*.py" | circleci tests split --split-by=timings) + pytest --cov=./ --cov-report=xml --verbose $TEST_FILES - run: name: Upload tests to Codecov command: | curl -Os https://uploader.codecov.io/v0.1.0_4653/linux/codecov chmod +x codecov ./codecov -t ${CODECOV_TOKEN} + - save_cache: + key: pip-cache + paths: + - ~/.cache/pip + - store_test_results: + path: test-results + - store_artifacts: + path: test-results - codecov/upload: file: coverage.xml + publish: docker: - image: cimg/python:3.9.0 steps: - checkout - - run: - name: init .pypirc - command: | - echo -e "[distutils]" >> ~/.pypirc - echo -e "index-servers = " >> ~/.pypirc - echo -e " pypi" >> ~/.pypirc - echo -e " gpjax" >> ~/.pypirc - echo -e "" >> ~/.pypirc - echo -e "[pypi]" >> ~/.pypirc - echo -e " username = thomaspinder" >> ~/.pypirc - echo -e " password = $PYPI_TOKEN" >> ~/.pypirc - echo -e "" >> ~/.pypirc - echo -e "[gpjax]" >> ~/.pypirc - echo -e " repository = https://upload.pypi.org/legacy/" >> ~/.pypirc - echo -e " username = __token__" >> ~/.pypirc - echo -e " password = $GPJAX_PYPI" >> ~/.pypirc - - run: - name: Build package - command: | - pip install -U twine - python setup.py sdist bdist_wheel - - run: - name: Upload to PyPI - command: twine upload dist/* -r gpjax --verbose + - create_pypirc + - publish_to_pypi: + pkgname: gpjax publish-nightly: docker: - image: cimg/python:3.9.0 steps: - checkout - - run: - name: init .pypirc - command: | - echo -e "[distutils]" >> ~/.pypirc - echo -e "index-servers = " >> ~/.pypirc - echo -e " pypi" >> ~/.pypirc - echo -e " gpjax-nightly" >> ~/.pypirc - echo -e "" >> ~/.pypirc - echo -e "[pypi]" >> ~/.pypirc - echo -e " username = thomaspinder" >> ~/.pypirc - echo -e " password = $PYPI_TOKEN" >> ~/.pypirc - echo -e "" >> ~/.pypirc - echo -e "[gpjax-nightly]" >> ~/.pypirc - echo -e " repository = https://upload.pypi.org/legacy/" >> ~/.pypirc - echo -e " username = __token__" >> ~/.pypirc - echo -e " password = $GPJAX_NIGHTLY_PYPI" >> ~/.pypirc - - run: - name: Build package - command: | - pip install -U twine - python setup.py sdist bdist_wheel - environment: - BUILD_GPJAX_NIGHTLY: 'nightly' - - run: - name: Upload to PyPI - command: twine upload dist/* -r gpjax-nightly --verbose + - create_pypirc + - publish_to_pypi: + pkgname: gpjax-nigthly + nightly: nightly workflows: main: @@ -112,6 +156,6 @@ workflows: filters: branches: only: - - main + - master jobs: - publish-nightly diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml new file mode 100644 index 00000000..7e09e081 --- /dev/null +++ b/.github/workflows/documentation.yml @@ -0,0 +1,40 @@ +name: Build the documentation + +on: + pull_request: + branches: [master] + +jobs: + build: + name: Build docs (${{ matrix.python-version }}, ${{ matrix.os }}) + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -l {0} + strategy: + matrix: + os: ["ubuntu-latest"] + python-version: ["3.9"] + steps: + - name: Checkout the branch + uses: actions/checkout@v2.3.1 + with: + persist-credentials: false + + - name: create Conda environment + uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + python-version: ${{ matrix.python-version }} + + # - name: Set up Python 3.9 + # uses: actions/setup-python@v1 + # with: + # python-version: 3.9 + + - name: Build the documentation with Sphinx + run: | + pip install -r docs/requirements.txt + conda install pandoc + cd docs + make html \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile index 24654987..b961096e 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -4,6 +4,7 @@ SPHINXOPTS = -j auto SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build +# NCORES ?= $(shell nproc) # Put it first so that "make" without argument is like "make help". help: @@ -12,4 +13,4 @@ help: .PHONY: help Makefile %: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/_api.rst b/docs/_api.rst index 175d297b..a26db0df 100644 --- a/docs/_api.rst +++ b/docs/_api.rst @@ -295,10 +295,17 @@ Configuration .. automodule:: gpjax.config .. currentmodule:: gpjax.config +.. autofunction:: add_parameter -.. autofunction:: get_defaults +.. autofunction:: get_global_config_if_exists -.. autofunction:: add_parameter +.. autofunction:: get_default_config + +.. autofunction:: update_x64_sensitive_settings + +.. autofunction:: get_global_config + +.. autofunction:: reset_global_config Quadrature diff --git a/docs/conf.py b/docs/conf.py index 24829663..06a1f705 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -20,8 +20,6 @@ from importlib_metadata import version -import docs.conf_sphinx_patch - def read(*names, **kwargs): """Function to decode a read files. Credit GPyTorch.""" @@ -59,7 +57,14 @@ def find_version(*file_paths): author = "Thomas Pinder" # The full version, including alpha/beta/rc tags -version = find_version("gpjax", "__init__.py") +import sys +from os.path import join, pardir, dirname + +sys.path.insert(0, join(dirname(__file__), pardir)) + +import gpjax + +version = gpjax.__version__ release = version diff --git a/docs/requirements.txt b/docs/requirements.txt index 0d137bcd..8cc65120 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -10,7 +10,6 @@ matplotlib==3.3.3 seaborn sphinx-copybutton networkx>=2.0.0 -pandoc sphinxcontrib.bibtex jupytext ipython diff --git a/examples/barycentres.pct.py b/examples/barycentres.pct.py index a23b282d..c4f3dbde 100644 --- a/examples/barycentres.pct.py +++ b/examples/barycentres.pct.py @@ -141,7 +141,7 @@ def wasserstein_barycentres( cov_stack = jnp.stack(covariances) stack_sqrt = jax.vmap(sqrtm)(cov_stack) - def step(covariance_candidate: jax.Array): + def step(covariance_candidate: jax.Array, idx: None): inner_term = jax.vmap(sqrtm)( jnp.matmul(jnp.matmul(stack_sqrt, covariance_candidate), stack_sqrt) ) diff --git a/examples/classification.pct.py b/examples/classification.pct.py index 5c3f5f2c..c0996daa 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -31,6 +31,7 @@ from jaxtyping import Array, Float from jaxutils import Dataset import jaxkern as jk +import jax import gpjax as gpx diff --git a/examples/data/yacht_hydrodynamics.data b/examples/data/yacht_hydrodynamics.data new file mode 100644 index 00000000..2ae2cff7 --- /dev/null +++ b/examples/data/yacht_hydrodynamics.data @@ -0,0 +1,309 @@ +-2.3 0.568 4.78 3.99 3.17 0.125 0.11 +-2.3 0.568 4.78 3.99 3.17 0.150 0.27 +-2.3 0.568 4.78 3.99 3.17 0.175 0.47 +-2.3 0.568 4.78 3.99 3.17 0.200 0.78 +-2.3 0.568 4.78 3.99 3.17 0.225 1.18 +-2.3 0.568 4.78 3.99 3.17 0.250 1.82 +-2.3 0.568 4.78 3.99 3.17 0.275 2.61 +-2.3 0.568 4.78 3.99 3.17 0.300 3.76 +-2.3 0.568 4.78 3.99 3.17 0.325 4.99 +-2.3 0.568 4.78 3.99 3.17 0.350 7.16 +-2.3 0.568 4.78 3.99 3.17 0.375 11.93 +-2.3 0.568 4.78 3.99 3.17 0.400 20.11 +-2.3 0.568 4.78 3.99 3.17 0.425 32.75 +-2.3 0.568 4.78 3.99 3.17 0.450 49.49 +-2.3 0.569 4.78 3.04 3.64 0.125 0.04 +-2.3 0.569 4.78 3.04 3.64 0.150 0.17 +-2.3 0.569 4.78 3.04 3.64 0.175 0.37 +-2.3 0.569 4.78 3.04 3.64 0.200 0.66 +-2.3 0.569 4.78 3.04 3.64 0.225 1.06 +-2.3 0.569 4.78 3.04 3.64 0.250 1.59 +-2.3 0.569 4.78 3.04 3.64 0.275 2.33 +-2.3 0.569 4.78 3.04 3.64 0.300 3.29 +-2.3 0.569 4.78 3.04 3.64 0.325 4.61 +-2.3 0.569 4.78 3.04 3.64 0.350 7.11 +-2.3 0.569 4.78 3.04 3.64 0.375 11.99 +-2.3 0.569 4.78 3.04 3.64 0.400 21.09 +-2.3 0.569 4.78 3.04 3.64 0.425 35.01 +-2.3 0.569 4.78 3.04 3.64 0.450 51.80 +-2.3 0.565 4.78 5.35 2.76 0.125 0.09 +-2.3 0.565 4.78 5.35 2.76 0.150 0.29 +-2.3 0.565 4.78 5.35 2.76 0.175 0.56 +-2.3 0.565 4.78 5.35 2.76 0.200 0.86 +-2.3 0.565 4.78 5.35 2.76 0.225 1.31 +-2.3 0.565 4.78 5.35 2.76 0.250 1.99 +-2.3 0.565 4.78 5.35 2.76 0.275 2.94 +-2.3 0.565 4.78 5.35 2.76 0.300 4.21 +-2.3 0.565 4.78 5.35 2.76 0.325 5.54 +-2.3 0.565 4.78 5.35 2.76 0.350 8.25 +-2.3 0.565 4.78 5.35 2.76 0.375 13.08 +-2.3 0.565 4.78 5.35 2.76 0.400 21.40 +-2.3 0.565 4.78 5.35 2.76 0.425 33.14 +-2.3 0.565 4.78 5.35 2.76 0.450 50.14 +-2.3 0.564 5.10 3.95 3.53 0.125 0.20 +-2.3 0.564 5.10 3.95 3.53 0.150 0.35 +-2.3 0.564 5.10 3.95 3.53 0.175 0.65 +-2.3 0.564 5.10 3.95 3.53 0.200 0.93 +-2.3 0.564 5.10 3.95 3.53 0.225 1.37 +-2.3 0.564 5.10 3.95 3.53 0.250 1.97 +-2.3 0.564 5.10 3.95 3.53 0.275 2.83 +-2.3 0.564 5.10 3.95 3.53 0.300 3.99 +-2.3 0.564 5.10 3.95 3.53 0.325 5.19 +-2.3 0.564 5.10 3.95 3.53 0.350 8.03 +-2.3 0.564 5.10 3.95 3.53 0.375 12.86 +-2.3 0.564 5.10 3.95 3.53 0.400 21.51 +-2.3 0.564 5.10 3.95 3.53 0.425 33.97 +-2.3 0.564 5.10 3.95 3.53 0.450 50.36 +-2.4 0.574 4.36 3.96 2.76 0.125 0.20 +-2.4 0.574 4.36 3.96 2.76 0.150 0.35 +-2.4 0.574 4.36 3.96 2.76 0.175 0.65 +-2.4 0.574 4.36 3.96 2.76 0.200 0.93 +-2.4 0.574 4.36 3.96 2.76 0.225 1.37 +-2.4 0.574 4.36 3.96 2.76 0.250 1.97 +-2.4 0.574 4.36 3.96 2.76 0.275 2.83 +-2.4 0.574 4.36 3.96 2.76 0.300 3.99 +-2.4 0.574 4.36 3.96 2.76 0.325 5.19 +-2.4 0.574 4.36 3.96 2.76 0.350 8.03 +-2.4 0.574 4.36 3.96 2.76 0.375 12.86 +-2.4 0.574 4.36 3.96 2.76 0.400 21.51 +-2.4 0.574 4.36 3.96 2.76 0.425 33.97 +-2.4 0.574 4.36 3.96 2.76 0.450 50.36 +-2.4 0.568 4.34 2.98 3.15 0.125 0.12 +-2.4 0.568 4.34 2.98 3.15 0.150 0.26 +-2.4 0.568 4.34 2.98 3.15 0.175 0.43 +-2.4 0.568 4.34 2.98 3.15 0.200 0.69 +-2.4 0.568 4.34 2.98 3.15 0.225 1.09 +-2.4 0.568 4.34 2.98 3.15 0.250 1.67 +-2.4 0.568 4.34 2.98 3.15 0.275 2.46 +-2.4 0.568 4.34 2.98 3.15 0.300 3.43 +-2.4 0.568 4.34 2.98 3.15 0.325 4.62 +-2.4 0.568 4.34 2.98 3.15 0.350 6.86 +-2.4 0.568 4.34 2.98 3.15 0.375 11.56 +-2.4 0.568 4.34 2.98 3.15 0.400 20.63 +-2.4 0.568 4.34 2.98 3.15 0.425 34.50 +-2.4 0.568 4.34 2.98 3.15 0.450 54.23 +-2.3 0.562 5.14 4.95 3.17 0.125 0.28 +-2.3 0.562 5.14 4.95 3.17 0.150 0.44 +-2.3 0.562 5.14 4.95 3.17 0.175 0.70 +-2.3 0.562 5.14 4.95 3.17 0.200 1.07 +-2.3 0.562 5.14 4.95 3.17 0.225 1.57 +-2.3 0.562 5.14 4.95 3.17 0.250 2.23 +-2.3 0.562 5.14 4.95 3.17 0.275 3.09 +-2.3 0.562 5.14 4.95 3.17 0.300 4.09 +-2.3 0.562 5.14 4.95 3.17 0.325 5.82 +-2.3 0.562 5.14 4.95 3.17 0.350 8.28 +-2.3 0.562 5.14 4.95 3.17 0.375 12.80 +-2.3 0.562 5.14 4.95 3.17 0.400 20.41 +-2.3 0.562 5.14 4.95 3.17 0.425 32.34 +-2.3 0.562 5.14 4.95 3.17 0.450 47.29 +-2.4 0.585 4.78 3.84 3.32 0.125 0.20 +-2.4 0.585 4.78 3.84 3.32 0.150 0.38 +-2.4 0.585 4.78 3.84 3.32 0.175 0.64 +-2.4 0.585 4.78 3.84 3.32 0.200 0.97 +-2.4 0.585 4.78 3.84 3.32 0.225 1.36 +-2.4 0.585 4.78 3.84 3.32 0.250 1.98 +-2.4 0.585 4.78 3.84 3.32 0.275 2.91 +-2.4 0.585 4.78 3.84 3.32 0.300 4.35 +-2.4 0.585 4.78 3.84 3.32 0.325 5.79 +-2.4 0.585 4.78 3.84 3.32 0.350 8.04 +-2.4 0.585 4.78 3.84 3.32 0.375 12.15 +-2.4 0.585 4.78 3.84 3.32 0.400 19.18 +-2.4 0.585 4.78 3.84 3.32 0.425 30.09 +-2.4 0.585 4.78 3.84 3.32 0.450 44.38 +-2.2 0.546 4.78 4.13 3.07 0.125 0.15 +-2.2 0.546 4.78 4.13 3.07 0.150 0.32 +-2.2 0.546 4.78 4.13 3.07 0.175 0.55 +-2.2 0.546 4.78 4.13 3.07 0.200 0.86 +-2.2 0.546 4.78 4.13 3.07 0.225 1.24 +-2.2 0.546 4.78 4.13 3.07 0.250 1.76 +-2.2 0.546 4.78 4.13 3.07 0.275 2.49 +-2.2 0.546 4.78 4.13 3.07 0.300 3.45 +-2.2 0.546 4.78 4.13 3.07 0.325 4.83 +-2.2 0.546 4.78 4.13 3.07 0.350 7.37 +-2.2 0.546 4.78 4.13 3.07 0.375 12.76 +-2.2 0.546 4.78 4.13 3.07 0.400 21.99 +-2.2 0.546 4.78 4.13 3.07 0.425 35.64 +-2.2 0.546 4.78 4.13 3.07 0.450 53.07 +0.0 0.565 4.77 3.99 3.15 0.125 0.11 +0.0 0.565 4.77 3.99 3.15 0.150 0.24 +0.0 0.565 4.77 3.99 3.15 0.175 0.49 +0.0 0.565 4.77 3.99 3.15 0.200 0.79 +0.0 0.565 4.77 3.99 3.15 0.225 1.28 +0.0 0.565 4.77 3.99 3.15 0.250 1.96 +0.0 0.565 4.77 3.99 3.15 0.275 2.88 +0.0 0.565 4.77 3.99 3.15 0.300 4.14 +0.0 0.565 4.77 3.99 3.15 0.325 5.96 +0.0 0.565 4.77 3.99 3.15 0.350 9.07 +0.0 0.565 4.77 3.99 3.15 0.375 14.93 +0.0 0.565 4.77 3.99 3.15 0.400 24.13 +0.0 0.565 4.77 3.99 3.15 0.425 38.12 +0.0 0.565 4.77 3.99 3.15 0.450 55.44 +-5.0 0.565 4.77 3.99 3.15 0.125 0.07 +-5.0 0.565 4.77 3.99 3.15 0.150 0.18 +-5.0 0.565 4.77 3.99 3.15 0.175 0.40 +-5.0 0.565 4.77 3.99 3.15 0.200 0.70 +-5.0 0.565 4.77 3.99 3.15 0.225 1.14 +-5.0 0.565 4.77 3.99 3.15 0.250 1.83 +-5.0 0.565 4.77 3.99 3.15 0.275 2.77 +-5.0 0.565 4.77 3.99 3.15 0.300 4.12 +-5.0 0.565 4.77 3.99 3.15 0.325 5.41 +-5.0 0.565 4.77 3.99 3.15 0.350 7.87 +-5.0 0.565 4.77 3.99 3.15 0.375 12.71 +-5.0 0.565 4.77 3.99 3.15 0.400 21.02 +-5.0 0.565 4.77 3.99 3.15 0.425 34.58 +-5.0 0.565 4.77 3.99 3.15 0.450 51.77 +0.0 0.565 5.10 3.94 3.51 0.125 0.08 +0.0 0.565 5.10 3.94 3.51 0.150 0.26 +0.0 0.565 5.10 3.94 3.51 0.175 0.50 +0.0 0.565 5.10 3.94 3.51 0.200 0.83 +0.0 0.565 5.10 3.94 3.51 0.225 1.28 +0.0 0.565 5.10 3.94 3.51 0.250 1.90 +0.0 0.565 5.10 3.94 3.51 0.275 2.68 +0.0 0.565 5.10 3.94 3.51 0.300 3.76 +0.0 0.565 5.10 3.94 3.51 0.325 5.57 +0.0 0.565 5.10 3.94 3.51 0.350 8.76 +0.0 0.565 5.10 3.94 3.51 0.375 14.24 +0.0 0.565 5.10 3.94 3.51 0.400 23.05 +0.0 0.565 5.10 3.94 3.51 0.425 35.46 +0.0 0.565 5.10 3.94 3.51 0.450 51.99 +-5.0 0.565 5.10 3.94 3.51 0.125 0.08 +-5.0 0.565 5.10 3.94 3.51 0.150 0.24 +-5.0 0.565 5.10 3.94 3.51 0.175 0.45 +-5.0 0.565 5.10 3.94 3.51 0.200 0.77 +-5.0 0.565 5.10 3.94 3.51 0.225 1.19 +-5.0 0.565 5.10 3.94 3.51 0.250 1.76 +-5.0 0.565 5.10 3.94 3.51 0.275 2.59 +-5.0 0.565 5.10 3.94 3.51 0.300 3.85 +-5.0 0.565 5.10 3.94 3.51 0.325 5.27 +-5.0 0.565 5.10 3.94 3.51 0.350 7.74 +-5.0 0.565 5.10 3.94 3.51 0.375 12.40 +-5.0 0.565 5.10 3.94 3.51 0.400 20.91 +-5.0 0.565 5.10 3.94 3.51 0.425 33.23 +-5.0 0.565 5.10 3.94 3.51 0.450 49.14 +-2.3 0.530 5.11 3.69 3.51 0.125 0.08 +-2.3 0.530 5.11 3.69 3.51 0.150 0.25 +-2.3 0.530 5.11 3.69 3.51 0.175 0.46 +-2.3 0.530 5.11 3.69 3.51 0.200 0.75 +-2.3 0.530 5.11 3.69 3.51 0.225 1.11 +-2.3 0.530 5.11 3.69 3.51 0.250 1.57 +-2.3 0.530 5.11 3.69 3.51 0.275 2.17 +-2.3 0.530 5.11 3.69 3.51 0.300 2.98 +-2.3 0.530 5.11 3.69 3.51 0.325 4.42 +-2.3 0.530 5.11 3.69 3.51 0.350 7.84 +-2.3 0.530 5.11 3.69 3.51 0.375 14.11 +-2.3 0.530 5.11 3.69 3.51 0.400 24.14 +-2.3 0.530 5.11 3.69 3.51 0.425 37.95 +-2.3 0.530 5.11 3.69 3.51 0.450 55.17 +-2.3 0.530 4.76 3.68 3.16 0.125 0.10 +-2.3 0.530 4.76 3.68 3.16 0.150 0.23 +-2.3 0.530 4.76 3.68 3.16 0.175 0.47 +-2.3 0.530 4.76 3.68 3.16 0.200 0.76 +-2.3 0.530 4.76 3.68 3.16 0.225 1.15 +-2.3 0.530 4.76 3.68 3.16 0.250 1.65 +-2.3 0.530 4.76 3.68 3.16 0.275 2.28 +-2.3 0.530 4.76 3.68 3.16 0.300 3.09 +-2.3 0.530 4.76 3.68 3.16 0.325 4.41 +-2.3 0.530 4.76 3.68 3.16 0.350 7.51 +-2.3 0.530 4.76 3.68 3.16 0.375 13.77 +-2.3 0.530 4.76 3.68 3.16 0.400 23.96 +-2.3 0.530 4.76 3.68 3.16 0.425 37.38 +-2.3 0.530 4.76 3.68 3.16 0.450 56.46 +-2.3 0.530 4.34 2.81 3.15 0.125 0.05 +-2.3 0.530 4.34 2.81 3.15 0.150 0.17 +-2.3 0.530 4.34 2.81 3.15 0.175 0.35 +-2.3 0.530 4.34 2.81 3.15 0.200 0.63 +-2.3 0.530 4.34 2.81 3.15 0.225 1.01 +-2.3 0.530 4.34 2.81 3.15 0.250 1.43 +-2.3 0.530 4.34 2.81 3.15 0.275 2.05 +-2.3 0.530 4.34 2.81 3.15 0.300 2.73 +-2.3 0.530 4.34 2.81 3.15 0.325 3.87 +-2.3 0.530 4.34 2.81 3.15 0.350 7.19 +-2.3 0.530 4.34 2.81 3.15 0.375 13.96 +-2.3 0.530 4.34 2.81 3.15 0.400 25.18 +-2.3 0.530 4.34 2.81 3.15 0.425 41.34 +-2.3 0.530 4.34 2.81 3.15 0.450 62.42 +0.0 0.600 4.78 4.24 3.15 0.125 0.03 +0.0 0.600 4.78 4.24 3.15 0.150 0.18 +0.0 0.600 4.78 4.24 3.15 0.175 0.40 +0.0 0.600 4.78 4.24 3.15 0.200 0.73 +0.0 0.600 4.78 4.24 3.15 0.225 1.30 +0.0 0.600 4.78 4.24 3.15 0.250 2.16 +0.0 0.600 4.78 4.24 3.15 0.275 3.35 +0.0 0.600 4.78 4.24 3.15 0.300 5.06 +0.0 0.600 4.78 4.24 3.15 0.325 7.14 +0.0 0.600 4.78 4.24 3.15 0.350 10.36 +0.0 0.600 4.78 4.24 3.15 0.375 15.25 +0.0 0.600 4.78 4.24 3.15 0.400 23.15 +0.0 0.600 4.78 4.24 3.15 0.425 34.62 +0.0 0.600 4.78 4.24 3.15 0.450 51.50 +-5.0 0.600 4.78 4.24 3.15 0.125 0.06 +-5.0 0.600 4.78 4.24 3.15 0.150 0.15 +-5.0 0.600 4.78 4.24 3.15 0.175 0.34 +-5.0 0.600 4.78 4.24 3.15 0.200 0.63 +-5.0 0.600 4.78 4.24 3.15 0.225 1.13 +-5.0 0.600 4.78 4.24 3.15 0.250 1.85 +-5.0 0.600 4.78 4.24 3.15 0.275 2.84 +-5.0 0.600 4.78 4.24 3.15 0.300 4.34 +-5.0 0.600 4.78 4.24 3.15 0.325 6.20 +-5.0 0.600 4.78 4.24 3.15 0.350 8.62 +-5.0 0.600 4.78 4.24 3.15 0.375 12.49 +-5.0 0.600 4.78 4.24 3.15 0.400 20.41 +-5.0 0.600 4.78 4.24 3.15 0.425 32.46 +-5.0 0.600 4.78 4.24 3.15 0.450 50.94 +0.0 0.530 4.78 3.75 3.15 0.125 0.16 +0.0 0.530 4.78 3.75 3.15 0.150 0.32 +0.0 0.530 4.78 3.75 3.15 0.175 0.59 +0.0 0.530 4.78 3.75 3.15 0.200 0.92 +0.0 0.530 4.78 3.75 3.15 0.225 1.37 +0.0 0.530 4.78 3.75 3.15 0.250 1.94 +0.0 0.530 4.78 3.75 3.15 0.275 2.62 +0.0 0.530 4.78 3.75 3.15 0.300 3.70 +0.0 0.530 4.78 3.75 3.15 0.325 5.45 +0.0 0.530 4.78 3.75 3.15 0.350 9.45 +0.0 0.530 4.78 3.75 3.15 0.375 16.31 +0.0 0.530 4.78 3.75 3.15 0.400 27.34 +0.0 0.530 4.78 3.75 3.15 0.425 41.77 +0.0 0.530 4.78 3.75 3.15 0.450 60.85 +-5.0 0.530 4.78 3.75 3.15 0.125 0.09 +-5.0 0.530 4.78 3.75 3.15 0.150 0.24 +-5.0 0.530 4.78 3.75 3.15 0.175 0.47 +-5.0 0.530 4.78 3.75 3.15 0.200 0.78 +-5.0 0.530 4.78 3.75 3.15 0.225 1.21 +-5.0 0.530 4.78 3.75 3.15 0.250 1.85 +-5.0 0.530 4.78 3.75 3.15 0.275 2.62 +-5.0 0.530 4.78 3.75 3.15 0.300 3.69 +-5.0 0.530 4.78 3.75 3.15 0.325 5.07 +-5.0 0.530 4.78 3.75 3.15 0.350 7.95 +-5.0 0.530 4.78 3.75 3.15 0.375 13.73 +-5.0 0.530 4.78 3.75 3.15 0.400 23.55 +-5.0 0.530 4.78 3.75 3.15 0.425 37.14 +-5.0 0.530 4.78 3.75 3.15 0.450 55.87 +-2.3 0.600 5.10 4.17 3.51 0.125 0.01 +-2.3 0.600 5.10 4.17 3.51 0.150 0.16 +-2.3 0.600 5.10 4.17 3.51 0.175 0.39 +-2.3 0.600 5.10 4.17 3.51 0.200 0.73 +-2.3 0.600 5.10 4.17 3.51 0.225 1.24 +-2.3 0.600 5.10 4.17 3.51 0.250 1.96 +-2.3 0.600 5.10 4.17 3.51 0.275 3.04 +-2.3 0.600 5.10 4.17 3.51 0.300 4.46 +-2.3 0.600 5.10 4.17 3.51 0.325 6.31 +-2.3 0.600 5.10 4.17 3.51 0.350 8.68 +-2.3 0.600 5.10 4.17 3.51 0.375 12.39 +-2.3 0.600 5.10 4.17 3.51 0.400 20.14 +-2.3 0.600 5.10 4.17 3.51 0.425 31.77 +-2.3 0.600 5.10 4.17 3.51 0.450 47.13 +-2.3 0.600 4.34 4.23 2.73 0.125 0.04 +-2.3 0.600 4.34 4.23 2.73 0.150 0.17 +-2.3 0.600 4.34 4.23 2.73 0.175 0.36 +-2.3 0.600 4.34 4.23 2.73 0.200 0.64 +-2.3 0.600 4.34 4.23 2.73 0.225 1.02 +-2.3 0.600 4.34 4.23 2.73 0.250 1.62 +-2.3 0.600 4.34 4.23 2.73 0.275 2.63 +-2.3 0.600 4.34 4.23 2.73 0.300 4.15 +-2.3 0.600 4.34 4.23 2.73 0.325 6.00 +-2.3 0.600 4.34 4.23 2.73 0.350 8.47 +-2.3 0.600 4.34 4.23 2.73 0.375 12.27 +-2.3 0.600 4.34 4.23 2.73 0.400 19.59 +-2.3 0.600 4.34 4.23 2.73 0.425 30.48 +-2.3 0.600 4.34 4.23 2.73 0.450 46.66 + diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index b547c0f3..166ef065 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -42,7 +42,6 @@ AbstractKernelComputation, AbstractKernel, ) -from gpjax.types import PRNGKeyType # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -106,12 +105,12 @@ def __call__( yt = self.network.apply(params=params, x=y) return self.base_kernel(params, xt, yt) - def initialise(self, dummy_x: Float[Array, "1 D"], key: PRNGKeyType) -> None: + def initialise(self, dummy_x: Float[Array, "1 D"], key: jr.KeyArray) -> None: nn_params = self.network.init(rng=key, x=dummy_x) base_kernel_params = self.base_kernel._initialise_params(key) self._params = {**nn_params, **base_kernel_params} - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: jr.KeyArray) -> Dict: return self._params diff --git a/examples/intro_to_gps.pct.py b/examples/intro_to_gps.pct.py index 3e8f676d..c667159f 100644 --- a/examples/intro_to_gps.pct.py +++ b/examples/intro_to_gps.pct.py @@ -170,10 +170,10 @@ d1 = dx.MultivariateNormalFullCovariance(jnp.zeros(2), jnp.eye(2)) d2 = dx.MultivariateNormalTri( - jnp.zeros(2), jnp.linalg.cholesky(jax.Array([[1.0, 0.9], [0.9, 1.0]])) + jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, 0.9], [0.9, 1.0]])) ) d3 = dx.MultivariateNormalTri( - jnp.zeros(2), jnp.linalg.cholesky(jax.Array([[1.0, -0.5], [-0.5, 1.0]])) + jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, -0.5], [-0.5, 1.0]])) ) dists = [d1, d2, d3] diff --git a/examples/yacht.pct.py b/examples/yacht.pct.py index 76162d0f..a2fa97a3 100644 --- a/examples/yacht.pct.py +++ b/examples/yacht.pct.py @@ -45,10 +45,7 @@ # We'll be using the [Yacht](https://archive.ics.uci.edu/ml/datasets/yacht+hydrodynamics) dataset from the UCI machine learning data repository. Each observation describes the hydrodynamic performance of a yacht through its resistance. The dataset contains 6 covariates and a single positive, real valued response variable. There are 308 observations in the dataset, so we can comfortably use a conjugate regression Gaussian process here (for more more details, checkout the [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). # %% -yacht = pd.read_fwf( - "https://archive.ics.uci.edu/ml/machine-learning-databases/00243/yacht_hydrodynamics.data", - header=None, -).values[:-1, :] +yacht = pd.read_fwf("data/yacht_hydrodynamics.data", header=None).values[:-1, :] X = yacht[:, :-1] y = yacht[:, -1].reshape(-1, 1) diff --git a/setup.py b/setup.py index 771443a9..53c5d678 100644 --- a/setup.py +++ b/setup.py @@ -8,17 +8,18 @@ # Handle builds of nightly release if "BUILD_GPJAX_NIGHTLY" in os.environ: - NAME += "-nightly" + if os.environ["BUILD_GPJAX_NIGHTLY"] == "nightly": + NAME += "-nightly" - from versioneer import get_versions as original_get_versions + from versioneer import get_versions as original_get_versions - def get_versions(): - from datetime import datetime, timezone + def get_versions(): + from datetime import datetime, timezone - suffix = datetime.now(timezone.utc).strftime(r".dev%Y%m%d") - versions = original_get_versions() - versions["version"] = versions["version"].split("+")[0] + suffix - return versions + suffix = datetime.now(timezone.utc).strftime(r".dev%Y%m%d") + versions = original_get_versions() + versions["version"] = versions["version"].split("+")[0] + suffix + return versions REQUIRES = [