Skip to content

Commit

Permalink
Ensures overloads are ordered from narrow to broad (python#2138)
Browse files Browse the repository at this point in the history
This commit reorders any overloads where the first overload was
"shadowing" the second, preventing it from ever being matched by type
checkers that work by selecting the first matching overload alternative.

For example, the first overload alternative below is strictly broader
then the second, preventing it from ever being selected:

    class Parent: pass
    class Child(Parent): pass

    @overload
    def foo(x: *int) -> Parent: ...
    @overload
    def foo(x: int, y: int) -> Child: ...

The correct thing to do is to either delete the second overload or
rearrange them to look like this:

    @overload
    def foo(x: int, y: int) -> Child: ...
    @overload
    def foo(x: *int) -> Parent: ...

Rationale: I'm currently [working on a proposal][0] that would amend
PEP 484 to (a) mandate type checkers check overloads in order and
(b) prohibit overloads where an earlier alternative completely shadows
a later one.

  [0]: python/typing#253 (comment)

This would prohibit overloads that look like the example below, where
the first alternative completely shadows the second.

I figured it would be a good idea to make these changes ahead of time:
if my proposal is accepted, it'd make the transition smoother. If not,
this is hopefully a relatively harmless change.

Note: I think some of these overloads could be simplified (e.g.
`reversed(...)`), but I mostly stuck with rearranging them in case I was
wrong. The only overload I actually changed was `hmac.compare_digest` --
I believe the Python 2 version actually accepts unicode.
  • Loading branch information
Michael0x2a authored and gwk committed May 29, 2018
1 parent 33cd2b2 commit a41030f
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 68 deletions.
4 changes: 2 additions & 2 deletions stdlib/2/__builtin__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -863,9 +863,9 @@ def reduce(function: Callable[[_T, _T], _T], iterable: Iterable[_T]) -> _T: ...

def reload(module: Any) -> Any: ...
@overload
def reversed(object: Reversible[_T]) -> Iterator[_T]: ...
@overload
def reversed(object: Sequence[_T]) -> Iterator[_T]: ...
@overload
def reversed(object: Reversible[_T]) -> Iterator[_T]: ...
def repr(o: object) -> str: ...
@overload
def round(number: float) -> float: ...
Expand Down
4 changes: 2 additions & 2 deletions stdlib/2/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -863,9 +863,9 @@ def reduce(function: Callable[[_T, _T], _T], iterable: Iterable[_T]) -> _T: ...

def reload(module: Any) -> Any: ...
@overload
def reversed(object: Reversible[_T]) -> Iterator[_T]: ...
@overload
def reversed(object: Sequence[_T]) -> Iterator[_T]: ...
@overload
def reversed(object: Reversible[_T]) -> Iterator[_T]: ...
def repr(o: object) -> str: ...
@overload
def round(number: float) -> float: ...
Expand Down
4 changes: 2 additions & 2 deletions stdlib/2/os/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def renames(old: _PathType, new: _PathType) -> None: ...
def rmdir(path: _PathType) -> None: ...
def stat(path: _PathType) -> Any: ...
@overload
def stat_float_times(newvalue: bool = ...) -> None: ...
@overload
def stat_float_times() -> bool: ...
@overload
def stat_float_times(newvalue: bool) -> None: ...
def statvfs(path: _PathType) -> _StatVFS: ... # Unix only
def symlink(source: _PathType, link_name: _PathType) -> None: ...
def unlink(path: _PathType) -> None: ...
Expand Down
8 changes: 4 additions & 4 deletions stdlib/2/os/path.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ if sys.version_info < (3, 0):
@overload
def join(__p1: bytes, *p: bytes) -> bytes: ...
@overload
def join(__p1: Text, *p: _PathType) -> Text: ...
@overload
def join(__p1: bytes, __p2: Text, *p: _PathType) -> Text: ...
def join(__p1: bytes, __p2: bytes, __p3: bytes, __p4: Text, *p: _PathType) -> Text: ...
@overload
def join(__p1: bytes, __p2: bytes, __p3: Text, *p: _PathType) -> Text: ...
@overload
def join(__p1: bytes, __p2: bytes, __p3: bytes, __p4: Text, *p: _PathType) -> Text: ...
def join(__p1: bytes, __p2: Text, *p: _PathType) -> Text: ...
@overload
def join(__p1: Text, *p: _PathType) -> Text: ...
elif sys.version_info >= (3, 6):
# Mypy complains that the signatures overlap (same for relpath below), but things seem to behave correctly anyway.
@overload
Expand Down
8 changes: 3 additions & 5 deletions stdlib/2and3/hmac.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Stubs for hmac

from typing import Any, Callable, Optional, Union, overload
from typing import Any, Callable, Optional, Union, overload, AnyStr
from types import ModuleType
import sys

Expand Down Expand Up @@ -29,12 +29,10 @@ class HMAC:
def hexdigest(self) -> str: ...
def copy(self) -> HMAC: ...

@overload
def compare_digest(a: str, b: str) -> bool: ...
@overload
def compare_digest(a: bytes, b: bytes) -> bool: ...
@overload
def compare_digest(a: bytearray, b: bytearray) -> bool: ...
@overload
def compare_digest(a: AnyStr, b: AnyStr) -> bool: ...

if sys.version_info >= (3, 7):
def digest(key: _B, msg: _B, digest: str) -> bytes: ...
4 changes: 2 additions & 2 deletions stdlib/2and3/sysconfig.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from typing import overload, Any, Dict, IO, List, Optional, Tuple, Union

@overload
def get_config_vars(*args: str) -> List[Any]: ...
@overload
def get_config_vars() -> Dict[str, Any]: ...
@overload
def get_config_vars(*args: str) -> List[Any]: ...
def get_config_var(name: str) -> Optional[str]: ...
def get_scheme_names() -> Tuple[str, ...]: ...
def get_path_names() -> Tuple[str, ...]: ...
Expand Down
8 changes: 4 additions & 4 deletions stdlib/3/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -816,9 +816,9 @@ def eval(source: Union[str, bytes, CodeType], globals: Optional[Dict[str, Any]]
def exec(object: Union[str, bytes, CodeType], globals: Optional[Dict[str, Any]] = ..., locals: Optional[Mapping[str, Any]] = ...) -> Any: ...
def exit(code: Any = ...) -> NoReturn: ...
@overload
def filter(function: Callable[[_T], Any], iterable: Iterable[_T]) -> Iterator[_T]: ...
@overload
def filter(function: None, iterable: Iterable[Optional[_T]]) -> Iterator[_T]: ...
@overload
def filter(function: Callable[[_T], Any], iterable: Iterable[_T]) -> Iterator[_T]: ...
def format(o: object, format_spec: str = ...) -> str: ...
def getattr(o: Any, name: str, default: Any = ...) -> Any: ...
def globals() -> Dict[str, Any]: ...
Expand Down Expand Up @@ -908,9 +908,9 @@ def pow(x: float, y: float) -> float: ...
def pow(x: float, y: float, z: float) -> float: ...
def quit(code: Optional[int] = ...) -> None: ...
@overload
def reversed(object: Reversible[_T]) -> Iterator[_T]: ...
@overload
def reversed(object: Sequence[_T]) -> Iterator[_T]: ...
@overload
def reversed(object: Reversible[_T]) -> Iterator[_T]: ...
def repr(o: object) -> str: ...
@overload
def round(number: float) -> int: ...
Expand Down
4 changes: 2 additions & 2 deletions stdlib/3/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,10 @@ class BinaryIO(IO[bytes]):
# TODO peek?
@overload
@abstractmethod
def write(self, s: bytes) -> int: ...
def write(self, s: bytearray) -> int: ...
@overload
@abstractmethod
def write(self, s: bytearray) -> int: ...
def write(self, s: bytes) -> int: ...

@abstractmethod
def __enter__(self) -> BinaryIO: ...
Expand Down
8 changes: 4 additions & 4 deletions stdlib/3/urllib/parse.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def urlsplit(url: str, scheme: str = ..., allow_fragments: bool = ...) -> SplitR
@overload
def urlsplit(url: bytes, scheme: bytes = ..., allow_fragments: bool = ...) -> SplitResultBytes: ...

@overload
def urlunparse(components: Sequence[AnyStr]) -> AnyStr: ...
@overload
def urlunparse(components: Tuple[AnyStr, AnyStr, AnyStr, AnyStr, AnyStr, AnyStr]) -> AnyStr: ...

@overload
def urlunsplit(components: Sequence[AnyStr]) -> AnyStr: ...
def urlunparse(components: Sequence[AnyStr]) -> AnyStr: ...

@overload
def urlunsplit(components: Tuple[AnyStr, AnyStr, AnyStr, AnyStr, AnyStr]) -> AnyStr: ...
@overload
def urlunsplit(components: Sequence[AnyStr]) -> AnyStr: ...
82 changes: 41 additions & 41 deletions third_party/2and3/attr/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,21 @@ class Attribute(Generic[_T]):
# attr(validator=<some callable>) -> Whatever the callable expects.
# This makes this type of assignments possible:
# x: int = attr(8)

#
# This form catches explicit None or no default but with no other arguments returns Any.
@overload
def attrib(default: None = ...,
validator: None = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: None = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: None = ...,
converter: None = ...,
factory: None = ...,
) -> Any: ...
# This form catches an explicit None or no default and infers the type from the other arguments.
@overload
def attrib(default: None = ...,
Expand Down Expand Up @@ -98,20 +112,6 @@ def attrib(default: _T,
converter: Optional[_ConverterType[_T]] = ...,
factory: Optional[Callable[[], _T]] = ...,
) -> _T: ...
# This form catches explicit None or no default but with no other arguments returns Any.
@overload
def attrib(default: None = ...,
validator: None = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: None = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: None = ...,
converter: None = ...,
factory: None = ...,
) -> Any: ...
# This form covers type=non-Type: e.g. forward references (str), Any
@overload
def attrib(default: Optional[_T] = ...,
Expand Down Expand Up @@ -216,6 +216,19 @@ def get_run_validators() -> bool: ...
# dataclass = attrs # Technically, partial(attrs, auto_attribs=True) ;)


@overload
def ib(default: None = ...,
validator: None = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: None = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: None = ...,
converter: None = ...,
factory: None = ...,
) -> Any: ...
@overload
def ib(default: None = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
Expand Down Expand Up @@ -243,19 +256,6 @@ def ib(default: _T,
factory: Optional[Callable[[], _T]] = ...,
) -> _T: ...
@overload
def ib(default: None = ...,
validator: None = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: None = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: None = ...,
converter: None = ...,
factory: None = ...,
) -> Any: ...
@overload
def ib(default: Optional[_T] = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: bool = ...,
Expand All @@ -270,6 +270,19 @@ def ib(default: Optional[_T] = ...,
) -> Any: ...

@overload
def attr(default: None = ...,
validator: None = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: None = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: None = ...,
converter: None = ...,
factory: None = ...,
) -> Any: ...
@overload
def attr(default: None = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: bool = ...,
Expand All @@ -296,19 +309,6 @@ def attr(default: _T,
factory: Optional[Callable[[], _T]] = ...,
) -> _T: ...
@overload
def attr(default: None = ...,
validator: None = ...,
repr: bool = ...,
cmp: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
convert: None = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: None = ...,
converter: None = ...,
factory: None = ...,
) -> Any: ...
@overload
def attr(default: Optional[_T] = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: bool = ...,
Expand Down

0 comments on commit a41030f

Please sign in to comment.