Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typo fixes #187

Merged
merged 1 commit into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions gpjax/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


class InferenceState(PyTree):
"""Imutable class for storing optimised parameters and training history."""
"""Immutable class for storing optimised parameters and training history."""

def __init__(self, params: Dict, history: Float[Array, "num_iters"]):
self._params = params
Expand Down Expand Up @@ -97,7 +97,7 @@ def loss(params: Dict) -> Float[Array, "1"]:
params = constrain(params, bijectors)
return objective(params)

# Tranform params to unconstrained space
# Transform params to unconstrained space
params = unconstrain(params, bijectors)

# Initialise optimiser state
Expand All @@ -122,7 +122,7 @@ def step(carry, iter_num: int):
# Run the optimisation loop
(params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums)

# Tranform final params to constrained space
# Transform final params to constrained space
params = constrain(params, bijectors)

return InferenceState(params=params, history=history)
Expand Down Expand Up @@ -166,7 +166,7 @@ def loss(params: Dict, batch: Dataset) -> Float[Array, "1"]:
params = constrain(params, bijectors)
return objective(params, batch)

# Tranform params to unconstrained space
# Transform params to unconstrained space
params = unconstrain(params, bijectors)

# Initialise optimiser state
Expand Down Expand Up @@ -197,7 +197,7 @@ def step(carry, iter_num__and__key):
# Run the optimisation loop
(params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys))

# Tranform final params to constrained space
# Transform final params to constrained space
params = constrain(params, bijectors)

return InferenceState(params=params, history=history)
Expand All @@ -215,7 +215,7 @@ def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset:
"""
x, y, n = train_data.X, train_data.y, train_data.n

# Subsample data inidicies with replacement to get the mini-batch
# Subsample data indices with replacement to get the mini-batch
indicies = jr.choice(key, n, (batch_size,), replace=True)

return Dataset(X=x[indicies], y=y[indicies])
Expand Down Expand Up @@ -257,7 +257,7 @@ def fit_natgrads(

params, trainables, bijectors = parameter_state.unpack()

# Tranform params to unconstrained space
# Transform params to unconstrained space
params = unconstrain(params, bijectors)

# Initialise optimiser states
Expand Down Expand Up @@ -302,7 +302,7 @@ def step(carry, iter_num__and__key):
step, (params, hyper_state, moment_state), (iter_nums, keys)
)

# Tranform final params to constrained space
# Transform final params to constrained space
params = constrain(params, bijectors)

return InferenceState(params=params, history=history)
Expand Down
2 changes: 1 addition & 1 deletion gpjax/gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _kl_divergence(

Args:
q (GaussianDistribution): A multivariate Gaussian distribution.
p (GaussianDistribution): A multivariate Gaussia distribution.
p (GaussianDistribution): A multivariate Gaussian distribution.

Returns:
Float[Array, "1"]: The KL divergence between q and p.
Expand Down
10 changes: 5 additions & 5 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def init_params(self, key: KeyArray) -> Dict:
details="Use the ``init_params`` method for parameter initialisation.",
)
def _initialise_params(self, key: KeyArray) -> Dict:
"""Deprecated method for initialising the GP's parameters. Succeded by ``init_params``."""
"""Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``."""
return self.init_params(key)


Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(
def __mul__(self, other: AbstractLikelihood):
"""The product of a prior and likelihood is proportional to the
posterior distribution. By computing the product of a GP prior and a
likelihood object, a posterior GP object will be returned. Mathetically,
likelihood object, a posterior GP object will be returned. Mathematically,
this can be described by:
.. math::

Expand Down Expand Up @@ -392,7 +392,7 @@ def predict(
The conditioning set is a GPJax ``Dataset`` object, whilst predictions
are made on a regular Jax array.

£xample:
Example:
For a ``posterior`` distribution, the following code snippet will
evaluate the predictive distribution.

Expand Down Expand Up @@ -694,7 +694,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution:
# Lx⁻¹ Kxt
Lx_inv_Kxt = Lx.solve(Ktx.T)

# Whitened function values, wx, correponding to the inputs, x
# Whitened function values, wx, corresponding to the inputs, x
wx = params["latent"]

# μt + Ktx Lx⁻¹ wx
Expand Down Expand Up @@ -778,7 +778,7 @@ def mll(params: Dict):
# Compute the prior mean function
μx = mean_function(params["mean_function"], x)

# Whitened function values, wx, correponding to the inputs, x
# Whitened function values, wx, corresponding to the inputs, x
wx = params["latent"]

# f(x) = μx + Lx wx
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def init_params(self, key: KeyArray) -> Dict:
details="Use the ``init_params`` method for parameter initialisation.",
)
def _initialise_params(self, key: KeyArray) -> Dict:
"""Deprecated method for initialising the GP's parameters. Succeded by ``init_params``."""
"""Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``."""
return self.init_params(key)


Expand Down
6 changes: 3 additions & 3 deletions gpjax/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def init_params(self, key: KeyArray) -> Dict:
details="Use the ``init_params`` method for parameter initialisation.",
)
def _initialise_params(self, key: KeyArray) -> Dict:
"""Deprecated method for initialising the GP's parameters. Succeded by ``init_params``."""
"""Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``."""
return self.init_params(key)

