Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some type annotations causing failing tests #456

Merged
merged 2 commits into from
Jul 1, 2024

Conversation

stephen-huan
Copy link
Contributor

Type of changes

  • Bug fix
  • New feature
  • Documentation / docstrings
  • Tests
  • Other

Checklist

  • I've formatted the new code by running poetry run pre-commit run --all-files --show-diff-on-failure before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

Fix some incorrect type annotations causing errors when running with the most recent version of beartype (0.18.5).

ImportError while loading conftest '/.../GPJax/tests/conftest.py'.
tests/conftest.py:8: in <module>
    import gpjax  # noqa: F401
/nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_import_hook.py:223: in exec_module
    return super().exec_module(module)
gpjax/__init__.py:15: in <module>
    from gpjax import (
/nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_import_hook.py:223: in exec_module
    return super().exec_module(module)
gpjax/base/__init__.py:16: in <module>
    from gpjax.base.module import (
/nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_import_hook.py:223: in exec_module
    return super().exec_module(module)
gpjax/base/module.py:1: in <module>
    ???
/nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_decorator.py:397: in jaxtyped
    full_fn = typechecker(full_fn)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/decorcache.py:77: in beartype
    return beartype_object(obj, conf)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/decorcore.py:87: in beartype_object
    _beartype_object_fatal(obj, conf=conf, **kwargs)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/decorcore.py:136: in _beartype_object_fatal
    beartype_nontype(obj, **kwargs)  # type: ignore[return-value]
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/_decornontype.py:182: in beartype_nontype
    return beartype_func(obj, **kwargs)  # type: ignore[return-value]
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/_decornontype.py:247: in beartype_func
    func_wrapper_code = generate_code(bear_call)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/wrap/wrapmain.py:122: in generate_code
    code_check_return = _code_check_return(bear_call)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/wrap/_wrapreturn.py:237: in code_check_return
    reraise_exception_placeholder(
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/error/utilerrraise.py:138: in reraise_exception_placeholder
    raise exception.with_traceback(exception.__traceback__)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_decor/wrap/_wrapreturn.py:174: in code_check_return
    ) = make_code_raiser_func_pith_check(  # type: ignore[assignment]
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:250: in _callable_cached
    raise exception
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:242: in _callable_cached
    return_value = args_flat_to_return_value[args_flat] = func(
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_check/checkmake.py:311: in make_code_raiser_func_pith_check
    ) = make_check_expr(hint, conf, cls_stack)
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:250: in _callable_cached
    raise exception
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py:242: in _callable_cached
    return_value = args_flat_to_return_value[args_flat] = func(
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_check/code/codemake.py:1578: in make_check_expr
    hint_childs = get_hint_pep484585_args(  # type: ignore[assignment]
/nix/store/anb9by7z6p5fmpf2wd47yc5wjv8yf1l9-python3.11-beartype-0.18.5/lib/python3.11/site-packages/beartype/_util/hint/pep/proposal/pep484585/utilpep484585.py:158: in get_hint_pep484585_args
    raise BeartypeDecorHintPep585Exception(
E   beartype.roar.BeartypeDecorHintPep585Exception: Class method gpjax.base.module.check_return() return PEP 585 type hint dict[str] not subscripted (indexed) by 2 arguments (i.e., subscripted by 1 != 2 arguments).
=================================== FAILURES ===================================
_________ test_thompson_sampling_non_conjugate_posterior_raises_error __________

args = (ThompsonSampling(num_features=100),)
kwargs = {'datasets': {'OBJECTIVE': - Number of observations: 10
- Input dimension: 1}, 'key': Array((), dtype=key<fry>) overla...       [-0.35197108],
       [-0.46512772],
       [ 0.11289125]], dtype=float64), key=Array([ 0, 42], dtype=uint32))}}
bound = <BoundArguments (self=ThompsonSampling(num_features=100), posteriors={'OBJECTIVE': NonConjugatePosterior(prior=Prior(k...s={'OBJECTIVE': - Number of observations: 10
- Input dimension: 1}, key=Array((), dtype=key<fry>) overlaying:
[ 0 42])>
memos = ({}, {}, {}, {'datasets': {'OBJECTIVE': - Number of observations: 10
- Input dimension: 1}, 'key': Array((), dtype=key...      [ 0.11289125]], dtype=float64), key=Array([ 0, 42], dtype=uint32))}, 'self': ThompsonSampling(num_features=100)})
argmsg = "\nThe problem arose whilst typechecking parameter 'posteriors'.\nActual value: { 'OBJECTIVE': NonConjugatePosterior(p...       key=Array([ 0, 42], dtype=uint32))}\nExpected
 type: collections.abc.Mapping[str, gpjax.gps.ConjugatePosterior]."
name = 'build_utility_function'
param_values = "{ 'datasets': {'OBJECTIVE': - Number of observations: 10\n- Input dimension: 1},\n  'key': Array((), dtype=key<fry>) ...                                   key=Array([
 0, 42], dtype=uint32))},\n  'self': ThompsonSampling(num_features=100)}"
param_hints = "(self, posteriors: collections.abc.Mapping[str, gpjax.gps.ConjugatePosterior], datasets: collections.abc.Mapping[str, gpjax.dataset.Dataset], key: Union[UInt32[Array,
'2'], Key[Array, '']])"
msg = "Type-check error whilst checking the parameters of build_utility_function.\nThe problem arose whilst typechecking par...or], datasets: collections.abc.Mapping[str, gpjax.datas
et.Dataset], key: Union[UInt32[Array, '2'], Key[Array, '']]).\n"

    @ft.wraps(fn)
    def wrapped_fn(*args, **kwargs):
        if config.jaxtyping_disable:
            return fn(*args, **kwargs)

        # Raise bind-time errors before we do any shape analysis. (I.e. skip
        # the pointless jaxtyping information for a non-typechecking failure.)
        bound = param_signature.bind(*args, **kwargs)

        memos = push_shape_memo(bound.arguments)
        try:
            # First type-check just the parameters before the function is
            # called.
            try:
>               param_fn(*args, **kwargs)
E               beartype.roar.BeartypeCallHintParamViolation: Method gpjax.decision_making.utility_functions.thompson_sampling.check_params() parameter posteriors={'OBJECTIVE': NonCo
njugatePosterior(prior=Prior(kernel=RBF(compute_engine=DenseKernelCompu...))} violates type hint collections.abc.Mapping[str, gpjax.gps.ConjugatePosterior], as dict key str 'OBJECTIV
E' value <protocol "gpjax.gps.NonConjugatePosterior"> "NonConjugatePosterior(prior=Prior(kernel=RBF(compute_engine=DenseKernelComputation(), act...))" not instance of <protocol "gpja
x.gps.ConjugatePosterior">.

/nix/store/kf64ch0bfisffrcy9w44dd2z3c6b7cq9-python3.11-jaxtyping-0.2.28/lib/python3.11/site-packages/jaxtyping/_decorator.py:418: BeartypeCallHintParamViolation

@thomaspinder thomaspinder merged commit 817c153 into JaxGaussianProcesses:main Jul 1, 2024
11 checks passed
@stephen-huan stephen-huan deleted the fix-tests branch July 1, 2024 19:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants