diff --git a/.github/workflows/build_docs.yml b/.github/workflows/build_docs.yml index 433d4e34..265cc536 100644 --- a/.github/workflows/build_docs.yml +++ b/.github/workflows/build_docs.yml @@ -52,16 +52,11 @@ jobs: virtualenvs-in-project: false installer-parallel: true - - name: Install LaTex - run: | - sudo apt-get update - sudo apt-get install texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra dvipng cm-super - - name: Build the documentation with MKDocs run: | poetry install --all-extras --with docs conda install pandoc - poetry run mkdocs build + poetry run python docs/scripts/gen_examples.py --execute && poetry run mkdocs build - name: Deploy Page 🚀 uses: JamesIves/github-pages-deploy-action@v4.4.1 diff --git a/.github/workflows/test_docs.yml b/.github/workflows/test_docs.yml index 9dc9c55d..129de0e3 100644 --- a/.github/workflows/test_docs.yml +++ b/.github/workflows/test_docs.yml @@ -33,20 +33,6 @@ jobs: auto-update-conda: true python-version: ${{ matrix.python-version }} - # Install katex for math support - - name: Install NPM - uses: actions/setup-node@v3 - with: - node-version: 16 - - name: Install KaTeX - run: | - npm install katex - - - name: Install LaTex - run: | - sudo apt-get update - sudo apt-get install texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra dvipng cm-super - # Install Poetry and build the documentation - name: Install and configure Poetry uses: snok/install-poetry@v1 @@ -60,4 +46,4 @@ jobs: run: | poetry install --all-extras --with docs conda install pandoc - poetry run python docs/scripts/gen_examples.py && poetry run mkdocs build + poetry run python docs/scripts/gen_examples.py --execute && poetry run mkdocs build diff --git a/README.md b/README.md index 8ebf160d..721d4d74 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,9 @@ helped to shape GPJax into the package it is today. ## Notebook examples > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/) -> - [**Classification with MCMC**](https://docs.jaxgaussianprocesses.com/examples/classification/) +> - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/) > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/) > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/) -> - [**BlackJax Integration**](https://docs.jaxgaussianprocesses.com/examples/classification/#mcmc-inference) > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation) > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel) > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/) diff --git a/docs/scripts/gen_examples.py b/docs/scripts/gen_examples.py index 8fbe71b4..632fad7f 100644 --- a/docs/scripts/gen_examples.py +++ b/docs/scripts/gen_examples.py @@ -28,7 +28,7 @@ def process_file(file: Path, out_file: Path | None = None, execute: bool = False f"| jupyter nbconvert --to markdown --execute --stdin --output {out_file}" ) else: - command = f"jupytext --to markdown {file} --output {out_file}" + command += f"jupytext --to markdown {file} --output {out_file}" subprocess.run(command, shell=True, check=False) @@ -64,21 +64,26 @@ def main(args): print(files) # process files in parallel - with ThreadPoolExecutor(max_workers=args.max_workers) as executor: - futures = [] - for file in files: - out_file = out_dir / f"{file.stem}.md" - futures.append( - executor.submit( - process_file, file, out_file=out_file, execute=args.execute + if args.parallel: + with ThreadPoolExecutor(max_workers=args.max_workers) as executor: + futures = [] + for file in files: + out_file = out_dir / f"{file.stem}.md" + futures.append( + executor.submit( + process_file, file, out_file=out_file, execute=args.execute + ) ) - ) - for future in as_completed(futures): - try: - future.result() - except Exception as e: - print(f"Error processing file: {e}") + for future in as_completed(futures): + try: + future.result() + except Exception as e: + print(f"Error processing file: {e}") + else: + for file in files: + out_file = out_dir / f"{file.stem}.md" + process_file(file, out_file=out_file, execute=args.execute) if __name__ == "__main__": @@ -91,6 +96,7 @@ def main(args): parser.add_argument( "--outdir", type=Path, default=project_root / "docs" / "_examples" ) + parser.add_argument("--parallel", type=bool, default=False) args = parser.parse_args() main(args) diff --git a/examples/backend.py b/examples/backend.py new file mode 100644 index 00000000..8e8bbe3a --- /dev/null +++ b/examples/backend.py @@ -0,0 +1,386 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: gpjax +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Backend Module Design +# +# Since v0.9, GPJax is built upon Flax's +# [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) module. This transition +# allows for more efficient parameter handling, improved integration with Flax and +# Flax-based libraries, and enhanced flexibility in model design. This notebook provides +# a high-level overview of the backend module design in GPJax. For an introduction to +# NNX, please refer to the [official +# documentation](https://flax.readthedocs.io/en/latest/nnx/index.html). +# + +# %% +# Enable Float64 for more stable matrix inversions. +from jax import ( + config, + grad, +) + +config.update("jax_enable_x64", True) + +import jax.numpy as jnp +from jaxtyping import ( + Float, + install_import_hook, +) +import matplotlib as mpl +import matplotlib.pyplot as plt + +from gpjax.mean_functions import Constant +from gpjax.parameters import ( + Parameter, + Real, +) + +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx + +from flax import nnx + +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + +cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + +# %% [markdown] +# ## Parameters +# +# The biggest change bought about by the transition to an NNX backend is the increased +# support we now provide for handling parameters. As discussed in our [Sharp Bits - +# Bijectors Doc](https://docs.jaxgaussianprocesses.com/sharp_bits/#bijectors), GPJax +# uses bijectors to transform constrained parameters to unconstrained parameters during +# optimisation. You may now register the support of a parameter using our `Parameter` +# class. To see this, consider the constant mean function who contains a single constant +# parameter whose value ordinarily exists on the real line. We can register this +# parameter as follows: + +# %% +constant_param = Parameter(value=1.0, tag=None) +meanf = Constant(constant_param) +print(meanf) + +# %% [markdown] +# However, suppose you wish your mean function's constant parameter to be strictly +# positive. This is easy to achieve by using the correct Parameter type which, in this +# case, will be the `PositiveReal`. However, any Parameter that subclasses from +# `Parameter` will be transformed by GPJax. + +# %% +from gpjax.parameters import PositiveReal + +issubclass(PositiveReal, Parameter) + +# %% [markdown] +# Injecting this newly constrained parameter into our mean function is then identical to before. + +# %% +constant_param = PositiveReal(value=1.0) +meanf = Constant(constant_param) +print(meanf) + +# %% [markdown] +# Were we to try and instantiate the `PositiveReal` class with a negative value, then an +# explicit error would be raised. + +# %% +try: + PositiveReal(value=-1.0) +except ValueError as e: + print(e) + +# %% [markdown] +# ### Parameter Transforms +# +# With a parameter instantiated, you likely wish to transform the parameter's value from +# its constrained support onto the entire real line. To do this, you can apply the +# `transform` function to the parameter. To control the bijector used to transform the +# parameter, you may pass a set of bijectors into the transform function. +# Under-the-hood, the `transform` function is looking up the bijector of a parameter +# using it's `_tag` field in the bijector dictionary, and then applying the bijector to +# the parameter's value using a tree map operation. + +# %% +print(constant_param._tag) + +# %% [markdown] +# For most users, you will not need to worry about this as we provide a set of default +# bijectors that are defined for all the parameter types we support. However, see our +# [Kernel Guide +# Notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/) to +# see how you can define your own bijectors and parameter types. + +# %% +from gpjax.parameters import DEFAULT_BIJECTION, transform + +print(DEFAULT_BIJECTION[constant_param._tag]) + +# %% [markdown] +# We see here that the Softplus bijector is specified as the default for strictly +# positive parameters. To apply this, we must first realise the _state_ of our model. +# This is achieved using the `split` function provided by `nnx`. + +# %% +_, _params = nnx.split(meanf, Parameter) + +tranformed_params = transform(_params, DEFAULT_BIJECTION, inverse=True) + +# %% [markdown] +# The parameter's value was changed here from 1. to 0.54132485. This is the result of +# applying the Softplus bijector to the parameter's value and projecting its value onto +# the real line. Were the parameter's value to be closer to 0, then the transformation +# would be more pronounced. + +# %% +_, _close_to_zero_state = nnx.split(Constant(PositiveReal(value=1e-6)), Parameter) + +transform(_close_to_zero_state, DEFAULT_BIJECTION, inverse=True) + +# %% [markdown] +# ### Transforming Multiple Parameters +# +# In the above, we transformed a single parameter. However, in practice your parameters +# may be nested within several functions e.g., a kernel function within a GP model. +# Fortunately, transforming several parameters is a simple operation that we here +# demonstrate for a conjugate GP posterior (see our [Regression +# Notebook](https://docs.jaxgaussianprocesses.com/examples/regression/) for detailed +# explanation of this model.). + +# %% +kernel = gpx.kernels.Matern32() +meanf = gpx.mean_functions.Constant() + +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) + +likelihood = gpx.likelihoods.Gaussian(100) +posterior = likelihood * prior +print(posterior) + +# %% [markdown] +# Now contained within the posterior PyGraph here there are four parameters: the +# kernel's lengthscale and variance, the noise variance of the likelihood, and the +# constant of the mean function. Using NNX, we may realise these parameters through the +# `nnx.split` function. The `split` function deomposes a PyGraph into a `GraphDef` and +# `State` object. As the name suggests, `State` contains information on the parameters' +# state, whilst `GraphDef` contains the information required to reconstruct a PyGraph +# from a give `State`. + +# %% +graphdef, state = nnx.split(posterior) +print(state) + +# %% [markdown] +# The `State` object behaves just like a PyTree and, consequently, we may use JAX's +# `tree_map` function to alter the values of the `State`. The updated `State` can then +# be used to reconstruct our posterior. In the below, we simply increment each +# parameter's value by 1. + +# %% +import jax.tree_util as jtu + +updated_state = jtu.tree_map(lambda x: x + 1, state) +print(updated_state) + +# %% [markdown] +# Let us now use NNX's `merge` function to reconstruct the posterior distribution using +# the updated state. + +# %% +updated_posterior = nnx.merge(graphdef, updated_state) +print(updated_posterior) + +# %% [markdown] +# However, we begun this point of conversation with bijectors in mind, so let us now see +# how bijectors may be applied to a collection of parameters in GPJax. Fortunately, this +# is very straightforward, and we may simply use the `transform` function as before. + +# %% +transformed_state = transform(state, DEFAULT_BIJECTION, inverse=True) +print(transformed_state) + +# %% [markdown] +# We may also (re-)constrain the parameters' values by setting the `inverse` argument of +# `transform` to False. + +# %% +retransformed_state = transform(transformed_state, DEFAULT_BIJECTION, inverse=False) + +# %% [markdown] +# ### Fine-Scale Control +# +# One of the advantages of being able to split and re-merge the PyGraph is that we are +# able to gain fine-scale control over the parameters' whose state we wish to realise. +# This is by virtue of the fact that each of our parameters now inherit from +# `gpjax.parameters.Parameter`. In the former, we were simply extracting any +# `Parameter`subclass from the posterior. However, suppose we only wish to extract those +# parameters whose support is the positive real line. This is easily achieved by +# altering the way in which we invoke `nnx.split`. + +# %% +from gpjax.parameters import PositiveReal + +graphdef, positive_reals, other_params = nnx.split(posterior, PositiveReal, ...) +print(positive_reals) + +# %% [markdown] +# Now we see that we have two state objects: one containing the positive real parameters +# and the other containing the remaining parameters. This functionality is exceptionally +# useful as it allows us to efficiently operate on a subset of the parameters whilst +# leaving the others untouched. Looking forward, we hope to use this functionality in +# our [Variational Inference +# Approximations](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/) to +# perform more efficient updates of the variational parameters and then the model's +# hyperparameters. + +# %% [markdown] +# ## NNX Modules +# +# To conclude this notebook, we will now demonstrate the ease of use and flexibility +# offered by NNX modules. To do this, we will implement a linear mean function using the +# existing abstractions in GPJax. +# +# For inputs $x_n \in \mathbb{R}^d$, the linear mean function $m(x): \mathbb{R}^d \to +# \mathbb{R}$ is defined as: +# $$ +# m(x) = \alpha + \sum_{i=1}^d \beta_i x_i +# $$ +# where $\alpha \in \mathbb{R}$ and $\beta_i \in \mathbb{R}$ are the parameters of the +# mean function. Let's now implement that using the new NNX backend. + +# %% +import typing as tp + +from jaxtyping import Float, Num + +from gpjax.mean_functions import AbstractMeanFunction +from gpjax.parameters import Parameter, Real +from gpjax.typing import ScalarFloat, Array + + +class LinearMeanFunction(AbstractMeanFunction): + def __init__( + self, + intercept: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0, + slope: tp.Union[ScalarFloat, Float[Array, " D O"], Parameter] = 0.0, + ): + if isinstance(intercept, Parameter): + self.intercept = intercept + else: + self.intercept = Real(jnp.array(intercept)) + + if isinstance(slope, Parameter): + self.slope = slope + else: + self.slope = Real(jnp.array(slope)) + + def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]: + return self.intercept.value + jnp.dot(x, self.slope.value) + + +# %% [markdown] +# As we can see, the implementation is straightforward and concise. The +# `AbstractMeanFunction` module is a subclass of `nnx.Module` and may, therefore, be +# used in any `split` or `merge` call. Further, we have registered the intercept and +# slope parameters as `Real` parameter types. This registers their value in the PyGraph +# and means that they will be part of any operation applied to the PyGraph e.g., +# transforming and differentiation. +# +# To check our implementation worked, let's now plot the value of our mean function for +# a linearly spaced set of inputs. + +# %% +N = 100 +X = jnp.linspace(-5.0, 5.0, N)[:, None] + +meanf = LinearMeanFunction(intercept=1.0, slope=2.0) +plt.plot(X, meanf(X)) + +# %% [markdown] +# Looks good! To conclude this section, let's now parameterise a GP with our new mean +# function and see how gradients may be computed. + +# %% +y = jnp.sin(X) +D = gpx.Dataset(X, y) + +prior = gpx.gps.Prior(mean_function=meanf, kernel=gpx.kernels.Matern32()) +likelihood = gpx.likelihoods.Gaussian(D.n) +posterior = likelihood * prior + +# %% [markdown] +# We'll compute derivatives of the conjugate marginal log-likelihood, with respect to +# the unconstrained state of the kernel, mean function, and likelihood parameters. + +# %% +graphdef, params, others = nnx.split(posterior, Parameter, ...) +params = transform(params, DEFAULT_BIJECTION, inverse=True) + + +def loss_fn(params: nnx.State, data: gpx.Dataset) -> ScalarFloat: + params = transform(params, DEFAULT_BIJECTION) + model = nnx.merge(graphdef, params, *others) + return -gpx.objectives.conjugate_mll(model, data) + + +param_grads = grad(loss_fn)(params, D) + +# %% [markdown] +# In practice, you would wish to perform multiple iterations of gradient descent to +# learn the optimal parameter values. However, for the purposes of illustration, we use +# another `tree_map` in the below to update the parameters' state using their previously +# computed gradients. As you can see, the really beauty in having access to the model's +# state is that we have full control over the operations that we perform to the state. + +# %% +LEARNING_RATE = 0.01 +optimised_params = jtu.tree_map( + lambda _params, _grads: _params + LEARNING_RATE * _grads, params, param_grads +) + +# %% [markdown] +# Now we will plot the updated mean function alongside its initial form. To achieve +# this, we first merge the state back into the model using `merge`, and we then simply +# invoke the model as normal. + +# %% +optimised_posterior = nnx.merge(graphdef, optimised_params, *others) + +fig, ax = plt.subplots() +ax.plot(X, optimised_posterior.prior.mean_function(X), label="Updated mean function") +ax.plot(X, meanf(X), label="Initial mean function") +ax.legend() +ax.set(xlabel="x", ylabel="m(x)") + +# %% [markdown] +# ## Conclusions +# +# In this notebook we have explored how GPJax's Flax-based backend may be easily +# manipulated and extended. For a more applied look at this, see how we construct a +# kernel on polar coordinates in our [Kernel +# Guide](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel) +# notebook. +# +# ## System configuration + +# %% +# %reload_ext watermark +# %watermark -n -u -v -iv -w -a 'Thomas Pinder' diff --git a/examples/barycentres.py b/examples/barycentres.py index 83e664de..62e06753 100644 --- a/examples/barycentres.py +++ b/examples/barycentres.py @@ -48,13 +48,12 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx +from examples.utils import use_mpl_style key = jr.key(123) # set the default style for plotting -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) +use_mpl_style() cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] diff --git a/examples/bayesian_optimisation.py b/examples/bayesian_optimisation.py index fb4e6a47..ac660a69 100644 --- a/examples/bayesian_optimisation.py +++ b/examples/bayesian_optimisation.py @@ -44,10 +44,13 @@ from gpjax.typing import Array, FunctionalSample, ScalarFloat from jaxopt import ScipyBoundedMinimize +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + key = jr.key(42) -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] diff --git a/examples/classification.py b/examples/classification.py index 837c3723..c53cf435 100644 --- a/examples/classification.py +++ b/examples/classification.py @@ -15,17 +15,11 @@ # name: python3 # --- -# %% -# %load_ext autoreload -# %autoreload 2 - # %% [markdown] # # Classification # # In this notebook we demonstrate how to perform inference for Gaussian process models -# with non-Gaussian likelihoods via maximum a posteriori (MAP) and Markov chain Monte -# Carlo (MCMC). We focus on a classification task here and use -# [BlackJax](https://github.com/blackjax-devs/blackjax/) for sampling. +# with non-Gaussian likelihoods via maximum a posteriori (MAP). We focus on a classification task here. # %% # Enable Float64 for more stable matrix inversions. @@ -54,12 +48,15 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx +from examples.utils import use_mpl_style + tfd = tfp.distributions identity_matrix = jnp.eye -key = jr.key(123) -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] @@ -320,186 +317,6 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma ) ax.legend() -# %% [markdown] -# However, the Laplace approximation is still limited by considering information about -# the posterior at a single location. On the other hand, through approximate sampling, -# MCMC methods allow us to learn all information about the posterior distribution. - -# %% [markdown] -# ## MCMC inference -# -# An MCMC sampler works by starting at an initial position and -# drawing a sample from a cheap-to-simulate distribution known as the _proposal_. The -# next step is to determine whether this sample could be considered a draw from the -# posterior. We accomplish this using an _acceptance probability_ determined via the -# sampler's _transition kernel_ which depends on the current position and the -# unnormalised target posterior distribution. If the new sample is more _likely_, we -# accept it; otherwise, we reject it and stay in our current position. Repeating these -# steps results in a Markov chain (a random sequence that depends only on the last -# state) whose stationary distribution (the long-run empirical distribution of the -# states visited) is the posterior. For a gentle introduction, see the first chapter -# of [A Handbook of Markov Chain Monte Carlo](https://www.mcmchandbook.net/HandbookChapter1.pdf). -# -# ### MCMC through BlackJax -# -# Rather than implementing a suite of MCMC samplers, GPJax relies on MCMC-specific -# libraries for sampling functionality. We focus on -# [BlackJax](https://github.com/blackjax-devs/blackjax/) in this notebook, which we -# recommend adopting for general applications. -# -# We'll use the No U-Turn Sampler (NUTS) implementation given in BlackJax for sampling. -# For the interested reader, NUTS is a Hamiltonian Monte Carlo sampling scheme where -# the number of leapfrog integration steps is computed at each step of the change -# according to the NUTS algorithm. In general, samplers constructed under this -# framework are very efficient. -# -# We begin by generating _sensible_ initial positions for our sampler before defining -# an inference loop and sampling 500 values from our Markov chain. In practice, -# drawing more samples will be necessary. - -# %% -num_adapt = 600 -num_samples = 600 - -graphdef, params, *static_state = nnx.split(posterior, gpx.parameters.Parameter, ...) -params_bijection = gpx.parameters.DEFAULT_BIJECTION - -# Transform the parameters to the unconstrained space -params = gpx.parameters.transform(params, params_bijection, inverse=True) - - -def logprob_fn(params): - params = gpx.parameters.transform(params, params_bijection) - model = nnx.merge(graphdef, params, *static_state) - return gpx.objectives.log_posterior_density(model, D) - - -# jit compile -logprob_fn = jax.jit(logprob_fn) -_ = logprob_fn(params) - -adapt = blackjax.window_adaptation( - blackjax.nuts, logprob_fn, num_adapt, target_acceptance_rate=0.65, progress_bar=True -) - -# Initialise the chain -start = time() -last_state, kernel, _ = adapt.run(key, params) -print(f"Adaption time taken: {time() - start: .1f} seconds") - - -def inference_loop(rng_key, kernel, initial_state, num_samples): - def one_step(state, rng_key): - state, info = kernel(rng_key, state) - return state, (state, info) - - keys = jax.random.split(rng_key, num_samples) - _, (states, infos) = jax.lax.scan(one_step, initial_state, keys, unroll=10) - - return states, infos - - -# Sample from the posterior distribution -start = time() -states, infos = inference_loop(key, kernel, last_state, num_samples) -print(f"Sampling time taken: {time() - start: .1f} seconds") - -# %% [markdown] -# ### Sampler efficiency -# -# BlackJax gives us easy access to our sampler's efficiency through metrics such as the -# sampler's _acceptance probability_ (the number of times that our chain accepted a -# proposed sample, divided by the total number of steps run by the chain). For NUTS and -# Hamiltonian Monte Carlo sampling, we typically seek an acceptance rate of 60-70% to -# strike the right balance between having a chain which is _stuck_ and rarely moves -# versus a chain that is too jumpy with frequent small steps. - -# %% -acceptance_rate = jnp.mean(infos.acceptance_probability) -print(f"Acceptance rate: {acceptance_rate:.2f}") - -# %% [markdown] -# Our acceptance rate is slightly too large, prompting an examination of the chain's -# trace plots. A well-mixing chain will have very few (if any) flat spots in its trace -# plot whilst also not having too many steps in the same direction. In addition to -# the model's hyperparameters, there will be 500 samples for each of the 100 latent -# function values in the `states.position` dictionary. We depict the chains that -# correspond to the model hyperparameters and the first value of the latent function -# for brevity. - -# %% -fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(10, 3)) -ax0.plot(states.position.prior.kernel.lengthscale.value) -ax1.plot(states.position.prior.kernel.variance.value) -ax2.plot(states.position.latent.value[:, 1, :]) -ax0.set_title("Kernel Lengthscale") -ax1.set_title("Kernel Variance") -ax2.set_title("Latent Function (index = 1)") - -# %% [markdown] -# ## Prediction -# -# Having obtained samples from the posterior, we draw ten instances from our model's -# predictive distribution per MCMC sample. Using these draws, we will be able to -# compute credible values and expected values under our posterior distribution. -# -# An ideal Markov chain would have samples completely uncorrelated with their -# neighbours after a single lag. However, in practice, correlations often exist -# within our chain's sample set. A commonly used technique to try and reduce this -# correlation is _thinning_ whereby we select every $n$th sample where $n$ is the -# minimum lag length at which we believe the samples are uncorrelated. Although further -# analysis of the chain's autocorrelation is required to find appropriate thinning -# factors, we employ a thin factor of 10 for demonstration purposes. - -# %% -thin_factor = 20 -posterior_samples = [] - -for i in trange(0, num_samples, thin_factor, desc="Drawing posterior samples"): - sample_params = jtu.tree_map(lambda samples, i=i: samples[i], states.position) - sample_params = gpx.parameters.transform(sample_params, params_bijection) - model = nnx.merge(graphdef, sample_params, *static_state) - latent_dist = model.predict(xtest, train_data=D) - predictive_dist = model.likelihood(latent_dist) - posterior_samples.append(predictive_dist.sample(seed=key, sample_shape=(10,))) - -posterior_samples = jnp.vstack(posterior_samples) -lower_ci, upper_ci = jnp.percentile(posterior_samples, jnp.array([2.5, 97.5]), axis=0) -expected_val = jnp.mean(posterior_samples, axis=0) - -# %% [markdown] -# -# Finally, we end this tutorial by plotting the predictions obtained from our model -# against the observed data. - -# %% -fig, ax = plt.subplots() -ax.scatter(x, y, color=cols[0], label="Observations", zorder=2, alpha=0.7) -ax.plot(xtest, expected_val, color=cols[1], label="Predicted mean", zorder=1) -ax.fill_between( - xtest.flatten(), - lower_ci.flatten(), - upper_ci.flatten(), - alpha=0.2, - color=cols[1], - label="95\\% CI", -) -ax.plot( - xtest, - lower_ci.flatten(), - color=cols[1], - linestyle="--", - linewidth=1, -) -ax.plot( - xtest, - upper_ci.flatten(), - color=cols[1], - linestyle="--", - linewidth=1, -) -ax.legend() - # %% [markdown] # ## System configuration diff --git a/examples/collapsed_vi.py b/examples/collapsed_vi.py index b397e4c2..95951592 100644 --- a/examples/collapsed_vi.py +++ b/examples/collapsed_vi.py @@ -42,10 +42,13 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx -key = jr.key(123) -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) + cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/examples/constructing_new_kernels.py b/examples/constructing_new_kernels.py index 7936d777..96d08525 100644 --- a/examples/constructing_new_kernels.py +++ b/examples/constructing_new_kernels.py @@ -40,11 +40,15 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx -key = jr.key(123) +from examples.utils import use_mpl_style + tfb = tfp.bijectors -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) + cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/examples/decision_making.py b/examples/decision_making.py index 7cc97f13..e281e55d 100644 --- a/examples/decision_making.py +++ b/examples/decision_making.py @@ -65,10 +65,15 @@ Float, ) + +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + key = jr.key(42) -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + + cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] diff --git a/examples/deep_kernels.py b/examples/deep_kernels.py index 3370e9d1..b41538c3 100644 --- a/examples/deep_kernels.py +++ b/examples/deep_kernels.py @@ -58,12 +58,16 @@ import gpjax as gpx from gpjax.kernels.base import AbstractKernel -key = jr.key(123) -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] +key = jr.key(42) + + # %% [markdown] # ## Dataset # diff --git a/examples/graph_kernels.py b/examples/graph_kernels.py index 4cd39779..e9a28b9a 100644 --- a/examples/graph_kernels.py +++ b/examples/graph_kernels.py @@ -42,10 +42,14 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx -key = jr.key(123) -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) + cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/examples/intro_to_gps.py b/examples/intro_to_gps.py index e04ec9ff..9fc0eeae 100644 --- a/examples/intro_to_gps.py +++ b/examples/intro_to_gps.py @@ -121,11 +121,14 @@ import pandas as pd import seaborn as sns import tensorflow_probability.substrates.jax as tfp -from docs.examples.utils import confidence_ellipse +from examples.utils import confidence_ellipse, use_mpl_style + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) + -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] tfd = tfp.distributions diff --git a/examples/intro_to_kernels.py b/examples/intro_to_kernels.py index 8387b1c8..396aa936 100644 --- a/examples/intro_to_kernels.py +++ b/examples/intro_to_kernels.py @@ -40,10 +40,12 @@ from gpjax.typing import Array from sklearn.preprocessing import StandardScaler +from examples.utils import use_mpl_style + key = jr.key(42) -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +# set the default style for plotting +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/examples/likelihoods_guide.py b/examples/likelihoods_guide.py index 2bff2fdf..10597a26 100644 --- a/examples/likelihoods_guide.py +++ b/examples/likelihoods_guide.py @@ -78,13 +78,15 @@ import matplotlib.pyplot as plt import tensorflow_probability.substrates.jax as tfp +from examples.utils import use_mpl_style + tfd = tfp.distributions -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +# set the default style for plotting +use_mpl_style() cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] -key = jr.key(123) +key = jr.key(42) n = 50 x = jnp.sort(jr.uniform(key=key, shape=(n, 1), minval=-3.0, maxval=3.0), axis=0) diff --git a/examples/oceanmodelling.py b/examples/oceanmodelling.py index 7422a8f0..d46e89af 100644 --- a/examples/oceanmodelling.py +++ b/examples/oceanmodelling.py @@ -45,11 +45,13 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx -# Enable Float64 for more stable matrix inversions. -key = jr.key(123) -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() + +key = jr.key(42) + colors = rcParams["axes.prop_cycle"].by_key()["color"] diff --git a/examples/poisson.py b/examples/poisson.py index 1c59b0fe..284cb54c 100644 --- a/examples/poisson.py +++ b/examples/poisson.py @@ -38,15 +38,18 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx +from examples.utils import use_mpl_style + # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) tfd = tfp.distributions -key = jr.key(123) -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +# set the default style for plotting +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] +key = jr.key(42) + # %% [markdown] # ## Dataset # diff --git a/examples/regression.py b/examples/regression.py index bf777b1e..c5ef2d50 100644 --- a/examples/regression.py +++ b/examples/regression.py @@ -35,12 +35,12 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx +from examples.utils import use_mpl_style + key = jr.key(123) # set the default style for plotting -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] diff --git a/examples/uncollapsed_vi.py b/examples/uncollapsed_vi.py index 21d51f4c..073819a6 100644 --- a/examples/uncollapsed_vi.py +++ b/examples/uncollapsed_vi.py @@ -48,13 +48,17 @@ import gpjax as gpx import gpjax.kernels as jk -key = jr.key(123) +from examples.utils import use_mpl_style + tfb = tfp.bijectors -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) + +key = jr.key(123) + +# set the default style for plotting +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + # %% [markdown] # ## Dataset # diff --git a/examples/utils.py b/examples/utils.py index b9ccaa65..b8d5a81f 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -73,3 +73,8 @@ def clean_legend(ax): by_label = dict(zip(labels, handles)) ax.legend(by_label.values(), by_label.keys()) return ax + + +def use_mpl_style(): + style_file = Path(__file__).parent / "gpjax.mplstyle" + plt.style.use(style_file) diff --git a/examples/yacht.py b/examples/yacht.py index 5e9ef0e7..940dff15 100644 --- a/examples/yacht.py +++ b/examples/yacht.py @@ -46,13 +46,14 @@ with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx -# Enable Float64 for more stable matrix inversions. -key = jr.key(123) -plt.style.use( - "https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" -) +from examples.utils import use_mpl_style + +# set the default style for plotting +use_mpl_style() cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] +key = jr.key(42) + # %% [markdown] # ## Data Loading # diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 0aa74614..c54676dd 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -31,11 +31,8 @@ def transform( >>> ) >>> params_bijection = {'positive': tfb.Softplus()} >>> transformed_params = transform(params, params_bijection) - >>> transformed_params["a"] - PositiveReal( - value=Array([1.3132617], dtype=float32), - _tag='positive' - ) + >>> print(transformed_params["a"].value) + [1.3132617] ``` @@ -59,6 +56,7 @@ def _inner(param): return param gp_params, *other_params = params.split(Parameter, ...) + transformed_gp_params: nnx.State = jtu.tree_map( lambda x: _inner(x), gp_params, diff --git a/mkdocs.yml b/mkdocs.yml index 87b51b12..a1b4728b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -32,6 +32,7 @@ nav: - 📖 Guides for customisation: - Kernels: _examples/constructing_new_kernels.md - Likelihoods: _examples/likelihoods_guide.md + - Model Guide: _examples/backend.md - UCI regression: _examples/yacht.md # - 💻 Raw tutorial code: give_me_the_code.md - Community: diff --git a/poetry.lock b/poetry.lock index 1141e126..3477e7fd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -1158,6 +1158,7 @@ description = "Python AST that abstracts the underlying Python version" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" files = [ + {file = "gast-0.6.0-py3-none-any.whl", hash = "sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54"}, {file = "gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb"}, ] @@ -1515,6 +1516,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + [[package]] name = "json5" version = "0.9.25" @@ -3723,6 +3735,51 @@ files = [ {file = "ruff-0.6.0.tar.gz", hash = "sha256:272a81830f68f9bd19d49eaf7fa01a5545c5a2e86f32a9935bb0e4bb9a1db5b8"}, ] +[[package]] +name = "scikit-learn" +version = "1.5.1" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745"}, + {file = "scikit_learn-1.5.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7"}, + {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac"}, + {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21"}, + {file = "scikit_learn-1.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1"}, + {file = "scikit_learn-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2"}, + {file = "scikit_learn-1.5.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe"}, + {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4"}, + {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf"}, + {file = "scikit_learn-1.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b"}, + {file = "scikit_learn-1.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395"}, + {file = "scikit_learn-1.5.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1"}, + {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915"}, + {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b"}, + {file = "scikit_learn-1.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74"}, + {file = "scikit_learn-1.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956"}, + {file = "scikit_learn-1.5.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855"}, + {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1"}, + {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d"}, + {file = "scikit_learn-1.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d"}, + {file = "scikit_learn-1.5.1.tar.gz", hash = "sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] + [[package]] name = "scipy" version = "1.14.0" @@ -3915,6 +3972,17 @@ files = [ ml-dtypes = ">=0.3.1" numpy = ">=1.22.0" +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] + [[package]] name = "tinycss2" version = "1.3.0" @@ -4213,4 +4281,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "4b2ae7bad45029e7becc027913e0adbf40d21a2632971f4eb50c3a4096f20766" +content-hash = "99d22602c5c323f3ea78b4a80ca493069b946cea47b23d0a6e932c2900c385a4" diff --git a/pyproject.toml b/pyproject.toml index 82a757ed..cc1b6412 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ pandas = "^1.5.3" pymdown-extensions = "^10.7.1" nbconvert = "^7.16.2" markdown-katex = "^202406.1035" +scikit-learn = "^1.5.1" [build-system] requires = ["poetry-core"] diff --git a/tests/test_markdown.py b/tests/test_markdown.py index d28e9078..e543db8e 100644 --- a/tests/test_markdown.py +++ b/tests/test_markdown.py @@ -5,12 +5,12 @@ # Ensure that code chunks within any markdown files execute without error -@pytest.mark.parametrize("fpath", pathlib.Path("gpjax/").glob("**/*.md"), ids=str) +@pytest.mark.parametrize("fpath", pathlib.Path("gpjax/").glob("*.md"), ids=str) def test_source_good(fpath): check_md_file(fpath=fpath, memory=True) -@pytest.mark.parametrize("fpath", pathlib.Path("docs").glob("**/*.md"), ids=str) +@pytest.mark.parametrize("fpath", pathlib.Path("docs").glob("*.md"), ids=str) def test_docs_good(fpath): check_md_file(fpath=fpath, memory=True)