Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for literals and generics #6035

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:

def analyze_var_ref(self, var: Var, context: Context) -> Type:
if var.type:
if is_literal_type_like(self.type_context[-1]) and var.name() in {'True', 'False'}:
if self.context_contains_literal_like_type() and var.name() in {'True', 'False'}:
return LiteralType(var.name() == 'True', self.named_type('builtins.bool'))
else:
return var.type
Expand Down Expand Up @@ -1747,14 +1747,14 @@ def analyze_external_member_access(self, member: str, base_type: Type,
def visit_int_expr(self, e: IntExpr) -> Type:
"""Type check an integer literal (trivial)."""
typ = self.named_type('builtins.int')
if is_literal_type_like(self.type_context[-1]):
if self.context_contains_literal_like_type():
return LiteralType(value=e.value, fallback=typ)
return typ

def visit_str_expr(self, e: StrExpr) -> Type:
"""Type check a string literal (trivial)."""
typ = self.named_type('builtins.str')
if is_literal_type_like(self.type_context[-1]):
if self.context_contains_literal_like_type():
return LiteralType(value=e.value, fallback=typ)
return typ

Expand Down Expand Up @@ -3346,6 +3346,10 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type:
return ans
return known_type

def context_contains_literal_like_type(self) -> bool:
"""Returns 'true' if the context contains anything that resembles a LiteralType"""
return any(is_literal_type_like(item) for item in self.type_context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you checking all items in the type context? I think only type_context[-1] should be checked.

Maybe add a test that checks that in this case literal type for 42 is not inferred:

def foo(x: T) -> T: ...
x: Union[List[int], Literal['bad']] = foo([42])

(i.e. this example should type-check without errors).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made that change to support this case:

def expects_literal(x: Literal[3]) -> None: pass
def foo(x: T) -> T: ...
expects_literal(foo(3))  # Should not have an error

I found that without checking the entire context, we got an error on that final line since type_context[-1] is T, not Literal[3].

Maybe I could modify context_contains_literal_like_type so that it starts by checking type_context[-1] and only tries moving backwards if type_context[-1] happened to be a TypeVar or something?

Or alternatively, modify this method so that it only tries inferring a Literal if the fallback matches -- so since 42 is an int and Literal['bad']'s fallback is str, we wouldn't error in your example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But what if the return type of foo is something other than T? Then it might not be correct to infer a literal type, right? I suspect that accessing type context below the top item does not really make sense, since we don't know what's "between" the type context items.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Michael0x2a The outer context doesn't propagate inside in your example because there is special casing for functions that return a plain type variable, for them, only generic instance context is used for inference. If you really want your example to work, then you can add or isinstance(ctx, LiteralType) (the current check is anyway pure heuristic). As @JukkaL explained, checking all context stack doesn't really make sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JukkaL and @ilevkivskyi -- ok, I updated the PR. I ended up using basically the fix Ivan suggested (and was pleasantly surprised to discover this seemed to fix all but one minor edge case).



def has_any_type(t: Type) -> bool:
"""Whether t contains an Any type"""
Expand Down Expand Up @@ -3627,5 +3631,8 @@ def is_literal_type_like(t: Optional[Type]) -> bool:
return True
elif isinstance(t, UnionType):
return any(is_literal_type_like(item) for item in t.items)
elif isinstance(t, TypeVarType):
return (is_literal_type_like(t.upper_bound)
or any(is_literal_type_like(item) for item in t.values))
else:
return False
195 changes: 195 additions & 0 deletions test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,18 @@ b: bt # E: Invalid type "__main__.bt"
[builtins fixtures/set.pyi]
[out]

[case testLiteralDisallowTypeVar]
from typing import TypeVar
from typing_extensions import Literal

T = TypeVar('T')

at = Literal[T] # E: Parameter 1 of Literal[...] is invalid
a: at

def foo(b: Literal[T]) -> T: pass # E: Parameter 1 of Literal[...] is invalid
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved
[out]


--
-- Test mixing and matching literals with other types
Expand Down Expand Up @@ -1180,3 +1192,186 @@ b = b * a
c = c.strip() # E: Incompatible types in assignment (expression has type "str", variable has type "Literal['foo']")
[builtins fixtures/ops.pyi]
[out]


--
-- Test to make sure literals interact with generics as expected
--

[case testLiteralAndGenericsWithSimpleFunctions]
from typing import TypeVar
from typing_extensions import Literal

T = TypeVar('T')
def foo(x: T) -> T: pass
def expects_literal(x: Literal[3]) -> None: pass
def expects_int(x: int) -> None: pass

a: Literal[3]
reveal_type(foo(3)) # E: Revealed type is 'builtins.int*'
reveal_type(foo(a)) # E: Revealed type is 'Literal[3]'

expects_literal(3)
expects_literal(foo(3))
expects_literal(foo(foo(3)))

expects_literal(a)
expects_literal(foo(a))
expects_literal(foo(foo(a)))

expects_literal(5) # E: Argument 1 to "expects_literal" has incompatible type "Literal[5]"; expected "Literal[3]"
expects_literal(foo(5)) # E: Argument 1 to "expects_literal" has incompatible type "Literal[5]"; expected "Literal[3]"
expects_literal(foo(foo(5))) # E: Argument 1 to "expects_literal" has incompatible type "Literal[5]"; expected "Literal[3]"

expects_int(a)
expects_int(foo(a))
expects_int(foo(foo(a)))
[out]

[case testLiteralAndGenericsWithSimpleClasses]
from typing import TypeVar, Generic
from typing_extensions import Literal

T = TypeVar('T')
class Wrapper(Generic[T]):
def __init__(self, val: T) -> None:
self.val = val
def inner(self) -> T:
return self.val

def expects_literal(a: Literal[3]) -> None: pass
def expects_literal_wrapper(x: Wrapper[Literal[3]]) -> None: pass

a: Literal[3]
reveal_type(Wrapper(3)) # E: Revealed type is '__main__.Wrapper[builtins.int*]'
reveal_type(Wrapper[Literal[3]](3)) # E: Revealed type is '__main__.Wrapper[Literal[3]]'
reveal_type(Wrapper(a)) # E: Revealed type is '__main__.Wrapper[Literal[3]]'

expects_literal(Wrapper(3).inner())
expects_literal(Wrapper(a).inner())

expects_literal_wrapper(Wrapper(3))
expects_literal_wrapper(Wrapper(a))

expects_literal(Wrapper(5).inner()) # E: Argument 1 to "expects_literal" has incompatible type "Literal[5]"; expected "Literal[3]"
expects_literal_wrapper(Wrapper(5)) # E: Argument 1 to "Wrapper" has incompatible type "Literal[5]"; expected "Literal[3]"
[out]

[case testLiteralAndGenericsRespectsUpperBound]
from typing import TypeVar
from typing_extensions import Literal

TLiteral = TypeVar('TLiteral', bound=Literal[3])
TInt = TypeVar('TInt', bound=int)

def func1(x: TLiteral) -> TLiteral: pass
def func2(x: TInt) -> TInt: pass

def func3(x: TLiteral) -> TLiteral:
y = func2(x)
return y
def func4(x: TInt) -> TInt:
y = func1(x) # E: Value of type variable "TLiteral" of "func1" cannot be "TInt"
return y

a: Literal[3]
b: Literal[4]
c: int

reveal_type(func1) # E: Revealed type is 'def [TLiteral <: Literal[3]] (x: TLiteral`-1) -> TLiteral`-1'

reveal_type(func1(3)) # E: Revealed type is 'Literal[3]'
reveal_type(func1(a)) # E: Revealed type is 'Literal[3]'
reveal_type(func1(4)) # E: Revealed type is 'Literal[4]' \
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]"
reveal_type(func1(b)) # E: Revealed type is 'Literal[4]' \
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]"
reveal_type(func1(c)) # E: Revealed type is 'builtins.int*' \
# E: Value of type variable "TLiteral" of "func1" cannot be "int"

reveal_type(func2(3)) # E: Revealed type is 'builtins.int*'
reveal_type(func2(a)) # E: Revealed type is 'Literal[3]'
reveal_type(func2(4)) # E: Revealed type is 'builtins.int*'
reveal_type(func2(b)) # E: Revealed type is 'Literal[4]'
reveal_type(func2(c)) # E: Revealed type is 'builtins.int*'
[out]

[case testLiteralAndGenericsRespectsValueRestriction]
from typing import TypeVar
from typing_extensions import Literal

TLiteral = TypeVar('TLiteral', Literal[3], Literal['foo'])
TNormal = TypeVar('TNormal', int, str)

def func1(x: TLiteral) -> TLiteral: pass
def func2(x: TNormal) -> TNormal: pass

def func3(x: TLiteral) -> TLiteral:
y = func2(x)
return y # E: Incompatible return value type (got "int", expected "Literal[3]") \
# E: Incompatible return value type (got "str", expected "Literal['foo']")
def func4(x: TNormal) -> TNormal:
y = func1(x) # E: Value of type variable "TLiteral" of "func1" cannot be "int" \
# E: Value of type variable "TLiteral" of "func1" cannot be "str"
return y

i1: Literal[3]
i2: Literal[4]
i: int

s1: Literal['foo']
s2: Literal['bar']
s: str