@property
Expand All @@ -99,11 +99,11 @@ def link_function(self) -> Callable:


class Conjugate:
"""An abstract class for conjugate likelihoods with respect to a Gaussain process prior."""
"""An abstract class for conjugate likelihoods with respect to a Gaussian process prior."""


class NonConjugate:
"""An abstract class for non-conjugate likelihoods with respect to a Gaussain process prior."""
"""An abstract class for non-conjugate likelihoods with respect to a Gaussian process prior."""


# TODO: revamp this with covariance operators.
Expand Down
2 changes: 1 addition & 1 deletion gpjax/mean_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def init_params(self, key: KeyArray) -> Dict:
details="Use the ``init_params`` method for parameter initialisation.",
)
def _initialise_params(self, key: KeyArray) -> Dict:
"""Deprecated method for initialising the GP's parameters. Succeded by ``init_params``."""
"""Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``."""
return self.init_params(key)


Expand Down
10 changes: 5 additions & 5 deletions gpjax/natural_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def natural_to_expectation(params: Dict) -> Dict:
In particular, in terms of the Gaussian mean μ and covariance matrix μ for
the Gaussian variational family,

- the natural parameteristaion is θ = (S⁻¹μ, -S⁻¹/2)
- the natural parameterisation is θ = (S⁻¹μ, -S⁻¹/2)
- the expectation parameters are η = (μ, S + μ μᵀ).

This function solves these eqautions in terms of μ and S to convert θ to η.
This function solves these equations in terms of μ and S to convert θ to η.

Writing θ = (θ₁, θ₂), we have that S⁻¹ = -2θ₂ . Taking the cholesky
decomposition of the inverse covariance, S⁻¹ = LLᵀ and defining C = L⁻¹, we
Expand Down Expand Up @@ -165,7 +165,7 @@ def _rename_natural_to_expectation(params: Dict) -> Dict:

def _get_moment_trainables(trainables: Dict) -> Dict:
"""
This function takes a trainbles dictionary, and sets non-moment parameter
This function takes a trainables dictionary, and sets non-moment parameter
training to false for gradient stopping.

Args:
Expand All @@ -185,7 +185,7 @@ def _get_moment_trainables(trainables: Dict) -> Dict:

def _get_hyperparameter_trainables(trainables: Dict) -> Dict:
"""
This function takes a trainbles dictionary, and sets moment parameter
This function takes a trainables dictionary, and sets moment parameter
training to false for gradient stopping.

Args:
Expand Down Expand Up @@ -251,7 +251,7 @@ def nat_grads_fn(params: Dict, batch: Dataset) -> Dict:
# Transform parameters to constrained space.
params = constrain(params, bijectors)

# Convert natural parameterisation θ to the expectation parametersation η.
# Convert natural parameterisation θ to the expectation parameterisation η.
expectation_params = natural_to_expectation(params)

# Compute gradient ∂L/∂η:
Expand Down
14 changes: 7 additions & 7 deletions gpjax/variational_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def init_params(self, key: KeyArray) -> Dict:
details="Use the ``init_params`` method for parameter initialisation.",
)
def _initialise_params(self, key: KeyArray) -> Dict:
"""Deprecated method for initialising the GP's parameters. Succeded by ``init_params``."""
"""Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``."""
return self.init_params(key)

@abc.abstractmethod
Expand Down Expand Up @@ -390,7 +390,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian):

The variational family is q(f(·)) = ∫ p(f(·)|u) q(u) du, where u = f(z) are the function values at the inducing inputs z
and the distribution over the inducing inputs is q(u) = N(μ, S). Expressing the variational distribution, in the form of the
exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural paramerisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2), to perform
exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural parameterisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2), to perform
model inference, where T(u) = [u, uuᵀ] are the sufficient statistics.
"""

Expand Down Expand Up @@ -433,7 +433,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]:

For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)],

with μ and S computed from the natural paramerisation θ = (S⁻¹μ, -S⁻¹/2).
with μ and S computed from the natural parameterisation θ = (S⁻¹μ, -S⁻¹/2).

Args:
params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated.
Expand Down Expand Up @@ -490,7 +490,7 @@ def predict(

N[f(t); μt + Ktz Kzz⁻¹ (μ - μz), Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt ],

with μ and S computed from the natural paramerisation θ = (S⁻¹μ, -S⁻¹/2).
with μ and S computed from the natural parameterisation θ = (S⁻¹μ, -S⁻¹/2).

Args:
params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP.
Expand Down Expand Up @@ -574,7 +574,7 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian):

The variational family is q(f(·)) = ∫ p(f(·)|u) q(u) du, where u = f(z) are the function values at the inducing inputs z
and the distribution over the inducing inputs is q(u) = N(μ, S). Expressing the variational distribution, in the form of the
exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural paramerisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2) and
exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural parameterisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2) and
sufficient stastics T(u) = [u, uuᵀ]. The expectation parameters are given by η = ∫ T(u) q(u) du. This gives a parameterisation,
η = (η₁, η₁) = (μ, S + uuᵀ) to perform model inference over.
"""
Expand Down Expand Up @@ -620,7 +620,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]:

For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)],

with μ and S computed from the expectation paramerisation η = (μ, S + uuᵀ).
with μ and S computed from the expectation parameterisation η = (μ, S + uuᵀ).

Args:
params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated.
Expand Down Expand Up @@ -670,7 +670,7 @@ def predict(

N[f(t); μt + Ktz Kzz⁻¹ (μ - μz), Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt ],

with μ and S computed from the expectation paramerisation η = (μ, S + uuᵀ).
with μ and S computed from the expectation parameterisation η = (μ, S + uuᵀ).

Args:
params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP.
Expand Down
4 changes: 2 additions & 2 deletions gpjax/variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def init_params(self, key: KeyArray) -> Dict:
details="Use the ``init_params`` method for parameter initialisation.",
)
def _initialise_params(self, key: KeyArray) -> Dict:
"""Deprecated method for initialising the GP's parameters. Succeded by ``init_params``."""
"""Deprecated method for initialising the GP's parameters. Succeeded by ``init_params``."""
return self.init_params(key)

@abc.abstractmethod
Expand Down Expand Up @@ -231,7 +231,7 @@ def elbo_fn(params: Dict) -> Float[Array, "1"]:
#
# with B = I + AAᵀ and A = Lz⁻¹ Kzx / σ.
#
# Similary we apply matrix inversion lemma to invert σ²I + Q
# Similarly we apply matrix inversion lemma to invert σ²I + Q
#
# (σ²I + Q)⁻¹ = (Iσ²)⁻¹ - (Iσ²)⁻¹ Kxz Lz⁻ᵀ (I + Lz⁻¹ Kzx (Iσ²)⁻¹ Kxz Lz⁻ᵀ )⁻¹ Lz⁻¹ Kzx (Iσ²)⁻¹
# = (Iσ²)⁻¹ - (Iσ²)⁻¹ σAᵀ (I + σA (Iσ²)⁻¹ σAᵀ)⁻¹ σA (Iσ²)⁻¹
Expand Down