diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 5a3352e6..ac56b39d 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -33,7 +33,7 @@ class InferenceState(PyTree): - """Imutable class for storing optimised parameters and training history.""" + """Immutable class for storing optimised parameters and training history.""" def __init__(self, params: Dict, history: Float[Array, "num_iters"]): self._params = params @@ -97,7 +97,7 @@ def loss(params: Dict) -> Float[Array, "1"]: params = constrain(params, bijectors) return objective(params) - # Tranform params to unconstrained space + # Transform params to unconstrained space params = unconstrain(params, bijectors) # Initialise optimiser state @@ -122,7 +122,7 @@ def step(carry, iter_num: int): # Run the optimisation loop (params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums) - # Tranform final params to constrained space + # Transform final params to constrained space params = constrain(params, bijectors) return InferenceState(params=params, history=history) @@ -166,7 +166,7 @@ def loss(params: Dict, batch: Dataset) -> Float[Array, "1"]: params = constrain(params, bijectors) return objective(params, batch) - # Tranform params to unconstrained space + # Transform params to unconstrained space params = unconstrain(params, bijectors) # Initialise optimiser state @@ -197,7 +197,7 @@ def step(carry, iter_num__and__key): # Run the optimisation loop (params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys)) - # Tranform final params to constrained space + # Transform final params to constrained space params = constrain(params, bijectors) return InferenceState(params=params, history=history) @@ -215,7 +215,7 @@ def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset: """ x, y, n = train_data.X, train_data.y, train_data.n - # Subsample data inidicies with replacement to get the mini-batch + # Subsample data indices with replacement to get the mini-batch indicies = jr.choice(key, n, (batch_size,), replace=True) return Dataset(X=x[indicies], y=y[indicies]) @@ -257,7 +257,7 @@ def fit_natgrads( params, trainables, bijectors = parameter_state.unpack() - # Tranform params to unconstrained space + # Transform params to unconstrained space params = unconstrain(params, bijectors) # Initialise optimiser states @@ -302,7 +302,7 @@ def step(carry, iter_num__and__key): step, (params, hyper_state, moment_state), (iter_nums, keys) ) - # Tranform final params to constrained space + # Transform final params to constrained space params = constrain(params, bijectors) return InferenceState(params=params, history=history) diff --git a/gpjax/gaussian_distribution.py b/gpjax/gaussian_distribution.py index db5a487c..526688e2 100644 --- a/gpjax/gaussian_distribution.py +++ b/gpjax/gaussian_distribution.py @@ -210,7 +210,7 @@ def _kl_divergence( Args: q (GaussianDistribution): A multivariate Gaussian distribution. - p (GaussianDistribution): A multivariate Gaussia distribution. + p (GaussianDistribution): A multivariate Gaussian distribution. Returns: Float[Array, "1"]: The KL divergence between q and p. diff --git a/gpjax/gps.py b/gpjax/gps.py index aefa9cb6..34867ba9 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -97,7 +97,7 @@ def init_params(self, key: KeyArray) -> Dict: details="Use the ``init_params`` method for parameter initialisation.", ) def _initialise_params(self, key: KeyArray) -> Dict: - """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + """Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``.""" return self.init_params(key) @@ -150,7 +150,7 @@ def __init__( def __mul__(self, other: AbstractLikelihood): """The product of a prior and likelihood is proportional to the posterior distribution. By computing the product of a GP prior and a - likelihood object, a posterior GP object will be returned. Mathetically, + likelihood object, a posterior GP object will be returned. Mathematically, this can be described by: .. math:: @@ -392,7 +392,7 @@ def predict( The conditioning set is a GPJax ``Dataset`` object, whilst predictions are made on a regular Jax array. - £xample: + Example: For a ``posterior`` distribution, the following code snippet will evaluate the predictive distribution. @@ -694,7 +694,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: # Lx⁻¹ Kxt Lx_inv_Kxt = Lx.solve(Ktx.T) - # Whitened function values, wx, correponding to the inputs, x + # Whitened function values, wx, corresponding to the inputs, x wx = params["latent"] # μt + Ktx Lx⁻¹ wx @@ -778,7 +778,7 @@ def mll(params: Dict): # Compute the prior mean function μx = mean_function(params["mean_function"], x) - # Whitened function values, wx, correponding to the inputs, x + # Whitened function values, wx, corresponding to the inputs, x wx = params["latent"] # f(x) = μx + Lx wx diff --git a/gpjax/kernels.py b/gpjax/kernels.py index 7b0b9759..6e2034e2 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -369,7 +369,7 @@ def init_params(self, key: KeyArray) -> Dict: details="Use the ``init_params`` method for parameter initialisation.", ) def _initialise_params(self, key: KeyArray) -> Dict: - """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + """Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``.""" return self.init_params(key) diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 3cc96be5..ce437736 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -84,7 +84,7 @@ def init_params(self, key: KeyArray) -> Dict: details="Use the ``init_params`` method for parameter initialisation.", ) def _initialise_params(self, key: KeyArray) -> Dict: - """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + """Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``.""" return self.init_params(key) @property @@ -99,11 +99,11 @@ def link_function(self) -> Callable: class Conjugate: - """An abstract class for conjugate likelihoods with respect to a Gaussain process prior.""" + """An abstract class for conjugate likelihoods with respect to a Gaussian process prior.""" class NonConjugate: - """An abstract class for non-conjugate likelihoods with respect to a Gaussain process prior.""" + """An abstract class for non-conjugate likelihoods with respect to a Gaussian process prior.""" # TODO: revamp this with covariance operators. diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index e6073a54..75ebf43e 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -70,7 +70,7 @@ def init_params(self, key: KeyArray) -> Dict: details="Use the ``init_params`` method for parameter initialisation.", ) def _initialise_params(self, key: KeyArray) -> Dict: - """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + """Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``.""" return self.init_params(key) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index db399d9a..596bb659 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -40,10 +40,10 @@ def natural_to_expectation(params: Dict) -> Dict: In particular, in terms of the Gaussian mean μ and covariance matrix μ for the Gaussian variational family, - - the natural parameteristaion is θ = (S⁻¹μ, -S⁻¹/2) + - the natural parameterisation is θ = (S⁻¹μ, -S⁻¹/2) - the expectation parameters are η = (μ, S + μ μᵀ). - This function solves these eqautions in terms of μ and S to convert θ to η. + This function solves these equations in terms of μ and S to convert θ to η. Writing θ = (θ₁, θ₂), we have that S⁻¹ = -2θ₂ . Taking the cholesky decomposition of the inverse covariance, S⁻¹ = LLᵀ and defining C = L⁻¹, we @@ -165,7 +165,7 @@ def _rename_natural_to_expectation(params: Dict) -> Dict: def _get_moment_trainables(trainables: Dict) -> Dict: """ - This function takes a trainbles dictionary, and sets non-moment parameter + This function takes a trainables dictionary, and sets non-moment parameter training to false for gradient stopping. Args: @@ -185,7 +185,7 @@ def _get_moment_trainables(trainables: Dict) -> Dict: def _get_hyperparameter_trainables(trainables: Dict) -> Dict: """ - This function takes a trainbles dictionary, and sets moment parameter + This function takes a trainables dictionary, and sets moment parameter training to false for gradient stopping. Args: @@ -251,7 +251,7 @@ def nat_grads_fn(params: Dict, batch: Dataset) -> Dict: # Transform parameters to constrained space. params = constrain(params, bijectors) - # Convert natural parameterisation θ to the expectation parametersation η. + # Convert natural parameterisation θ to the expectation parameterisation η. expectation_params = natural_to_expectation(params) # Compute gradient ∂L/∂η: diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 7416fc99..7d16eb4d 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -75,7 +75,7 @@ def init_params(self, key: KeyArray) -> Dict: details="Use the ``init_params`` method for parameter initialisation.", ) def _initialise_params(self, key: KeyArray) -> Dict: - """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + """Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``.""" return self.init_params(key) @abc.abstractmethod @@ -390,7 +390,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian): The variational family is q(f(·)) = ∫ p(f(·)|u) q(u) du, where u = f(z) are the function values at the inducing inputs z and the distribution over the inducing inputs is q(u) = N(μ, S). Expressing the variational distribution, in the form of the - exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural paramerisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2), to perform + exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural parameterisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2), to perform model inference, where T(u) = [u, uuᵀ] are the sufficient statistics. """ @@ -433,7 +433,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)], - with μ and S computed from the natural paramerisation θ = (S⁻¹μ, -S⁻¹/2). + with μ and S computed from the natural parameterisation θ = (S⁻¹μ, -S⁻¹/2). Args: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. @@ -490,7 +490,7 @@ def predict( N[f(t); μt + Ktz Kzz⁻¹ (μ - μz), Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt ], - with μ and S computed from the natural paramerisation θ = (S⁻¹μ, -S⁻¹/2). + with μ and S computed from the natural parameterisation θ = (S⁻¹μ, -S⁻¹/2). Args: params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. @@ -574,7 +574,7 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian): The variational family is q(f(·)) = ∫ p(f(·)|u) q(u) du, where u = f(z) are the function values at the inducing inputs z and the distribution over the inducing inputs is q(u) = N(μ, S). Expressing the variational distribution, in the form of the - exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural paramerisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2) and + exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural parameterisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2) and sufficient stastics T(u) = [u, uuᵀ]. The expectation parameters are given by η = ∫ T(u) q(u) du. This gives a parameterisation, η = (η₁, η₁) = (μ, S + uuᵀ) to perform model inference over. """ @@ -620,7 +620,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)], - with μ and S computed from the expectation paramerisation η = (μ, S + uuᵀ). + with μ and S computed from the expectation parameterisation η = (μ, S + uuᵀ). Args: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. @@ -670,7 +670,7 @@ def predict( N[f(t); μt + Ktz Kzz⁻¹ (μ - μz), Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt ], - with μ and S computed from the expectation paramerisation η = (μ, S + uuᵀ). + with μ and S computed from the expectation parameterisation η = (μ, S + uuᵀ). Args: params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index d33755b3..308beab9 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -72,7 +72,7 @@ def init_params(self, key: KeyArray) -> Dict: details="Use the ``init_params`` method for parameter initialisation.", ) def _initialise_params(self, key: KeyArray) -> Dict: - """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + """Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``.""" return self.init_params(key) @abc.abstractmethod @@ -231,7 +231,7 @@ def elbo_fn(params: Dict) -> Float[Array, "1"]: # # with B = I + AAᵀ and A = Lz⁻¹ Kzx / σ. # - # Similary we apply matrix inversion lemma to invert σ²I + Q + # Similarly we apply matrix inversion lemma to invert σ²I + Q # # (σ²I + Q)⁻¹ = (Iσ²)⁻¹ - (Iσ²)⁻¹ Kxz Lz⁻ᵀ (I + Lz⁻¹ Kzx (Iσ²)⁻¹ Kxz Lz⁻ᵀ )⁻¹ Lz⁻¹ Kzx (Iσ²)⁻¹ # = (Iσ²)⁻¹ - (Iσ²)⁻¹ σAᵀ (I + σA (Iσ²)⁻¹ σAᵀ)⁻¹ σA (Iσ²)⁻¹