diff --git a/gpjax/kernels/non_euclidean/categorical.py b/gpjax/kernels/non_euclidean/categorical.py index 0ef17632..bbb19406 100644 --- a/gpjax/kernels/non_euclidean/categorical.py +++ b/gpjax/kernels/non_euclidean/categorical.py @@ -15,7 +15,7 @@ from dataclasses import dataclass -from typing import NamedTuple, Union +from typing import NamedTuple import jax.numpy as jnp from jaxtyping import Float, Int import tensorflow_probability.substrates.jax as tfp