reveal_type(func1) # E: Revealed type is 'def [TLiteral in (Literal[3], Literal['foo'])] (x: TLiteral`-1) -> TLiteral`-1'

reveal_type(func1(3)) # E: Revealed type is 'Literal[3]'
reveal_type(func1(i1)) # E: Revealed type is 'Literal[3]'
reveal_type(func1(4)) # E: Revealed type is 'Literal[4]' \
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]"
reveal_type(func1(i2)) # E: Revealed type is 'Literal[4]' \
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal[4]"
reveal_type(func1(i)) # E: Revealed type is 'builtins.int*' \
# E: Value of type variable "TLiteral" of "func1" cannot be "int"

reveal_type(func1("foo")) # E: Revealed type is 'Literal['foo']'
reveal_type(func1(s1)) # E: Revealed type is 'Literal['foo']'
reveal_type(func1("bar")) # E: Revealed type is 'Literal['bar']' \
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal['bar']"
reveal_type(func1(s2)) # E: Revealed type is 'Literal['bar']' \
# E: Value of type variable "TLiteral" of "func1" cannot be "Literal['bar']"
reveal_type(func1(s)) # E: Revealed type is 'builtins.str*' \
# E: Value of type variable "TLiteral" of "func1" cannot be "str"

reveal_type(func2(3)) # E: Revealed type is 'builtins.int*'
reveal_type(func2(i1)) # E: Revealed type is 'builtins.int*'
reveal_type(func2(4)) # E: Revealed type is 'builtins.int*'
reveal_type(func2(i2)) # E: Revealed type is 'builtins.int*'
reveal_type(func2("foo")) # E: Revealed type is 'builtins.str*'
reveal_type(func2(s1)) # E: Revealed type is 'builtins.str*'
reveal_type(func2("bar")) # E: Revealed type is 'builtins.str*'
reveal_type(func2(s2)) # E: Revealed type is 'builtins.str*'
[out]

[case testLiteralAndGenericsWithOverloads]
from typing import TypeVar, overload, Union
from typing_extensions import Literal

@overload
def func1(x: Literal[4]) -> Literal[19]: ...
@overload
def func1(x: int) -> int: ...
def func1(x: int) -> int: pass

T = TypeVar('T')
def identity(x: T) -> T: pass

a: Literal[4]
b: Literal[5]

reveal_type(func1(identity(4))) # E: Revealed type is 'Literal[19]'
reveal_type(func1(identity(5))) # E: Revealed type is 'builtins.int'
reveal_type(func1(identity(a))) # E: Revealed type is 'Literal[19]'
reveal_type(func1(identity(b))) # E: Revealed type is 'builtins.int'

[out]