Skip to content

Commit

Permalink
Use correct type for AugAssign and AnnAssign target (#396)
Browse files Browse the repository at this point in the history
* Use correct type

* Add tests

* Suppress intentional type errors in pyre
  • Loading branch information
cdonovick authored Oct 1, 2020
1 parent 10d6451 commit 6731aa5
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 18 deletions.
4 changes: 2 additions & 2 deletions libcst/_nodes/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ class AnnAssign(BaseSmallStatement):
"""

#: The target that is being annotated and possibly assigned to.
target: BaseExpression
target: BaseAssignTargetExpression

#: The annotation for the target.
annotation: Annotation
Expand Down Expand Up @@ -1393,7 +1393,7 @@ class AugAssign(BaseSmallStatement):
"""

#: Target that is being operated on and assigned to.
target: BaseExpression
target: BaseAssignTargetExpression

#: The augmented assignment operation being performed.
operator: BaseAugOp
Expand Down
6 changes: 6 additions & 0 deletions libcst/_nodes/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def assert_invalid(
with self.assertRaisesRegex(cst.CSTValidationError, expected_re):
get_node()

def assert_invalid_types(
self, get_node: Callable[[], cst.CSTNode], expected_re: str
) -> None:
with self.assertRaisesRegex(TypeError, expected_re):
get_node().validate_types_shallow()

def __assert_codegen(
self,
node: cst.CSTNode,
Expand Down
72 changes: 72 additions & 0 deletions libcst/_nodes/tests/test_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,29 @@ def test_valid(self, **kwargs: Any) -> None:
def test_invalid(self, **kwargs: Any) -> None:
self.assert_invalid(**kwargs)

@data_provider(
(
{
"get_node": (
lambda: cst.Assign(
# pyre-ignore: Incompatible parameter type [6]
targets=[
cst.BinaryOperation(
left=cst.Name("x"),
operator=cst.Add(),
right=cst.Integer("1"),
),
],
value=cst.Name("y"),
)
),
"expected_re": "Expected an instance of .*statement.AssignTarget.*",
},
)
)
def test_invalid_types(self, **kwargs: Any) -> None:
self.assert_invalid_types(**kwargs)


class AnnAssignTest(CSTNodeTest):
@data_provider(
Expand Down Expand Up @@ -284,6 +307,31 @@ def test_valid(self, **kwargs: Any) -> None:
def test_invalid(self, **kwargs: Any) -> None:
self.assert_invalid(**kwargs)

@data_provider(
(
{
"get_node": (
lambda: cst.AnnAssign(
# pyre-ignore: Incompatible parameter type [6]
target=cst.BinaryOperation(
left=cst.Name("x"),
operator=cst.Add(),
right=cst.Integer("1"),
),
annotation=cst.Annotation(cst.Name("int")),
equal=cst.AssignEqual(),
value=cst.Name("y"),
)
),
"expected_re": (
"Expected an instance of .*BaseAssignTargetExpression.*"
),
},
)
)
def test_invalid_types(self, **kwargs: Any) -> None:
self.assert_invalid_types(**kwargs)


class AugAssignTest(CSTNodeTest):
@data_provider(
Expand Down Expand Up @@ -362,3 +410,27 @@ class AugAssignTest(CSTNodeTest):
)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)

@data_provider(
(
{
"get_node": (
lambda: cst.AugAssign(
# pyre-ignore: Incompatible parameter type [6]
target=cst.BinaryOperation(
left=cst.Name("x"),
operator=cst.Add(),
right=cst.Integer("1"),
),
operator=cst.Add(),
value=cst.Name("y"),
)
),
"expected_re": (
"Expected an instance of .*BaseAssignTargetExpression.*"
),
},
)
)
def test_invalid_types(self, **kwargs: Any) -> None:
self.assert_invalid_types(**kwargs)
32 changes: 16 additions & 16 deletions libcst/matchers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,10 @@ class And(BaseBooleanOp, BaseMatcherNode):
] = DoNotCare()


BaseExpressionMatchType = Union[
"BaseExpression",
BaseAssignTargetExpressionMatchType = Union[
"BaseAssignTargetExpression",
MetadataMatchType,
MatchIfTrue[Callable[[cst.BaseExpression], bool]],
MatchIfTrue[Callable[[cst.BaseAssignTargetExpression], bool]],
]
AnnotationMatchType = Union[
"Annotation", MetadataMatchType, MatchIfTrue[Callable[[cst.Annotation], bool]]
Expand All @@ -248,10 +248,10 @@ class And(BaseBooleanOp, BaseMatcherNode):
@dataclass(frozen=True, eq=False, unsafe_hash=False)
class AnnAssign(BaseSmallStatement, BaseMatcherNode):
target: Union[
BaseExpressionMatchType,
BaseAssignTargetExpressionMatchType,
DoNotCareSentinel,
OneOf[BaseExpressionMatchType],
AllOf[BaseExpressionMatchType],
OneOf[BaseAssignTargetExpressionMatchType],
AllOf[BaseAssignTargetExpressionMatchType],
] = DoNotCare()
annotation: Union[
AnnotationMatchType,
Expand Down Expand Up @@ -285,6 +285,13 @@ class AnnAssign(BaseSmallStatement, BaseMatcherNode):
] = DoNotCare()


BaseExpressionMatchType = Union[
"BaseExpression",
MetadataMatchType,
MatchIfTrue[Callable[[cst.BaseExpression], bool]],
]


@dataclass(frozen=True, eq=False, unsafe_hash=False)
class Annotation(BaseMatcherNode):
annotation: Union[
Expand Down Expand Up @@ -597,13 +604,6 @@ class AssignEqual(BaseMatcherNode):
] = DoNotCare()


BaseAssignTargetExpressionMatchType = Union[
"BaseAssignTargetExpression",
MetadataMatchType,
MatchIfTrue[Callable[[cst.BaseAssignTargetExpression], bool]],
]


@dataclass(frozen=True, eq=False, unsafe_hash=False)
class AssignTarget(BaseMatcherNode):
target: Union[
Expand Down Expand Up @@ -852,10 +852,10 @@ class Attribute(
@dataclass(frozen=True, eq=False, unsafe_hash=False)
class AugAssign(BaseSmallStatement, BaseMatcherNode):
target: Union[
BaseExpressionMatchType,
BaseAssignTargetExpressionMatchType,
DoNotCareSentinel,
OneOf[BaseExpressionMatchType],
AllOf[BaseExpressionMatchType],
OneOf[BaseAssignTargetExpressionMatchType],
AllOf[BaseAssignTargetExpressionMatchType],
] = DoNotCare()
operator: Union[
BaseAugOpMatchType,
Expand Down

0 comments on commit 6731aa5

Please sign in to comment.