Skip to content

Commit

Permalink
[WIP] Support Parenthesized With Statements (#584)
Browse files Browse the repository at this point in the history
On the python side, we can add parentheses from MaybeSentinel.DEFAULT if the whitespace requires it.

On the rust side, we support the new grammar but codegen will only add explicitly included parentheses for now - it should be possible to match python behavior but it's not urgent so I've left a TODO
  • Loading branch information
stroxler authored Jan 7, 2022
1 parent 9f6ff01 commit 1337022
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 33 deletions.
57 changes: 52 additions & 5 deletions libcst/_nodes/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from libcst._nodes.whitespace import (
BaseParenthesizableWhitespace,
EmptyLine,
ParenthesizedWhitespace,
SimpleWhitespace,
TrailingWhitespace,
)
Expand Down Expand Up @@ -2017,24 +2018,47 @@ class With(BaseCompoundStatement):
#: Sequence of empty lines appearing before this with statement.
leading_lines: Sequence[EmptyLine] = ()

#: Optional open parenthesis for multi-line with bindings
lpar: Union[LeftParen, MaybeSentinel] = MaybeSentinel.DEFAULT

#: Optional close parenthesis for multi-line with bindings
rpar: Union[RightParen, MaybeSentinel] = MaybeSentinel.DEFAULT

#: Whitespace after the ``with`` keyword and before the first item.
whitespace_after_with: SimpleWhitespace = SimpleWhitespace.field(" ")

#: Whitespace after the last item and before the colon.
whitespace_before_colon: SimpleWhitespace = SimpleWhitespace.field("")

def _validate_parens(self) -> None:
if isinstance(self.lpar, MaybeSentinel) and isinstance(self.rpar, RightParen):
raise CSTValidationError(
"Do not mix concrete LeftParen/RightParen with MaybeSentinel."
)
if isinstance(self.lpar, LeftParen) and isinstance(self.rpar, MaybeSentinel):
raise CSTValidationError(
"Do not mix concrete LeftParen/RightParen with MaybeSentinel."
)

def _validate(self) -> None:
self._validate_parens()
if len(self.items) == 0:
raise CSTValidationError(
"A With statement must have at least one WithItem."
)
if self.items[-1].comma != MaybeSentinel.DEFAULT:
if (
isinstance(self.rpar, MaybeSentinel)
and self.items[-1].comma != MaybeSentinel.DEFAULT
):
raise CSTValidationError(
"The last WithItem in a With cannot have a trailing comma."
"The last WithItem in an unparenthesized With cannot have a trailing comma."
)
if self.whitespace_after_with.empty and not self.items[
0
].item._safe_to_use_with_word_operator(ExpressionPosition.RIGHT):
if self.whitespace_after_with.empty and not (
isinstance(self.lpar, LeftParen)
or self.items[0].item._safe_to_use_with_word_operator(
ExpressionPosition.RIGHT
)
):
raise CSTValidationError("Must have at least one space after with keyword.")

def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "With":
Expand All @@ -2048,7 +2072,9 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "With":
whitespace_after_with=visit_required(
self, "whitespace_after_with", self.whitespace_after_with, visitor
),
lpar=visit_sentinel(self, "lpar", self.lpar, visitor),
items=visit_sequence(self, "items", self.items, visitor),
rpar=visit_sentinel(self, "rpar", self.rpar, visitor),
whitespace_before_colon=visit_required(
self, "whitespace_before_colon", self.whitespace_before_colon, visitor
),
Expand All @@ -2060,15 +2086,36 @@ def _codegen_impl(self, state: CodegenState) -> None:
ll._codegen(state)
state.add_indent_tokens()

needs_paren = False
for item in self.items:
comma = item.comma
if isinstance(comma, Comma):
if isinstance(
comma.whitespace_after,
(EmptyLine, TrailingWhitespace, ParenthesizedWhitespace),
):
needs_paren = True
break

with state.record_syntactic_position(self, end_node=self.body):
asynchronous = self.asynchronous
if asynchronous is not None:
asynchronous._codegen(state)
state.add_token("with")
self.whitespace_after_with._codegen(state)
lpar = self.lpar
if isinstance(lpar, LeftParen):
lpar._codegen(state)
elif needs_paren:
state.add_token("(")
last_item = len(self.items) - 1
for i, item in enumerate(self.items):
item._codegen(state, default_comma=(i != last_item))
rpar = self.rpar
if isinstance(rpar, RightParen):
rpar._codegen(state)
elif needs_paren:
state.add_token(")")
self.whitespace_before_colon._codegen(state)
state.add_token(":")
self.body._codegen(state)
Expand Down
112 changes: 97 additions & 15 deletions libcst/_nodes/tests/test_with.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

import libcst as cst
from libcst import parse_statement, PartialParserConfig
from libcst._maybe_sentinel import MaybeSentinel
from libcst._nodes.tests.base import CSTNodeTest, DummyIndentedBlock, parse_statement_as
from libcst._parser.entrypoints import is_native
from libcst.metadata import CodeRange
from libcst.testing.utils import data_provider


class WithTest(CSTNodeTest):
maxDiff: int = 2000

@data_provider(
(
# Simple with block
Expand Down Expand Up @@ -138,45 +141,83 @@ class WithTest(CSTNodeTest):
"parser": parse_statement,
"expected_position": CodeRange((2, 0), (2, 24)),
},
# Weird spacing rules
# Whitespace
{
"node": cst.With(
(
cst.WithItem(
cst.Call(cst.Name("context_mgr")),
cst.AsName(
cst.Name("ctx"),
whitespace_before_as=cst.SimpleWhitespace(" "),
whitespace_after_as=cst.SimpleWhitespace(" "),
),
),
),
cst.SimpleStatementSuite((cst.Pass(),)),
whitespace_after_with=cst.SimpleWhitespace(" "),
whitespace_before_colon=cst.SimpleWhitespace(" "),
),
"code": "with context_mgr() as ctx : pass\n",
"parser": parse_statement,
"expected_position": CodeRange((1, 0), (1, 36)),
},
# Weird spacing rules, that parse differently depending on whether
# we are using a grammar that included parenthesized with statements.
{
"node": cst.With(
(
cst.WithItem(
cst.Call(
cst.Name("context_mgr"),
lpar=(cst.LeftParen(),),
rpar=(cst.RightParen(),),
lpar=() if is_native() else (cst.LeftParen(),),
rpar=() if is_native() else (cst.RightParen(),),
)
),
),
cst.SimpleStatementSuite((cst.Pass(),)),
lpar=(cst.LeftParen() if is_native() else MaybeSentinel.DEFAULT),
rpar=(cst.RightParen() if is_native() else MaybeSentinel.DEFAULT),
whitespace_after_with=cst.SimpleWhitespace(""),
),
"code": "with(context_mgr()): pass\n",
"parser": parse_statement,
"expected_position": CodeRange((1, 0), (1, 25)),
},
# Whitespace
# Multi-line parenthesized with.
{
"node": cst.With(
(
cst.WithItem(
cst.Call(cst.Name("context_mgr")),
cst.AsName(
cst.Name("ctx"),
whitespace_before_as=cst.SimpleWhitespace(" "),
whitespace_after_as=cst.SimpleWhitespace(" "),
cst.Call(cst.Name("foo")),
comma=cst.Comma(
whitespace_after=cst.ParenthesizedWhitespace(
first_line=cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(
value="",
),
comment=None,
newline=cst.Newline(
value=None,
),
),
empty_lines=[],
indent=True,
last_line=cst.SimpleWhitespace(
value=" ",
),
)
),
),
cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()),
),
cst.SimpleStatementSuite((cst.Pass(),)),
whitespace_after_with=cst.SimpleWhitespace(" "),
whitespace_before_colon=cst.SimpleWhitespace(" "),
lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
),
"code": "with context_mgr() as ctx : pass\n",
"parser": parse_statement,
"expected_position": CodeRange((1, 0), (1, 36)),
"code": ("with ( foo(),\n" " bar(), ): pass\n"), # noqa
"parser": parse_statement if is_native() else None,
"expected_position": CodeRange((1, 0), (2, 21)),
},
)
)
Expand All @@ -201,7 +242,8 @@ def test_valid(self, **kwargs: Any) -> None:
),
cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)),
),
"expected_re": "The last WithItem in a With cannot have a trailing comma",
"expected_re": "The last WithItem in an unparenthesized With cannot "
+ "have a trailing comma.",
},
{
"get_node": lambda: cst.With(
Expand All @@ -211,6 +253,26 @@ def test_valid(self, **kwargs: Any) -> None:
),
"expected_re": "Must have at least one space after with keyword",
},
{
"get_node": lambda: cst.With(
(cst.WithItem(cst.Call(cst.Name("context_mgr"))),),
cst.SimpleStatementSuite((cst.Pass(),)),
whitespace_after_with=cst.SimpleWhitespace(""),
lpar=cst.LeftParen(),
),
"expected_re": "Do not mix concrete LeftParen/RightParen with "
+ "MaybeSentinel",
},
{
"get_node": lambda: cst.With(
(cst.WithItem(cst.Call(cst.Name("context_mgr"))),),
cst.SimpleStatementSuite((cst.Pass(),)),
whitespace_after_with=cst.SimpleWhitespace(""),
rpar=cst.RightParen(),
),
"expected_re": "Do not mix concrete LeftParen/RightParen with "
+ "MaybeSentinel",
},
)
)
def test_invalid(self, **kwargs: Any) -> None:
Expand All @@ -234,3 +296,23 @@ def test_versions(self, **kwargs: Any) -> None:
if is_native() and not kwargs.get("expect_success", True):
self.skipTest("parse errors are disabled for native parser")
self.assert_parses(**kwargs)

def test_adding_parens(self) -> None:
node = cst.With(
(
cst.WithItem(
cst.Call(cst.Name("foo")),
comma=cst.Comma(
whitespace_after=cst.ParenthesizedWhitespace(),
),
),
cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()),
),
cst.SimpleStatementSuite((cst.Pass(),)),
lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
)
module = cst.Module([])
self.assertEqual(
module.code_for_node(node), ("with ( foo(),\n" "bar(), ): pass\n") # noqa
)
16 changes: 16 additions & 0 deletions libcst/_typed_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5379,6 +5379,22 @@ def visit_With_leading_lines(self, node: "With") -> None:
def leave_With_leading_lines(self, node: "With") -> None:
pass

@mark_no_op
def visit_With_lpar(self, node: "With") -> None:
pass

@mark_no_op
def leave_With_lpar(self, node: "With") -> None:
pass

@mark_no_op
def visit_With_rpar(self, node: "With") -> None:
pass

@mark_no_op
def leave_With_rpar(self, node: "With") -> None:
pass

@mark_no_op
def visit_With_whitespace_after_with(self, node: "With") -> None:
pass
Expand Down
12 changes: 12 additions & 0 deletions libcst/matchers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15411,6 +15411,18 @@ class With(BaseCompoundStatement, BaseStatement, BaseMatcherNode):
]
],
] = DoNotCare()
lpar: Union[
LeftParenMatchType,
DoNotCareSentinel,
OneOf[LeftParenMatchType],
AllOf[LeftParenMatchType],
] = DoNotCare()
rpar: Union[
RightParenMatchType,
DoNotCareSentinel,
OneOf[RightParenMatchType],
AllOf[RightParenMatchType],
] = DoNotCare()
whitespace_after_with: Union[
SimpleWhitespaceMatchType,
DoNotCareSentinel,
Expand Down
Loading

0 comments on commit 1337022

Please sign in to comment.