Skip to content

Commit

Permalink
Enabling kernels to use PyTree inputs
Browse files Browse the repository at this point in the history
First implementation and demo notebook
  • Loading branch information
ingmarschuster committed Jun 1, 2023
1 parent 6138af4 commit a4e877f
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 10 deletions.
89 changes: 89 additions & 0 deletions docs/examples/pytree_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# %%
import jax.numpy as jnp
import datasets as ds
import gpjax as gpx
from jax import jit
import optax as ox
import jax.random as jr
from jaxtyping import PyTree
import matplotlib.pyplot as plt

# %% [markdown]
# Now load a graph dataset and pad it

# %%
gd = ds.load_dataset("graphs-datasets/AQSOL")

gd = gd.map(
lambda x: {
"num_edges": len(x["edge_index"][0]),
}
)
gd.set_format("jax")

max_num_edges = max([gd[i]["num_edges"].max() for i in gd])
max_num_nodes = max([gd[i]["num_nodes"].max() for i in gd])

small_gd = (
gd["train"]
.select(range(100))
.map(
lambda x: {
"num_edges": len(x["edge_index"][0]),
}
)
)


def pad_edge_attr_node_feat(x):
nf = (
jnp.zeros(max_num_nodes).at[: len(x["node_feat"])].set(x["node_feat"].squeeze())
)
ea = (
jnp.zeros(max_num_edges).at[: len(x["edge_attr"])].set(x["edge_attr"].squeeze())
)
return {"node_feat": nf, "edge_attr": ea}


small_gd = small_gd.map(pad_edge_attr_node_feat)

# prepare the dataset for GPjax
D = gpx.Dataset(X={i: small_gd[i] for i in ("node_feat", "edge_attr")}, y=small_gd["y"])

# %% [markdown]
# Now define a naive Graph kernel that takes node and edge features


# %%
class GraphKern(gpx.AbstractKernel):
def __call__(self, x1: PyTree, x2: PyTree, **kwargs):
return gpx.kernels.RBF()(x1["node_feat"], x2["node_feat"]) + gpx.kernels.RBF()(
x1["edge_attr"], x2["edge_attr"]
)


# %% [markdown]
# And we're ready to fit a model!

# %%
meanf = gpx.mean_functions.Zero()
prior = gpx.Prior(mean_function=meanf, kernel=GraphKern())
likelihood = gpx.Gaussian(num_datapoints=D.n)
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))
likelihood = gpx.Gaussian(num_datapoints=D.n)
posterior = prior * likelihood

opt_posterior, mll_history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=600,
safe=True,
key=jr.PRNGKey(0),
)

# %%
plt.plot(mll_history)

# %%
38 changes: 31 additions & 7 deletions gpjax/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# ==============================================================================

from dataclasses import dataclass
from typing import TypeVar, Union, Callable

from beartype.typing import Optional
import jax.numpy as jnp
import jax
from jaxtyping import Num
from simple_pytree import Pytree

Expand All @@ -43,7 +45,7 @@ def __post_init__(self) -> None:
def __repr__(self) -> str:
r"""Returns a string representation of the dataset."""
repr = (
f"- Number of observations: {self.n}\n- Input dimension:"
f"- Number of observations: {self.n}\n- Input dimension (sum over PyTree):"
f" {self.in_dim}\n- Output dimension: {self.out_dim}"
)
return repr
Expand Down Expand Up @@ -72,12 +74,14 @@ def __add__(self, other: "Dataset") -> "Dataset":
@property
def n(self) -> int:
r"""Number of observations."""
return self.X.shape[0]
return jax.tree_util.tree_leaves(self.X)[0].shape[0]

@property
def in_dim(self) -> int:
r"""Dimension of the inputs, $`X`$."""
return self.X.shape[1]
return jax.tree_util.tree_reduce(
lambda a, b: a + b, jax.tree_map(lambda a: a.shape[1], self.X), 0
)

@property
def out_dim(self) -> int:
Expand All @@ -89,15 +93,17 @@ def _check_shape(
X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]
) -> None:
r"""Checks that the shapes of $`X`$ and $`y`$ are compatible."""
if X is not None and y is not None and X.shape[0] != y.shape[0]:
len_ok, X_length = _check_all_leaves_const(lambda a: len(a), len(y), X)
if X is not None and y is not None and not len_ok:
raise ValueError(
"Inputs, X, and outputs, y, must have the same number of rows."
f" Got X.shape={X.shape} and y.shape={y.shape}."
f" Got len(y)={len(y)} and len(X)={X_length}."
)

if X is not None and X.ndim != 2:
dim_ok, X_dim = _check_all_leaves_const(lambda a: a.ndim, 2, X)
if X is not None and not dim_ok:
raise ValueError(
f"Inputs, X, must be a 2-dimensional array. Got X.ndim={X.ndim}."
f"Inputs, X, must be a 2-dimensional array. Got X.ndim={X_dim}."
)

if y is not None and y.ndim != 2:
Expand All @@ -106,6 +112,24 @@ def _check_shape(
)


T = TypeVar("T")


def _check_all_leaves_const(
extract_value: Callable[[any], T],
equal_to: T,
X: Optional[Union[Pytree, Num[Array, "..."]]],
) -> bool:
values = jax.tree_map(extract_value, X)

return (
jax.tree_util.tree_reduce(
lambda a, b: a and b, jax.tree_map(lambda a: a == equal_to, values), True
),
values,
)


__all__ = [
"Dataset",
]
5 changes: 3 additions & 2 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Callable,
Optional,
)
import jax
import jax.numpy as jnp
from jax.random import (
PRNGKey,
Expand Down Expand Up @@ -481,7 +482,7 @@ def predict(
x, y, n = train_data.X, train_data.y, train_data.n

# Unpack test inputs
t, n_test = test_inputs, test_inputs.shape[0]
t, n_test = test_inputs, jax.tree_util.tree_leaves(test_inputs)[0].shape[0]

# Observation noise o²
obs_noise = self.likelihood.obs_noise
Expand Down Expand Up @@ -655,7 +656,7 @@ def predict(
Lx = Kxx.to_root()

# Unpack test inputs
t, n_test = test_inputs, test_inputs.shape[0]
t, n_test = test_inputs, jax.tree_util.tree_leaves(test_inputs)[0].shape[0]

# Compute terms of the posterior predictive distribution
Ktx = kernel.cross_covariance(t, x)
Expand Down
3 changes: 2 additions & 1 deletion gpjax/mean_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Float,
Num,
)
import jax

from gpjax.base import (
Module,
Expand Down Expand Up @@ -147,7 +148,7 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
-------
Float[Array, "1"]: The evaluated mean function.
"""
return jnp.ones((x.shape[0], 1)) * self.constant
return jnp.ones((jax.tree_util.tree_leaves(x)[0].shape[0], 1)) * self.constant


@dataclasses.dataclass
Expand Down

0 comments on commit a4e877f

Please sign in to comment.