Skip to content

Commit

Permalink
fix dataset static typing
Browse files Browse the repository at this point in the history
  • Loading branch information
frazane committed Aug 2, 2023
1 parent 06f82ff commit a1676a8
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions gpjax/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

from dataclasses import dataclass

from beartype.typing import Optional
from beartype.typing import (
Optional,
Union,
)
import jax.numpy as jnp
from jaxtyping import (
Bool,
Expand All @@ -30,7 +33,6 @@ class _Missing:
"""Sentinel class for not-yet-computed mask"""



@dataclass
class Dataset(Pytree):
r"""Base class for datasets.
Expand All @@ -49,7 +51,7 @@ class Dataset(Pytree):

X: Optional[Num[Array, "N D"]] = None
y: Optional[Num[Array, "N Q"]] = None
mask: Bool[Array, "N Q"] | None = _Missing()
mask: Optional[Union[Bool[Array, "N Q"], None]] = _Missing()

def __post_init__(self) -> None:
r"""Checks that the shapes of $`X`$ and $`y`$ are compatible."""
Expand Down

0 comments on commit a1676a8

Please sign in to comment.