Skip to content

Commit

Permalink
Arccosine kernel (#245)
Browse files Browse the repository at this point in the history
* WIP

* first go

* nice test

---------

Signed-off-by: Thomas Pinder <tompinder@live.co.uk>
Co-authored-by: Thomas Pinder <tompinder@live.co.uk>
  • Loading branch information
henrymoss and thomaspinder authored Apr 30, 2023
1 parent 5381567 commit c07935a
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 10 deletions.
2 changes: 2 additions & 0 deletions gpjax/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from gpjax.kernels.non_euclidean import GraphKernel
from gpjax.kernels.nonstationary import (
ArcCosine,
Linear,
Polynomial,
)
Expand All @@ -45,6 +46,7 @@

__all__ = [
"AbstractKernel",
"ArcCosine",
"RBF",
"GraphKernel",
"Matern12",
Expand Down
3 changes: 2 additions & 1 deletion gpjax/kernels/nonstationary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
# ==============================================================================

from gpjax.kernels.nonstationary.arccosine import ArcCosine
from gpjax.kernels.nonstationary.linear import Linear
from gpjax.kernels.nonstationary.polynomial import Polynomial

__all__ = ["Linear", "Polynomial"]
__all__ = ["Linear", "Polynomial", "ArcCosine"]
117 changes: 117 additions & 0 deletions gpjax/kernels/nonstationary/arccosine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass

from beartype.typing import Union
import jax.numpy as jnp
from jaxtyping import Float
from simple_pytree import static_field
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base import param_field
from gpjax.kernels.base import AbstractKernel
from gpjax.typing import (
Array,
ScalarFloat,
ScalarInt,
)


@dataclass
class ArcCosine(AbstractKernel):
"""The ArCosine kernel. This kernel is non-stationary and resembles the behavior
of neural networks. See Section 3.1 of https://arxiv.org/pdf/1112.3712.pdf for
additional details.
"""

order: ScalarInt = static_field(0)
variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
weight_variance: Union[ScalarFloat, Float[Array, " D"]] = param_field(
jnp.array(1.0), bijector=tfb.Softplus()
)
bias_variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())

def __post_init__(self):
if self.order not in [0, 1, 2]:
raise ValueError("ArcCosine kernel only implemented for orders 0, 1 and 2.")

self.name = f"ArcCosine (order {self.order})"

def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)`
Args:
x (Float[Array, "D"]): The left hand argument of the kernel function's
call.
y (Float[Array, "D"]): The right hand argument of the kernel function's
call
Returns
-------
ScalarFloat: The value of :math:`k(x, y)`.
"""

x = self.slice_input(x)
y = self.slice_input(y)

x_x = self._weighted_prod(x, x)
x_y = self._weighted_prod(x, y)
y_y = self._weighted_prod(y, y)

cos_theta = x_y / jnp.sqrt(x_x * y_y)
jitter = 1e-15 # improve numerical stability
theta = jnp.arccos(jitter + (1 - 2 * jitter) * cos_theta)

K = self._J(theta)
K *= jnp.sqrt(x_x) ** self.order
K *= jnp.sqrt(y_y) ** self.order
K *= self.variance / jnp.pi

return K.squeeze()

def _weighted_prod(
self, x: Float[Array, " D"], y: Float[Array, " D"]
) -> ScalarFloat:
"""Calculate the weighted product between two arguments.
Args:
x (Float[Array, "D"]): The left hand argument.
y (Float[Array, "D"]): The right hand argument.
Returns
-------
ScalarFloat: The value of the weighted product between the two arguments``.
"""
return jnp.inner(self.weight_variance * x, y) + self.bias_variance

def _J(self, theta: ScalarFloat) -> ScalarFloat:
"""Evaluate the angular dependency function corresponding to the desired order.
Args:
theta (Float[Array, "1"]): The weighted angle between inputs.
Returns
-------
Float[Array, "1"]: The value of the angular dependency function`.
"""

if self.order == 0:
return jnp.pi - theta
elif self.order == 1:
return jnp.sin(theta) + (jnp.pi - theta) * jnp.cos(theta)
else:
return 3.0 * jnp.sin(theta) * jnp.cos(theta) + (jnp.pi - theta) * (
1.0 + 2.0 * jnp.cos(theta) ** 2
)
64 changes: 55 additions & 9 deletions tests/test_kernels/test_nonstationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
# ==============================================================================

from dataclasses import is_dataclass
from itertools import (
permutations,
product,
)
from itertools import product
from typing import List

import jax
Expand All @@ -31,13 +28,11 @@
from gpjax.kernels.base import AbstractKernel
from gpjax.kernels.computations import DenseKernelComputation
from gpjax.kernels.nonstationary import (
ArcCosine,
Linear,
Polynomial,
)
from gpjax.linops import (
LinearOperator,
identity,
)
from gpjax.linops import LinearOperator

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -101,7 +96,9 @@ def test_initialization(self, fields: dict, dim: int) -> None:
# Check meta leaves
meta = kernel._pytree__meta
assert not any(f in meta for f in self.static_fields)
assert list(meta.keys()) == sorted(set(fields) - set(self.static_fields))
assert sorted(list(meta.keys())) == sorted(
set(fields) - set(self.static_fields)
)

for field in meta:
# Bijectors
Expand Down Expand Up @@ -170,3 +167,52 @@ class TestPolynomial(BaseTestKernel):
static_fields = ["degree"]
params = {"test_initialization": fields}
default_compute_engine = DenseKernelComputation


class TestArcCosine(BaseTestKernel):
kernel = ArcCosine
fields = prod(
{
"variance": [0.1, 1.0],
"order": [0, 1, 2],
"weight_variance": [0.1, 1.0],
"bias_variance": [0.1, 1.0],
}
)
static_fields = ["order"]
params = {"test_initialization": fields}
default_compute_engine = DenseKernelComputation

@pytest.mark.parametrize("order", [-1, 3], ids=lambda x: f"order={x}")
def test_defaults(self, order: int) -> None:
with pytest.raises(ValueError):
self.kernel(order=order)

@pytest.mark.parametrize("order", [0, 1, 2], ids=lambda x: f"order={x}")
def test_values_by_monte_carlo_in_special_case(self, order: int) -> None:
"""For certain values of weight variance (1.0) and bias variance (0.0), we can test
our calculations using the Monte Carlo expansion of the arccosine kernel, e.g.
see Eq. (1) of https://cseweb.ucsd.edu/~saul/papers/nips09_kernel.pdf.
"""
kernel: AbstractKernel = self.kernel(
weight_variance=jnp.array([1.0, 1.0]), bias_variance=1e-25, order=order
)
key = jr.PRNGKey(123)

# Inputs close(ish) together
a = jnp.array([[0.0, 0.0]])
b = jnp.array([[2.0, 2.0]])

# calc cross-covariance exactly
Kab_exact = kernel.cross_covariance(a, b)

# calc cross-covariance using samples
weights = jax.random.normal(key, (10_000, 2)) # [S, d]
weights_a = jnp.matmul(weights, a.T) # [S, 1]
weights_b = jnp.matmul(weights, b.T) # [S, 1]
H_a = jnp.heaviside(weights_a, 0.5)
H_b = jnp.heaviside(weights_b, 0.5)
integrands = H_a * H_b * (weights_a**order) * (weights_b**order)
Kab_approx = 2.0 * jnp.mean(integrands)

assert jnp.max(Kab_approx - Kab_exact) < 1e-4

0 comments on commit c07935a

Please sign in to comment.