diff --git a/pydantic/_internal/_typing_extra.py b/pydantic/_internal/_typing_extra.py index 43bcc07b35..4526cf29e7 100644 --- a/pydantic/_internal/_typing_extra.py +++ b/pydantic/_internal/_typing_extra.py @@ -307,6 +307,7 @@ def get_function_type_hints( globalns = add_module_globals(function) type_hints = {} + type_params: tuple[Any] = getattr(function, '__type_params__', ()) # type: ignore for name, value in annotations.items(): if include_keys is not None and name not in include_keys: continue @@ -315,7 +316,7 @@ def get_function_type_hints( elif isinstance(value, str): value = _make_forward_ref(value) - type_hints[name] = eval_type_backport(value, globalns, types_namespace) + type_hints[name] = eval_type_backport(value, globalns, types_namespace, type_params) return type_hints diff --git a/pydantic/_internal/_validate_call.py b/pydantic/_internal/_validate_call.py index 664c063013..3fae2d10e4 100644 --- a/pydantic/_internal/_validate_call.py +++ b/pydantic/_internal/_validate_call.py @@ -23,7 +23,13 @@ class ValidateCallWrapper: '__dict__', # required for __module__ ) - def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool): + def __init__( + self, + function: Callable[..., Any], + config: ConfigDict | None, + validate_return: bool, + namespace: dict[str, Any] | None, + ): if isinstance(function, partial): func = function.func schema_type = func @@ -36,7 +42,16 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali self.__qualname__ = function.__qualname__ self.__module__ = function.__module__ - namespace = _typing_extra.add_module_globals(function, None) + global_ns = _typing_extra.add_module_globals(function, None) + # TODO: this is a bit of a hack, we should probably have a better way to handle this + # specifically, we shouldn't be pumping the namespace full of type_params + # when we take namespace and type_params arguments in eval_type_backport + type_params = getattr(schema_type, '__type_params__', ()) + namespace = { + **{param.__name__: param for param in type_params}, + **(global_ns or {}), + **(namespace or {}), + } config_wrapper = ConfigWrapper(config) gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace) schema = gen_schema.clean_schema(gen_schema.generate_schema(function)) diff --git a/pydantic/validate_call_decorator.py b/pydantic/validate_call_decorator.py index afa109c211..5314c9207c 100644 --- a/pydantic/validate_call_decorator.py +++ b/pydantic/validate_call_decorator.py @@ -5,7 +5,7 @@ import functools from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload -from ._internal import _validate_call +from ._internal import _typing_extra, _validate_call __all__ = ('validate_call',) @@ -46,12 +46,14 @@ def validate_call( Returns: The decorated function. """ + local_ns = _typing_extra.parent_frame_namespace() def validate(function: AnyCallableT) -> AnyCallableT: if isinstance(function, (classmethod, staticmethod)): name = type(function).__name__ raise TypeError(f'The `@{name}` decorator should be applied after `@validate_call` (put `@{name}` on top)') - validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return) + + validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return, local_ns) @functools.wraps(function) def wrapper_function(*args, **kwargs): diff --git a/tests/test_validate_call.py b/tests/test_validate_call.py index f4d466ff23..63371fa7b8 100644 --- a/tests/test_validate_call.py +++ b/tests/test_validate_call.py @@ -1,6 +1,7 @@ import asyncio import inspect import re +import sys from datetime import datetime, timezone from functools import partial from typing import Any, List, Tuple @@ -803,3 +804,59 @@ def foo(bar: 'list[int | str]') -> 'list[int | str]': 'input': {'not a str or int'}, }, ] + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason='requires Python 3.12+ for PEP 695 syntax with generics') +def test_validate_call_with_pep_695_syntax() -> None: + """Note: validate_call still doesn't work properly with generics, see https://github.com/pydantic/pydantic/issues/7796. + + This test is just to ensure that the syntax is accepted and doesn't raise a NameError.""" + globs = {} + exec( + """ +from typing import Iterable +from pydantic import validate_call + +@validate_call +def find_max_no_validate_return[T](args: Iterable[T]) -> T: + return sorted(args, reverse=True)[0] + +@validate_call(validate_return=True) +def find_max_validate_return[T](args: Iterable[T]) -> T: + return sorted(args, reverse=True)[0] + """, + globs, + ) + functions = [globs['find_max_no_validate_return'], globs['find_max_validate_return']] + for find_max in functions: + assert len(find_max.__type_params__) == 1 + assert find_max([1, 2, 10, 5]) == 10 + + with pytest.raises(ValidationError): + find_max(1) + + +class M0(BaseModel): + z: int + + +M = M0 + + +def test_uses_local_ns(): + class M1(BaseModel): + y: int + + M = M1 # noqa: F841 + + def foo(): + class M2(BaseModel): + z: int + + M = M2 + + @validate_call + def bar(m: M) -> M: + return m + + assert bar({'z': 1}) == M2(z=1)