Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precision check feature #347

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion gpjax/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================

from dataclasses import dataclass
import warnings

from beartype.typing import Optional
import jax.numpy as jnp
Expand All @@ -37,8 +38,10 @@ class Dataset(Pytree):
y: Optional[Num[Array, "N Q"]] = None

def __post_init__(self) -> None:
r"""Checks that the shapes of $`X`$ and $`y`$ are compatible."""
r"""Checks that the shapes of $`X`$ and $`y`$ are compatible,
and provides warnings regarding the precision of $`X`$ and $`y`$."""
_check_shape(self.X, self.y)
_check_precision(self.X, self.y)

def __repr__(self) -> str:
r"""Returns a string representation of the dataset."""
Expand Down Expand Up @@ -106,6 +109,25 @@ def _check_shape(
)


def _check_precision(
X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]
) -> None:
r"""Checks the precision of $`X`$ and $`y`."""
if X is not None and X.dtype != jnp.float64:
warnings.warn(
"X is not of type float64. "
f"Got X.dtype={X.dtype}. This may lead to numerical instability. ",
stacklevel=2,
)

if y is not None and y.dtype != jnp.float64:
warnings.warn(
"y is not of type float64."
f"Got y.dtype={y.dtype}. This may lead to numerical instability.",
stacklevel=2,
)


__all__ = [
"Dataset",
]
37 changes: 37 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
except ImportError:
ValidationErrors = ValueError

from jax.config import config
import jax.numpy as jnp
import jax.tree_util as jtu
import pytest

from gpjax.dataset import Dataset

config.update("jax_enable_x64", True)


@pytest.mark.parametrize("n", [1, 2, 10])
@pytest.mark.parametrize("out_dim", [1, 2, 10])
Expand All @@ -36,6 +39,7 @@ def test_dataset_init(n: int, in_dim: int, out_dim: int) -> None:
# Create dataset
x = jnp.ones((n, in_dim))
y = jnp.ones((n, out_dim))

D = Dataset(X=x, y=y)

# Test dataset shapes
Expand Down Expand Up @@ -154,3 +158,36 @@ def test_y_none(n: int, in_dim: int) -> None:

# Check tree flatten
assert jtu.tree_leaves(D) == [x]


@pytest.mark.parametrize(
("prec_x", "prec_y"),
[
(jnp.float32, jnp.float64),
(jnp.float64, jnp.float32),
(jnp.float32, jnp.float32),
],
)
@pytest.mark.parametrize("n", [1, 2, 10])
@pytest.mark.parametrize("in_dim", [1, 2, 10])
@pytest.mark.parametrize("out_dim", [1, 2, 10])
def test_precision_warning(
n: int, in_dim: int, out_dim: int, prec_x: jnp.dtype, prec_y: jnp.dtype
) -> None:
# Create dataset
x = jnp.ones((n, in_dim)).astype(prec_x)
y = jnp.ones((n, out_dim)).astype(prec_y)

# Check for warnings if dtypes are not float64
expected_warnings = 0
if prec_x != jnp.float64:
expected_warnings += 1
if prec_y != jnp.float64:
expected_warnings += 1

with pytest.warns(
UserWarning, match=".* is not of type float64.*"
) as record:
Dataset(X=x, y=y)

assert len(record) == expected_warnings