From fac885bf0afd2cb183afc207e1c3af52761b4be1 Mon Sep 17 00:00:00 2001 From: Tom Savage Date: Sun, 6 Aug 2023 21:02:04 +0100 Subject: [PATCH 1/2] added precision warning feature on dataset creation --- gpjax/dataset.py | 24 +++++++++++++++++++++++- tests/test_dataset.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/gpjax/dataset.py b/gpjax/dataset.py index 4ffd93b7..45b5a694 100644 --- a/gpjax/dataset.py +++ b/gpjax/dataset.py @@ -14,6 +14,7 @@ # ============================================================================== from dataclasses import dataclass +import warnings from beartype.typing import Optional import jax.numpy as jnp @@ -37,8 +38,10 @@ class Dataset(Pytree): y: Optional[Num[Array, "N Q"]] = None def __post_init__(self) -> None: - r"""Checks that the shapes of $`X`$ and $`y`$ are compatible.""" + r"""Checks that the shapes of $`X`$ and $`y`$ are compatible, + and provides warnings regarding the precision of $`X`$ and $`y`$.""" _check_shape(self.X, self.y) + _check_precision(self.X, self.y) def __repr__(self) -> str: r"""Returns a string representation of the dataset.""" @@ -106,6 +109,25 @@ def _check_shape( ) +def _check_precision( + X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]] +) -> None: + r"""Checks the precision of $`X`$ and $`y`.""" + if X is not None and X.dtype != jnp.float64: + warnings.warn( + "Warning: X is not of type float64. " + f"Got X.dtype={X.dtype}. This may lead to numerical instability. ", + stacklevel=2, + ) + + if y is not None and y.dtype != jnp.float64: + warnings.warn( + "Warning: y is not of type float64." + f"Got y.dtype={y.dtype}. This may lead to numerical instability.", + stacklevel=2, + ) + + __all__ = [ "Dataset", ] diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 5a76a8b9..ee75a339 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -22,12 +22,15 @@ except ImportError: ValidationErrors = ValueError +from jax.config import config import jax.numpy as jnp import jax.tree_util as jtu import pytest from gpjax.dataset import Dataset +config.update("jax_enable_x64", True) + @pytest.mark.parametrize("n", [1, 2, 10]) @pytest.mark.parametrize("out_dim", [1, 2, 10]) @@ -36,6 +39,7 @@ def test_dataset_init(n: int, in_dim: int, out_dim: int) -> None: # Create dataset x = jnp.ones((n, in_dim)) y = jnp.ones((n, out_dim)) + D = Dataset(X=x, y=y) # Test dataset shapes @@ -154,3 +158,36 @@ def test_y_none(n: int, in_dim: int) -> None: # Check tree flatten assert jtu.tree_leaves(D) == [x] + + +@pytest.mark.parametrize( + ("prec_x", "prec_y"), + [ + (jnp.float32, jnp.float64), + (jnp.float64, jnp.float32), + (jnp.float32, jnp.float32), + ], +) +@pytest.mark.parametrize("n", [1, 2, 10]) +@pytest.mark.parametrize("in_dim", [1, 2, 10]) +@pytest.mark.parametrize("out_dim", [1, 2, 10]) +def test_precision_warning( + n: int, in_dim: int, out_dim: int, prec_x: jnp.dtype, prec_y: jnp.dtype +) -> None: + # Create dataset + x = jnp.ones((n, in_dim)).astype(prec_x) + y = jnp.ones((n, out_dim)).astype(prec_y) + + # Check for warnings if dtypes are not float64 + expected_warnings = 0 + if prec_x != jnp.float64: + expected_warnings += 1 + if prec_y != jnp.float64: + expected_warnings += 1 + + with pytest.warns( + UserWarning, match="Warning:.*is not of type float64.*" + ) as record: + Dataset(X=x, y=y) + + assert len(record) == expected_warnings From c02f5b7065ccea9b0982a349bae6f96e0f991f90 Mon Sep 17 00:00:00 2001 From: Tom Savage Date: Sun, 6 Aug 2023 22:05:06 +0100 Subject: [PATCH 2/2] changed warning message and updated matching test --- gpjax/dataset.py | 4 ++-- tests/test_dataset.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gpjax/dataset.py b/gpjax/dataset.py index 45b5a694..d2364b28 100644 --- a/gpjax/dataset.py +++ b/gpjax/dataset.py @@ -115,14 +115,14 @@ def _check_precision( r"""Checks the precision of $`X`$ and $`y`.""" if X is not None and X.dtype != jnp.float64: warnings.warn( - "Warning: X is not of type float64. " + "X is not of type float64. " f"Got X.dtype={X.dtype}. This may lead to numerical instability. ", stacklevel=2, ) if y is not None and y.dtype != jnp.float64: warnings.warn( - "Warning: y is not of type float64." + "y is not of type float64." f"Got y.dtype={y.dtype}. This may lead to numerical instability.", stacklevel=2, ) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ee75a339..38c27752 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -186,7 +186,7 @@ def test_precision_warning( expected_warnings += 1 with pytest.warns( - UserWarning, match="Warning:.*is not of type float64.*" + UserWarning, match=".* is not of type float64.*" ) as record: Dataset(X=x, y=y)