From 88dd0c39f0edc8309027e261d5ccfb09c1309cb4 Mon Sep 17 00:00:00 2001 From: Ben Green Date: Tue, 15 Dec 2020 08:26:35 -0500 Subject: [PATCH] Support PEP-604 style unions in decorator annotations (#429) These unions were introduced in Python 3.10 and do not define __origin__, so some extra checks are necessary to identify then. Since there is not yet a 3.10 build, a somewhat hacky test was added to simulate one of these new Unions. Resolves #414. --- libcst/matchers/_visitors.py | 12 +++++++++++- libcst/matchers/tests/test_decorators.py | 23 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/libcst/matchers/_visitors.py b/libcst/matchers/_visitors.py index 301e675aa..be50edfd3 100644 --- a/libcst/matchers/_visitors.py +++ b/libcst/matchers/_visitors.py @@ -79,8 +79,18 @@ def _get_possible_match_classes(matcher: BaseMatcherNode) -> List[Type[cst.CSTNo return [getattr(cst, matcher.__class__.__name__)] -def _get_possible_annotated_classes(annotation: object) -> List[Type[object]]: +def _annotation_looks_like_union(annotation: object) -> bool: if getattr(annotation, "__origin__", None) is Union: + return True + # support PEP-604 style unions introduced in Python 3.10 + return ( + annotation.__class__.__name__ == "Union" + and annotation.__class__.__module__ == "types" + ) + + +def _get_possible_annotated_classes(annotation: object) -> List[Type[object]]: + if _annotation_looks_like_union(annotation): return getattr(annotation, "__args__", []) else: return [cast(Type[object], annotation)] diff --git a/libcst/matchers/tests/test_decorators.py b/libcst/matchers/tests/test_decorators.py index c102f2ab5..b1ff3d054 100644 --- a/libcst/matchers/tests/test_decorators.py +++ b/libcst/matchers/tests/test_decorators.py @@ -6,6 +6,7 @@ from ast import literal_eval from textwrap import dedent from typing import List, Set +from unittest.mock import Mock import libcst as cst import libcst.matchers as m @@ -993,3 +994,25 @@ def bar() -> None: # We should have only visited a select number of nodes. self.assertEqual(visitor.visits, ['"baz"']) + + +# This is meant to simulate `cst.ImportFrom | cst.RemovalSentinel` in py3.10 +FakeUnionClass: Mock = Mock() +setattr(FakeUnionClass, "__name__", "Union") +setattr(FakeUnionClass, "__module__", "types") +FakeUnion: Mock = Mock() +FakeUnion.__class__ = FakeUnionClass +FakeUnion.__args__ = [cst.ImportFrom, cst.RemovalSentinel] + + +class MatchersUnionDecoratorsTest(UnitTest): + def test_init_with_new_union_annotation(self) -> None: + class TransformerWithUnionReturnAnnotation(m.MatcherDecoratableTransformer): + @m.leave(m.ImportFrom(module=m.Name(value="typing"))) + def test( + self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + ) -> FakeUnion: + pass + + # assert that init (specifically _check_types on return annotation) passes + TransformerWithUnionReturnAnnotation()